diff --git a/test/test_rangeify.py b/test/test_rangeify.py index e8de1d6513..7a5059e80d 100644 --- a/test/test_rangeify.py +++ b/test/test_rangeify.py @@ -11,14 +11,13 @@ class TestDoubleMatmul(unittest.TestCase): def setUp(self): with Context(DEBUG=0): self.a, self.b, self.c = [Tensor.randn(16, 16).contiguous().realize() for _ in range(3)] - self.cmp = (self.a @ self.b @ self.c).realize() def _test(self, opts): with Context(PCONTIG=2, DEBUG=max(2, DEBUG.value)): out = (self.a @ self.b @ self.c).contiguous(arg=opts).realize() with Context(DEBUG=0): - err = (out-self.cmp).square() + err = (out-(self.a @ self.b @ self.c)).square() self.assertLess(err.max().item(), 1e-4) self.assertLess(err.mean().item(), 1e-6) @@ -26,8 +25,8 @@ class TestDoubleMatmul(unittest.TestCase): def test_upcast_0(self): self._test((Opt(OptOps.UPCAST, 0, 4),)) def test_upcast_1(self): self._test((Opt(OptOps.UPCAST, 1, 4),)) def test_upcast_2(self): self._test((Opt(OptOps.UPCAST, 2, 4),)) - @unittest.skip("doesn't work") def test_upcast_01(self): self._test((Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4))) + def test_upcast_01_mismatch(self): self._test((Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.UPCAST, 1, 4))) def test_upcast_02(self): self._test((Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 2, 4))) def test_upcast_12(self): self._test((Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 2, 4))) @@ -39,6 +38,11 @@ class TestDoubleMatmul(unittest.TestCase): def test_upcast_1_unroll_0(self): self._test((Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4))) def test_upcast_2_unroll_0(self): self._test((Opt(OptOps.UPCAST, 2, 4), Opt(OptOps.UNROLL, 0, 4))) + def test_upcast_0_unroll_1(self): self._test((Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 1, 4))) + @unittest.skip("doesn't work") + def test_upcast_1_unroll_1(self): self._test((Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 1, 4))) + def test_upcast_2_unroll_1(self): self._test((Opt(OptOps.UPCAST, 2, 4), Opt(OptOps.UNROLL, 1, 4))) + @unittest.skip("doesn't work") def test_upcast_01_unroll_01(self): self._test((Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4))) diff --git a/tinygrad/codegen/late/devectorizer.py b/tinygrad/codegen/late/devectorizer.py index 5195ee4c64..65e21096c7 100644 --- a/tinygrad/codegen/late/devectorizer.py +++ b/tinygrad/codegen/late/devectorizer.py @@ -231,15 +231,27 @@ def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp): assert idx.dtype.count == 1, f"idx dtype must be 1 {idx.dtype}" return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.index.vec(cnt), tuple(range(cnt)))) +def no_vectorized_index_broadcast(buf:UOp, cast:UOp, bcast:UOp, idx:UOp): + cnt = cast.dtype.count + precnt = len(bcast.src) + gep_arg = tuple(flatten([range(precnt) for _ in range(cnt)])) + sum_arg = tuple(flatten([[i]*precnt for i in range(cnt)])) + return buf.broadcast(cnt*precnt).index(idx.gep(gep_arg)*cnt+UOp.const(dtypes.index.vec(cnt*precnt), sum_arg)) + +devectorize_buf_and_index = PatternMatcher([ + (UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf), + (UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index), + (UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").broadcast(name="bcast").index(UPat.var("idx")), + no_vectorized_index_broadcast), +]) + devectorize = PatternMatcher([ # CAST after AFTER (UPat(Ops.CAST, name="c").f(Ops.AFTER, allow_any_len=True, name="a"), lambda c,a: c.src[0].after(*a.src[1:]).cast(c.dtype)), # no ALU on vectorized dtypes (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name="alu"), no_vectorized_alu), (UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma), - (UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf), - (UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index), -]) +])+devectorize_buf_and_index pm_render = PatternMatcher([ # for rendering, we use explicit VECTORIZE