From a85493bdbe057e42ea481420146dd2810829a602 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Tue, 23 Jul 2024 15:09:15 -0700 Subject: [PATCH] multiaxis contract test --- test/test_uop_graph.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 433b6585f8..7e831709ed 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -308,6 +308,24 @@ class TestExpander(unittest.TestCase): self.assertListEqual([x.arg for x in sink.src[0].src], [0,1,2,3]) self.assertListEqual([x.arg for x in sink.src[3].src], [12,13,14,15]) + def test_contract_axis_2_big(self): + e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,2),(2,2),(3,2),(4,2))) + con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), (2,)) + sink = expander_rewrite(con) + assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (3, 2), (4, 2)) + self.assertListEqual([x.arg for x in sink.src[0].src], [0,4]) + self.assertListEqual([x.arg for x in sink.src[6].src], [10,14]) + + @unittest.skip("TODO: add support for this") + def test_contract_multi_axis(self): + e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,2),(2,2),(3,2),(4,2))) + sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), (3,2))) + assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (4, 2)) + self.assertListEqual([x.arg for x in sink.src[0].src], [0,4,2,6]) + sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), (2,3))) + assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (4, 2)) + self.assertListEqual([x.arg for x in sink.src[0].src], [0,2,4,6]) + def test_contract_mid(self): e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(8)), ((1,2),(2,2),(3,2))) con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), (2,))