mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
jit: fix raw in same kernel (#14699)
* jit: fix raw in same kernel * fix * ugh * x * simpler
This commit is contained in:
@@ -23,7 +23,7 @@ ERRORS RAISED (lower priority - at least users know):
|
||||
"""
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, TinyJit
|
||||
from tinygrad import Tensor, TinyJit, Device
|
||||
from tinygrad.engine.jit import JitError
|
||||
from tinygrad.helpers import JIT
|
||||
|
||||
@@ -66,7 +66,6 @@ class TestJitFootguns(unittest.TestCase):
|
||||
|
||||
This requires multiple kernels to trigger because single-kernel JITs don't get graphed ("only one kernel doesn't graph").
|
||||
"""
|
||||
from tinygrad import Device
|
||||
if Device[Device.DEFAULT].graph is None or JIT != 1:
|
||||
self.skipTest("test requires JIT graph support")
|
||||
|
||||
@@ -111,6 +110,15 @@ class TestJitFootguns(unittest.TestCase):
|
||||
self.assertEqual(first.numpy().item(), expected_first)
|
||||
buf = new_buf
|
||||
|
||||
def test_intra_kernel_output_input_aliasing(self):
|
||||
"""JIT must copy aliased input when output buffer is fed back as input (read-write race in same kernel)."""
|
||||
N = 1 << 20
|
||||
f = TinyJit(lambda buf, new: buf[N//2:].cat(new), prune=True)
|
||||
buf = Tensor.zeros(N, dtype='int32').contiguous().realize()
|
||||
for i in range(10):
|
||||
buf = f(buf, Tensor(np.ones(N//2, dtype=np.int32)*(i+1)))
|
||||
np.testing.assert_array_equal(buf[:N//2].numpy(), np.full(N//2, i, dtype=np.int32))
|
||||
|
||||
def test_slice_assign_works_without_realize(self):
|
||||
"""Slice assign then read from same buffer - pending assigns are side-realized."""
|
||||
from tinygrad import Variable
|
||||
|
||||
@@ -187,7 +187,10 @@ class CapturedJit(Generic[ReturnType]):
|
||||
# precompute read-after-write hazard detection
|
||||
self._output_to_writer = {b: j for j, ei in enumerate(self.jit_cache) for b in get_out_buffers_for_ei(ei)}
|
||||
self._input_to_max_reader: dict[int, int] = {}
|
||||
for (j, _), idx in self.input_replace.items(): self._input_to_max_reader[idx] = max(self._input_to_max_reader.get(idx, -1), j)
|
||||
for (j, i), idx in self.input_replace.items():
|
||||
# only buffers that were different during capture but alias at jit time (e.g. feeding output back as input) need the copy.
|
||||
if self.jit_cache[j].bufs[i] not in get_out_buffers_for_ei(self.jit_cache[j]):
|
||||
self._input_to_max_reader[idx] = max(self._input_to_max_reader.get(idx, -1), j)
|
||||
self._clear_inputs()
|
||||
|
||||
def _clear_inputs(self):
|
||||
@@ -218,7 +221,7 @@ class CapturedJit(Generic[ReturnType]):
|
||||
|
||||
# copy aliased inputs to prevent read-after-write hazard
|
||||
for i, ib in enumerate(input_buffers):
|
||||
if (writer := self._output_to_writer.get(ib)) is not None and self._input_to_max_reader.get(i, -1) > writer:
|
||||
if (writer := self._output_to_writer.get(ib)) is not None and self._input_to_max_reader.get(i, -1) >= writer:
|
||||
input_buffers[i] = Buffer(ib.device, ib.size, ib.dtype).ensure_allocated().copyin(ib.as_memoryview())
|
||||
|
||||
for (j,i),input_idx in self._input_replace.items(): self._jit_cache[j].bufs[i] = input_buffers[input_idx]
|
||||
|
||||
Reference in New Issue
Block a user