jit: capturedjit is linear (#15743)

* jit: capturedjit is linear

* x

* new beam

* test

* imp

* clean

* spec

* linter
This commit is contained in:
nimlgen
2026-04-16 14:54:39 +03:00
committed by GitHub
parent d1cce7a476
commit c04f3eaa70
3 changed files with 97 additions and 107 deletions

View File

@@ -1,10 +1,10 @@
from typing import TypeVar, Generic, Callable, cast, Any
import functools, collections
from tinygrad.tensor import Tensor
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, JIT_BATCH_SIZE, dedup, unwrap, pluralize, VIZ
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, JIT_BATCH_SIZE, dedup, pluralize, VIZ
from tinygrad.device import Buffer, Compiled, Device, MultiBuffer
from tinygrad.dtype import DType, dtypes
from tinygrad.uop.ops import UOp, PatternMatcher, Variable, sym_infer, Ops, buffers, track_rewrites, graph_rewrite
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Variable, sym_infer, Ops, buffers, track_rewrites, graph_rewrite
from tinygrad.engine.realize import ExecItem, capturing, BufferCopy, BufferXfer, EncDec, CompiledRunner, Runner, Estimates
from tinygrad.schedule.memory import memory_plan_rewrite, _collect_bufs
from tinygrad.schedule import linear_to_schedule
@@ -22,14 +22,13 @@ def prune_linear(linear:UOp, needed:set[UOp]) -> tuple[UOp, UOp]:
else: onetime.append(si)
return linear.replace(src=tuple(kept)), linear.replace(src=tuple(onetime))
def create_graph_call(batch:list[UOp], input_buffers:set[Buffer]) -> UOp:
def bufs_for(b): return b.buffer.bufs if isinstance(b.buffer, MultiBuffer) else [b.buffer]
input_list = dedup(b for si in batch for b in si.src[1:] if b.op in (Ops.BUFFER, Ops.BUFFER_VIEW) and not input_buffers.isdisjoint(bufs_for(b)))
def create_graph_call(batch:list[UOp]) -> UOp:
# all external inputs are PARAMs
input_list = dedup(b for si in batch for b in si.src[1:] if b.op is Ops.PARAM)
cf = UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(UOp(Ops.LINEAR, src=tuple(batch)), *input_list), arg="graph")
return cf.call(*input_list, metadata=tuple(m for si in batch for m in si.arg.metadata))
def graph_split_rewrite(linear:UOp, input_buffers:set[Buffer], max_batch_size:int=0) -> UOp:
def graph_split_rewrite(linear:UOp, max_batch_size:int=0) -> UOp:
new_src: list[UOp] = []
current_batch: list[UOp] = []
current_batch_devs: list[Compiled] = []
@@ -38,7 +37,7 @@ def graph_split_rewrite(linear:UOp, input_buffers:set[Buffer], max_batch_size:in
nonlocal current_batch, current_batch_devs, max_batch_size, new_src
if len(current_batch) <= 1 and not getenv("GRAPH_ONE_KERNEL"): new_src.extend(current_batch)
else:
new_src.append(create_graph_call(current_batch, input_buffers))
new_src.append(create_graph_call(current_batch))
max_batch_size *= 2
if DEBUG >= 2: print(f"JIT GRAPHing batch with {len(current_batch)} kernels")
current_batch, current_batch_devs = [], []
@@ -66,13 +65,26 @@ def jit_cache_bufs(jit_cache:list[ExecItem]):
if b is not None: yield b
if isinstance(ei.prg, GraphRunner): yield from jit_cache_bufs(ei.prg.jit_cache)
@track_rewrites(lambda linear,held_bufs,input_buffers=None,ret=(): f"JIT {pluralize('call', len(linear.src))}")
def jit_lower(linear:UOp, held_bufs:set[UOp], input_buffers:list[Buffer]|None=None) -> list[ExecItem]:
def _unwrap_beam(ast:UOp) -> UOp: return ast.src[0] if ast.op is Ops.BEAM else ast
pm_beam = PatternMatcher([
(UPat(Ops.CALL, src=(UPat(Ops.SINK, name="sink"),), name="call", allow_any_len=True),
lambda ctx,call,sink: call.replace(src=(UOp(Ops.BEAM, src=(sink,), arg=ctx), *call.src[1:])))])
@track_rewrites(lambda linear,held_bufs,input_uops,ret=(): f"JIT {pluralize('call', len(linear.src))}")
def jit_lower(linear:UOp, held_bufs:set[UOp], input_uops:list[UOp]) -> UOp:
if VIZ: graph_rewrite(linear, PatternMatcher([]), name="View captured linear")
# parametrize input buffers: map each input buffer UOp to a PARAM with the correct slot index
linear = linear.substitute({u: UOp.param(i, u.dtype, u.shape, u.device) for i,u in enumerate(input_uops)}, walk=True)
# wrap SINKs with BEAM if jitbeam is set
if (jitbeam:=getenv("JITBEAM", BEAM.value)) >= 1: linear = graph_rewrite(linear, pm_beam, ctx=jitbeam, walk=True)
linear = memory_plan_rewrite(linear, held_bufs)
if JIT < 2: linear = graph_split_rewrite(linear, set(input_buffers or []), max_batch_size=JIT_BATCH_SIZE.value)
if JIT < 2: linear = graph_split_rewrite(linear, max_batch_size=JIT_BATCH_SIZE.value)
if VIZ: graph_rewrite(linear, PatternMatcher([]), name="View graphed linear")
return [ei.lower() for ei in linear_to_schedule(linear)]
return linear
class GraphException(Exception): pass
class JitError(Exception): pass
@@ -86,26 +98,37 @@ def _check_no_non_tensor_return(ret):
def graph_class(dev): return dev.graph.func if isinstance(dev.graph, functools.partial) else dev.graph
def get_input_replace(jit_cache: list[ExecItem], input_buffers:list[Buffer],
orig_valid_positions: dict[int, set[int]]|None = None) -> dict[tuple[int, int], int]:
def get_input_replace(jit_cache: list[ExecItem], input_buffers:list[Buffer]) -> dict[tuple[int, int], int]:
input_replace: dict[tuple[int, int], int] = {}
for j,ji in enumerate(jit_cache):
for i,a in enumerate(ji.bufs):
if a in input_buffers:
# filter out positions that weren't valid inputs in the original capture (prevents aliasing bugs)
if orig_valid_positions is not None and i not in orig_valid_positions.get(id(ji), set()): continue
input_replace[(j,i)] = input_buffers.index(a)
if a in input_buffers: input_replace[(j,i)] = input_buffers.index(a)
return input_replace
pm_params = PatternMatcher([(UPat(Ops.PARAM, src=(UPat(), UPat(Ops.DEVICE)), name="p"), lambda ctx,p: ctx[p.arg])])
def linear_to_jit_cache(linear:UOp, input_uops:list[UOp]) -> tuple[list[ExecItem], dict[tuple[int,int],int], list[tuple[int,int,str,int,DType]]]:
# substitute PARAMs with input buffer UOps before lowering
linear = graph_rewrite(linear, pm_params, ctx=input_uops, walk=True, enter_calls=True)
# convert to jit_cache
jit_cache = [ei.lower() for ei in linear_to_schedule(linear)]
for b in jit_cache_bufs(jit_cache): b.ensure_allocated()
# derive input_buffers from input_uops
input_buffers: list[Buffer] = flatten([b.bufs if isinstance(b, MultiBuffer) else [b] for u in input_uops if (b:=buffers[u]) is not None])
# track view buffers whose base is an input buffer
extra_view_inputs: list[tuple[int, int, str, int, DType]] = []
for ei in jit_cache:
for b in ei.bufs:
if b is not None and b._base is not None and b._base in input_buffers and b not in input_buffers:
extra_view_inputs.append((input_buffers.index(b._base), b.offset, b.device, b.size, b.dtype))
input_buffers.append(b)
return jit_cache, get_input_replace(jit_cache, input_buffers), extra_view_inputs
class GraphRunner(Runner):
def __init__(self, linear:UOp|None, input_buffers:list[Buffer]|None,
jit_cache:list[ExecItem]|None=None, input_replace:dict[tuple[int,int],int]|None=None):
# TODO: captured jit as linear?
if linear is not None:
jit_cache = [ei.lower() for ei in linear_to_schedule(linear.src[0])]
for b in jit_cache_bufs(jit_cache): b.ensure_allocated()
input_replace = get_input_replace(jit_cache, input_buffers) if input_buffers else {}
self.jit_cache, self.input_replace = unwrap(jit_cache), input_replace or {}
def __init__(self, linear:UOp, input_buffers:list[Buffer]):
self.jit_cache = [ei.lower() for ei in linear_to_schedule(linear.src[0])]
for b in jit_cache_bufs(self.jit_cache): b.ensure_allocated()
self.input_replace = get_input_replace(self.jit_cache, input_buffers) if input_buffers else {}
self.var_vals_replace:dict[int, list[tuple[int, int]]] = {}
self.launch_dims_replace:dict[int, tuple[int|None, int|None]] = {}
@@ -141,8 +164,6 @@ class GraphRunner(Runner):
assert self.jit_cache[0].prg is not None
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), self.jit_cache[0].prg.device.split(":")[0], estimates.simplify())
def __reduce__(self): return self.__class__, (None, None, self.jit_cache, self.input_replace)
def updated_vars(self, var_vals: dict[str, int]):
vals = [var_vals[v] for v in self.vars]
for j, vidxs in self.var_vals_replace.items():
@@ -179,14 +200,15 @@ class GraphRunner(Runner):
@staticmethod
def supports_exec_item(batch_devs:list[Compiled], new_call:UOp) -> bool:
return new_call.src[0].op in (Ops.SINK, Ops.PROGRAM) and len(GraphRunner._all_devs(batch_devs, new_call)) == 1
return _unwrap_beam(new_call.src[0]).op in (Ops.SINK, Ops.PROGRAM) and len(GraphRunner._all_devs(batch_devs, new_call)) == 1
# a marker for your graph supporting multiple devices of the same type
class MultiGraphRunner(GraphRunner):
@staticmethod
def supports_exec_item(batch_devs:list[Compiled], new_call:UOp) -> bool:
# Devices must be the same type
return new_call.src[0].op in (Ops.SINK, Ops.PROGRAM, Ops.COPY) and len(dedup([type(d) for d in GraphRunner._all_devs(batch_devs, new_call)])) == 1
return _unwrap_beam(new_call.src[0]).op in (Ops.SINK, Ops.PROGRAM, Ops.COPY) \
and len(dedup([type(d) for d in GraphRunner._all_devs(batch_devs, new_call)])) == 1
def get_out_buffers_for_ei(ei:ExecItem) -> list[Buffer]:
if isinstance(ei.prg, CompiledRunner): return [cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs if out not in ei.prg.p.ins]
@@ -202,33 +224,42 @@ ReturnType = TypeVar('ReturnType')
@dataclass
class CapturedJit(Generic[ReturnType]):
ret: Any # includes the Tensors or any other returned object
jit_cache: list[ExecItem]
input_replace: dict[tuple[int, int], int]
extra_view_inputs: list[tuple[int, int, str, int, DType]]
linear: UOp
expected_names: list[int|str]
expected_input_info: list[tuple[UOp, tuple[Variable, ...], DType, str]] # (view, variables, dtype, device) per input
def __reduce__(self):
# TODO: free_intermediates here?
return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs, self.expected_names, self.expected_input_info)
def __reduce__(self): return self.__class__, (self.ret, self.linear, self.expected_names, self.expected_input_info)
def __post_init__(self): self._jit_cache = None
@property
def jit_cache(self) -> list[ExecItem]: return self._jit_cache if self._jit_cache is not None else []
def __post_init__(self):
self._jit_cache: list[ExecItem] = self.jit_cache
self._input_replace: dict[tuple[int, int], int] = self.input_replace
self._first_run = True
self._needs_rebuild = False
# precompute read-after-write hazard detection
self._output_to_writer = {b: j for j, ei in enumerate(self.jit_cache) for b in get_out_buffers_for_ei(ei)}
def _init(self, input_uops:list[UOp]):
self._jit_cache, self._input_replace, self._extra_view_inputs = linear_to_jit_cache(self.linear, input_uops)
self._output_to_writer = {b: j for j, ei in enumerate(self._jit_cache) for b in get_out_buffers_for_ei(ei)}
self._input_to_max_reader: dict[int, int] = {}
for (j, i), idx in self.input_replace.items():
# only buffers that were different during capture but alias at jit time (e.g. feeding output back as input) need the copy.
if self.jit_cache[j].bufs[i] not in get_out_buffers_for_ei(self.jit_cache[j]):
for (j, i), idx in self._input_replace.items():
if self._jit_cache[j].bufs[i] not in get_out_buffers_for_ei(self._jit_cache[j]):
self._input_to_max_reader[idx] = max(self._input_to_max_reader.get(idx, -1), j)
self._clear_inputs()
def _clear_inputs(self):
for (j,i) in self._input_replace.keys(): self._jit_cache[j].bufs[i] = None
def __call__(self, input_uops:list[UOp], var_vals:dict[str, int]) -> ReturnType:
if self._jit_cache is None: self._init(input_uops)
assert self._jit_cache is not None
# derive input_buffers from input_uops (flatten MultiBuffer)
input_buffers: list[Buffer] = flatten([b.bufs if isinstance(b, MultiBuffer) else [b] for u in input_uops if (b:=buffers[u]) is not None])
# recreate view buffers from input bases
for idx, offset, device, size, dtype in self._extra_view_inputs:
input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated())
# copy aliased inputs to prevent read-after-write hazard
for i, ib in enumerate(input_buffers):
if (writer := self._output_to_writer.get(ib)) is not None and self._input_to_max_reader.get(i, -1) >= writer:
input_buffers[i] = Buffer(ib.device, ib.size, ib.dtype).ensure_allocated().copyin(ib.as_memoryview())
for (j,i),input_idx in self._input_replace.items(): self._jit_cache[j].bufs[i] = input_buffers[input_idx]
if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels")
for ei in self._jit_cache: ei.run(var_vals, jit=True)
for (j,i) in self._input_replace.keys(): self._jit_cache[j].bufs[i] = None
return self.ret
def free_intermediates(self):
depends: set[Buffer|None] = set([None])
update_depends(depends, self.jit_cache)
@@ -239,33 +270,6 @@ class CapturedJit(Generic[ReturnType]):
for a in arenas:
if a.allocated_views == 0 and a.is_allocated(): a.deallocate()
self.__post_init__()
self._needs_rebuild = True
# jit exec
def __call__(self, input_buffers:list[Buffer], var_vals:dict[str, int]) -> ReturnType:
# assign inputs
for idx, offset, device, size, dtype in self.extra_view_inputs:
input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated())
# copy aliased inputs to prevent read-after-write hazard
for i, ib in enumerate(input_buffers):
if (writer := self._output_to_writer.get(ib)) is not None and self._input_to_max_reader.get(i, -1) >= writer:
input_buffers[i] = Buffer(ib.device, ib.size, ib.dtype).ensure_allocated().copyin(ib.as_memoryview())
for (j,i),input_idx in self._input_replace.items(): self._jit_cache[j].bufs[i] = input_buffers[input_idx]
# allocate intermediates if freed on first run
if self._first_run:
for b in jit_cache_bufs(self.jit_cache): b.ensure_allocated()
if self._needs_rebuild:
for ei in self.jit_cache:
if isinstance(ei.prg, GraphRunner): ei.prg = type(ei.prg)(None, None, ei.prg.jit_cache, ei.prg.input_replace)
self._first_run = self._needs_rebuild = False
if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels")
for ei in self._jit_cache: ei.run(var_vals, jit=True)
self._clear_inputs()
return self.ret
def _prepare_jit_inputs(args, kwargs):
input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor]
@@ -278,13 +282,14 @@ def _prepare_jit_inputs(args, kwargs):
input_uops: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])
if any(u.base.op is Ops.CONST for u in input_uops):
raise JitError("JIT inputs cannot be const, create a buffer with .contiguous()")
input_buffers: list[Buffer] = flatten([b.bufs if isinstance(b, MultiBuffer) else [b] for u in input_uops if (b:=u.base.realized) is not None])
if len(set(input_buffers)) != len(input_buffers): raise JitError("duplicate inputs to JIT")
# collect buffer UOps (including MultiBuffer)
input_buf_uops: list[UOp] = [u.base for u in input_uops if u.base.realized is not None]
if len(set(input_buf_uops)) != len(input_buf_uops): raise JitError("duplicate inputs to JIT")
inputs = [(*(u.substitute({u.base:UOp(Ops.NOOP)}, extra_pm=mop_cleanup).unbind_all()), u.dtype, u.device) for u in input_uops]
_var_vals = merge_dicts([x[1] for x in inputs] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))])
var_vals = {k.expr:v for k,v in _var_vals.items()}
expected_input_info = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in inputs]
return input_buffers, var_vals, names, expected_input_info
return input_buf_uops, var_vals, names, expected_input_info
class TinyJit(Generic[ReturnType]):
def __init__(self, fxn:Callable[..., ReturnType]|None, captured:CapturedJit|None=None, prune=False):
@@ -307,14 +312,15 @@ class TinyJit(Generic[ReturnType]):
# keep legacy code working
@property
def jit_cache(self) -> list[ExecItem]: return self.captured._jit_cache if self.captured is not None else []
def jit_cache(self) -> list[ExecItem]: return self.captured._jit_cache if self.captured is not None and self.captured._jit_cache is not None else []
@property
def input_replace(self) -> dict[tuple[int, int], int]: return self.captured._input_replace if self.captured is not None else {}
def input_replace(self) -> dict[tuple[int, int], int]:
return self.captured._input_replace if self.captured is not None and self.captured._jit_cache is not None else {}
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) # add support for instance methods
def __call__(self, *args, **kwargs) -> ReturnType:
input_buffers, var_vals, names, expected_input_info = _prepare_jit_inputs(args, kwargs)
input_buf_uops, var_vals, names, expected_input_info = _prepare_jit_inputs(args, kwargs)
if not JIT or self.cnt == 0:
# jit ignore
assert self.fxn is not None
@@ -333,46 +339,30 @@ class TinyJit(Generic[ReturnType]):
finally: capturing.clear()
if not len(self._linears): raise JitError("didn't JIT anything!")
_check_no_non_tensor_return(ret)
if DEBUG >= 1: print(f"JIT captured {len(self._linears)} linears with {len(input_buffers)} inputs")
if DEBUG >= 1: print(f"JIT captured {len(self._linears)} linears with {len(input_buf_uops)} inputs")
# combine all captured linears into one, memory plan, and convert to ExecItems
big_linear = UOp(Ops.LINEAR, src=tuple(flatten([l.src for l in self._linears])))
del self._linears
if self.prune:
big_linear, onetime_linear = prune_linear(big_linear, {k for k,v in buffers.items() if isinstance(v, Buffer) and v in set(input_buffers)})
big_linear, onetime_linear = prune_linear(big_linear, set(input_buf_uops))
if DEBUG >= 1: print(f"pruned from {len(big_linear.src) + len(onetime_linear.src)} -> {len(big_linear.src)} kernels")
for ei in (si.lower() for si in linear_to_schedule(onetime_linear)):
for b in ei.bufs: cast(Buffer, b).ensure_allocated()
ei.run(var_vals, jit=True)
held_bufs = set(buffers) | {t.uop.buf_uop for t in get_parameters(ret) if t.uop.buf_uop.op is Ops.BUFFER}
with Context(BEAM=getenv("JITBEAM", BEAM.value)):
jit_cache = jit_lower(big_linear, held_bufs, input_buffers)
# track inputs that are views of buffers
# TODO: eventually expected_buffers should live in ExecItem
extra_view_inputs: list[tuple[int, int, str, int, DType]] = []
for item in jit_cache:
for b in item.bufs:
if b is not None and b._base is not None and b._base in input_buffers:
input_buffers.append(b)
extra_view_inputs.append((input_buffers.index(b.base), b.offset, b.device, b.size, b.dtype))
input_replace = get_input_replace(jit_cache, input_buffers)
if DEBUG >= 1 and len(set(input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found")
# exec
for ei in jit_cache: ei.run(var_vals)
self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, expected_input_info)
linear = jit_lower(big_linear, held_bufs, input_buf_uops)
self.captured = CapturedJit(ret, linear, names, expected_input_info)
ret = self.captured(input_buf_uops, var_vals)
elif self.cnt >= 2:
# jit exec
assert self.captured is not None
if self.captured.expected_names != names: raise JitError(f"args mismatch in JIT: {self.captured.expected_names=} != {names}")
if self.captured.expected_input_info != expected_input_info:
raise JitError(f"args mismatch in JIT: {self.captured.expected_input_info=} != {expected_input_info=}")
ret = self.captured(input_buffers, var_vals)
ret = self.captured(input_buf_uops, var_vals)
self.cnt += 1
return ret

View File

@@ -6,7 +6,7 @@ from tinygrad.device import Buffer, BufferSpec, Compiled, Device, ProfileGraphEn
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import UOp, Ops, Variable
from tinygrad.engine.realize import BufferXfer, CompiledRunner, BufferCopy
from tinygrad.engine.jit import GraphRunner, MultiGraphRunner
from tinygrad.engine.jit import GraphRunner, MultiGraphRunner, _unwrap_beam
from tinygrad.runtime.ops_rdma import RDMACopyQueue
class HCQGraph(MultiGraphRunner):
@@ -326,4 +326,4 @@ class HCQGraph(MultiGraphRunner):
# MOCKGPU is not supported, since it can't execute commands in parallel
is_xfer = len(set(type(d) for d in all_devs)) == 1 and hasattr(alc:=all_devs[0].allocator, '_transfer') and alc.supports_transfer
return is_xfer or (all_devs[0].hw_copy_queue_t is not None and not getenv("MOCKGPU"))
return new_call.src[0].op in (Ops.SINK, Ops.PROGRAM)
return _unwrap_beam(new_call.src[0]).op in (Ops.SINK, Ops.PROGRAM)

View File

@@ -92,7 +92,7 @@ _tensor_spec = PatternMatcher([
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, DType)),
# BUFFER_VIEW on BUFFER is allowed if BUFFER is
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),)), lambda: True),
(UPat(Ops.BUFFER_VIEW, src=(UPat((Ops.BUFFER, Ops.PARAM)),)), lambda: True),
# KERNEL can attach to an AFTER to describe the compute required to realize a BUFFER
(UPat((Ops.CALL, Ops.FUNCTION), src=UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True),