mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
test exact kernel count in torch_backend/test_kernel_fusion (#15091)
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
# simple tests
|
||||
import unittest
|
||||
import torch
|
||||
import warnings
|
||||
from tinygrad.helpers import getenv, GlobalCounters
|
||||
if getenv("TINY_BACKEND2"):
|
||||
import extra.torch_backend.backend2
|
||||
@@ -18,9 +17,7 @@ class TestKernelFusionRegression(unittest.TestCase):
|
||||
torch.manual_seed(42)
|
||||
GlobalCounters.reset()
|
||||
fn().detach().cpu().numpy()
|
||||
expectation = f"{GlobalCounters.kernel_count} vs {expected_kernels} expected."
|
||||
if GlobalCounters.kernel_count < expected_kernels: warnings.warn(f"{expectation} Expectation can be lowered.", UserWarning)
|
||||
self.assertLessEqual(GlobalCounters.kernel_count, expected_kernels, f"{expectation}")
|
||||
self.assertEqual(GlobalCounters.kernel_count, expected_kernels)
|
||||
|
||||
def test_elementwise_fusion(self):
|
||||
def fn():
|
||||
@@ -34,7 +31,7 @@ class TestKernelFusionRegression(unittest.TestCase):
|
||||
conv = torch.nn.Conv2d(3, 16, 3, padding=1).to(device)
|
||||
with torch.no_grad():
|
||||
return torch.nn.functional.relu(conv(x))
|
||||
self._check_kernel_count(fn, 8)
|
||||
self._check_kernel_count(fn, 6)
|
||||
|
||||
def test_batchnorm_fusion(self):
|
||||
def fn():
|
||||
@@ -44,7 +41,7 @@ class TestKernelFusionRegression(unittest.TestCase):
|
||||
bn.eval()
|
||||
with torch.no_grad():
|
||||
return torch.nn.functional.relu(bn(conv(x)))
|
||||
self._check_kernel_count(fn, 16)
|
||||
self._check_kernel_count(fn, 10)
|
||||
|
||||
def test_reduce_fusion(self):
|
||||
def fn():
|
||||
@@ -92,7 +89,7 @@ class TestKernelFusionRegression(unittest.TestCase):
|
||||
out = bn(conv(x))
|
||||
out += identity
|
||||
return torch.nn.functional.relu(out)
|
||||
self._check_kernel_count(fn, 17)
|
||||
self._check_kernel_count(fn, 12)
|
||||
|
||||
def test_multiple_inplace_ops_fusion(self):
|
||||
def fn():
|
||||
@@ -117,7 +114,7 @@ class TestKernelFusionRegression(unittest.TestCase):
|
||||
bn.train()
|
||||
with torch.no_grad():
|
||||
return bn(x)
|
||||
self._check_kernel_count(fn, 10)
|
||||
self._check_kernel_count(fn, 8)
|
||||
|
||||
# this is a minimal extra/other_mnist/beautiful_mnist_torch.py to cover fusion for training with optimizer
|
||||
def test_mnist_training_fusion(self):
|
||||
@@ -138,7 +135,7 @@ class TestKernelFusionRegression(unittest.TestCase):
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
return loss
|
||||
self._check_kernel_count(fn, 28)
|
||||
self._check_kernel_count(fn, 24)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user