mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 06:18:01 -05:00
switch contract arg to match expand arg [run_process_replay] (#5667)
* switch contract arg to match expand arg [run_process_replay] * support multiaxis contract too, it's easy * cancel contract/expand
This commit is contained in:
@@ -285,14 +285,14 @@ class TestExpander(unittest.TestCase):
|
||||
|
||||
def test_contract_simple(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), (1,))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is UOps.VECTORIZE and len(sink.src) == 4
|
||||
self.assertListEqual([x.arg for x in sink.src], [0,1,2,3])
|
||||
|
||||
def test_contract_axis_1(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,4),(2,4)))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), (1,))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is UOps.EXPAND and len(sink.src) == 4 and sink.arg == ((2,4),)
|
||||
assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 4
|
||||
@@ -301,7 +301,7 @@ class TestExpander(unittest.TestCase):
|
||||
|
||||
def test_contract_axis_2(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,4),(2,4)))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), (2,))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((2,4),))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is UOps.EXPAND and len(sink.src) == 4 and sink.arg == ((1,4),)
|
||||
assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 4
|
||||
@@ -310,25 +310,24 @@ class TestExpander(unittest.TestCase):
|
||||
|
||||
def test_contract_axis_2_big(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,2),(2,2),(3,2),(4,2)))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), (2,))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (3, 2), (4, 2))
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src], [0,4])
|
||||
self.assertListEqual([x.arg for x in sink.src[6].src], [10,14])
|
||||
|
||||
@unittest.skip("TODO: add support for this")
|
||||
def test_contract_multi_axis(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,2),(2,2),(3,2),(4,2)))
|
||||
sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), (3,2)))
|
||||
sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((3,2),(2,2))))
|
||||
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (4, 2))
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src], [0,4,2,6])
|
||||
sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), (2,3)))
|
||||
sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((2,2),(3,2))))
|
||||
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (4, 2))
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src], [0,2,4,6])
|
||||
|
||||
def test_contract_mid(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(8)), ((1,2),(2,2),(3,2)))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), (2,))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is UOps.EXPAND and len(sink.src) == 4 and sink.arg == ((1,2),(3,2))
|
||||
assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 2
|
||||
|
||||
Reference in New Issue
Block a user