mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
don't allow jit input to be const (#14045)
* don't allow jit input to be unbuffered like const * just const to fix multi * fix rnnt
This commit is contained in:
@@ -129,7 +129,7 @@ class LSTM:
|
||||
return self.do_step(x_, hc_)
|
||||
|
||||
if hc is None:
|
||||
hc = Tensor.zeros(self.layers, 2 * x.shape[1], self.hidden_size, requires_grad=False)
|
||||
hc = Tensor.zeros(self.layers, 2 * x.shape[1], self.hidden_size, requires_grad=False).contiguous().realize()
|
||||
|
||||
output = None
|
||||
for t in range(x.shape[0]):
|
||||
|
||||
@@ -414,12 +414,6 @@ class TestJit(unittest.TestCase):
|
||||
assert isinstance(jf.jit_cache[0].prg, graph_t)
|
||||
assert isinstance(jf.jit_cache[1].prg, graph_t)
|
||||
|
||||
def test_jit_const_inputs(self):
|
||||
@TinyJit
|
||||
def g(x,y,z): return (x+y+z).realize()
|
||||
for i in range(5):
|
||||
np.testing.assert_equal(g(Tensor([i]*3), Tensor.ones(3), Tensor.zeros(3)).numpy(), np.array([i+1]*3))
|
||||
|
||||
def test_jitted_clone(self):
|
||||
def f(a): return a.clone().realize()
|
||||
jf = TinyJit(f)
|
||||
@@ -496,9 +490,10 @@ class TestJit(unittest.TestCase):
|
||||
|
||||
f(Tensor.empty(1))
|
||||
f(Tensor.empty(1))
|
||||
# TODO: this should fail since input has a different size
|
||||
f(Tensor(2.0)).item()
|
||||
# TODO: this should not fail, and should return 3
|
||||
# scalar const input is not allowed
|
||||
with self.assertRaises(JitError):
|
||||
f(Tensor(2.0)).item()
|
||||
# list input has different view structure than empty(1)
|
||||
with self.assertRaises(JitError):
|
||||
f(Tensor([2.0])).item()
|
||||
|
||||
|
||||
@@ -11,9 +11,9 @@ SILENT MISMATCHES (highest priority - wrong results, no error):
|
||||
output_buffer_reuse MED performance tradeoff, could add option or better docs
|
||||
python_constants_frozen HARD inherent to tracing JITs
|
||||
conditional_branches_frozen HARD inherent to tracing JITs
|
||||
unrealized_const_input_frozen HARD unrealized const has no buffer to replace, values baked in
|
||||
|
||||
ERRORS RAISED (lower priority - at least users know):
|
||||
unrealized_const_input_error EASY raises JitError for unrealized const inputs
|
||||
non_tensor_outputs_error EASY raises JitError if return contains non-Tensor values
|
||||
positional_kwargs_cannot_mix EASY normalize positional args to kwargs using function signature
|
||||
duplicate_inputs_fail MED would need to handle aliasing in input_replace
|
||||
@@ -140,16 +140,20 @@ class TestJitFootguns(unittest.TestCase):
|
||||
self.assertEqual(results[2], 20) # should be 30!
|
||||
self.assertEqual(results[3], 20) # should be 40!
|
||||
|
||||
def test_unrealized_const_input_frozen(self):
|
||||
"""Unrealized const tensors have no buffer to replace, so values are baked in at capture time."""
|
||||
def test_unrealized_const_input_error(self):
|
||||
"""Const tensors have no buffer to replace, so JIT raises an error. Even explicit .realize() doesn't help."""
|
||||
@TinyJit
|
||||
def f(a, b): return (a * b).realize()
|
||||
|
||||
for i in range(1, 5):
|
||||
result = f(Tensor([1, 2, 3]).realize(), Tensor(i)) # Tensor(i) is unrealized const
|
||||
# value is frozen at capture (i=2), so i=3,4 give wrong results
|
||||
expected = [2, 4, 6] if i >= 2 else [i, 2*i, 3*i]
|
||||
np.testing.assert_equal(result.numpy(), expected) # i=3,4 should be [3,6,9], [4,8,12]!
|
||||
# unrealized const fails
|
||||
with self.assertRaises(JitError):
|
||||
f(Tensor([1, 2, 3]).realize(), Tensor(2))
|
||||
|
||||
# explicit .realize() on const still fails - const cannot be realized to have a buffer
|
||||
@TinyJit
|
||||
def g(a, b): return (a * b).realize()
|
||||
with self.assertRaises(JitError):
|
||||
g(Tensor([1, 2, 3]).realize(), Tensor(2).realize())
|
||||
|
||||
def test_conditional_branches_frozen(self):
|
||||
"""Only the branch taken during capture runs thereafter."""
|
||||
|
||||
@@ -217,7 +217,7 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
def test_ones_sum(self):
|
||||
def f(a): return a.sum().realize()
|
||||
jf = TinyJit(f)
|
||||
t = Tensor.ones(10)
|
||||
t = Tensor.ones(10).contiguous()
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
symbolic = jf(t[:vi]).item()
|
||||
|
||||
@@ -199,7 +199,7 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_ones_sum(self):
|
||||
t = Tensor.ones(10)
|
||||
t = Tensor.ones(10).contiguous()
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
symbolic = t[:vi].sum().item()
|
||||
|
||||
@@ -231,6 +231,8 @@ def _prepare_jit_inputs(args, kwargs):
|
||||
if len(unrealized_tensors := [x for x in tensors if not x.uop.is_realized]): Tensor.realize(*unrealized_tensors)
|
||||
# TODO: this multi unpack stuff is not well tested.
|
||||
lbs: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])
|
||||
if any(lb.base.op is Ops.CONST for lb in lbs):
|
||||
raise JitError("JIT inputs cannot be const, create a buffer with .contiguous()")
|
||||
input_buffers: list[Buffer] = flatten([rb.bufs if isinstance(rb:=lb.base.realized, MultiBuffer) else [rb]
|
||||
for lb in lbs if lb.base.realized is not None])
|
||||
if len(set(input_buffers)) != len(input_buffers): raise JitError("duplicate inputs to JIT")
|
||||
|
||||
Reference in New Issue
Block a user