mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
jit copy aliased output if it's read later (#14210)
This commit is contained in:
@@ -9,7 +9,6 @@ SILENT MISMATCHES (highest priority - wrong results, no error):
|
||||
class_method_shared_across_instances EASY could check if first arg is self and warn
|
||||
slice_assign_requires_realize MED assign graph not connected to read during JIT replay
|
||||
output_buffer_reuse MED performance tradeoff, could add option or better docs
|
||||
multiple_outputs_same_intermediate MED outputs derived from same intermediate get aliased during replay
|
||||
python_constants_frozen HARD inherent to tracing JITs
|
||||
conditional_branches_frozen HARD inherent to tracing JITs
|
||||
|
||||
@@ -52,29 +51,19 @@ class TestJitFootguns(unittest.TestCase):
|
||||
self.assertEqual([r1.item(), r2.item(), r3.item()], [2, 4, 6])
|
||||
|
||||
def test_multiple_outputs_same_intermediate(self):
|
||||
"""Multiple outputs derived from the same intermediate get aliased during replay."""
|
||||
# TODO: fix this, clone should not be required, related to https://github.com/tinygrad/tinygrad/issues/13364
|
||||
"""Multiple outputs derived from the same intermediate - JIT copies aliased inputs to prevent hazard."""
|
||||
@TinyJit
|
||||
def f_broken(buf, frame):
|
||||
def f(buf, frame):
|
||||
new_buf = buf[1:].cat(frame, dim=0)
|
||||
return new_buf.contiguous(), new_buf[:1].contiguous()
|
||||
|
||||
@TinyJit
|
||||
def f_fixed(buf, frame):
|
||||
new_buf = buf[1:].cat(frame, dim=0)
|
||||
return new_buf.clone().contiguous(), new_buf[:1].contiguous()
|
||||
|
||||
for f, should_work in [(f_broken, False), (f_fixed, True)]:
|
||||
buf = Tensor([[0], [1], [2]]).contiguous().realize()
|
||||
for i in range(4):
|
||||
frame = Tensor([[10+i]]).contiguous().realize()
|
||||
expected_first = buf[1:2].numpy().item()
|
||||
new_buf, first = f(buf, frame)
|
||||
if should_work:
|
||||
self.assertEqual(first.numpy().item(), expected_first)
|
||||
else:
|
||||
if i >= 2: self.assertNotEqual(first.numpy().item(), expected_first) # fails on 3rd iteration!
|
||||
buf = new_buf
|
||||
buf = Tensor([[0], [1], [2]]).contiguous().realize()
|
||||
for i in range(4):
|
||||
frame = Tensor([[10+i]]).contiguous().realize()
|
||||
expected_first = buf[1:2].numpy().item()
|
||||
new_buf, first = f(buf, frame)
|
||||
self.assertEqual(first.numpy().item(), expected_first)
|
||||
buf = new_buf
|
||||
|
||||
def test_slice_assign_requires_realize(self):
|
||||
"""Slice assign then read from same buffer - assign isn't connected to read without explicit realize()."""
|
||||
|
||||
@@ -178,6 +178,10 @@ class CapturedJit(Generic[ReturnType]):
|
||||
self._jit_cache: list[ExecItem] = self.jit_cache
|
||||
self._input_replace: dict[tuple[int, int], int] = self.input_replace
|
||||
self._first_run = True
|
||||
# precompute read-after-write hazard detection
|
||||
self._output_to_writer = {b: j for j, ei in enumerate(self.jit_cache) for b in get_out_buffers_for_ei(ei)}
|
||||
self._input_to_max_reader: dict[int, int] = {}
|
||||
for (j, _), idx in self.input_replace.items(): self._input_to_max_reader[idx] = max(self._input_to_max_reader.get(idx, -1), j)
|
||||
self._clear_inputs()
|
||||
|
||||
def _clear_inputs(self):
|
||||
@@ -205,6 +209,12 @@ class CapturedJit(Generic[ReturnType]):
|
||||
# assign inputs
|
||||
for idx, offset, device, size, dtype in self.extra_view_inputs:
|
||||
input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated())
|
||||
|
||||
# copy aliased inputs to prevent read-after-write hazard
|
||||
for i, ib in enumerate(input_buffers):
|
||||
if (writer := self._output_to_writer.get(ib)) is not None and self._input_to_max_reader.get(i, -1) > writer:
|
||||
input_buffers[i] = Buffer(ib.device, ib.size, ib.dtype).ensure_allocated().copyin(ib.as_buffer())
|
||||
|
||||
for (j,i),input_idx in self._input_replace.items(): self._jit_cache[j].bufs[i] = input_buffers[input_idx]
|
||||
|
||||
# Condense the items into a graph executor.
|
||||
|
||||
Reference in New Issue
Block a user