diff --git a/extra/torch_backend/test_kernel_fusion.py b/extra/torch_backend/test_kernel_fusion.py index 0546323366..dffcfe067f 100644 --- a/extra/torch_backend/test_kernel_fusion.py +++ b/extra/torch_backend/test_kernel_fusion.py @@ -1,7 +1,6 @@ # simple tests import unittest import torch -import warnings from tinygrad.helpers import getenv, GlobalCounters if getenv("TINY_BACKEND2"): import extra.torch_backend.backend2 @@ -18,9 +17,7 @@ class TestKernelFusionRegression(unittest.TestCase): torch.manual_seed(42) GlobalCounters.reset() fn().detach().cpu().numpy() - expectation = f"{GlobalCounters.kernel_count} vs {expected_kernels} expected." - if GlobalCounters.kernel_count < expected_kernels: warnings.warn(f"{expectation} Expectation can be lowered.", UserWarning) - self.assertLessEqual(GlobalCounters.kernel_count, expected_kernels, f"{expectation}") + self.assertEqual(GlobalCounters.kernel_count, expected_kernels) def test_elementwise_fusion(self): def fn(): @@ -34,7 +31,7 @@ class TestKernelFusionRegression(unittest.TestCase): conv = torch.nn.Conv2d(3, 16, 3, padding=1).to(device) with torch.no_grad(): return torch.nn.functional.relu(conv(x)) - self._check_kernel_count(fn, 8) + self._check_kernel_count(fn, 6) def test_batchnorm_fusion(self): def fn(): @@ -44,7 +41,7 @@ class TestKernelFusionRegression(unittest.TestCase): bn.eval() with torch.no_grad(): return torch.nn.functional.relu(bn(conv(x))) - self._check_kernel_count(fn, 16) + self._check_kernel_count(fn, 10) def test_reduce_fusion(self): def fn(): @@ -92,7 +89,7 @@ class TestKernelFusionRegression(unittest.TestCase): out = bn(conv(x)) out += identity return torch.nn.functional.relu(out) - self._check_kernel_count(fn, 17) + self._check_kernel_count(fn, 12) def test_multiple_inplace_ops_fusion(self): def fn(): @@ -117,7 +114,7 @@ class TestKernelFusionRegression(unittest.TestCase): bn.train() with torch.no_grad(): return bn(x) - self._check_kernel_count(fn, 10) + self._check_kernel_count(fn, 8) # this is a minimal extra/other_mnist/beautiful_mnist_torch.py to cover fusion for training with optimizer def test_mnist_training_fusion(self): @@ -138,7 +135,7 @@ class TestKernelFusionRegression(unittest.TestCase): loss.backward() optimizer.step() return loss - self._check_kernel_count(fn, 28) + self._check_kernel_count(fn, 24) if __name__ == "__main__": unittest.main()