diff --git a/test/test_jit_footguns.py b/test/test_jit_footguns.py index 534026268c..6d3ea32e23 100644 --- a/test/test_jit_footguns.py +++ b/test/test_jit_footguns.py @@ -9,7 +9,6 @@ SILENT MISMATCHES (highest priority - wrong results, no error): class_method_shared_across_instances EASY could check if first arg is self and warn slice_assign_requires_realize MED assign graph not connected to read during JIT replay output_buffer_reuse MED performance tradeoff, could add option or better docs - graph_input_output_aliasing MED GraphRunner skips aliased buffers but only input_replace updated python_constants_frozen HARD inherent to tracing JITs conditional_branches_frozen HARD inherent to tracing JITs @@ -93,7 +92,7 @@ class TestJitFootguns(unittest.TestCase): y = Tensor([100]).contiguous().realize() for _ in range(3): y = step(y) # should be (100+1)*2=202, (202+1)*2=406, (406+1)*2=814 - self.assertEqual(y.item(), 1406) # TODO: should be 814 + self.assertEqual(y.item(), 814) # fails with 1406 if bug exists (uses 350 instead of 100) def test_multiple_outputs_same_intermediate(self): """Multiple outputs derived from the same intermediate - JIT copies aliased inputs to prevent hazard.""" diff --git a/tinygrad/apps/llm.py b/tinygrad/apps/llm.py index 24fb059b8a..c6eff5c5b7 100644 --- a/tinygrad/apps/llm.py +++ b/tinygrad/apps/llm.py @@ -218,7 +218,6 @@ class Transformer: def generate(self, tokens:list[int], start_pos=0): v_start_pos = UOp.variable("start_pos", 1, self.max_context-1) t = Tensor([tokens[start_pos:]], dtype="int32") - self.forward_jit.reset() # TODO: why is this required? root cause the issue and make it not be needed while len(tokens) < self.max_context: t = self(t, v_start_pos.bind(start_pos) if getenv("SYM", 1) and start_pos != 0 and t.shape[-1] == 1 else start_pos) next_id = int(t.item()) diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index d2c3610610..584826bbb7 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -24,7 +24,8 @@ def _check_no_non_tensor_return(ret): def graph_class(dev): return dev.graph.func if isinstance(dev.graph, functools.partial) else dev.graph -def apply_graph_to_jit(jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int], max_batch_size=0) -> list[ExecItem]: +def apply_graph_to_jit(jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int], + orig_valid_positions: dict[int, set[int]]|None = None, max_batch_size=0) -> list[ExecItem]: # Split JIT cache into batches for faster graph execution. # This allows the accelerator to run some batches while subsequent graphs are still being updated. graphed_jit_cache: list[ExecItem] = [] @@ -36,7 +37,7 @@ def apply_graph_to_jit(jit_cache: list[ExecItem], input_buffers: list[Buffer], v try: if len(current_batch_devs) == 0: raise GraphException("no device for graph") if len(current_batch) <= 1 and not getenv("GRAPH_ONE_KERNEL"): raise GraphException("only one kernel doesn't graph") - graph_runner = current_batch_devs[0].graph(current_batch, input_buffers, var_vals) + graph_runner = current_batch_devs[0].graph(current_batch, input_buffers, var_vals, orig_valid_positions=orig_valid_positions) # clear jit inputs to allow their memory to be freed/reused for (j,i) in graph_runner.input_replace.keys(): graph_runner.jit_cache[j].bufs[i] = None graphed_jit_cache.append(ExecItem(UOp(Ops.NOOP), cast(list[Buffer|None], input_buffers), prg=graph_runner)) @@ -72,18 +73,22 @@ def apply_graph_to_jit(jit_cache: list[ExecItem], input_buffers: list[Buffer], v if len(current_batch) > 0: flush_batch() return graphed_jit_cache -def get_input_replace(jit_cache: list[ExecItem], input_buffers:list[Buffer]) -> dict[tuple[int, int], int]: +def get_input_replace(jit_cache: list[ExecItem], input_buffers:list[Buffer], + orig_valid_positions: dict[int, set[int]]|None = None) -> dict[tuple[int, int], int]: input_replace: dict[tuple[int, int], int] = {} for j,ji in enumerate(jit_cache): for i,a in enumerate(ji.bufs): if a in input_buffers: + # filter out positions that weren't valid inputs in the original capture (prevents aliasing bugs) + if orig_valid_positions is not None and i not in orig_valid_positions.get(id(ji), set()): continue input_replace[(j,i)] = input_buffers.index(a) return input_replace class GraphRunner(Runner): - def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int]): + def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int], + orig_valid_positions: dict[int, set[int]]|None = None): self.jit_cache = jit_cache # NOTE: this is not used, but you have to keep these objects alive for the Graph - self.input_replace:dict[tuple[int, int], int] = get_input_replace(jit_cache, input_buffers) + self.input_replace:dict[tuple[int, int], int] = get_input_replace(jit_cache, input_buffers, orig_valid_positions) self.var_vals_replace:dict[int, list[tuple[int, int]]] = {} self.launch_dims_replace:dict[int, tuple[int|None, int|None]] = {} self.launch_dims_base:dict[int, tuple[tuple[int, ...], tuple[int, ...]]] = {} @@ -225,8 +230,14 @@ class CapturedJit(Generic[ReturnType]): if b is not None: b.ensure_allocated() # create graph if needed if JIT < 2: - self._jit_cache = apply_graph_to_jit(self.jit_cache, input_buffers, var_vals, max_batch_size=JIT_BATCH_SIZE.value) - self._input_replace = get_input_replace(self._jit_cache, input_buffers) + # build a map from ExecItem object to the buffer positions that are valid inputs (from original input_replace) + orig_valid_positions: dict[int, set[int]] = {} # id(ExecItem) -> set of valid buffer indices + for (j, i) in self.input_replace: orig_valid_positions.setdefault(id(self.jit_cache[j]), set()).add(i) + self._jit_cache = apply_graph_to_jit(self.jit_cache, input_buffers, var_vals, orig_valid_positions, max_batch_size=JIT_BATCH_SIZE.value) + # recompute input_replace: GraphRunner items have all positions valid, non-GraphRunner items use orig_valid_positions + valid_positions = {id(ji): set(range(len(ji.bufs))) if isinstance(ji.prg, GraphRunner) else orig_valid_positions.get(id(ji), set()) + for ji in self._jit_cache} + self._input_replace = get_input_replace(self._jit_cache, input_buffers, valid_positions) self._first_run = False if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels") diff --git a/tinygrad/runtime/graph/cuda.py b/tinygrad/runtime/graph/cuda.py index ae9d48381a..e6e9afd2b4 100644 --- a/tinygrad/runtime/graph/cuda.py +++ b/tinygrad/runtime/graph/cuda.py @@ -9,8 +9,9 @@ from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner from tinygrad.engine.jit import MultiGraphRunner, GraphException class CUDAGraph(MultiGraphRunner): - def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int]): - super().__init__(jit_cache, input_buffers, var_vals) + def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int], + orig_valid_positions: dict[int, set[int]]|None = None): + super().__init__(jit_cache, input_buffers, var_vals, orig_valid_positions) # Check all jit items are compatible. if not all(isinstance(ji.prg, (CompiledRunner, BufferXfer)) for ji in jit_cache): raise GraphException diff --git a/tinygrad/runtime/graph/hcq.py b/tinygrad/runtime/graph/hcq.py index 2cfa438392..8c611554f7 100644 --- a/tinygrad/runtime/graph/hcq.py +++ b/tinygrad/runtime/graph/hcq.py @@ -9,8 +9,9 @@ from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner, Buffer from tinygrad.engine.jit import MultiGraphRunner class HCQGraph(MultiGraphRunner): - def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int]): - super().__init__(jit_cache, input_buffers, var_vals) + def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int], + orig_valid_positions: dict[int, set[int]]|None = None): + super().__init__(jit_cache, input_buffers, var_vals, orig_valid_positions) self.devices = list(set(cast(HCQCompiled, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs])) # CPU Device is always last diff --git a/tinygrad/runtime/graph/metal.py b/tinygrad/runtime/graph/metal.py index e7e767b3ef..c1931295a6 100644 --- a/tinygrad/runtime/graph/metal.py +++ b/tinygrad/runtime/graph/metal.py @@ -10,8 +10,9 @@ from tinygrad.runtime.autogen import metal from tinygrad.runtime.support import objc class MetalGraph(GraphRunner): - def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int]): - super().__init__(jit_cache, input_buffers, var_vals) + def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int], + orig_valid_positions: dict[int, set[int]]|None = None): + super().__init__(jit_cache, input_buffers, var_vals, orig_valid_positions) if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException # create metal batch exec @@ -39,7 +40,7 @@ class MetalGraph(GraphRunner): all_pipelines.append(prg._prg.pipeline_state) icb_command.setComputePipelineState(prg._prg.pipeline_state) for i,b in enumerate(ji.bufs): - if b is not None and b not in input_buffers: + if b is not None and (j,i) not in self.input_replace: icb_command.setKernelBuffer_offset_atIndex(b._buf.buf, b._buf.offset, i) all_resources.append(b._buf.buf) for i,v in enumerate(prg.p.vars): icb_command.setKernelBuffer_offset_atIndex(self.int_buf.buf, self.varlist.index(v.expr)*4, len(ji.bufs)+i)