test exact kernel count in torch_backend/test_kernel_fusion (#15091)

This commit is contained in:
chenyu
2026-03-02 14:26:32 -05:00
committed by GitHub
parent f80b1033c5
commit 71f228f80f

View File

@@ -1,7 +1,6 @@
# simple tests
import unittest
import torch
import warnings
from tinygrad.helpers import getenv, GlobalCounters
if getenv("TINY_BACKEND2"):
import extra.torch_backend.backend2
@@ -18,9 +17,7 @@ class TestKernelFusionRegression(unittest.TestCase):
torch.manual_seed(42)
GlobalCounters.reset()
fn().detach().cpu().numpy()
expectation = f"{GlobalCounters.kernel_count} vs {expected_kernels} expected."
if GlobalCounters.kernel_count < expected_kernels: warnings.warn(f"{expectation} Expectation can be lowered.", UserWarning)
self.assertLessEqual(GlobalCounters.kernel_count, expected_kernels, f"{expectation}")
self.assertEqual(GlobalCounters.kernel_count, expected_kernels)
def test_elementwise_fusion(self):
def fn():
@@ -34,7 +31,7 @@ class TestKernelFusionRegression(unittest.TestCase):
conv = torch.nn.Conv2d(3, 16, 3, padding=1).to(device)
with torch.no_grad():
return torch.nn.functional.relu(conv(x))
self._check_kernel_count(fn, 8)
self._check_kernel_count(fn, 6)
def test_batchnorm_fusion(self):
def fn():
@@ -44,7 +41,7 @@ class TestKernelFusionRegression(unittest.TestCase):
bn.eval()
with torch.no_grad():
return torch.nn.functional.relu(bn(conv(x)))
self._check_kernel_count(fn, 16)
self._check_kernel_count(fn, 10)
def test_reduce_fusion(self):
def fn():
@@ -92,7 +89,7 @@ class TestKernelFusionRegression(unittest.TestCase):
out = bn(conv(x))
out += identity
return torch.nn.functional.relu(out)
self._check_kernel_count(fn, 17)
self._check_kernel_count(fn, 12)
def test_multiple_inplace_ops_fusion(self):
def fn():
@@ -117,7 +114,7 @@ class TestKernelFusionRegression(unittest.TestCase):
bn.train()
with torch.no_grad():
return bn(x)
self._check_kernel_count(fn, 10)
self._check_kernel_count(fn, 8)
# this is a minimal extra/other_mnist/beautiful_mnist_torch.py to cover fusion for training with optimizer
def test_mnist_training_fusion(self):
@@ -138,7 +135,7 @@ class TestKernelFusionRegression(unittest.TestCase):
loss.backward()
optimizer.step()
return loss
self._check_kernel_count(fn, 28)
self._check_kernel_count(fn, 24)
if __name__ == "__main__":
unittest.main()