don't allow jit input to be const (#14045)

* don't allow jit input to be unbuffered like const

* just const to fix multi

* fix rnnt
This commit is contained in:
chenyu
2026-01-06 18:15:22 -05:00
committed by GitHub
parent a8896f28e1
commit c714881832
6 changed files with 21 additions and 20 deletions

View File

@@ -129,7 +129,7 @@ class LSTM:
return self.do_step(x_, hc_)
if hc is None:
hc = Tensor.zeros(self.layers, 2 * x.shape[1], self.hidden_size, requires_grad=False)
hc = Tensor.zeros(self.layers, 2 * x.shape[1], self.hidden_size, requires_grad=False).contiguous().realize()
output = None
for t in range(x.shape[0]):

View File

@@ -414,12 +414,6 @@ class TestJit(unittest.TestCase):
assert isinstance(jf.jit_cache[0].prg, graph_t)
assert isinstance(jf.jit_cache[1].prg, graph_t)
def test_jit_const_inputs(self):
@TinyJit
def g(x,y,z): return (x+y+z).realize()
for i in range(5):
np.testing.assert_equal(g(Tensor([i]*3), Tensor.ones(3), Tensor.zeros(3)).numpy(), np.array([i+1]*3))
def test_jitted_clone(self):
def f(a): return a.clone().realize()
jf = TinyJit(f)
@@ -496,9 +490,10 @@ class TestJit(unittest.TestCase):
f(Tensor.empty(1))
f(Tensor.empty(1))
# TODO: this should fail since input has a different size
f(Tensor(2.0)).item()
# TODO: this should not fail, and should return 3
# scalar const input is not allowed
with self.assertRaises(JitError):
f(Tensor(2.0)).item()
# list input has different view structure than empty(1)
with self.assertRaises(JitError):
f(Tensor([2.0])).item()

View File

@@ -11,9 +11,9 @@ SILENT MISMATCHES (highest priority - wrong results, no error):
output_buffer_reuse MED performance tradeoff, could add option or better docs
python_constants_frozen HARD inherent to tracing JITs
conditional_branches_frozen HARD inherent to tracing JITs
unrealized_const_input_frozen HARD unrealized const has no buffer to replace, values baked in
ERRORS RAISED (lower priority - at least users know):
unrealized_const_input_error EASY raises JitError for unrealized const inputs
non_tensor_outputs_error EASY raises JitError if return contains non-Tensor values
positional_kwargs_cannot_mix EASY normalize positional args to kwargs using function signature
duplicate_inputs_fail MED would need to handle aliasing in input_replace
@@ -140,16 +140,20 @@ class TestJitFootguns(unittest.TestCase):
self.assertEqual(results[2], 20) # should be 30!
self.assertEqual(results[3], 20) # should be 40!
def test_unrealized_const_input_frozen(self):
"""Unrealized const tensors have no buffer to replace, so values are baked in at capture time."""
def test_unrealized_const_input_error(self):
"""Const tensors have no buffer to replace, so JIT raises an error. Even explicit .realize() doesn't help."""
@TinyJit
def f(a, b): return (a * b).realize()
for i in range(1, 5):
result = f(Tensor([1, 2, 3]).realize(), Tensor(i)) # Tensor(i) is unrealized const
# value is frozen at capture (i=2), so i=3,4 give wrong results
expected = [2, 4, 6] if i >= 2 else [i, 2*i, 3*i]
np.testing.assert_equal(result.numpy(), expected) # i=3,4 should be [3,6,9], [4,8,12]!
# unrealized const fails
with self.assertRaises(JitError):
f(Tensor([1, 2, 3]).realize(), Tensor(2))
# explicit .realize() on const still fails - const cannot be realized to have a buffer
@TinyJit
def g(a, b): return (a * b).realize()
with self.assertRaises(JitError):
g(Tensor([1, 2, 3]).realize(), Tensor(2).realize())
def test_conditional_branches_frozen(self):
"""Only the branch taken during capture runs thereafter."""

View File

@@ -217,7 +217,7 @@ class TestSymbolicJit(unittest.TestCase):
def test_ones_sum(self):
def f(a): return a.sum().realize()
jf = TinyJit(f)
t = Tensor.ones(10)
t = Tensor.ones(10).contiguous()
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
symbolic = jf(t[:vi]).item()

View File

@@ -199,7 +199,7 @@ class TestSymbolicOps(unittest.TestCase):
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_ones_sum(self):
t = Tensor.ones(10)
t = Tensor.ones(10).contiguous()
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
symbolic = t[:vi].sum().item()

View File

@@ -231,6 +231,8 @@ def _prepare_jit_inputs(args, kwargs):
if len(unrealized_tensors := [x for x in tensors if not x.uop.is_realized]): Tensor.realize(*unrealized_tensors)
# TODO: this multi unpack stuff is not well tested.
lbs: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])
if any(lb.base.op is Ops.CONST for lb in lbs):
raise JitError("JIT inputs cannot be const, create a buffer with .contiguous()")
input_buffers: list[Buffer] = flatten([rb.bufs if isinstance(rb:=lb.base.realized, MultiBuffer) else [rb]
for lb in lbs if lb.base.realized is not None])
if len(set(input_buffers)) != len(input_buffers): raise JitError("duplicate inputs to JIT")