diff --git a/test/test_softmax_fusion.py b/test/test_softmax_fusion.py index 9a7efbb651..f4b848e41a 100644 --- a/test/test_softmax_fusion.py +++ b/test/test_softmax_fusion.py @@ -117,6 +117,15 @@ class TestFuse(unittest.TestCase): c = (a.sum(axis=1) + b.sum(axis=1)).fuse() self.assertListEqual(c.tolist(), [30]*16) + @unittest.skipUnless(Device.DEFAULT == "METAL", "METAL TC") + @unittest.expectedFailure # TODO: fix + def test_fuse_and_tc_opt(self): + A = Tensor.randn(8, 8).realize() + B = Tensor.randn(8, 8).realize() + C = Tensor.ones(1, 8, 8).pad(((1,1), None, None),).sum(0) + out = (C + (A @ B)).fuse() + out.realize() + class TestSoftmaxFusion(unittest.TestCase): @classmethod def setUpClass(cls):