diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index b95f9d66a0..845c6265ca 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -1,10 +1,10 @@ from typing import TypeVar, Generic, Callable, cast, Any import functools, collections from tinygrad.tensor import Tensor -from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, JIT_BATCH_SIZE, dedup, unwrap, pluralize, VIZ +from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, JIT_BATCH_SIZE, dedup, pluralize, VIZ from tinygrad.device import Buffer, Compiled, Device, MultiBuffer from tinygrad.dtype import DType, dtypes -from tinygrad.uop.ops import UOp, PatternMatcher, Variable, sym_infer, Ops, buffers, track_rewrites, graph_rewrite +from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Variable, sym_infer, Ops, buffers, track_rewrites, graph_rewrite from tinygrad.engine.realize import ExecItem, capturing, BufferCopy, BufferXfer, EncDec, CompiledRunner, Runner, Estimates from tinygrad.schedule.memory import memory_plan_rewrite, _collect_bufs from tinygrad.schedule import linear_to_schedule @@ -22,14 +22,13 @@ def prune_linear(linear:UOp, needed:set[UOp]) -> tuple[UOp, UOp]: else: onetime.append(si) return linear.replace(src=tuple(kept)), linear.replace(src=tuple(onetime)) -def create_graph_call(batch:list[UOp], input_buffers:set[Buffer]) -> UOp: - def bufs_for(b): return b.buffer.bufs if isinstance(b.buffer, MultiBuffer) else [b.buffer] - - input_list = dedup(b for si in batch for b in si.src[1:] if b.op in (Ops.BUFFER, Ops.BUFFER_VIEW) and not input_buffers.isdisjoint(bufs_for(b))) +def create_graph_call(batch:list[UOp]) -> UOp: + # all external inputs are PARAMs + input_list = dedup(b for si in batch for b in si.src[1:] if b.op is Ops.PARAM) cf = UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(UOp(Ops.LINEAR, src=tuple(batch)), *input_list), arg="graph") return cf.call(*input_list, metadata=tuple(m for si in batch for m in si.arg.metadata)) -def graph_split_rewrite(linear:UOp, input_buffers:set[Buffer], max_batch_size:int=0) -> UOp: +def graph_split_rewrite(linear:UOp, max_batch_size:int=0) -> UOp: new_src: list[UOp] = [] current_batch: list[UOp] = [] current_batch_devs: list[Compiled] = [] @@ -38,7 +37,7 @@ def graph_split_rewrite(linear:UOp, input_buffers:set[Buffer], max_batch_size:in nonlocal current_batch, current_batch_devs, max_batch_size, new_src if len(current_batch) <= 1 and not getenv("GRAPH_ONE_KERNEL"): new_src.extend(current_batch) else: - new_src.append(create_graph_call(current_batch, input_buffers)) + new_src.append(create_graph_call(current_batch)) max_batch_size *= 2 if DEBUG >= 2: print(f"JIT GRAPHing batch with {len(current_batch)} kernels") current_batch, current_batch_devs = [], [] @@ -66,13 +65,26 @@ def jit_cache_bufs(jit_cache:list[ExecItem]): if b is not None: yield b if isinstance(ei.prg, GraphRunner): yield from jit_cache_bufs(ei.prg.jit_cache) -@track_rewrites(lambda linear,held_bufs,input_buffers=None,ret=(): f"JIT {pluralize('call', len(linear.src))}") -def jit_lower(linear:UOp, held_bufs:set[UOp], input_buffers:list[Buffer]|None=None) -> list[ExecItem]: +def _unwrap_beam(ast:UOp) -> UOp: return ast.src[0] if ast.op is Ops.BEAM else ast + +pm_beam = PatternMatcher([ + (UPat(Ops.CALL, src=(UPat(Ops.SINK, name="sink"),), name="call", allow_any_len=True), + lambda ctx,call,sink: call.replace(src=(UOp(Ops.BEAM, src=(sink,), arg=ctx), *call.src[1:])))]) + +@track_rewrites(lambda linear,held_bufs,input_uops,ret=(): f"JIT {pluralize('call', len(linear.src))}") +def jit_lower(linear:UOp, held_bufs:set[UOp], input_uops:list[UOp]) -> UOp: if VIZ: graph_rewrite(linear, PatternMatcher([]), name="View captured linear") + + # parametrize input buffers: map each input buffer UOp to a PARAM with the correct slot index + linear = linear.substitute({u: UOp.param(i, u.dtype, u.shape, u.device) for i,u in enumerate(input_uops)}, walk=True) + + # wrap SINKs with BEAM if jitbeam is set + if (jitbeam:=getenv("JITBEAM", BEAM.value)) >= 1: linear = graph_rewrite(linear, pm_beam, ctx=jitbeam, walk=True) + linear = memory_plan_rewrite(linear, held_bufs) - if JIT < 2: linear = graph_split_rewrite(linear, set(input_buffers or []), max_batch_size=JIT_BATCH_SIZE.value) + if JIT < 2: linear = graph_split_rewrite(linear, max_batch_size=JIT_BATCH_SIZE.value) if VIZ: graph_rewrite(linear, PatternMatcher([]), name="View graphed linear") - return [ei.lower() for ei in linear_to_schedule(linear)] + return linear class GraphException(Exception): pass class JitError(Exception): pass @@ -86,26 +98,37 @@ def _check_no_non_tensor_return(ret): def graph_class(dev): return dev.graph.func if isinstance(dev.graph, functools.partial) else dev.graph -def get_input_replace(jit_cache: list[ExecItem], input_buffers:list[Buffer], - orig_valid_positions: dict[int, set[int]]|None = None) -> dict[tuple[int, int], int]: +def get_input_replace(jit_cache: list[ExecItem], input_buffers:list[Buffer]) -> dict[tuple[int, int], int]: input_replace: dict[tuple[int, int], int] = {} for j,ji in enumerate(jit_cache): for i,a in enumerate(ji.bufs): - if a in input_buffers: - # filter out positions that weren't valid inputs in the original capture (prevents aliasing bugs) - if orig_valid_positions is not None and i not in orig_valid_positions.get(id(ji), set()): continue - input_replace[(j,i)] = input_buffers.index(a) + if a in input_buffers: input_replace[(j,i)] = input_buffers.index(a) return input_replace +pm_params = PatternMatcher([(UPat(Ops.PARAM, src=(UPat(), UPat(Ops.DEVICE)), name="p"), lambda ctx,p: ctx[p.arg])]) + +def linear_to_jit_cache(linear:UOp, input_uops:list[UOp]) -> tuple[list[ExecItem], dict[tuple[int,int],int], list[tuple[int,int,str,int,DType]]]: + # substitute PARAMs with input buffer UOps before lowering + linear = graph_rewrite(linear, pm_params, ctx=input_uops, walk=True, enter_calls=True) + # convert to jit_cache + jit_cache = [ei.lower() for ei in linear_to_schedule(linear)] + for b in jit_cache_bufs(jit_cache): b.ensure_allocated() + # derive input_buffers from input_uops + input_buffers: list[Buffer] = flatten([b.bufs if isinstance(b, MultiBuffer) else [b] for u in input_uops if (b:=buffers[u]) is not None]) + # track view buffers whose base is an input buffer + extra_view_inputs: list[tuple[int, int, str, int, DType]] = [] + for ei in jit_cache: + for b in ei.bufs: + if b is not None and b._base is not None and b._base in input_buffers and b not in input_buffers: + extra_view_inputs.append((input_buffers.index(b._base), b.offset, b.device, b.size, b.dtype)) + input_buffers.append(b) + return jit_cache, get_input_replace(jit_cache, input_buffers), extra_view_inputs + class GraphRunner(Runner): - def __init__(self, linear:UOp|None, input_buffers:list[Buffer]|None, - jit_cache:list[ExecItem]|None=None, input_replace:dict[tuple[int,int],int]|None=None): - # TODO: captured jit as linear? - if linear is not None: - jit_cache = [ei.lower() for ei in linear_to_schedule(linear.src[0])] - for b in jit_cache_bufs(jit_cache): b.ensure_allocated() - input_replace = get_input_replace(jit_cache, input_buffers) if input_buffers else {} - self.jit_cache, self.input_replace = unwrap(jit_cache), input_replace or {} + def __init__(self, linear:UOp, input_buffers:list[Buffer]): + self.jit_cache = [ei.lower() for ei in linear_to_schedule(linear.src[0])] + for b in jit_cache_bufs(self.jit_cache): b.ensure_allocated() + self.input_replace = get_input_replace(self.jit_cache, input_buffers) if input_buffers else {} self.var_vals_replace:dict[int, list[tuple[int, int]]] = {} self.launch_dims_replace:dict[int, tuple[int|None, int|None]] = {} @@ -141,8 +164,6 @@ class GraphRunner(Runner): assert self.jit_cache[0].prg is not None super().__init__(colored(f"", "cyan"), self.jit_cache[0].prg.device.split(":")[0], estimates.simplify()) - def __reduce__(self): return self.__class__, (None, None, self.jit_cache, self.input_replace) - def updated_vars(self, var_vals: dict[str, int]): vals = [var_vals[v] for v in self.vars] for j, vidxs in self.var_vals_replace.items(): @@ -179,14 +200,15 @@ class GraphRunner(Runner): @staticmethod def supports_exec_item(batch_devs:list[Compiled], new_call:UOp) -> bool: - return new_call.src[0].op in (Ops.SINK, Ops.PROGRAM) and len(GraphRunner._all_devs(batch_devs, new_call)) == 1 + return _unwrap_beam(new_call.src[0]).op in (Ops.SINK, Ops.PROGRAM) and len(GraphRunner._all_devs(batch_devs, new_call)) == 1 # a marker for your graph supporting multiple devices of the same type class MultiGraphRunner(GraphRunner): @staticmethod def supports_exec_item(batch_devs:list[Compiled], new_call:UOp) -> bool: # Devices must be the same type - return new_call.src[0].op in (Ops.SINK, Ops.PROGRAM, Ops.COPY) and len(dedup([type(d) for d in GraphRunner._all_devs(batch_devs, new_call)])) == 1 + return _unwrap_beam(new_call.src[0]).op in (Ops.SINK, Ops.PROGRAM, Ops.COPY) \ + and len(dedup([type(d) for d in GraphRunner._all_devs(batch_devs, new_call)])) == 1 def get_out_buffers_for_ei(ei:ExecItem) -> list[Buffer]: if isinstance(ei.prg, CompiledRunner): return [cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs if out not in ei.prg.p.ins] @@ -202,33 +224,42 @@ ReturnType = TypeVar('ReturnType') @dataclass class CapturedJit(Generic[ReturnType]): ret: Any # includes the Tensors or any other returned object - jit_cache: list[ExecItem] - input_replace: dict[tuple[int, int], int] - extra_view_inputs: list[tuple[int, int, str, int, DType]] + linear: UOp expected_names: list[int|str] expected_input_info: list[tuple[UOp, tuple[Variable, ...], DType, str]] # (view, variables, dtype, device) per input - def __reduce__(self): - # TODO: free_intermediates here? - return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs, self.expected_names, self.expected_input_info) + def __reduce__(self): return self.__class__, (self.ret, self.linear, self.expected_names, self.expected_input_info) + def __post_init__(self): self._jit_cache = None + @property + def jit_cache(self) -> list[ExecItem]: return self._jit_cache if self._jit_cache is not None else [] - def __post_init__(self): - self._jit_cache: list[ExecItem] = self.jit_cache - self._input_replace: dict[tuple[int, int], int] = self.input_replace - self._first_run = True - self._needs_rebuild = False - # precompute read-after-write hazard detection - self._output_to_writer = {b: j for j, ei in enumerate(self.jit_cache) for b in get_out_buffers_for_ei(ei)} + def _init(self, input_uops:list[UOp]): + self._jit_cache, self._input_replace, self._extra_view_inputs = linear_to_jit_cache(self.linear, input_uops) + self._output_to_writer = {b: j for j, ei in enumerate(self._jit_cache) for b in get_out_buffers_for_ei(ei)} self._input_to_max_reader: dict[int, int] = {} - for (j, i), idx in self.input_replace.items(): - # only buffers that were different during capture but alias at jit time (e.g. feeding output back as input) need the copy. - if self.jit_cache[j].bufs[i] not in get_out_buffers_for_ei(self.jit_cache[j]): + for (j, i), idx in self._input_replace.items(): + if self._jit_cache[j].bufs[i] not in get_out_buffers_for_ei(self._jit_cache[j]): self._input_to_max_reader[idx] = max(self._input_to_max_reader.get(idx, -1), j) - self._clear_inputs() - - def _clear_inputs(self): for (j,i) in self._input_replace.keys(): self._jit_cache[j].bufs[i] = None + def __call__(self, input_uops:list[UOp], var_vals:dict[str, int]) -> ReturnType: + if self._jit_cache is None: self._init(input_uops) + assert self._jit_cache is not None + # derive input_buffers from input_uops (flatten MultiBuffer) + input_buffers: list[Buffer] = flatten([b.bufs if isinstance(b, MultiBuffer) else [b] for u in input_uops if (b:=buffers[u]) is not None]) + # recreate view buffers from input bases + for idx, offset, device, size, dtype in self._extra_view_inputs: + input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated()) + # copy aliased inputs to prevent read-after-write hazard + for i, ib in enumerate(input_buffers): + if (writer := self._output_to_writer.get(ib)) is not None and self._input_to_max_reader.get(i, -1) >= writer: + input_buffers[i] = Buffer(ib.device, ib.size, ib.dtype).ensure_allocated().copyin(ib.as_memoryview()) + for (j,i),input_idx in self._input_replace.items(): self._jit_cache[j].bufs[i] = input_buffers[input_idx] + if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels") + for ei in self._jit_cache: ei.run(var_vals, jit=True) + for (j,i) in self._input_replace.keys(): self._jit_cache[j].bufs[i] = None + return self.ret + def free_intermediates(self): depends: set[Buffer|None] = set([None]) update_depends(depends, self.jit_cache) @@ -239,33 +270,6 @@ class CapturedJit(Generic[ReturnType]): for a in arenas: if a.allocated_views == 0 and a.is_allocated(): a.deallocate() self.__post_init__() - self._needs_rebuild = True - - # jit exec - def __call__(self, input_buffers:list[Buffer], var_vals:dict[str, int]) -> ReturnType: - # assign inputs - for idx, offset, device, size, dtype in self.extra_view_inputs: - input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated()) - - # copy aliased inputs to prevent read-after-write hazard - for i, ib in enumerate(input_buffers): - if (writer := self._output_to_writer.get(ib)) is not None and self._input_to_max_reader.get(i, -1) >= writer: - input_buffers[i] = Buffer(ib.device, ib.size, ib.dtype).ensure_allocated().copyin(ib.as_memoryview()) - - for (j,i),input_idx in self._input_replace.items(): self._jit_cache[j].bufs[i] = input_buffers[input_idx] - - # allocate intermediates if freed on first run - if self._first_run: - for b in jit_cache_bufs(self.jit_cache): b.ensure_allocated() - if self._needs_rebuild: - for ei in self.jit_cache: - if isinstance(ei.prg, GraphRunner): ei.prg = type(ei.prg)(None, None, ei.prg.jit_cache, ei.prg.input_replace) - self._first_run = self._needs_rebuild = False - - if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels") - for ei in self._jit_cache: ei.run(var_vals, jit=True) - self._clear_inputs() - return self.ret def _prepare_jit_inputs(args, kwargs): input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor] @@ -278,13 +282,14 @@ def _prepare_jit_inputs(args, kwargs): input_uops: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors]) if any(u.base.op is Ops.CONST for u in input_uops): raise JitError("JIT inputs cannot be const, create a buffer with .contiguous()") - input_buffers: list[Buffer] = flatten([b.bufs if isinstance(b, MultiBuffer) else [b] for u in input_uops if (b:=u.base.realized) is not None]) - if len(set(input_buffers)) != len(input_buffers): raise JitError("duplicate inputs to JIT") + # collect buffer UOps (including MultiBuffer) + input_buf_uops: list[UOp] = [u.base for u in input_uops if u.base.realized is not None] + if len(set(input_buf_uops)) != len(input_buf_uops): raise JitError("duplicate inputs to JIT") inputs = [(*(u.substitute({u.base:UOp(Ops.NOOP)}, extra_pm=mop_cleanup).unbind_all()), u.dtype, u.device) for u in input_uops] _var_vals = merge_dicts([x[1] for x in inputs] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))]) var_vals = {k.expr:v for k,v in _var_vals.items()} expected_input_info = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in inputs] - return input_buffers, var_vals, names, expected_input_info + return input_buf_uops, var_vals, names, expected_input_info class TinyJit(Generic[ReturnType]): def __init__(self, fxn:Callable[..., ReturnType]|None, captured:CapturedJit|None=None, prune=False): @@ -307,14 +312,15 @@ class TinyJit(Generic[ReturnType]): # keep legacy code working @property - def jit_cache(self) -> list[ExecItem]: return self.captured._jit_cache if self.captured is not None else [] + def jit_cache(self) -> list[ExecItem]: return self.captured._jit_cache if self.captured is not None and self.captured._jit_cache is not None else [] @property - def input_replace(self) -> dict[tuple[int, int], int]: return self.captured._input_replace if self.captured is not None else {} + def input_replace(self) -> dict[tuple[int, int], int]: + return self.captured._input_replace if self.captured is not None and self.captured._jit_cache is not None else {} def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) # add support for instance methods def __call__(self, *args, **kwargs) -> ReturnType: - input_buffers, var_vals, names, expected_input_info = _prepare_jit_inputs(args, kwargs) + input_buf_uops, var_vals, names, expected_input_info = _prepare_jit_inputs(args, kwargs) if not JIT or self.cnt == 0: # jit ignore assert self.fxn is not None @@ -333,46 +339,30 @@ class TinyJit(Generic[ReturnType]): finally: capturing.clear() if not len(self._linears): raise JitError("didn't JIT anything!") _check_no_non_tensor_return(ret) - if DEBUG >= 1: print(f"JIT captured {len(self._linears)} linears with {len(input_buffers)} inputs") + if DEBUG >= 1: print(f"JIT captured {len(self._linears)} linears with {len(input_buf_uops)} inputs") # combine all captured linears into one, memory plan, and convert to ExecItems big_linear = UOp(Ops.LINEAR, src=tuple(flatten([l.src for l in self._linears]))) del self._linears if self.prune: - big_linear, onetime_linear = prune_linear(big_linear, {k for k,v in buffers.items() if isinstance(v, Buffer) and v in set(input_buffers)}) + big_linear, onetime_linear = prune_linear(big_linear, set(input_buf_uops)) if DEBUG >= 1: print(f"pruned from {len(big_linear.src) + len(onetime_linear.src)} -> {len(big_linear.src)} kernels") for ei in (si.lower() for si in linear_to_schedule(onetime_linear)): for b in ei.bufs: cast(Buffer, b).ensure_allocated() ei.run(var_vals, jit=True) held_bufs = set(buffers) | {t.uop.buf_uop for t in get_parameters(ret) if t.uop.buf_uop.op is Ops.BUFFER} - with Context(BEAM=getenv("JITBEAM", BEAM.value)): - jit_cache = jit_lower(big_linear, held_bufs, input_buffers) - - # track inputs that are views of buffers - # TODO: eventually expected_buffers should live in ExecItem - extra_view_inputs: list[tuple[int, int, str, int, DType]] = [] - for item in jit_cache: - for b in item.bufs: - if b is not None and b._base is not None and b._base in input_buffers: - input_buffers.append(b) - extra_view_inputs.append((input_buffers.index(b.base), b.offset, b.device, b.size, b.dtype)) - - input_replace = get_input_replace(jit_cache, input_buffers) - if DEBUG >= 1 and len(set(input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found") - - # exec - for ei in jit_cache: ei.run(var_vals) - - self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, expected_input_info) + linear = jit_lower(big_linear, held_bufs, input_buf_uops) + self.captured = CapturedJit(ret, linear, names, expected_input_info) + ret = self.captured(input_buf_uops, var_vals) elif self.cnt >= 2: # jit exec assert self.captured is not None if self.captured.expected_names != names: raise JitError(f"args mismatch in JIT: {self.captured.expected_names=} != {names}") if self.captured.expected_input_info != expected_input_info: raise JitError(f"args mismatch in JIT: {self.captured.expected_input_info=} != {expected_input_info=}") - ret = self.captured(input_buffers, var_vals) + ret = self.captured(input_buf_uops, var_vals) self.cnt += 1 return ret diff --git a/tinygrad/runtime/graph/hcq.py b/tinygrad/runtime/graph/hcq.py index 687d8735ca..710676f1bc 100644 --- a/tinygrad/runtime/graph/hcq.py +++ b/tinygrad/runtime/graph/hcq.py @@ -6,7 +6,7 @@ from tinygrad.device import Buffer, BufferSpec, Compiled, Device, ProfileGraphEn from tinygrad.dtype import dtypes from tinygrad.uop.ops import UOp, Ops, Variable from tinygrad.engine.realize import BufferXfer, CompiledRunner, BufferCopy -from tinygrad.engine.jit import GraphRunner, MultiGraphRunner +from tinygrad.engine.jit import GraphRunner, MultiGraphRunner, _unwrap_beam from tinygrad.runtime.ops_rdma import RDMACopyQueue class HCQGraph(MultiGraphRunner): @@ -326,4 +326,4 @@ class HCQGraph(MultiGraphRunner): # MOCKGPU is not supported, since it can't execute commands in parallel is_xfer = len(set(type(d) for d in all_devs)) == 1 and hasattr(alc:=all_devs[0].allocator, '_transfer') and alc.supports_transfer return is_xfer or (all_devs[0].hw_copy_queue_t is not None and not getenv("MOCKGPU")) - return new_call.src[0].op in (Ops.SINK, Ops.PROGRAM) + return _unwrap_beam(new_call.src[0]).op in (Ops.SINK, Ops.PROGRAM) diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index a29c97efc1..6f871e09b1 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -92,7 +92,7 @@ _tensor_spec = PatternMatcher([ lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, DType)), # BUFFER_VIEW on BUFFER is allowed if BUFFER is - (UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),)), lambda: True), + (UPat(Ops.BUFFER_VIEW, src=(UPat((Ops.BUFFER, Ops.PARAM)),)), lambda: True), # KERNEL can attach to an AFTER to describe the compute required to realize a BUFFER (UPat((Ops.CALL, Ops.FUNCTION), src=UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True),