mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 22:08:08 -05:00
expand UOps with rewrite rules (#5501)
* expand UOps with rewrite rules [run_process_replay] * progress * much closer * close, way less bugs * bunch of expander tests * fix contract * ops tests pass * fix barrier * mostly passing * bitcast in expanded ops * support more expand merges * all tests pass maybe * fix empty EXPAND * fix LIN fuzzing * add ALL_SAME assert * all same * all same work * raise CompileError * pass fuzz linearizer * revert whitespace * fix nv tensor core test * fix mypy * bug fix * fuzzer passes * put tests back * expand arg to idx
This commit is contained in:
@@ -270,15 +270,14 @@ class TestUOpGraph(TestUOps):
|
||||
self.assertEqual(endranges[-1].src[0], ranges[0])
|
||||
|
||||
def expander_rewrite(sink):
|
||||
#from tinygrad.codegen.uopgraph import expander, constant_folder
|
||||
#together = PatternMatcher(expander.patterns + constant_folder.patterns)
|
||||
#return graph_rewrite(sink, together)
|
||||
out = UOpGraph(UOp(UOps.SINK, None, (sink,)))
|
||||
out.linearize()
|
||||
return out.uops[-1]
|
||||
from tinygrad.codegen.uopgraph import expander, constant_folder
|
||||
together = PatternMatcher(expander.patterns + constant_folder.patterns)
|
||||
return graph_rewrite(sink, together)
|
||||
#out = UOpGraph(UOp(UOps.SINK, None, (sink,)))
|
||||
#out.linearize()
|
||||
#return out.uops[-1]
|
||||
|
||||
class TestExpander(unittest.TestCase):
|
||||
@unittest.skip
|
||||
def test_expand_add_broadcast(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
|
||||
sink = expander_rewrite(e1+3)
|
||||
@@ -292,7 +291,6 @@ class TestExpander(unittest.TestCase):
|
||||
assert sink.op is UOps.VECTORIZE and len(sink.src) == 4
|
||||
self.assertListEqual([x.arg for x in sink.src], [0,1,2,3])
|
||||
|
||||
@unittest.skip
|
||||
def test_contract_axis_1(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,4),(2,4)))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), (1,))
|
||||
@@ -302,7 +300,6 @@ class TestExpander(unittest.TestCase):
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src], [0,4,8,12])
|
||||
self.assertListEqual([x.arg for x in sink.src[3].src], [3,7,11,15])
|
||||
|
||||
@unittest.skip
|
||||
def test_contract_axis_2(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,4),(2,4)))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), (2,))
|
||||
@@ -312,7 +309,6 @@ class TestExpander(unittest.TestCase):
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src], [0,1,2,3])
|
||||
self.assertListEqual([x.arg for x in sink.src[3].src], [12,13,14,15])
|
||||
|
||||
@unittest.skip
|
||||
def test_contract_mid(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(8)), ((1,2),(2,2),(3,2)))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), (2,))
|
||||
@@ -324,7 +320,6 @@ class TestExpander(unittest.TestCase):
|
||||
self.assertListEqual([x.arg for x in sink.src[2].src], [4,6])
|
||||
self.assertListEqual([x.arg for x in sink.src[3].src], [5,7])
|
||||
|
||||
@unittest.skip
|
||||
def test_expand_same_axis(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
|
||||
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4*x) for x in range(4)), ((1,4),))
|
||||
@@ -332,7 +327,6 @@ class TestExpander(unittest.TestCase):
|
||||
assert sink.op is UOps.EXPAND and len(sink.src) == 4
|
||||
self.assertListEqual([x.arg for x in sink.src], [0,5,10,15])
|
||||
|
||||
@unittest.skip
|
||||
def test_expand_different_axis(self, flip=False):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4*x) for x in range(4)), ((1,4),))
|
||||
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),))
|
||||
@@ -357,7 +351,6 @@ class TestExpander(unittest.TestCase):
|
||||
assert sink.op is UOps.CONST
|
||||
self.assertEqual(sink.arg, 3*4)
|
||||
|
||||
@unittest.skip
|
||||
def test_double_expand(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),))
|
||||
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((2,4),))
|
||||
@@ -367,7 +360,6 @@ class TestExpander(unittest.TestCase):
|
||||
assert sink.arg == ((1, 2), (2, 4))
|
||||
self.assertListEqual([x.arg for x in sink.src], [0,1,2,3,4,5,6,7])
|
||||
|
||||
@unittest.skip
|
||||
def test_double_expand_reverse(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
|
||||
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,4),))
|
||||
@@ -377,7 +369,6 @@ class TestExpander(unittest.TestCase):
|
||||
assert sink.arg == ((1, 4), (2, 2))
|
||||
self.assertListEqual([x.arg for x in sink.src], [0, 4, 1, 5, 2, 6, 3, 7])
|
||||
|
||||
@unittest.skip
|
||||
def test_double_expand_middle(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,2),(3,2)))
|
||||
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,2),(3,2)))
|
||||
|
||||
Reference in New Issue
Block a user