expand UOps with rewrite rules (#5501)

* expand UOps with rewrite rules [run_process_replay]

* progress

* much closer

* close, way less bugs

* bunch of expander tests

* fix contract

* ops tests pass

* fix barrier

* mostly passing

* bitcast in expanded ops

* support more expand merges

* all tests pass maybe

* fix empty EXPAND

* fix LIN fuzzing

* add ALL_SAME assert

* all same

* all same work

* raise CompileError

* pass fuzz linearizer

* revert whitespace

* fix nv tensor core test

* fix mypy

* bug fix

* fuzzer passes

* put tests back

* expand arg to idx
This commit is contained in:
George Hotz
2024-07-17 10:17:50 -07:00
committed by GitHub
parent 158221b36b
commit 1242b302fa
2 changed files with 127 additions and 160 deletions

View File

@@ -270,15 +270,14 @@ class TestUOpGraph(TestUOps):
self.assertEqual(endranges[-1].src[0], ranges[0])
def expander_rewrite(sink):
#from tinygrad.codegen.uopgraph import expander, constant_folder
#together = PatternMatcher(expander.patterns + constant_folder.patterns)
#return graph_rewrite(sink, together)
out = UOpGraph(UOp(UOps.SINK, None, (sink,)))
out.linearize()
return out.uops[-1]
from tinygrad.codegen.uopgraph import expander, constant_folder
together = PatternMatcher(expander.patterns + constant_folder.patterns)
return graph_rewrite(sink, together)
#out = UOpGraph(UOp(UOps.SINK, None, (sink,)))
#out.linearize()
#return out.uops[-1]
class TestExpander(unittest.TestCase):
@unittest.skip
def test_expand_add_broadcast(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
sink = expander_rewrite(e1+3)
@@ -292,7 +291,6 @@ class TestExpander(unittest.TestCase):
assert sink.op is UOps.VECTORIZE and len(sink.src) == 4
self.assertListEqual([x.arg for x in sink.src], [0,1,2,3])
@unittest.skip
def test_contract_axis_1(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,4),(2,4)))
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), (1,))
@@ -302,7 +300,6 @@ class TestExpander(unittest.TestCase):
self.assertListEqual([x.arg for x in sink.src[0].src], [0,4,8,12])
self.assertListEqual([x.arg for x in sink.src[3].src], [3,7,11,15])
@unittest.skip
def test_contract_axis_2(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,4),(2,4)))
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), (2,))
@@ -312,7 +309,6 @@ class TestExpander(unittest.TestCase):
self.assertListEqual([x.arg for x in sink.src[0].src], [0,1,2,3])
self.assertListEqual([x.arg for x in sink.src[3].src], [12,13,14,15])
@unittest.skip
def test_contract_mid(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(8)), ((1,2),(2,2),(3,2)))
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), (2,))
@@ -324,7 +320,6 @@ class TestExpander(unittest.TestCase):
self.assertListEqual([x.arg for x in sink.src[2].src], [4,6])
self.assertListEqual([x.arg for x in sink.src[3].src], [5,7])
@unittest.skip
def test_expand_same_axis(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4*x) for x in range(4)), ((1,4),))
@@ -332,7 +327,6 @@ class TestExpander(unittest.TestCase):
assert sink.op is UOps.EXPAND and len(sink.src) == 4
self.assertListEqual([x.arg for x in sink.src], [0,5,10,15])
@unittest.skip
def test_expand_different_axis(self, flip=False):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4*x) for x in range(4)), ((1,4),))
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),))
@@ -357,7 +351,6 @@ class TestExpander(unittest.TestCase):
assert sink.op is UOps.CONST
self.assertEqual(sink.arg, 3*4)
@unittest.skip
def test_double_expand(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),))
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((2,4),))
@@ -367,7 +360,6 @@ class TestExpander(unittest.TestCase):
assert sink.arg == ((1, 2), (2, 4))
self.assertListEqual([x.arg for x in sink.src], [0,1,2,3,4,5,6,7])
@unittest.skip
def test_double_expand_reverse(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,4),))
@@ -377,7 +369,6 @@ class TestExpander(unittest.TestCase):
assert sink.arg == ((1, 4), (2, 2))
self.assertListEqual([x.arg for x in sink.src], [0, 4, 1, 5, 2, 6, 3, 7])
@unittest.skip
def test_double_expand_middle(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,2),(3,2)))
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,2),(3,2)))