mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
fix more double matmuls (#12991)
* fix more double matmuls * a few more
This commit is contained in:
@@ -11,14 +11,13 @@ class TestDoubleMatmul(unittest.TestCase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
with Context(DEBUG=0):
|
with Context(DEBUG=0):
|
||||||
self.a, self.b, self.c = [Tensor.randn(16, 16).contiguous().realize() for _ in range(3)]
|
self.a, self.b, self.c = [Tensor.randn(16, 16).contiguous().realize() for _ in range(3)]
|
||||||
self.cmp = (self.a @ self.b @ self.c).realize()
|
|
||||||
|
|
||||||
def _test(self, opts):
|
def _test(self, opts):
|
||||||
with Context(PCONTIG=2, DEBUG=max(2, DEBUG.value)):
|
with Context(PCONTIG=2, DEBUG=max(2, DEBUG.value)):
|
||||||
out = (self.a @ self.b @ self.c).contiguous(arg=opts).realize()
|
out = (self.a @ self.b @ self.c).contiguous(arg=opts).realize()
|
||||||
|
|
||||||
with Context(DEBUG=0):
|
with Context(DEBUG=0):
|
||||||
err = (out-self.cmp).square()
|
err = (out-(self.a @ self.b @ self.c)).square()
|
||||||
self.assertLess(err.max().item(), 1e-4)
|
self.assertLess(err.max().item(), 1e-4)
|
||||||
self.assertLess(err.mean().item(), 1e-6)
|
self.assertLess(err.mean().item(), 1e-6)
|
||||||
|
|
||||||
@@ -26,8 +25,8 @@ class TestDoubleMatmul(unittest.TestCase):
|
|||||||
def test_upcast_0(self): self._test((Opt(OptOps.UPCAST, 0, 4),))
|
def test_upcast_0(self): self._test((Opt(OptOps.UPCAST, 0, 4),))
|
||||||
def test_upcast_1(self): self._test((Opt(OptOps.UPCAST, 1, 4),))
|
def test_upcast_1(self): self._test((Opt(OptOps.UPCAST, 1, 4),))
|
||||||
def test_upcast_2(self): self._test((Opt(OptOps.UPCAST, 2, 4),))
|
def test_upcast_2(self): self._test((Opt(OptOps.UPCAST, 2, 4),))
|
||||||
@unittest.skip("doesn't work")
|
|
||||||
def test_upcast_01(self): self._test((Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)))
|
def test_upcast_01(self): self._test((Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)))
|
||||||
|
def test_upcast_01_mismatch(self): self._test((Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.UPCAST, 1, 4)))
|
||||||
def test_upcast_02(self): self._test((Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 2, 4)))
|
def test_upcast_02(self): self._test((Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 2, 4)))
|
||||||
def test_upcast_12(self): self._test((Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 2, 4)))
|
def test_upcast_12(self): self._test((Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 2, 4)))
|
||||||
|
|
||||||
@@ -39,6 +38,11 @@ class TestDoubleMatmul(unittest.TestCase):
|
|||||||
def test_upcast_1_unroll_0(self): self._test((Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)))
|
def test_upcast_1_unroll_0(self): self._test((Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)))
|
||||||
def test_upcast_2_unroll_0(self): self._test((Opt(OptOps.UPCAST, 2, 4), Opt(OptOps.UNROLL, 0, 4)))
|
def test_upcast_2_unroll_0(self): self._test((Opt(OptOps.UPCAST, 2, 4), Opt(OptOps.UNROLL, 0, 4)))
|
||||||
|
|
||||||
|
def test_upcast_0_unroll_1(self): self._test((Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 1, 4)))
|
||||||
|
@unittest.skip("doesn't work")
|
||||||
|
def test_upcast_1_unroll_1(self): self._test((Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 1, 4)))
|
||||||
|
def test_upcast_2_unroll_1(self): self._test((Opt(OptOps.UPCAST, 2, 4), Opt(OptOps.UNROLL, 1, 4)))
|
||||||
|
|
||||||
@unittest.skip("doesn't work")
|
@unittest.skip("doesn't work")
|
||||||
def test_upcast_01_unroll_01(self):
|
def test_upcast_01_unroll_01(self):
|
||||||
self._test((Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4)))
|
self._test((Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4)))
|
||||||
|
|||||||
@@ -231,15 +231,27 @@ def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp):
|
|||||||
assert idx.dtype.count == 1, f"idx dtype must be 1 {idx.dtype}"
|
assert idx.dtype.count == 1, f"idx dtype must be 1 {idx.dtype}"
|
||||||
return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.index.vec(cnt), tuple(range(cnt))))
|
return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.index.vec(cnt), tuple(range(cnt))))
|
||||||
|
|
||||||
|
def no_vectorized_index_broadcast(buf:UOp, cast:UOp, bcast:UOp, idx:UOp):
|
||||||
|
cnt = cast.dtype.count
|
||||||
|
precnt = len(bcast.src)
|
||||||
|
gep_arg = tuple(flatten([range(precnt) for _ in range(cnt)]))
|
||||||
|
sum_arg = tuple(flatten([[i]*precnt for i in range(cnt)]))
|
||||||
|
return buf.broadcast(cnt*precnt).index(idx.gep(gep_arg)*cnt+UOp.const(dtypes.index.vec(cnt*precnt), sum_arg))
|
||||||
|
|
||||||
|
devectorize_buf_and_index = PatternMatcher([
|
||||||
|
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf),
|
||||||
|
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index),
|
||||||
|
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").broadcast(name="bcast").index(UPat.var("idx")),
|
||||||
|
no_vectorized_index_broadcast),
|
||||||
|
])
|
||||||
|
|
||||||
devectorize = PatternMatcher([
|
devectorize = PatternMatcher([
|
||||||
# CAST after AFTER
|
# CAST after AFTER
|
||||||
(UPat(Ops.CAST, name="c").f(Ops.AFTER, allow_any_len=True, name="a"), lambda c,a: c.src[0].after(*a.src[1:]).cast(c.dtype)),
|
(UPat(Ops.CAST, name="c").f(Ops.AFTER, allow_any_len=True, name="a"), lambda c,a: c.src[0].after(*a.src[1:]).cast(c.dtype)),
|
||||||
# no ALU on vectorized dtypes
|
# no ALU on vectorized dtypes
|
||||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name="alu"), no_vectorized_alu),
|
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name="alu"), no_vectorized_alu),
|
||||||
(UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma),
|
(UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma),
|
||||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf),
|
])+devectorize_buf_and_index
|
||||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index),
|
|
||||||
])
|
|
||||||
|
|
||||||
pm_render = PatternMatcher([
|
pm_render = PatternMatcher([
|
||||||
# for rendering, we use explicit VECTORIZE
|
# for rendering, we use explicit VECTORIZE
|
||||||
|
|||||||
Reference in New Issue
Block a user