mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
jit includes tensor inputs in containers (#14043)
* jit includes tensor inputs in containers * cleanup
This commit is contained in:
@@ -184,16 +184,9 @@ class TestJit(unittest.TestCase):
|
||||
def test_array_jit(self):
|
||||
@TinyJit
|
||||
def add_array(a, arr): return (a+arr[0]).realize()
|
||||
for i in range(5):
|
||||
a = Tensor.randn(10, 10)
|
||||
b = Tensor.randn(10, 10)
|
||||
a.realize(), b.realize()
|
||||
c = add_array(a, [b])
|
||||
if i >= 2:
|
||||
# should fail once jitted since jit can't handle arrays
|
||||
np.testing.assert_allclose(np.any(np.not_equal(c.numpy(),a.numpy()+b.numpy())), True, atol=1e-4, rtol=1e-5)
|
||||
else:
|
||||
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
for _ in range(5):
|
||||
a, b = Tensor.randn(10, 10).realize(), Tensor.randn(10, 10).realize()
|
||||
np.testing.assert_allclose(add_array(a, [b]).numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
assert_jit_cache_len(add_array, 1)
|
||||
|
||||
def test_jit_copyin(self):
|
||||
|
||||
@@ -6,7 +6,6 @@ Each test shows behavior that works without JIT but changes with JIT.
|
||||
Comments marked "should be X!" indicate the intuitively expected value.
|
||||
|
||||
SILENT MISMATCHES (highest priority - wrong results, no error):
|
||||
tensors_in_containers_ignored EASY only checks t.__class__ is Tensor, could scan lists/dicts
|
||||
class_method_shared_across_instances EASY could check if first arg is self and warn
|
||||
output_buffer_reuse MED performance tradeoff, could add option or better docs
|
||||
python_constants_frozen HARD inherent to tracing JITs
|
||||
@@ -65,20 +64,12 @@ class TestJitFootguns(unittest.TestCase):
|
||||
with self.assertRaises(JitError):
|
||||
f(x, x)
|
||||
|
||||
def test_tensors_in_containers_ignored(self):
|
||||
"""Tensors inside lists/dicts are not tracked as inputs."""
|
||||
def test_tensors_in_containers(self):
|
||||
@TinyJit
|
||||
def f(a, arr): return (a + arr[0]).realize()
|
||||
|
||||
results = []
|
||||
for i in range(4):
|
||||
a, b = Tensor([1, 1, 1]).realize(), Tensor([i, i, i]).realize()
|
||||
results.append(f(a, [b]).numpy().copy())
|
||||
|
||||
np.testing.assert_array_equal(results[0], [1, 1, 1]) # warmup
|
||||
np.testing.assert_array_equal(results[1], [2, 2, 2]) # capture
|
||||
np.testing.assert_array_equal(results[2], [2, 2, 2]) # should be [3,3,3]!
|
||||
np.testing.assert_array_equal(results[3], [2, 2, 2]) # should be [4,4,4]!
|
||||
np.testing.assert_array_equal(f(a, [b]).numpy(), [1+i, 1+i, 1+i])
|
||||
|
||||
def test_nested_jit_fails_on_second_call(self):
|
||||
"""Nested JIT works on first call but fails on second."""
|
||||
|
||||
@@ -228,6 +228,10 @@ class CapturedJit(Generic[ReturnType]):
|
||||
def _prepare_jit_inputs(args, kwargs):
|
||||
input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor]
|
||||
names, tensors = [name for name,_ in input_tensors], [t for _,t in input_tensors]
|
||||
# extract tensors from containers (shallow, not recursive to avoid grabbing model weights)
|
||||
for x in args + tuple(kwargs.values()):
|
||||
it = x if isinstance(x, (tuple,list)) else x.values() if isinstance(x, dict) else []
|
||||
tensors += [t for t in it if t.__class__ is Tensor and not any(t is y for y in tensors)]
|
||||
if len(unrealized_tensors := [x for x in tensors if not x.uop.is_realized]): Tensor.realize(*unrealized_tensors)
|
||||
# TODO: this multi unpack stuff is not well tested.
|
||||
lbs: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])
|
||||
|
||||
Reference in New Issue
Block a user