From b8d3bf8970fd400786da9635fe26f974399e933d Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Mon, 20 Apr 2026 23:03:30 +0300 Subject: [PATCH] run_linear in jit (#15827) * run_linear in jit * x * x * f * casts * ugh * f * x * x * simple --- .github/workflows/benchmark.yml | 4 +- tinygrad/engine/jit.py | 133 +++++++++++--------------------- tinygrad/engine/realize.py | 54 ++++++++----- 3 files changed, 85 insertions(+), 106 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 0b1849dcbb..9277b16886 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -626,7 +626,7 @@ jobs: - name: IR3 openpilot compile3 0.11.0 driving_vision run: BENCHMARK_LOG=ir3_openpilot_0_11_0_vision PYTHONPATH="." ASSERT_MIN_STEP_TIME=17 DEV=QCOM:IR3 FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.11.0/selfdrive/modeld/models/driving_vision.onnx - name: openpilot compile3 0.11.0 driving_policy - run: BENCHMARK_LOG=openpilot_0_11_0_policy PYTHONPATH="." ASSERT_MIN_STEP_TIME=3 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.11.0/selfdrive/modeld/models/driving_policy.onnx + run: BENCHMARK_LOG=openpilot_0_11_0_policy PYTHONPATH="." ASSERT_MIN_STEP_TIME=4 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.11.0/selfdrive/modeld/models/driving_policy.onnx - name: openpilot compile3 0.11.0 dmonitoring run: BENCHMARK_LOG=openpilot_0_11_0_dmonitoring PYTHONPATH="." ASSERT_MIN_STEP_TIME=11 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.11.0/selfdrive/modeld/models/dmonitoring_model.onnx - name: DEBUG=2 openpilot compile3 0.10.1 driving_vision @@ -634,7 +634,7 @@ jobs: - name: openpilot compile3 0.10.1 driving_vision run: BENCHMARK_LOG=openpilot_0_10_1_vision PYTHONPATH="." ASSERT_MIN_STEP_TIME=17 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_vision.onnx - name: openpilot compile3 0.10.1 driving_policy - run: BENCHMARK_LOG=openpilot_0_10_1_policy PYTHONPATH="." ASSERT_MIN_STEP_TIME=3 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_policy.onnx + run: BENCHMARK_LOG=openpilot_0_10_1_policy PYTHONPATH="." ASSERT_MIN_STEP_TIME=4 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_policy.onnx - name: openpilot compile3 0.10.1 dmonitoring run: BENCHMARK_LOG=openpilot_0_10_1_dmonitoring PYTHONPATH="." ASSERT_MIN_STEP_TIME=11 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/dmonitoring_model.onnx - name: benchmark MobileNetV2 on DSP diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 0674eff8ce..cbe4811141 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -1,11 +1,11 @@ -from typing import TypeVar, Generic, Callable, cast, Any +from typing import TypeVar, Generic, Callable, Any import functools, collections from tinygrad.tensor import Tensor from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, JIT_BATCH_SIZE, dedup, pluralize, VIZ from tinygrad.device import Buffer, Compiled, Device, MultiBuffer from tinygrad.dtype import DType, dtypes -from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Variable, sym_infer, Ops, buffers, track_rewrites, graph_rewrite -from tinygrad.engine.realize import ExecItem, capturing, BufferCopy, BufferXfer, EncDec, CompiledRunner, Runner, Estimates, pm_beam +from tinygrad.uop.ops import UOp, PatternMatcher, Variable, sym_infer, Ops, buffers, track_rewrites, graph_rewrite +from tinygrad.engine.realize import ExecItem, capturing, CompiledRunner, Runner, Estimates, pm_beam, run_linear, get_runner, graph_cache from tinygrad.schedule.memory import memory_plan_rewrite, _collect_bufs from tinygrad.schedule import linear_to_schedule from tinygrad.nn.state import get_parameters @@ -24,7 +24,7 @@ def prune_linear(linear:UOp, needed:set[UOp]) -> tuple[UOp, UOp]: def create_graph_call(batch:list[UOp]) -> UOp: # all external inputs are PARAMs - input_list = dedup(b for si in batch for b in si.src[1:] if b.op is Ops.PARAM) + input_list = dedup(u for si in batch for b in si.src[1:] for u in b.toposort() if u.op is Ops.PARAM) cf = UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(UOp(Ops.LINEAR, src=tuple(batch)), *input_list), arg="graph") return cf.call(*input_list, metadata=tuple(m for si in batch for m in si.arg.metadata)) @@ -59,14 +59,22 @@ def graph_split_rewrite(linear:UOp, max_batch_size:int=0) -> UOp: if current_batch: flush_batch() return linear.replace(src=tuple(new_src)) -def jit_cache_bufs(jit_cache:list[ExecItem]): - for ei in jit_cache: - for b in ei.bufs: - if b is not None: yield b - if isinstance(ei.prg, GraphRunner): yield from jit_cache_bufs(ei.prg.jit_cache) - def _unwrap_beam(ast:UOp) -> UOp: return ast.src[0] if ast.op is Ops.BEAM else ast +def _call_outs_ins(call:UOp) -> tuple[set[int], set[int]]: + non_bind = [s for s in call.src[1:] if s.op is not Ops.BIND] + ast = _unwrap_beam(call.src[0]) + if ast.op in (Ops.SINK, Ops.PROGRAM): + prg = get_runner(non_bind[0].device if isinstance(non_bind[0].device, str) else non_bind[0].device[0], call.src[0]) + return set(prg.p.outs), set(prg.p.ins) + if ast.op in (Ops.COPY, Ops.BUFFER_VIEW): return {0}, {1} + if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "encdec": return {0}, set(range(1, len(non_bind))) + return set(), set() + +def _copy_input(u:UOp) -> UOp: + run_linear(UOp(Ops.LINEAR, src=(u.copy_to_device(u.device).call(new:=UOp.new_buffer(u.device, u.arg, u.dtype), u, metadata=()),))) + return new + @track_rewrites(lambda linear,held_bufs,input_uops,ret=(): f"JIT {pluralize('call', len(linear.src))}") def jit_lower(linear:UOp, held_bufs:set[UOp], input_uops:list[UOp]) -> UOp: if VIZ: graph_rewrite(linear, PatternMatcher([]), name="View captured linear") @@ -101,29 +109,12 @@ def get_input_replace(jit_cache: list[ExecItem], input_buffers:list[Buffer]) -> if a in input_buffers: input_replace[(j,i)] = input_buffers.index(a) return input_replace -pm_params = PatternMatcher([(UPat(Ops.PARAM, src=(UPat(), UPat(Ops.DEVICE)), name="p"), lambda ctx,p: ctx[p.arg])]) - -def linear_to_jit_cache(linear:UOp, input_uops:list[UOp]) -> tuple[list[ExecItem], dict[tuple[int,int],int], list[tuple[int,int,str,int,DType]]]: - # substitute PARAMs with input buffer UOps before lowering - linear = graph_rewrite(linear, pm_params, ctx=input_uops, walk=True, enter_calls=True) - # convert to jit_cache - jit_cache = [ei.lower() for ei in linear_to_schedule(linear)] - for b in jit_cache_bufs(jit_cache): b.ensure_allocated() - # derive input_buffers from input_uops - input_buffers: list[Buffer] = flatten([b.bufs if isinstance(b, MultiBuffer) else [b] for u in input_uops if (b:=buffers[u]) is not None]) - # track view buffers whose base is an input buffer - extra_view_inputs: list[tuple[int, int, str, int, DType]] = [] - for ei in jit_cache: - for b in ei.bufs: - if b is not None and b._base is not None and b._base in input_buffers and b not in input_buffers: - extra_view_inputs.append((input_buffers.index(b._base), b.offset, b.device, b.size, b.dtype)) - input_buffers.append(b) - return jit_cache, get_input_replace(jit_cache, input_buffers), extra_view_inputs - class GraphRunner(Runner): def __init__(self, linear:UOp, input_buffers:list[Buffer]): self.jit_cache = [ei.lower() for ei in linear_to_schedule(linear.src[0])] - for b in jit_cache_bufs(self.jit_cache): b.ensure_allocated() + for ei in self.jit_cache: + for b in ei.bufs: + if b is not None: b.ensure_allocated() self.input_replace = get_input_replace(self.jit_cache, input_buffers) if input_buffers else {} self.var_vals_replace:dict[int, list[tuple[int, int]]] = {} @@ -206,16 +197,6 @@ class MultiGraphRunner(GraphRunner): return _unwrap_beam(new_call.src[0]).op in (Ops.SINK, Ops.PROGRAM, Ops.COPY) \ and len(dedup([type(d) for d in GraphRunner._all_devs(batch_devs, new_call)])) == 1 -def get_out_buffers_for_ei(ei:ExecItem) -> list[Buffer]: - if isinstance(ei.prg, CompiledRunner): return [cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs if out not in ei.prg.p.ins] - if isinstance(ei.prg, (BufferCopy, BufferXfer, EncDec)): return [cast(Buffer, ei.bufs[0])] - if isinstance(ei.prg, GraphRunner): return dedup([b for inner in ei.prg.jit_cache for b in get_out_buffers_for_ei(inner)]) - return [] - -def update_depends(depends:set[Buffer|None], jit_cache:list[ExecItem]): - for ei in jit_cache: - if any(b in depends for b in ei.bufs): depends.update(get_out_buffers_for_ei(ei)) - ReturnType = TypeVar('ReturnType') @dataclass class CapturedJit(Generic[ReturnType]): @@ -225,47 +206,36 @@ class CapturedJit(Generic[ReturnType]): expected_input_info: list[tuple[UOp, tuple[Variable, ...], DType, str]] # (view, variables, dtype, device) per input def __reduce__(self): return self.__class__, (self.ret, self.linear, self.expected_names, self.expected_input_info) - def __post_init__(self): self._jit_cache = None - @property - def jit_cache(self) -> list[ExecItem]: return self._jit_cache if self._jit_cache is not None else [] - def _init(self, input_uops:list[UOp]): - self._jit_cache, self._input_replace, self._extra_view_inputs = linear_to_jit_cache(self.linear, input_uops) - self._output_to_writer = {b: j for j, ei in enumerate(self._jit_cache) for b in get_out_buffers_for_ei(ei)} - self._input_to_max_reader: dict[int, int] = {} - for (j, i), idx in self._input_replace.items(): - if self._jit_cache[j].bufs[i] not in get_out_buffers_for_ei(self._jit_cache[j]): - self._input_to_max_reader[idx] = max(self._input_to_max_reader.get(idx, -1), j) - for (j,i) in self._input_replace.keys(): self._jit_cache[j].bufs[i] = None + @functools.cached_property + def _written_uops(self) -> set[UOp]: + out: set[UOp] = set() + for call in self.linear.toposort(): + if call.op is not Ops.CALL: continue + non_bind = [s for s in call.src[1:] if s.op is not Ops.BIND] + outs, ins = _call_outs_ins(call) + out |= {non_bind[k] for k in outs - ins if non_bind[k].op in (Ops.BUFFER, Ops.BUFFER_VIEW)} + return out def __call__(self, input_uops:list[UOp], var_vals:dict[str, int]) -> ReturnType: - if self._jit_cache is None: self._init(input_uops) - assert self._jit_cache is not None - # derive input_buffers from input_uops (flatten MultiBuffer) - input_buffers: list[Buffer] = flatten([b.bufs if isinstance(b, MultiBuffer) else [b] for u in input_uops if (b:=buffers[u]) is not None]) - # recreate view buffers from input bases - for idx, offset, device, size, dtype in self._extra_view_inputs: - input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated()) - # copy aliased inputs to prevent read-after-write hazard - for i, ib in enumerate(input_buffers): - if (writer := self._output_to_writer.get(ib)) is not None and self._input_to_max_reader.get(i, -1) >= writer: - input_buffers[i] = Buffer(ib.device, ib.size, ib.dtype).ensure_allocated().copyin(ib.as_memoryview()) - for (j,i),input_idx in self._input_replace.items(): self._jit_cache[j].bufs[i] = input_buffers[input_idx] - if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels") - for ei in self._jit_cache: ei.run(var_vals, jit=True) - for (j,i) in self._input_replace.keys(): self._jit_cache[j].bufs[i] = None + concrete = tuple(_copy_input(u) if u in self._written_uops else u for u in input_uops) + if DEBUG >= 1 and len(self.linear.src) >= 10: print(f"jit execs {len(self.linear.src)} calls") + run_linear(self.linear, var_vals, input_uops=concrete, jit=True) return self.ret def free_intermediates(self): - depends: set[Buffer|None] = set([None]) - update_depends(depends, self.jit_cache) - arenas = {b._base for b in depends if b is not None and b._base is not None} - to_free = {b for b in depends if b is not None} | {b for b in jit_cache_bufs(self.jit_cache) if b._base in arenas} - for b in to_free: - if hasattr(b, '_buf'): b.deallocate() - for a in arenas: - if a.allocated_views == 0 and a.is_allocated(): a.deallocate() - self.__post_init__() + # drop graph runners + for call in self.linear.src: + if call.src[0].op is Ops.CUSTOM_FUNCTION and call.src[0].arg == "graph": graph_cache.pop(call.src[0], None) + bases: set[Buffer] = set() + for u in self._written_uops: + try: buf = u.buffer + except Exception: continue + for b in (buf.bufs if isinstance(buf, MultiBuffer) else [buf]): + if hasattr(b, '_buf'): b.deallocate() + if b._base is not None: bases.add(b._base) + for a in bases: + if a.is_allocated() and a.allocated_views == 0: a.deallocate() def _prepare_jit_inputs(args, kwargs): input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor] @@ -306,13 +276,6 @@ class TinyJit(Generic[ReturnType]): assert self.captured is not None, "can't pickle an uncaptured JIT" return self.__class__, (None, self.captured) - # keep legacy code working - @property - def jit_cache(self) -> list[ExecItem]: return self.captured._jit_cache if self.captured is not None and self.captured._jit_cache is not None else [] - @property - def input_replace(self) -> dict[tuple[int, int], int]: - return self.captured._input_replace if self.captured is not None and self.captured._jit_cache is not None else {} - def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) # add support for instance methods def __call__(self, *args, **kwargs) -> ReturnType: @@ -337,16 +300,14 @@ class TinyJit(Generic[ReturnType]): _check_no_non_tensor_return(ret) if DEBUG >= 1: print(f"JIT captured {len(self._linears)} linears with {len(input_buf_uops)} inputs") - # combine all captured linears into one, memory plan, and convert to ExecItems + # combine all captured linears into one, memory plan, and graph split big_linear = UOp(Ops.LINEAR, src=tuple(flatten([l.src for l in self._linears]))) del self._linears if self.prune: big_linear, onetime_linear = prune_linear(big_linear, set(input_buf_uops)) if DEBUG >= 1: print(f"pruned from {len(big_linear.src) + len(onetime_linear.src)} -> {len(big_linear.src)} kernels") - for ei in (si.lower() for si in linear_to_schedule(onetime_linear)): - for b in ei.bufs: cast(Buffer, b).ensure_allocated() - ei.run(var_vals, jit=True) + run_linear(onetime_linear, var_vals) held_bufs = set(buffers) | {t.uop.buf_uop for t in get_parameters(ret) if t.uop.buf_uop.op is Ops.BUFFER} linear = jit_lower(big_linear, held_bufs, input_buf_uops) diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index ea92fbd0ea..35382bb873 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -1,9 +1,9 @@ from typing import cast, Callable, Iterator -import time, pprint, random, itertools, math, contextlib +import time, pprint, random, itertools, math, contextlib, weakref from dataclasses import dataclass, replace, field from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, NOOPT, all_int, Metadata, TRACEMETA, TracingKey from tinygrad.helpers import BEAM, DEVECTORIZE, size_to_str, time_to_str, VALIDATE_WITH_CPU, cpu_profile, PROFILE, ProfilePointEvent, cpu_events -from tinygrad.helpers import prod, unwrap, EMULATED_DTYPES +from tinygrad.helpers import prod, unwrap, EMULATED_DTYPES, flatten from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, buffers, graph_rewrite from tinygrad.device import Device, Buffer, MultiBuffer from tinygrad.renderer import ProgramSpec, Estimates @@ -222,48 +222,55 @@ def run_schedule(schedule:list[ExecItem], var_vals:dict[str, int]|None=None, do_ @dataclass class ExecContext: var_vals: dict[str, int] = field(default_factory=dict) + input_uops: tuple[UOp, ...] = () do_update_stats: bool = True + jit: bool = False + +def _resolve(b:UOp, inputs:tuple[UOp, ...]) -> UOp: + if b.op is Ops.BUFFER_VIEW and b.src[0].op is Ops.PARAM: return b.replace(src=(inputs[b.src[0].arg], *b.src[1:])) + return inputs[b.arg] if b.op is Ops.PARAM else b +def resolve_params(ctx:ExecContext, call:UOp) -> list[UOp]: return [_resolve(b, ctx.input_uops) for b in call.src[1:] if b.op is not Ops.BIND] @contextlib.contextmanager -def track_stats(ctx:ExecContext, call:UOp, device:str, display_name:str, estimates:Estimates, bufs:list[Buffer], var_vals:dict[str, int], *, +def track_stats(ctx:ExecContext, call:UOp, device:str, display_name:str, estimates:Estimates, bufs:list[Buffer], var_vals:dict[str, int], outputs=(0,), inputs=(1,), first_run=False): if PROFILE: cpu_events.append(ProfilePointEvent(device, "exec", len(cpu_events), {"metadata": call.arg.metadata, "var_vals": var_vals, "bufs": [b.trace_num for b in bufs], "name": display_name, "outputs": outputs, "inputs": inputs})) timing: list[float|None] = [None] - st = time.perf_counter() + if DEBUG >= 2: st = time.perf_counter() yield timing if not ctx.do_update_stats: return - if timing[0] is None and DEBUG >= 2: + if DEBUG >= 2 and timing[0] is None: Device[device].synchronize() timing[0] = time.perf_counter() - st - update_stats(display_name, device, estimates, var_vals, timing[0], len(bufs), jit=False, metadata=call.arg.metadata, first_run=first_run) + update_stats(display_name, device, estimates, var_vals, timing[0], len(bufs), jit=ctx.jit, metadata=call.arg.metadata, first_run=first_run) -def unwrap_multi(call:UOp) -> Iterator[tuple[list[Buffer], dict[str, int]]]: - bufs = [b.buffer for b in call.src[1:] if b.op is not Ops.BIND] +def unwrap_multi(call:UOp, resolved:list[UOp]) -> Iterator[tuple[list[Buffer], dict[str, int]]]: + bufs = [b.buffer for b in resolved] if not any(isinstance(b, MultiBuffer) for b in bufs): yield cast(list[Buffer], bufs), {} else: dnum = next((x.expr for x in call.src[0].variables() if x.expr == '_device_num'), None) for j, per_dev in enumerate(zip(*[cast(MultiBuffer, b).bufs for b in bufs])): yield list(per_dev), {dnum: j} if dnum else {} def exec_view(ctx:ExecContext, call, ast): - bufs = [b.buffer for b in call.src[1:] if b.op is not Ops.BIND] - bv = bufs[1].view(call.src[1].arg, ast.dtype, ast.arg[1]*bufs[1].dtype.itemsize) + resolved = resolve_params(ctx, call) + bufs = [cast(Buffer, b.buffer) for b in resolved] + bv = bufs[1].view(resolved[0].arg, ast.dtype, ast.arg[1]*bufs[1].dtype.itemsize) with track_stats(ctx, call, bv.device, colored(f"view {bv.nbytes:8d} @ {bv.offset:<10d}", "yellow"), Estimates(), [bv, bufs[1]], ctx.var_vals): - buffers[call.src[1]] = bv + buffers[resolved[0]] = bv def exec_copy(ctx:ExecContext, call, ast): - for bufs, device_vars in unwrap_multi(call): + for bufs, device_vars in unwrap_multi(call, resolve_params(ctx, call)): dest, src = bufs[0].ensure_allocated(), bufs[1].ensure_allocated() xfer = hasattr(alc:=Device[dest.device].allocator,'_transfer') and alc.supports_transfer and dest.device.split(":")[0]==src.device.split(":")[0] prg = (BufferXfer if xfer else BufferCopy)(dest.nbytes, dest.device, src.device) - name = colored(f"{'xfer' if xfer else 'copy'} {size_to_str(dest.nbytes):>8s}, {dest.device[:7]:>7s} <- {src.device[:7]:7s}", "yellow") - with track_stats(ctx, call, dest.device, name, Estimates(lds=dest.nbytes, mem=dest.nbytes), [dest, src], {**ctx.var_vals, **device_vars}): + with track_stats(ctx, call, dest.device, prg.display_name, Estimates(lds=dest.nbytes, mem=dest.nbytes), [dest, src], ctx.var_vals): prg.copy(dest, src) def exec_kernel(ctx:ExecContext, call, ast): sink = ast.src[0] if ast.op is Ops.BEAM else ast - for bufs, device_vars in unwrap_multi(call): + for bufs, device_vars in unwrap_multi(call, resolve_params(ctx, call)): var_vals = {**ctx.var_vals, **device_vars} prg = get_runner(bufs[0].device, ast) prg_bufs = [bufs[i].ensure_allocated() for i in prg.p.globals] @@ -283,12 +290,22 @@ def exec_kernel(ctx:ExecContext, call, ast): for i in prg.p.outs: np.testing.assert_allclose(prg_bufs[i].numpy(), cpu_bufs[i].numpy(), rtol=1e-3, atol=1e-3) def exec_encdec(ctx:ExecContext, call, ast): - bufs = [b.buffer.ensure_allocated() for b in call.src[1:] if b.op is not Ops.BIND] + bufs = [cast(Buffer, b.buffer).ensure_allocated() for b in resolve_params(ctx, call)] shape, pos_var = tuple(s.arg for s in ast.src if s.op is Ops.CONST), ast.variables()[0].expr with track_stats(ctx, call, bufs[0].device, colored(f"enc/dec {size_to_str(bufs[0].nbytes)}", "yellow"), Estimates(lds=bufs[0].nbytes, mem=bufs[0].nbytes), bufs, ctx.var_vals): bufs[0].allocator._encode_decode(bufs[0]._buf, bufs[1]._buf, bufs[2]._buf, [x._buf for x in bufs[3:]], shape, ctx.var_vals[pos_var]) +graph_cache:weakref.WeakKeyDictionary[UOp, Runner] = weakref.WeakKeyDictionary() +def exec_graph(ctx:ExecContext, call, cf): + inputs = resolve_params(ctx, call) + bufs = flatten([b.bufs if isinstance(b, MultiBuffer) else [b] for b in (u.buffer for u in inputs)]) + if (runner:=graph_cache.get(cf)) is None: + sub = cf.substitute(dict(zip(cf.src[1:], inputs))) + graph_cache[cf] = runner = Device[cf.device if isinstance(cf.device, str) else cf.device[0]].graph(sub, bufs) + with track_stats(ctx, call, runner.device, runner.display_name, runner.estimates, bufs, ctx.var_vals) as t: + t[0] = runner(bufs, ctx.var_vals, wait=DEBUG >= 2) + pm_beam = PatternMatcher([ (UPat(Ops.CALL, src=(UPat(Ops.SINK, name="sink"),), name="call", allow_any_len=True), lambda ctx,call,sink: call.replace(src=(UOp(Ops.BEAM, src=(sink,), arg=ctx), *call.src[1:]))), @@ -299,9 +316,10 @@ pm_exec = PatternMatcher([ (UPat(Ops.CALL, src=(UPat(Ops.COPY, name="ast"),), name="call", allow_any_len=True), exec_copy), (UPat(Ops.CALL, src=(UPat((Ops.SINK, Ops.PROGRAM, Ops.BEAM), name="ast"),), name="call", allow_any_len=True), exec_kernel), (UPat(Ops.CALL, src=(UPat(Ops.CUSTOM_FUNCTION, arg="encdec", name="ast"),), name="call", allow_any_len=True), exec_encdec), + (UPat(Ops.CALL, src=(UPat(Ops.CUSTOM_FUNCTION, arg="graph", name="cf"),), name="call", allow_any_len=True), exec_graph), ]) -def run_linear(linear:UOp, var_vals:dict[str, int]|None=None, do_update_stats=True): +def run_linear(linear:UOp, var_vals:dict[str, int]|None=None, input_uops:tuple[UOp, ...]=(), do_update_stats=True, jit=False): if BEAM >= 1: linear = graph_rewrite(linear, pm_beam, ctx=BEAM.value, name="add beam") - ctx = ExecContext(var_vals or {}, do_update_stats) + ctx = ExecContext(var_vals or {}, input_uops, do_update_stats, jit) for call in linear.src: pm_exec.rewrite(call, ctx)