mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
fix more double matmuls (#12991)
* fix more double matmuls * a few more
This commit is contained in:
@@ -11,14 +11,13 @@ class TestDoubleMatmul(unittest.TestCase):
|
||||
def setUp(self):
|
||||
with Context(DEBUG=0):
|
||||
self.a, self.b, self.c = [Tensor.randn(16, 16).contiguous().realize() for _ in range(3)]
|
||||
self.cmp = (self.a @ self.b @ self.c).realize()
|
||||
|
||||
def _test(self, opts):
|
||||
with Context(PCONTIG=2, DEBUG=max(2, DEBUG.value)):
|
||||
out = (self.a @ self.b @ self.c).contiguous(arg=opts).realize()
|
||||
|
||||
with Context(DEBUG=0):
|
||||
err = (out-self.cmp).square()
|
||||
err = (out-(self.a @ self.b @ self.c)).square()
|
||||
self.assertLess(err.max().item(), 1e-4)
|
||||
self.assertLess(err.mean().item(), 1e-6)
|
||||
|
||||
@@ -26,8 +25,8 @@ class TestDoubleMatmul(unittest.TestCase):
|
||||
def test_upcast_0(self): self._test((Opt(OptOps.UPCAST, 0, 4),))
|
||||
def test_upcast_1(self): self._test((Opt(OptOps.UPCAST, 1, 4),))
|
||||
def test_upcast_2(self): self._test((Opt(OptOps.UPCAST, 2, 4),))
|
||||
@unittest.skip("doesn't work")
|
||||
def test_upcast_01(self): self._test((Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)))
|
||||
def test_upcast_01_mismatch(self): self._test((Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.UPCAST, 1, 4)))
|
||||
def test_upcast_02(self): self._test((Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 2, 4)))
|
||||
def test_upcast_12(self): self._test((Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 2, 4)))
|
||||
|
||||
@@ -39,6 +38,11 @@ class TestDoubleMatmul(unittest.TestCase):
|
||||
def test_upcast_1_unroll_0(self): self._test((Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)))
|
||||
def test_upcast_2_unroll_0(self): self._test((Opt(OptOps.UPCAST, 2, 4), Opt(OptOps.UNROLL, 0, 4)))
|
||||
|
||||
def test_upcast_0_unroll_1(self): self._test((Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 1, 4)))
|
||||
@unittest.skip("doesn't work")
|
||||
def test_upcast_1_unroll_1(self): self._test((Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 1, 4)))
|
||||
def test_upcast_2_unroll_1(self): self._test((Opt(OptOps.UPCAST, 2, 4), Opt(OptOps.UNROLL, 1, 4)))
|
||||
|
||||
@unittest.skip("doesn't work")
|
||||
def test_upcast_01_unroll_01(self):
|
||||
self._test((Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4)))
|
||||
|
||||
@@ -231,15 +231,27 @@ def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp):
|
||||
assert idx.dtype.count == 1, f"idx dtype must be 1 {idx.dtype}"
|
||||
return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.index.vec(cnt), tuple(range(cnt))))
|
||||
|
||||
def no_vectorized_index_broadcast(buf:UOp, cast:UOp, bcast:UOp, idx:UOp):
|
||||
cnt = cast.dtype.count
|
||||
precnt = len(bcast.src)
|
||||
gep_arg = tuple(flatten([range(precnt) for _ in range(cnt)]))
|
||||
sum_arg = tuple(flatten([[i]*precnt for i in range(cnt)]))
|
||||
return buf.broadcast(cnt*precnt).index(idx.gep(gep_arg)*cnt+UOp.const(dtypes.index.vec(cnt*precnt), sum_arg))
|
||||
|
||||
devectorize_buf_and_index = PatternMatcher([
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").broadcast(name="bcast").index(UPat.var("idx")),
|
||||
no_vectorized_index_broadcast),
|
||||
])
|
||||
|
||||
devectorize = PatternMatcher([
|
||||
# CAST after AFTER
|
||||
(UPat(Ops.CAST, name="c").f(Ops.AFTER, allow_any_len=True, name="a"), lambda c,a: c.src[0].after(*a.src[1:]).cast(c.dtype)),
|
||||
# no ALU on vectorized dtypes
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name="alu"), no_vectorized_alu),
|
||||
(UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index),
|
||||
])
|
||||
])+devectorize_buf_and_index
|
||||
|
||||
pm_render = PatternMatcher([
|
||||
# for rendering, we use explicit VECTORIZE
|
||||
|
||||
Reference in New Issue
Block a user