mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
scheduler: sched_cache bugfix for different Tensor.custom_kernel schedules (#14161)
* simplest failing test * min fix * same function reuses the cache * SPEC=2 never worked for custom_kernel
This commit is contained in:
@@ -1,7 +1,12 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, Variable
|
||||
import functools
|
||||
from tinygrad import Tensor, Variable, UOp, Context
|
||||
from tinygrad.uop.ops import KernelInfo
|
||||
from tinygrad.engine.schedule import schedule_cache
|
||||
|
||||
def custom_set0_kernel(A:UOp, num:int) -> UOp:
|
||||
return A[0].set(num).sink(arg=KernelInfo(f"custom_set0_{num}"))
|
||||
|
||||
class TestScheduleCache(unittest.TestCase):
|
||||
def test_bound_variable_reuses_cache(self):
|
||||
schedule_cache.clear()
|
||||
@@ -26,6 +31,33 @@ class TestScheduleCache(unittest.TestCase):
|
||||
_, var_vals = t.schedule_with_vars()
|
||||
self.assertEqual(var_vals, {'pos': 42})
|
||||
|
||||
@Context(SPEC=0)
|
||||
def test_custom_kernel(self):
|
||||
for i in range(4):
|
||||
a = Tensor.empty(1)
|
||||
a = Tensor.custom_kernel(a, fxn=functools.partial(custom_set0_kernel, num=i))[0]
|
||||
a.realize()
|
||||
self.assertEqual(a.item(), i)
|
||||
|
||||
@Context(SPEC=0)
|
||||
def test_same_custom_function_reuses_cache(self):
|
||||
schedule_cache.clear()
|
||||
fxn = functools.partial(custom_set0_kernel, num=10)
|
||||
|
||||
# first run
|
||||
a = Tensor.empty(1)
|
||||
a = Tensor.custom_kernel(a, fxn=fxn)[0]
|
||||
a.realize()
|
||||
self.assertEqual(a.item(), 10)
|
||||
cache_size_after_first = len(schedule_cache)
|
||||
|
||||
# second run with same function should reuse cache
|
||||
b = Tensor.empty(1)
|
||||
b = Tensor.custom_kernel(b, fxn=fxn)[0]
|
||||
b.realize()
|
||||
self.assertEqual(b.item(), 10)
|
||||
self.assertEqual(len(schedule_cache), cache_size_after_first)
|
||||
|
||||
def test_simple(self):
|
||||
a = Tensor.ones(10).contiguous()
|
||||
b = Tensor.ones(10).contiguous()
|
||||
|
||||
@@ -839,7 +839,7 @@ class CustomKernel:
|
||||
grad_fxn: Callable|None = None
|
||||
# sadly CustomKernel can't be pickled or reconstructed as a str
|
||||
def __reduce__(self): return (CustomKernel, (panic,))
|
||||
def __repr__(self): return "CustomKernel(panic)"
|
||||
def __repr__(self): return f"CustomKernel({id(self.fxn)})"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Kernel:
|
||||
|
||||
Reference in New Issue
Block a user