use CONTRACT before REDUCE (#5903)

* use CONTRACT before REDUCE [run_process_replay]

* support half expand

* EXPAND GEP
This commit is contained in:
George Hotz
2024-08-04 16:17:33 -07:00
committed by GitHub
parent f33950f454
commit be8958e26b
5 changed files with 33 additions and 8 deletions

View File

@@ -352,6 +352,22 @@ class TestExpander(unittest.TestCase):
self.assertListEqual([x.arg for x in sink.src[2].src], [4,6])
self.assertListEqual([x.arg for x in sink.src[3].src], [5,7])
def test_contract_no_expand(self):
e1 = UOp(UOps.DEFINE_VAR, dtypes.int)
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
sink = expander_rewrite(con)
assert sink.op is UOps.VECTORIZE and len(sink.src) == 2
assert sink.src[0] == sink.src[1]
def test_contract_half_expand(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
con = UOp(UOps.CONTRACT, dtypes.int.vec(8), (e1,), ((1,4), (2,2)))
sink = expander_rewrite(con)
assert sink.op is UOps.VECTORIZE and len(sink.src) == 8
assert sink.src[0] == sink.src[1]
assert sink.src[0] != sink.src[2]
assert sink.src[6] == sink.src[7]
def test_expand_same_axis(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4*x) for x in range(4)), ((1,4),))