diff --git a/extra/models/rnnt.py b/extra/models/rnnt.py index 8382aae6ac..e7ad0f54b9 100644 --- a/extra/models/rnnt.py +++ b/extra/models/rnnt.py @@ -129,7 +129,7 @@ class LSTM: return self.do_step(x_, hc_) if hc is None: - hc = Tensor.zeros(self.layers, 2 * x.shape[1], self.hidden_size, requires_grad=False) + hc = Tensor.zeros(self.layers, 2 * x.shape[1], self.hidden_size, requires_grad=False).contiguous().realize() output = None for t in range(x.shape[0]): diff --git a/test/test_jit.py b/test/test_jit.py index 265cee5fdd..7720d36dfc 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -414,12 +414,6 @@ class TestJit(unittest.TestCase): assert isinstance(jf.jit_cache[0].prg, graph_t) assert isinstance(jf.jit_cache[1].prg, graph_t) - def test_jit_const_inputs(self): - @TinyJit - def g(x,y,z): return (x+y+z).realize() - for i in range(5): - np.testing.assert_equal(g(Tensor([i]*3), Tensor.ones(3), Tensor.zeros(3)).numpy(), np.array([i+1]*3)) - def test_jitted_clone(self): def f(a): return a.clone().realize() jf = TinyJit(f) @@ -496,9 +490,10 @@ class TestJit(unittest.TestCase): f(Tensor.empty(1)) f(Tensor.empty(1)) - # TODO: this should fail since input has a different size - f(Tensor(2.0)).item() - # TODO: this should not fail, and should return 3 + # scalar const input is not allowed + with self.assertRaises(JitError): + f(Tensor(2.0)).item() + # list input has different view structure than empty(1) with self.assertRaises(JitError): f(Tensor([2.0])).item() diff --git a/test/test_jit_footguns.py b/test/test_jit_footguns.py index 24c6b3a643..f54838e54a 100644 --- a/test/test_jit_footguns.py +++ b/test/test_jit_footguns.py @@ -11,9 +11,9 @@ SILENT MISMATCHES (highest priority - wrong results, no error): output_buffer_reuse MED performance tradeoff, could add option or better docs python_constants_frozen HARD inherent to tracing JITs conditional_branches_frozen HARD inherent to tracing JITs - unrealized_const_input_frozen HARD unrealized const has no buffer to replace, values baked in ERRORS RAISED (lower priority - at least users know): + unrealized_const_input_error EASY raises JitError for unrealized const inputs non_tensor_outputs_error EASY raises JitError if return contains non-Tensor values positional_kwargs_cannot_mix EASY normalize positional args to kwargs using function signature duplicate_inputs_fail MED would need to handle aliasing in input_replace @@ -140,16 +140,20 @@ class TestJitFootguns(unittest.TestCase): self.assertEqual(results[2], 20) # should be 30! self.assertEqual(results[3], 20) # should be 40! - def test_unrealized_const_input_frozen(self): - """Unrealized const tensors have no buffer to replace, so values are baked in at capture time.""" + def test_unrealized_const_input_error(self): + """Const tensors have no buffer to replace, so JIT raises an error. Even explicit .realize() doesn't help.""" @TinyJit def f(a, b): return (a * b).realize() - for i in range(1, 5): - result = f(Tensor([1, 2, 3]).realize(), Tensor(i)) # Tensor(i) is unrealized const - # value is frozen at capture (i=2), so i=3,4 give wrong results - expected = [2, 4, 6] if i >= 2 else [i, 2*i, 3*i] - np.testing.assert_equal(result.numpy(), expected) # i=3,4 should be [3,6,9], [4,8,12]! + # unrealized const fails + with self.assertRaises(JitError): + f(Tensor([1, 2, 3]).realize(), Tensor(2)) + + # explicit .realize() on const still fails - const cannot be realized to have a buffer + @TinyJit + def g(a, b): return (a * b).realize() + with self.assertRaises(JitError): + g(Tensor([1, 2, 3]).realize(), Tensor(2).realize()) def test_conditional_branches_frozen(self): """Only the branch taken during capture runs thereafter.""" diff --git a/test/test_symbolic_jit.py b/test/test_symbolic_jit.py index ddc464661d..85c11ac5b9 100644 --- a/test/test_symbolic_jit.py +++ b/test/test_symbolic_jit.py @@ -217,7 +217,7 @@ class TestSymbolicJit(unittest.TestCase): def test_ones_sum(self): def f(a): return a.sum().realize() jf = TinyJit(f) - t = Tensor.ones(10) + t = Tensor.ones(10).contiguous() for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) symbolic = jf(t[:vi]).item() diff --git a/test/test_symbolic_ops.py b/test/test_symbolic_ops.py index 6b7ea926ce..9efadc7625 100644 --- a/test/test_symbolic_ops.py +++ b/test/test_symbolic_ops.py @@ -199,7 +199,7 @@ class TestSymbolicOps(unittest.TestCase): np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) def test_ones_sum(self): - t = Tensor.ones(10) + t = Tensor.ones(10).contiguous() for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) symbolic = t[:vi].sum().item() diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 8c2363903e..00906c61f5 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -231,6 +231,8 @@ def _prepare_jit_inputs(args, kwargs): if len(unrealized_tensors := [x for x in tensors if not x.uop.is_realized]): Tensor.realize(*unrealized_tensors) # TODO: this multi unpack stuff is not well tested. lbs: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors]) + if any(lb.base.op is Ops.CONST for lb in lbs): + raise JitError("JIT inputs cannot be const, create a buffer with .contiguous()") input_buffers: list[Buffer] = flatten([rb.bufs if isinstance(rb:=lb.base.realized, MultiBuffer) else [rb] for lb in lbs if lb.base.realized is not None]) if len(set(input_buffers)) != len(input_buffers): raise JitError("duplicate inputs to JIT")