mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
jit graph handle input==output aliasing (#14287)
a position that wasn't an input during capture should never become an input during execution, but graph cannot tell this by jit_cache and input_buffers only
This commit is contained in:
@@ -9,7 +9,6 @@ SILENT MISMATCHES (highest priority - wrong results, no error):
|
||||
class_method_shared_across_instances EASY could check if first arg is self and warn
|
||||
slice_assign_requires_realize MED assign graph not connected to read during JIT replay
|
||||
output_buffer_reuse MED performance tradeoff, could add option or better docs
|
||||
graph_input_output_aliasing MED GraphRunner skips aliased buffers but only input_replace updated
|
||||
python_constants_frozen HARD inherent to tracing JITs
|
||||
conditional_branches_frozen HARD inherent to tracing JITs
|
||||
|
||||
@@ -93,7 +92,7 @@ class TestJitFootguns(unittest.TestCase):
|
||||
y = Tensor([100]).contiguous().realize()
|
||||
for _ in range(3):
|
||||
y = step(y) # should be (100+1)*2=202, (202+1)*2=406, (406+1)*2=814
|
||||
self.assertEqual(y.item(), 1406) # TODO: should be 814
|
||||
self.assertEqual(y.item(), 814) # fails with 1406 if bug exists (uses 350 instead of 100)
|
||||
|
||||
def test_multiple_outputs_same_intermediate(self):
|
||||
"""Multiple outputs derived from the same intermediate - JIT copies aliased inputs to prevent hazard."""
|
||||
|
||||
@@ -218,7 +218,6 @@ class Transformer:
|
||||
def generate(self, tokens:list[int], start_pos=0):
|
||||
v_start_pos = UOp.variable("start_pos", 1, self.max_context-1)
|
||||
t = Tensor([tokens[start_pos:]], dtype="int32")
|
||||
self.forward_jit.reset() # TODO: why is this required? root cause the issue and make it not be needed
|
||||
while len(tokens) < self.max_context:
|
||||
t = self(t, v_start_pos.bind(start_pos) if getenv("SYM", 1) and start_pos != 0 and t.shape[-1] == 1 else start_pos)
|
||||
next_id = int(t.item())
|
||||
|
||||
@@ -24,7 +24,8 @@ def _check_no_non_tensor_return(ret):
|
||||
|
||||
def graph_class(dev): return dev.graph.func if isinstance(dev.graph, functools.partial) else dev.graph
|
||||
|
||||
def apply_graph_to_jit(jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int], max_batch_size=0) -> list[ExecItem]:
|
||||
def apply_graph_to_jit(jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int],
|
||||
orig_valid_positions: dict[int, set[int]]|None = None, max_batch_size=0) -> list[ExecItem]:
|
||||
# Split JIT cache into batches for faster graph execution.
|
||||
# This allows the accelerator to run some batches while subsequent graphs are still being updated.
|
||||
graphed_jit_cache: list[ExecItem] = []
|
||||
@@ -36,7 +37,7 @@ def apply_graph_to_jit(jit_cache: list[ExecItem], input_buffers: list[Buffer], v
|
||||
try:
|
||||
if len(current_batch_devs) == 0: raise GraphException("no device for graph")
|
||||
if len(current_batch) <= 1 and not getenv("GRAPH_ONE_KERNEL"): raise GraphException("only one kernel doesn't graph")
|
||||
graph_runner = current_batch_devs[0].graph(current_batch, input_buffers, var_vals)
|
||||
graph_runner = current_batch_devs[0].graph(current_batch, input_buffers, var_vals, orig_valid_positions=orig_valid_positions)
|
||||
# clear jit inputs to allow their memory to be freed/reused
|
||||
for (j,i) in graph_runner.input_replace.keys(): graph_runner.jit_cache[j].bufs[i] = None
|
||||
graphed_jit_cache.append(ExecItem(UOp(Ops.NOOP), cast(list[Buffer|None], input_buffers), prg=graph_runner))
|
||||
@@ -72,18 +73,22 @@ def apply_graph_to_jit(jit_cache: list[ExecItem], input_buffers: list[Buffer], v
|
||||
if len(current_batch) > 0: flush_batch()
|
||||
return graphed_jit_cache
|
||||
|
||||
def get_input_replace(jit_cache: list[ExecItem], input_buffers:list[Buffer]) -> dict[tuple[int, int], int]:
|
||||
def get_input_replace(jit_cache: list[ExecItem], input_buffers:list[Buffer],
|
||||
orig_valid_positions: dict[int, set[int]]|None = None) -> dict[tuple[int, int], int]:
|
||||
input_replace: dict[tuple[int, int], int] = {}
|
||||
for j,ji in enumerate(jit_cache):
|
||||
for i,a in enumerate(ji.bufs):
|
||||
if a in input_buffers:
|
||||
# filter out positions that weren't valid inputs in the original capture (prevents aliasing bugs)
|
||||
if orig_valid_positions is not None and i not in orig_valid_positions.get(id(ji), set()): continue
|
||||
input_replace[(j,i)] = input_buffers.index(a)
|
||||
return input_replace
|
||||
|
||||
class GraphRunner(Runner):
|
||||
def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int]):
|
||||
def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int],
|
||||
orig_valid_positions: dict[int, set[int]]|None = None):
|
||||
self.jit_cache = jit_cache # NOTE: this is not used, but you have to keep these objects alive for the Graph
|
||||
self.input_replace:dict[tuple[int, int], int] = get_input_replace(jit_cache, input_buffers)
|
||||
self.input_replace:dict[tuple[int, int], int] = get_input_replace(jit_cache, input_buffers, orig_valid_positions)
|
||||
self.var_vals_replace:dict[int, list[tuple[int, int]]] = {}
|
||||
self.launch_dims_replace:dict[int, tuple[int|None, int|None]] = {}
|
||||
self.launch_dims_base:dict[int, tuple[tuple[int, ...], tuple[int, ...]]] = {}
|
||||
@@ -225,8 +230,14 @@ class CapturedJit(Generic[ReturnType]):
|
||||
if b is not None: b.ensure_allocated()
|
||||
# create graph if needed
|
||||
if JIT < 2:
|
||||
self._jit_cache = apply_graph_to_jit(self.jit_cache, input_buffers, var_vals, max_batch_size=JIT_BATCH_SIZE.value)
|
||||
self._input_replace = get_input_replace(self._jit_cache, input_buffers)
|
||||
# build a map from ExecItem object to the buffer positions that are valid inputs (from original input_replace)
|
||||
orig_valid_positions: dict[int, set[int]] = {} # id(ExecItem) -> set of valid buffer indices
|
||||
for (j, i) in self.input_replace: orig_valid_positions.setdefault(id(self.jit_cache[j]), set()).add(i)
|
||||
self._jit_cache = apply_graph_to_jit(self.jit_cache, input_buffers, var_vals, orig_valid_positions, max_batch_size=JIT_BATCH_SIZE.value)
|
||||
# recompute input_replace: GraphRunner items have all positions valid, non-GraphRunner items use orig_valid_positions
|
||||
valid_positions = {id(ji): set(range(len(ji.bufs))) if isinstance(ji.prg, GraphRunner) else orig_valid_positions.get(id(ji), set())
|
||||
for ji in self._jit_cache}
|
||||
self._input_replace = get_input_replace(self._jit_cache, input_buffers, valid_positions)
|
||||
self._first_run = False
|
||||
|
||||
if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels")
|
||||
|
||||
@@ -9,8 +9,9 @@ from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
||||
from tinygrad.engine.jit import MultiGraphRunner, GraphException
|
||||
|
||||
class CUDAGraph(MultiGraphRunner):
|
||||
def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int]):
|
||||
super().__init__(jit_cache, input_buffers, var_vals)
|
||||
def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int],
|
||||
orig_valid_positions: dict[int, set[int]]|None = None):
|
||||
super().__init__(jit_cache, input_buffers, var_vals, orig_valid_positions)
|
||||
|
||||
# Check all jit items are compatible.
|
||||
if not all(isinstance(ji.prg, (CompiledRunner, BufferXfer)) for ji in jit_cache): raise GraphException
|
||||
|
||||
@@ -9,8 +9,9 @@ from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner, Buffer
|
||||
from tinygrad.engine.jit import MultiGraphRunner
|
||||
|
||||
class HCQGraph(MultiGraphRunner):
|
||||
def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int]):
|
||||
super().__init__(jit_cache, input_buffers, var_vals)
|
||||
def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int],
|
||||
orig_valid_positions: dict[int, set[int]]|None = None):
|
||||
super().__init__(jit_cache, input_buffers, var_vals, orig_valid_positions)
|
||||
self.devices = list(set(cast(HCQCompiled, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs]))
|
||||
|
||||
# CPU Device is always last
|
||||
|
||||
@@ -10,8 +10,9 @@ from tinygrad.runtime.autogen import metal
|
||||
from tinygrad.runtime.support import objc
|
||||
|
||||
class MetalGraph(GraphRunner):
|
||||
def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int]):
|
||||
super().__init__(jit_cache, input_buffers, var_vals)
|
||||
def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int],
|
||||
orig_valid_positions: dict[int, set[int]]|None = None):
|
||||
super().__init__(jit_cache, input_buffers, var_vals, orig_valid_positions)
|
||||
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
|
||||
|
||||
# create metal batch exec
|
||||
@@ -39,7 +40,7 @@ class MetalGraph(GraphRunner):
|
||||
all_pipelines.append(prg._prg.pipeline_state)
|
||||
icb_command.setComputePipelineState(prg._prg.pipeline_state)
|
||||
for i,b in enumerate(ji.bufs):
|
||||
if b is not None and b not in input_buffers:
|
||||
if b is not None and (j,i) not in self.input_replace:
|
||||
icb_command.setKernelBuffer_offset_atIndex(b._buf.buf, b._buf.offset, i)
|
||||
all_resources.append(b._buf.buf)
|
||||
for i,v in enumerate(prg.p.vars): icb_command.setKernelBuffer_offset_atIndex(self.int_buf.buf, self.varlist.index(v.expr)*4, len(ji.bufs)+i)
|
||||
|
||||
Reference in New Issue
Block a user