jit copy aliased output if it's read later (#14210)

This commit is contained in:
chenyu
2026-01-18 18:48:59 -05:00
committed by GitHub
parent 97333b1954
commit 67d9712ef6
2 changed files with 19 additions and 20 deletions

View File

@@ -9,7 +9,6 @@ SILENT MISMATCHES (highest priority - wrong results, no error):
class_method_shared_across_instances EASY could check if first arg is self and warn
slice_assign_requires_realize MED assign graph not connected to read during JIT replay
output_buffer_reuse MED performance tradeoff, could add option or better docs
multiple_outputs_same_intermediate MED outputs derived from same intermediate get aliased during replay
python_constants_frozen HARD inherent to tracing JITs
conditional_branches_frozen HARD inherent to tracing JITs
@@ -52,29 +51,19 @@ class TestJitFootguns(unittest.TestCase):
self.assertEqual([r1.item(), r2.item(), r3.item()], [2, 4, 6])
def test_multiple_outputs_same_intermediate(self):
"""Multiple outputs derived from the same intermediate get aliased during replay."""
# TODO: fix this, clone should not be required, related to https://github.com/tinygrad/tinygrad/issues/13364
"""Multiple outputs derived from the same intermediate - JIT copies aliased inputs to prevent hazard."""
@TinyJit
def f_broken(buf, frame):
def f(buf, frame):
new_buf = buf[1:].cat(frame, dim=0)
return new_buf.contiguous(), new_buf[:1].contiguous()
@TinyJit
def f_fixed(buf, frame):
new_buf = buf[1:].cat(frame, dim=0)
return new_buf.clone().contiguous(), new_buf[:1].contiguous()
for f, should_work in [(f_broken, False), (f_fixed, True)]:
buf = Tensor([[0], [1], [2]]).contiguous().realize()
for i in range(4):
frame = Tensor([[10+i]]).contiguous().realize()
expected_first = buf[1:2].numpy().item()
new_buf, first = f(buf, frame)
if should_work:
self.assertEqual(first.numpy().item(), expected_first)
else:
if i >= 2: self.assertNotEqual(first.numpy().item(), expected_first) # fails on 3rd iteration!
buf = new_buf
buf = Tensor([[0], [1], [2]]).contiguous().realize()
for i in range(4):
frame = Tensor([[10+i]]).contiguous().realize()
expected_first = buf[1:2].numpy().item()
new_buf, first = f(buf, frame)
self.assertEqual(first.numpy().item(), expected_first)
buf = new_buf
def test_slice_assign_requires_realize(self):
"""Slice assign then read from same buffer - assign isn't connected to read without explicit realize()."""

View File

@@ -178,6 +178,10 @@ class CapturedJit(Generic[ReturnType]):
self._jit_cache: list[ExecItem] = self.jit_cache
self._input_replace: dict[tuple[int, int], int] = self.input_replace
self._first_run = True
# precompute read-after-write hazard detection
self._output_to_writer = {b: j for j, ei in enumerate(self.jit_cache) for b in get_out_buffers_for_ei(ei)}
self._input_to_max_reader: dict[int, int] = {}
for (j, _), idx in self.input_replace.items(): self._input_to_max_reader[idx] = max(self._input_to_max_reader.get(idx, -1), j)
self._clear_inputs()
def _clear_inputs(self):
@@ -205,6 +209,12 @@ class CapturedJit(Generic[ReturnType]):
# assign inputs
for idx, offset, device, size, dtype in self.extra_view_inputs:
input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated())
# copy aliased inputs to prevent read-after-write hazard
for i, ib in enumerate(input_buffers):
if (writer := self._output_to_writer.get(ib)) is not None and self._input_to_max_reader.get(i, -1) > writer:
input_buffers[i] = Buffer(ib.device, ib.size, ib.dtype).ensure_allocated().copyin(ib.as_buffer())
for (j,i),input_idx in self._input_replace.items(): self._jit_cache[j].bufs[i] = input_buffers[input_idx]
# Condense the items into a graph executor.