mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 14:28:09 -05:00
multiaxis contract test
This commit is contained in:
@@ -308,6 +308,24 @@ class TestExpander(unittest.TestCase):
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src], [0,1,2,3])
|
||||
self.assertListEqual([x.arg for x in sink.src[3].src], [12,13,14,15])
|
||||
|
||||
def test_contract_axis_2_big(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,2),(2,2),(3,2),(4,2)))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), (2,))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (3, 2), (4, 2))
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src], [0,4])
|
||||
self.assertListEqual([x.arg for x in sink.src[6].src], [10,14])
|
||||
|
||||
@unittest.skip("TODO: add support for this")
|
||||
def test_contract_multi_axis(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,2),(2,2),(3,2),(4,2)))
|
||||
sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), (3,2)))
|
||||
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (4, 2))
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src], [0,4,2,6])
|
||||
sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), (2,3)))
|
||||
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (4, 2))
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src], [0,2,4,6])
|
||||
|
||||
def test_contract_mid(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(8)), ((1,2),(2,2),(3,2)))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), (2,))
|
||||
|
||||
Reference in New Issue
Block a user