diff --git a/test/test_jit_footguns.py b/test/test_jit_footguns.py index 6d3ea32e23..0844e3034e 100644 --- a/test/test_jit_footguns.py +++ b/test/test_jit_footguns.py @@ -9,6 +9,7 @@ SILENT MISMATCHES (highest priority - wrong results, no error): class_method_shared_across_instances EASY could check if first arg is self and warn slice_assign_requires_realize MED assign graph not connected to read during JIT replay output_buffer_reuse MED performance tradeoff, could add option or better docs + symbolic_pad_view_frozen MED pad view BIND values baked in at capture time python_constants_frozen HARD inherent to tracing JITs conditional_branches_frozen HARD inherent to tracing JITs @@ -134,6 +135,23 @@ class TestJitFootguns(unittest.TestCase): cache2.assign(Tensor.zeros(4, 4)).realize() self.assertEqual(f_fixed(v_pos.bind(i)).item(), 4.0) + def test_symbolic_pad_view_frozen(self): + """Symbolic pad view has BIND values baked in at capture time. TODO: pad should be captured in jit.""" + from tinygrad import Variable + a = Tensor.rand(3, 10).realize() + + # broken: pad is a view, BIND values frozen at capture (i=2) + @TinyJit + def f_broken(a): return (a+1).pad((None, (0, 10-a.shape[1]))).realize() + for i in range(1, 5): f_broken(a[:, :Variable("i", 1, 10).bind(i)]) + self.assertEqual(int((f_broken(a[:, :Variable("i", 1, 10).bind(4)])[0] != 0).sum().item()), 2) # should be 4! + + # workaround: contiguous fuses pad into kernel + @TinyJit + def f_fixed(a): return (a+1).pad((None, (0, 10-a.shape[1]))).contiguous().realize() + for i in range(1, 5): f_fixed(a[:, :Variable("i", 1, 10).bind(i)]) + self.assertEqual(int((f_fixed(a[:, :Variable("i", 1, 10).bind(4)])[0] != 0).sum().item()), 4) + def test_non_tensor_outputs_error(self): @TinyJit def f(x, mult): return (x * 2).realize(), mult * 10