mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
raise when jit fxn returns non-Tensor output (#14042)
This commit is contained in:
@@ -230,20 +230,9 @@ class TestJit(unittest.TestCase):
|
||||
def test_jit_output_non_tensor_fail(self):
|
||||
@TinyJit
|
||||
def f(a, b, i): return (a+b).realize(), i
|
||||
output1, output2 = [], []
|
||||
expect1, expect2 = [], []
|
||||
for i in range(5):
|
||||
a = Tensor.randn(10, 10)
|
||||
b = Tensor.randn(10, 10)
|
||||
o1, o2 = f(a, b, i)
|
||||
output1.append(o1.numpy().copy())
|
||||
output2.append(o2)
|
||||
expect1.append(a.numpy().copy()+b.numpy().copy())
|
||||
expect2.append(i)
|
||||
np.testing.assert_allclose(output1, expect1, atol=1e-4, rtol=1e-5)
|
||||
# the jit only works with Tensor outputs
|
||||
assert output2 != expect2
|
||||
assert_jit_cache_len(f, 1)
|
||||
with self.assertRaises(JitError):
|
||||
for i in range(3):
|
||||
f(Tensor.randn(10, 10), Tensor.randn(10, 10), i)
|
||||
|
||||
def test_jit_random_regen(self):
|
||||
def f(a, b):
|
||||
|
||||
Reference in New Issue
Block a user