diff --git a/test/unit/test_schedule_cache.py b/test/unit/test_schedule_cache.py index 368a517d73..0d56dafa86 100644 --- a/test/unit/test_schedule_cache.py +++ b/test/unit/test_schedule_cache.py @@ -1,7 +1,12 @@ import unittest -from tinygrad import Tensor, Variable +import functools +from tinygrad import Tensor, Variable, UOp, Context +from tinygrad.uop.ops import KernelInfo from tinygrad.engine.schedule import schedule_cache +def custom_set0_kernel(A:UOp, num:int) -> UOp: + return A[0].set(num).sink(arg=KernelInfo(f"custom_set0_{num}")) + class TestScheduleCache(unittest.TestCase): def test_bound_variable_reuses_cache(self): schedule_cache.clear() @@ -26,6 +31,33 @@ class TestScheduleCache(unittest.TestCase): _, var_vals = t.schedule_with_vars() self.assertEqual(var_vals, {'pos': 42}) + @Context(SPEC=0) + def test_custom_kernel(self): + for i in range(4): + a = Tensor.empty(1) + a = Tensor.custom_kernel(a, fxn=functools.partial(custom_set0_kernel, num=i))[0] + a.realize() + self.assertEqual(a.item(), i) + + @Context(SPEC=0) + def test_same_custom_function_reuses_cache(self): + schedule_cache.clear() + fxn = functools.partial(custom_set0_kernel, num=10) + + # first run + a = Tensor.empty(1) + a = Tensor.custom_kernel(a, fxn=fxn)[0] + a.realize() + self.assertEqual(a.item(), 10) + cache_size_after_first = len(schedule_cache) + + # second run with same function should reuse cache + b = Tensor.empty(1) + b = Tensor.custom_kernel(b, fxn=fxn)[0] + b.realize() + self.assertEqual(b.item(), 10) + self.assertEqual(len(schedule_cache), cache_size_after_first) + def test_simple(self): a = Tensor.ones(10).contiguous() b = Tensor.ones(10).contiguous() diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 748e28038f..93f8db1b96 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -839,7 +839,7 @@ class CustomKernel: grad_fxn: Callable|None = None # sadly CustomKernel can't be pickled or reconstructed as a str def __reduce__(self): return (CustomKernel, (panic,)) - def __repr__(self): return "CustomKernel(panic)" + def __repr__(self): return f"CustomKernel({id(self.fxn)})" @dataclass(frozen=True) class Kernel: