mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
jit includes tensor inputs in containers (#14043)
* jit includes tensor inputs in containers * cleanup
This commit is contained in:
@@ -184,16 +184,9 @@ class TestJit(unittest.TestCase):
|
||||
def test_array_jit(self):
|
||||
@TinyJit
|
||||
def add_array(a, arr): return (a+arr[0]).realize()
|
||||
for i in range(5):
|
||||
a = Tensor.randn(10, 10)
|
||||
b = Tensor.randn(10, 10)
|
||||
a.realize(), b.realize()
|
||||
c = add_array(a, [b])
|
||||
if i >= 2:
|
||||
# should fail once jitted since jit can't handle arrays
|
||||
np.testing.assert_allclose(np.any(np.not_equal(c.numpy(),a.numpy()+b.numpy())), True, atol=1e-4, rtol=1e-5)
|
||||
else:
|
||||
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
for _ in range(5):
|
||||
a, b = Tensor.randn(10, 10).realize(), Tensor.randn(10, 10).realize()
|
||||
np.testing.assert_allclose(add_array(a, [b]).numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
|
||||
assert_jit_cache_len(add_array, 1)
|
||||
|
||||
def test_jit_copyin(self):
|
||||
|
||||
@@ -6,7 +6,6 @@ Each test shows behavior that works without JIT but changes with JIT.
|
||||
Comments marked "should be X!" indicate the intuitively expected value.
|
||||
|
||||
SILENT MISMATCHES (highest priority - wrong results, no error):
|
||||
tensors_in_containers_ignored EASY only checks t.__class__ is Tensor, could scan lists/dicts
|
||||
class_method_shared_across_instances EASY could check if first arg is self and warn
|
||||
output_buffer_reuse MED performance tradeoff, could add option or better docs
|
||||
python_constants_frozen HARD inherent to tracing JITs
|
||||
@@ -65,20 +64,12 @@ class TestJitFootguns(unittest.TestCase):
|
||||
with self.assertRaises(JitError):
|
||||
f(x, x)
|
||||
|
||||
def test_tensors_in_containers_ignored(self):
|
||||
"""Tensors inside lists/dicts are not tracked as inputs."""
|
||||
def test_tensors_in_containers(self):
|
||||
@TinyJit
|
||||
def f(a, arr): return (a + arr[0]).realize()
|
||||
|
||||
results = []
|
||||
for i in range(4):
|
||||
a, b = Tensor([1, 1, 1]).realize(), Tensor([i, i, i]).realize()
|
||||
results.append(f(a, [b]).numpy().copy())
|
||||
|
||||
np.testing.assert_array_equal(results[0], [1, 1, 1]) # warmup
|
||||
np.testing.assert_array_equal(results[1], [2, 2, 2]) # capture
|
||||
np.testing.assert_array_equal(results[2], [2, 2, 2]) # should be [3,3,3]!
|
||||
np.testing.assert_array_equal(results[3], [2, 2, 2]) # should be [4,4,4]!
|
||||
np.testing.assert_array_equal(f(a, [b]).numpy(), [1+i, 1+i, 1+i])
|
||||
|
||||
def test_nested_jit_fails_on_second_call(self):
|
||||
"""Nested JIT works on first call but fails on second."""
|
||||
|
||||
Reference in New Issue
Block a user