mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
test_unrealized_const_input_frozen (#14044)
unrealized const is not replaced in jit
This commit is contained in:
@@ -11,6 +11,7 @@ SILENT MISMATCHES (highest priority - wrong results, no error):
|
||||
output_buffer_reuse MED performance tradeoff, could add option or better docs
|
||||
python_constants_frozen HARD inherent to tracing JITs
|
||||
conditional_branches_frozen HARD inherent to tracing JITs
|
||||
unrealized_const_input_frozen HARD unrealized const has no buffer to replace, values baked in
|
||||
|
||||
ERRORS RAISED (lower priority - at least users know):
|
||||
non_tensor_outputs_error EASY raises JitError if return contains non-Tensor values
|
||||
@@ -139,6 +140,17 @@ class TestJitFootguns(unittest.TestCase):
|
||||
self.assertEqual(results[2], 20) # should be 30!
|
||||
self.assertEqual(results[3], 20) # should be 40!
|
||||
|
||||
def test_unrealized_const_input_frozen(self):
|
||||
"""Unrealized const tensors have no buffer to replace, so values are baked in at capture time."""
|
||||
@TinyJit
|
||||
def f(a, b): return (a * b).realize()
|
||||
|
||||
for i in range(1, 5):
|
||||
result = f(Tensor([1, 2, 3]).realize(), Tensor(i)) # Tensor(i) is unrealized const
|
||||
# value is frozen at capture (i=2), so i=3,4 give wrong results
|
||||
expected = [2, 4, 6] if i >= 2 else [i, 2*i, 3*i]
|
||||
np.testing.assert_equal(result.numpy(), expected) # i=3,4 should be [3,6,9], [4,8,12]!
|
||||
|
||||
def test_conditional_branches_frozen(self):
|
||||
"""Only the branch taken during capture runs thereafter."""
|
||||
@TinyJit
|
||||
|
||||
Reference in New Issue
Block a user