mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
pcontig double matmul works (#12899)
* pcontig double matmul works * tests * contract * closer * works-ish * add that broadcast * 2 more work * something * disable broken ones * llvm * align 16
This commit is contained in:
@@ -1,10 +1,51 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, nn, Device
|
||||
from tinygrad.helpers import Context, GlobalCounters, CI, getenv, PCONTIG
|
||||
from tinygrad.helpers import Context, GlobalCounters, CI, getenv, PCONTIG, DEBUG
|
||||
from tinygrad.uop.ops import graph_rewrite, PatternMatcher, UPat, Ops
|
||||
from tinygrad.codegen.opt import OptOps, Opt
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
from tinygrad.renderer.nir import NIRRenderer
|
||||
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, NIRRenderer), "broken in LVP")
|
||||
class TestDoubleMatmul(unittest.TestCase):
|
||||
def setUp(self):
|
||||
with Context(DEBUG=0):
|
||||
self.a, self.b, self.c = [Tensor.randn(16, 16).contiguous().realize() for _ in range(3)]
|
||||
self.cmp = (self.a @ self.b @ self.c).realize()
|
||||
|
||||
def _test(self, opts):
|
||||
with Context(PCONTIG=2, DEBUG=max(2, DEBUG.value)):
|
||||
out = (self.a @ self.b @ self.c).contiguous(arg=opts).realize()
|
||||
|
||||
with Context(DEBUG=0):
|
||||
err = (out-self.cmp).square()
|
||||
self.assertLess(err.max().item(), 1e-4)
|
||||
self.assertLess(err.mean().item(), 1e-6)
|
||||
|
||||
def test_baseline(self): self._test(())
|
||||
def test_upcast_0(self): self._test((Opt(OptOps.UPCAST, 0, 4),))
|
||||
def test_upcast_1(self): self._test((Opt(OptOps.UPCAST, 1, 4),))
|
||||
def test_upcast_2(self): self._test((Opt(OptOps.UPCAST, 2, 4),))
|
||||
@unittest.skip("doesn't work")
|
||||
def test_upcast_01(self): self._test((Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)))
|
||||
def test_upcast_02(self): self._test((Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 2, 4)))
|
||||
def test_upcast_12(self): self._test((Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 2, 4)))
|
||||
|
||||
def test_unroll_0(self): self._test((Opt(OptOps.UNROLL, 0, 4),))
|
||||
def test_unroll_1(self): self._test((Opt(OptOps.UNROLL, 1, 4),))
|
||||
def test_unroll_01(self): self._test((Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4)))
|
||||
|
||||
def test_upcast_0_unroll_0(self): self._test((Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)))
|
||||
def test_upcast_1_unroll_0(self): self._test((Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)))
|
||||
def test_upcast_2_unroll_0(self): self._test((Opt(OptOps.UPCAST, 2, 4), Opt(OptOps.UNROLL, 0, 4)))
|
||||
|
||||
@unittest.skip("doesn't work")
|
||||
def test_upcast_01_unroll_01(self):
|
||||
self._test((Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4)))
|
||||
@unittest.skip("doesn't work")
|
||||
def test_upcast_12_unroll_01(self):
|
||||
self._test((Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 2, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4)))
|
||||
|
||||
class TestRangeifyAssign(unittest.TestCase):
|
||||
def test_assign_permuted(self):
|
||||
A = Tensor.empty(4, 4, dtype='int')
|
||||
|
||||
Reference in New Issue
Block a user