mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
jit includes tensor inputs in containers (#14043)
* jit includes tensor inputs in containers * cleanup
This commit is contained in:
@@ -184,16 +184,9 @@ class TestJit(unittest.TestCase):
|
|||||||
def test_array_jit(self):
|
def test_array_jit(self):
|
||||||
@TinyJit
|
@TinyJit
|
||||||
def add_array(a, arr): return (a+arr[0]).realize()
|
def add_array(a, arr): return (a+arr[0]).realize()
|
||||||
for i in range(5):
|
for _ in range(5):
|
||||||
a = Tensor.randn(10, 10)
|
a, b = Tensor.randn(10, 10).realize(), Tensor.randn(10, 10).realize()
|
||||||
b = Tensor.randn(10, 10)
|
np.testing.assert_allclose(add_array(a, [b]).numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
|
||||||
a.realize(), b.realize()
|
|
||||||
c = add_array(a, [b])
|
|
||||||
if i >= 2:
|
|
||||||
# should fail once jitted since jit can't handle arrays
|
|
||||||
np.testing.assert_allclose(np.any(np.not_equal(c.numpy(),a.numpy()+b.numpy())), True, atol=1e-4, rtol=1e-5)
|
|
||||||
else:
|
|
||||||
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
|
|
||||||
assert_jit_cache_len(add_array, 1)
|
assert_jit_cache_len(add_array, 1)
|
||||||
|
|
||||||
def test_jit_copyin(self):
|
def test_jit_copyin(self):
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ Each test shows behavior that works without JIT but changes with JIT.
|
|||||||
Comments marked "should be X!" indicate the intuitively expected value.
|
Comments marked "should be X!" indicate the intuitively expected value.
|
||||||
|
|
||||||
SILENT MISMATCHES (highest priority - wrong results, no error):
|
SILENT MISMATCHES (highest priority - wrong results, no error):
|
||||||
tensors_in_containers_ignored EASY only checks t.__class__ is Tensor, could scan lists/dicts
|
|
||||||
class_method_shared_across_instances EASY could check if first arg is self and warn
|
class_method_shared_across_instances EASY could check if first arg is self and warn
|
||||||
output_buffer_reuse MED performance tradeoff, could add option or better docs
|
output_buffer_reuse MED performance tradeoff, could add option or better docs
|
||||||
python_constants_frozen HARD inherent to tracing JITs
|
python_constants_frozen HARD inherent to tracing JITs
|
||||||
@@ -65,20 +64,12 @@ class TestJitFootguns(unittest.TestCase):
|
|||||||
with self.assertRaises(JitError):
|
with self.assertRaises(JitError):
|
||||||
f(x, x)
|
f(x, x)
|
||||||
|
|
||||||
def test_tensors_in_containers_ignored(self):
|
def test_tensors_in_containers(self):
|
||||||
"""Tensors inside lists/dicts are not tracked as inputs."""
|
|
||||||
@TinyJit
|
@TinyJit
|
||||||
def f(a, arr): return (a + arr[0]).realize()
|
def f(a, arr): return (a + arr[0]).realize()
|
||||||
|
|
||||||
results = []
|
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
a, b = Tensor([1, 1, 1]).realize(), Tensor([i, i, i]).realize()
|
a, b = Tensor([1, 1, 1]).realize(), Tensor([i, i, i]).realize()
|
||||||
results.append(f(a, [b]).numpy().copy())
|
np.testing.assert_array_equal(f(a, [b]).numpy(), [1+i, 1+i, 1+i])
|
||||||
|
|
||||||
np.testing.assert_array_equal(results[0], [1, 1, 1]) # warmup
|
|
||||||
np.testing.assert_array_equal(results[1], [2, 2, 2]) # capture
|
|
||||||
np.testing.assert_array_equal(results[2], [2, 2, 2]) # should be [3,3,3]!
|
|
||||||
np.testing.assert_array_equal(results[3], [2, 2, 2]) # should be [4,4,4]!
|
|
||||||
|
|
||||||
def test_nested_jit_fails_on_second_call(self):
|
def test_nested_jit_fails_on_second_call(self):
|
||||||
"""Nested JIT works on first call but fails on second."""
|
"""Nested JIT works on first call but fails on second."""
|
||||||
|
|||||||
@@ -228,6 +228,10 @@ class CapturedJit(Generic[ReturnType]):
|
|||||||
def _prepare_jit_inputs(args, kwargs):
|
def _prepare_jit_inputs(args, kwargs):
|
||||||
input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor]
|
input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor]
|
||||||
names, tensors = [name for name,_ in input_tensors], [t for _,t in input_tensors]
|
names, tensors = [name for name,_ in input_tensors], [t for _,t in input_tensors]
|
||||||
|
# extract tensors from containers (shallow, not recursive to avoid grabbing model weights)
|
||||||
|
for x in args + tuple(kwargs.values()):
|
||||||
|
it = x if isinstance(x, (tuple,list)) else x.values() if isinstance(x, dict) else []
|
||||||
|
tensors += [t for t in it if t.__class__ is Tensor and not any(t is y for y in tensors)]
|
||||||
if len(unrealized_tensors := [x for x in tensors if not x.uop.is_realized]): Tensor.realize(*unrealized_tensors)
|
if len(unrealized_tensors := [x for x in tensors if not x.uop.is_realized]): Tensor.realize(*unrealized_tensors)
|
||||||
# TODO: this multi unpack stuff is not well tested.
|
# TODO: this multi unpack stuff is not well tested.
|
||||||
lbs: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])
|
lbs: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])
|
||||||
|
|||||||
Reference in New Issue
Block a user