Files
tinygrad/test/test_jit_footguns.py
chenyu 72a3f78d19 jit includes tensor inputs in containers (#14043)
* jit includes tensor inputs in containers

* cleanup
2026-01-06 19:42:06 -05:00

298 lines
9.8 KiB
Python

#!/usr/bin/env python
"""
JIT Footguns: Documenting unexpected behavior changes when using @TinyJit
Each test shows behavior that works without JIT but changes with JIT.
Comments marked "should be X!" indicate the intuitively expected value.
SILENT MISMATCHES (highest priority - wrong results, no error):
class_method_shared_across_instances EASY could check if first arg is self and warn
output_buffer_reuse MED performance tradeoff, could add option or better docs
python_constants_frozen HARD inherent to tracing JITs
conditional_branches_frozen HARD inherent to tracing JITs
ERRORS RAISED (lower priority - at least users know):
unrealized_const_input_error EASY raises JitError for unrealized const inputs
non_tensor_outputs_error EASY raises JitError if return contains non-Tensor values
positional_kwargs_cannot_mix EASY normalize positional args to kwargs using function signature
duplicate_inputs_fail MED would need to handle aliasing in input_replace
nested_jit_fails_on_second_call MED could fail on first call instead of second
"""
import unittest
import numpy as np
from tinygrad import Tensor, TinyJit
from tinygrad.engine.jit import JitError
class TestJitFootguns(unittest.TestCase):
def test_output_buffer_reuse(self):
"""Output tensors share buffer after capture - old references get overwritten."""
@TinyJit
def f(x): return x.sum().realize()
r1 = f(Tensor([1, 1])) # warmup
r2 = f(Tensor([2, 2])) # capture
r3 = f(Tensor([3, 3])) # jit exec
self.assertEqual(r1.item(), 2) # warmup result independent
self.assertEqual(r3.item(), 6) # latest is correct
self.assertEqual(r2.item(), 6) # should be 4! (overwritten by r3)
def test_output_buffer_workaround(self):
"""Use .clone().realize() to get independent copies."""
@TinyJit
def f(x): return x.sum().realize()
r1 = f(Tensor([1, 1])).clone().realize()
r2 = f(Tensor([2, 2])).clone().realize()
r3 = f(Tensor([3, 3])).clone().realize()
self.assertEqual([r1.item(), r2.item(), r3.item()], [2, 4, 6])
def test_non_tensor_outputs_error(self):
@TinyJit
def f(x, mult): return (x * 2).realize(), mult * 10
with self.assertRaises(JitError):
for i in range(3): f(Tensor([i]), i)
def test_duplicate_inputs_fail(self):
"""JIT cannot handle the same tensor passed as multiple arguments."""
@TinyJit
def f(a, b): return (a + b).realize()
x = Tensor([1, 2, 3])
with self.assertRaises(JitError):
f(x, x)
def test_tensors_in_containers(self):
@TinyJit
def f(a, arr): return (a + arr[0]).realize()
for i in range(4):
a, b = Tensor([1, 1, 1]).realize(), Tensor([i, i, i]).realize()
np.testing.assert_array_equal(f(a, [b]).numpy(), [1+i, 1+i, 1+i])
def test_nested_jit_fails_on_second_call(self):
"""Nested JIT works on first call but fails on second."""
@TinyJit
def inner(t): return t + 1
@TinyJit
def outer(t): return inner(t) * 3
self.assertEqual(outer(Tensor([1])).realize().item(), 6) # works!
with self.assertRaises(RuntimeError):
outer(Tensor([2])).realize() # fails
def test_implicit_inputs_need_realize(self):
"""Closure tensors must be realized before JIT call."""
x = Tensor([0])
@TinyJit
def f(): return (x * 2).realize()
for i in range(5):
x.assign(Tensor([i])).realize() # must realize!
self.assertEqual(f().item(), i * 2)
def test_views_with_different_offsets_fail(self):
"""JIT requires consistent tensor views across calls."""
@TinyJit
def f(a): return (a + 1).realize()
base = Tensor.randn(10, 10).realize()
with self.assertRaises(JitError):
for i in range(1, 5):
f(base[:, i:i+2]) # different offset each time
def test_shape_change_after_capture_fails(self):
"""Shapes are locked at capture time."""
@TinyJit
def f(a, b): return (a + b).realize()
f(Tensor.randn(10, 10), Tensor.randn(10, 10)) # warmup
f(Tensor.randn(10, 10), Tensor.randn(10, 10)) # capture
with self.assertRaises(JitError):
f(Tensor.randn(20, 20), Tensor.randn(20, 20))
def test_python_constants_frozen(self):
"""Python variables inside JIT use capture-time values."""
mult = 1
@TinyJit
def f(x): return (x * mult).realize()
results = []
for i in range(5):
mult = i + 1
results.append(f(Tensor([10])).item())
self.assertEqual(results[0], 10) # warmup, mult=1
self.assertEqual(results[1], 20) # capture, mult=2
self.assertEqual(results[2], 20) # should be 30!
self.assertEqual(results[3], 20) # should be 40!
def test_unrealized_const_input_error(self):
"""Const tensors have no buffer to replace, so JIT raises an error. Even explicit .realize() doesn't help."""
@TinyJit
def f(a, b): return (a * b).realize()
# unrealized const fails
with self.assertRaises(JitError):
f(Tensor([1, 2, 3]).realize(), Tensor(2))
# explicit .realize() on const still fails - const cannot be realized to have a buffer
@TinyJit
def g(a, b): return (a * b).realize()
with self.assertRaises(JitError):
g(Tensor([1, 2, 3]).realize(), Tensor(2).realize())
def test_conditional_branches_frozen(self):
"""Only the branch taken during capture runs thereafter."""
@TinyJit
def f(x, use_square):
if use_square:
return (x * x).realize()
return (x * 2).realize()
f(Tensor([3]), True) # warmup
f(Tensor([3]), False) # capture (False branch)
result = f(Tensor([3]), True) # passing True but False branch runs
self.assertEqual(result.item(), 6) # should be 9!
def test_positional_kwargs_cannot_mix(self):
"""Must use same calling convention after capture."""
@TinyJit
def f(a, b): return (a + b).realize()
f(Tensor([1]), Tensor([2])) # warmup with positional
f(Tensor([1]), Tensor([2])) # capture with positional
with self.assertRaises(JitError):
f(a=Tensor([3]), b=Tensor([4])) # kwargs fail
def test_class_method_shared_across_instances(self):
"""JIT on instance methods is shared at class level."""
class Model:
def __init__(self, scale):
self.scale = Tensor([scale])
@TinyJit
def forward(self, x):
return (x * self.scale).realize()
m1, m2 = Model(2), Model(3)
m1.forward(Tensor([5])) # warmup
m1.forward(Tensor([5])) # capture with m1.scale=2
self.assertEqual(m1.forward(Tensor([5])).item(), 10)
self.assertEqual(m2.forward(Tensor([5])).item(), 10) # should be 15!
def test_side_effects_only_during_capture(self):
"""Function body not executed during JIT replay."""
call_count = [0]
@TinyJit
def f(x):
call_count[0] += 1
return (x * 2).realize()
f(Tensor([1])) # warmup
f(Tensor([2])) # capture
self.assertEqual(call_count[0], 2)
f(Tensor([3]))
f(Tensor([4]))
f(Tensor([5]))
self.assertEqual(call_count[0], 2) # still 2, not 5!
def test_nothing_realized_fails(self):
"""Must JIT at least one kernel."""
@TinyJit
def f(a, b): return None
with self.assertRaises(JitError):
for _ in range(3):
f(Tensor([1]), Tensor([2]))
def test_item_creates_unrealized_return(self):
""".item() in shape computation creates unrealized return with baked-in shape."""
@TinyJit
def f(x): return Tensor.zeros(x.sum().item())
for _ in range(3): f(Tensor([1, 1, 1])) # captures with sum=3
result = f(Tensor([2, 2, 2])) # sum=6, but shape is baked in
assert result.shape == (3,) # should be (6,)!
def test_item_bakes_in_values(self):
""".item() value is baked in, causing wrong output shapes (silent failure)."""
@TinyJit
def f(x, mask): return x.masked_select(mask)
mask_2 = Tensor([True, False, True, False])
for _ in range(3): f(Tensor([1, 2, 3, 4]), mask_2)
mask_3 = Tensor([True, True, True, False])
result = f(Tensor([1, 2, 3, 4]), mask_3)
assert result.shape == (2,) # should be (3,)!
def test_tolist_bakes_in_values(self):
""".tolist() returns Python values that get baked in (silent failure)."""
@TinyJit
def f(x): return Tensor(x.tolist())
for _ in range(3): f(Tensor([1, 2, 3]))
result = f(Tensor([4, 5, 6]))
np.testing.assert_equal(result.numpy(), [1, 2, 3]) # should be [4,5,6]!
class TestJitCorrectBehavior(unittest.TestCase):
"""Behaviors that work correctly - documented for clarity."""
def test_random_regenerates(self):
"""Random tensors regenerate each call."""
@TinyJit
def f(x):
return (x + Tensor.rand(3)).realize()
f(Tensor([0, 0, 0])) # warmup
f(Tensor([0, 0, 0])) # capture
results = {tuple(f(Tensor([0, 0, 0])).numpy().tolist()) for _ in range(5)}
self.assertEqual(len(results), 5)
def test_unrealized_return_auto_realized(self):
"""Unrealized return tensors are auto-realized."""
@TinyJit
def f(a, b): return a + b # no explicit realize
for _ in range(5):
a, b = Tensor.randn(10), Tensor.randn(10)
np.testing.assert_allclose(f(a, b).numpy(), a.numpy() + b.numpy(), atol=1e-5)
def test_kwargs_order_doesnt_matter(self):
"""Kwargs are sorted by name, so order doesn't matter."""
@TinyJit
def f(first, second): return (first / second).realize()
for _ in range(3):
a, b = Tensor.randn(10), Tensor.randn(10) + 1
np.testing.assert_allclose(f(second=b, first=a).numpy(), a.numpy() / b.numpy(), atol=1e-4)
np.testing.assert_allclose(f(first=a, second=b).numpy(), a.numpy() / b.numpy(), atol=1e-4)
def test_input_mutation_consistent(self):
"""Input mutation via assign works consistently."""
@TinyJit
def f(x):
x += 1
x.realize()
return x
a = Tensor([0]).contiguous().realize()
for _ in range(5):
f(a)
self.assertEqual(a.item(), 5)
if __name__ == '__main__':
unittest.main()