fix more double matmuls (#12991)

* fix more double matmuls

* a few more
This commit is contained in:
George Hotz
2025-10-29 16:09:48 +08:00
committed by GitHub
parent e42b4edf8c
commit 1c362736aa
2 changed files with 22 additions and 6 deletions

View File

@@ -11,14 +11,13 @@ class TestDoubleMatmul(unittest.TestCase):
def setUp(self):
with Context(DEBUG=0):
self.a, self.b, self.c = [Tensor.randn(16, 16).contiguous().realize() for _ in range(3)]
self.cmp = (self.a @ self.b @ self.c).realize()
def _test(self, opts):
with Context(PCONTIG=2, DEBUG=max(2, DEBUG.value)):
out = (self.a @ self.b @ self.c).contiguous(arg=opts).realize()
with Context(DEBUG=0):
err = (out-self.cmp).square()
err = (out-(self.a @ self.b @ self.c)).square()
self.assertLess(err.max().item(), 1e-4)
self.assertLess(err.mean().item(), 1e-6)
@@ -26,8 +25,8 @@ class TestDoubleMatmul(unittest.TestCase):
def test_upcast_0(self): self._test((Opt(OptOps.UPCAST, 0, 4),))
def test_upcast_1(self): self._test((Opt(OptOps.UPCAST, 1, 4),))
def test_upcast_2(self): self._test((Opt(OptOps.UPCAST, 2, 4),))
@unittest.skip("doesn't work")
def test_upcast_01(self): self._test((Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)))
def test_upcast_01_mismatch(self): self._test((Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.UPCAST, 1, 4)))
def test_upcast_02(self): self._test((Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 2, 4)))
def test_upcast_12(self): self._test((Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 2, 4)))
@@ -39,6 +38,11 @@ class TestDoubleMatmul(unittest.TestCase):
def test_upcast_1_unroll_0(self): self._test((Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)))
def test_upcast_2_unroll_0(self): self._test((Opt(OptOps.UPCAST, 2, 4), Opt(OptOps.UNROLL, 0, 4)))
def test_upcast_0_unroll_1(self): self._test((Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 1, 4)))
@unittest.skip("doesn't work")
def test_upcast_1_unroll_1(self): self._test((Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 1, 4)))
def test_upcast_2_unroll_1(self): self._test((Opt(OptOps.UPCAST, 2, 4), Opt(OptOps.UNROLL, 1, 4)))
@unittest.skip("doesn't work")
def test_upcast_01_unroll_01(self):
self._test((Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4)))

View File

@@ -231,15 +231,27 @@ def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp):
assert idx.dtype.count == 1, f"idx dtype must be 1 {idx.dtype}"
return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.index.vec(cnt), tuple(range(cnt))))
def no_vectorized_index_broadcast(buf:UOp, cast:UOp, bcast:UOp, idx:UOp):
cnt = cast.dtype.count
precnt = len(bcast.src)
gep_arg = tuple(flatten([range(precnt) for _ in range(cnt)]))
sum_arg = tuple(flatten([[i]*precnt for i in range(cnt)]))
return buf.broadcast(cnt*precnt).index(idx.gep(gep_arg)*cnt+UOp.const(dtypes.index.vec(cnt*precnt), sum_arg))
devectorize_buf_and_index = PatternMatcher([
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf),
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index),
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").broadcast(name="bcast").index(UPat.var("idx")),
no_vectorized_index_broadcast),
])
devectorize = PatternMatcher([
# CAST after AFTER
(UPat(Ops.CAST, name="c").f(Ops.AFTER, allow_any_len=True, name="a"), lambda c,a: c.src[0].after(*a.src[1:]).cast(c.dtype)),
# no ALU on vectorized dtypes
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name="alu"), no_vectorized_alu),
(UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma),
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf),
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index),
])
])+devectorize_buf_and_index
pm_render = PatternMatcher([
# for rendering, we use explicit VECTORIZE