diff --git a/test/test_softmax_fusion.py b/test/test_softmax_fusion.py index f4b848e41a..e9a4594bfb 100644 --- a/test/test_softmax_fusion.py +++ b/test/test_softmax_fusion.py @@ -118,7 +118,6 @@ class TestFuse(unittest.TestCase): self.assertListEqual(c.tolist(), [30]*16) @unittest.skipUnless(Device.DEFAULT == "METAL", "METAL TC") - @unittest.expectedFailure # TODO: fix def test_fuse_and_tc_opt(self): A = Tensor.randn(8, 8).realize() B = Tensor.randn(8, 8).realize() diff --git a/tinygrad/codegen/opt/kernel.py b/tinygrad/codegen/opt/kernel.py index 83395a591d..3df7100e78 100644 --- a/tinygrad/codegen/opt/kernel.py +++ b/tinygrad/codegen/opt/kernel.py @@ -378,9 +378,9 @@ class Kernel: tensor_cores = self.opts.tensor_cores if tc_select == -1 else [self.opts.tensor_cores[tc_select]] for tc in tensor_cores: tensor_core_opts = [self._create_tc_opts(reduceop, tc, axis, opt_level) for reduceop in self.reduceops] + if tensor_core_opts[0] is None: continue # can only fuse reduces with the same tc options assert all_same(tensor_core_opts) - if tensor_core_opts[0] is None: continue self.tensor_core_opts = tc_opts = tensor_core_opts[0] # attempt to pad the tensor axes that require it