scheduler: sched_cache bugfix for different Tensor.custom_kernel schedules (#14161)

* simplest failing test

* min fix

* same function reuses the cache

* SPEC=2 never worked for custom_kernel
This commit is contained in:
qazal
2026-01-15 00:59:14 -05:00
committed by GitHub
parent b46da603fe
commit 164bc678a6
2 changed files with 34 additions and 2 deletions

View File

@@ -1,7 +1,12 @@
import unittest
from tinygrad import Tensor, Variable
import functools
from tinygrad import Tensor, Variable, UOp, Context
from tinygrad.uop.ops import KernelInfo
from tinygrad.engine.schedule import schedule_cache
def custom_set0_kernel(A:UOp, num:int) -> UOp:
return A[0].set(num).sink(arg=KernelInfo(f"custom_set0_{num}"))
class TestScheduleCache(unittest.TestCase):
def test_bound_variable_reuses_cache(self):
schedule_cache.clear()
@@ -26,6 +31,33 @@ class TestScheduleCache(unittest.TestCase):
_, var_vals = t.schedule_with_vars()
self.assertEqual(var_vals, {'pos': 42})
@Context(SPEC=0)
def test_custom_kernel(self):
for i in range(4):
a = Tensor.empty(1)
a = Tensor.custom_kernel(a, fxn=functools.partial(custom_set0_kernel, num=i))[0]
a.realize()
self.assertEqual(a.item(), i)
@Context(SPEC=0)
def test_same_custom_function_reuses_cache(self):
schedule_cache.clear()
fxn = functools.partial(custom_set0_kernel, num=10)
# first run
a = Tensor.empty(1)
a = Tensor.custom_kernel(a, fxn=fxn)[0]
a.realize()
self.assertEqual(a.item(), 10)
cache_size_after_first = len(schedule_cache)
# second run with same function should reuse cache
b = Tensor.empty(1)
b = Tensor.custom_kernel(b, fxn=fxn)[0]
b.realize()
self.assertEqual(b.item(), 10)
self.assertEqual(len(schedule_cache), cache_size_after_first)
def test_simple(self):
a = Tensor.ones(10).contiguous()
b = Tensor.ones(10).contiguous()

View File

@@ -839,7 +839,7 @@ class CustomKernel:
grad_fxn: Callable|None = None
# sadly CustomKernel can't be pickled or reconstructed as a str
def __reduce__(self): return (CustomKernel, (panic,))
def __repr__(self): return "CustomKernel(panic)"
def __repr__(self): return f"CustomKernel({id(self.fxn)})"
@dataclass(frozen=True)
class Kernel: