run_linear in jit (#15827)

* run_linear in jit

* x

* x

* f

* casts

* ugh

* f

* x

* x

* simple
This commit is contained in:
nimlgen
2026-04-20 23:03:30 +03:00
committed by GitHub
parent e00cc8ae5e
commit b8d3bf8970
3 changed files with 85 additions and 106 deletions

View File

@@ -626,7 +626,7 @@ jobs:
- name: IR3 openpilot compile3 0.11.0 driving_vision
run: BENCHMARK_LOG=ir3_openpilot_0_11_0_vision PYTHONPATH="." ASSERT_MIN_STEP_TIME=17 DEV=QCOM:IR3 FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.11.0/selfdrive/modeld/models/driving_vision.onnx
- name: openpilot compile3 0.11.0 driving_policy
run: BENCHMARK_LOG=openpilot_0_11_0_policy PYTHONPATH="." ASSERT_MIN_STEP_TIME=3 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.11.0/selfdrive/modeld/models/driving_policy.onnx
run: BENCHMARK_LOG=openpilot_0_11_0_policy PYTHONPATH="." ASSERT_MIN_STEP_TIME=4 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.11.0/selfdrive/modeld/models/driving_policy.onnx
- name: openpilot compile3 0.11.0 dmonitoring
run: BENCHMARK_LOG=openpilot_0_11_0_dmonitoring PYTHONPATH="." ASSERT_MIN_STEP_TIME=11 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.11.0/selfdrive/modeld/models/dmonitoring_model.onnx
- name: DEBUG=2 openpilot compile3 0.10.1 driving_vision
@@ -634,7 +634,7 @@ jobs:
- name: openpilot compile3 0.10.1 driving_vision
run: BENCHMARK_LOG=openpilot_0_10_1_vision PYTHONPATH="." ASSERT_MIN_STEP_TIME=17 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_vision.onnx
- name: openpilot compile3 0.10.1 driving_policy
run: BENCHMARK_LOG=openpilot_0_10_1_policy PYTHONPATH="." ASSERT_MIN_STEP_TIME=3 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_policy.onnx
run: BENCHMARK_LOG=openpilot_0_10_1_policy PYTHONPATH="." ASSERT_MIN_STEP_TIME=4 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_policy.onnx
- name: openpilot compile3 0.10.1 dmonitoring
run: BENCHMARK_LOG=openpilot_0_10_1_dmonitoring PYTHONPATH="." ASSERT_MIN_STEP_TIME=11 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/dmonitoring_model.onnx
- name: benchmark MobileNetV2 on DSP

View File

@@ -1,11 +1,11 @@
from typing import TypeVar, Generic, Callable, cast, Any
from typing import TypeVar, Generic, Callable, Any
import functools, collections
from tinygrad.tensor import Tensor
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, JIT_BATCH_SIZE, dedup, pluralize, VIZ
from tinygrad.device import Buffer, Compiled, Device, MultiBuffer
from tinygrad.dtype import DType, dtypes
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Variable, sym_infer, Ops, buffers, track_rewrites, graph_rewrite
from tinygrad.engine.realize import ExecItem, capturing, BufferCopy, BufferXfer, EncDec, CompiledRunner, Runner, Estimates, pm_beam
from tinygrad.uop.ops import UOp, PatternMatcher, Variable, sym_infer, Ops, buffers, track_rewrites, graph_rewrite
from tinygrad.engine.realize import ExecItem, capturing, CompiledRunner, Runner, Estimates, pm_beam, run_linear, get_runner, graph_cache
from tinygrad.schedule.memory import memory_plan_rewrite, _collect_bufs
from tinygrad.schedule import linear_to_schedule
from tinygrad.nn.state import get_parameters
@@ -24,7 +24,7 @@ def prune_linear(linear:UOp, needed:set[UOp]) -> tuple[UOp, UOp]:
def create_graph_call(batch:list[UOp]) -> UOp:
# all external inputs are PARAMs
input_list = dedup(b for si in batch for b in si.src[1:] if b.op is Ops.PARAM)
input_list = dedup(u for si in batch for b in si.src[1:] for u in b.toposort() if u.op is Ops.PARAM)
cf = UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(UOp(Ops.LINEAR, src=tuple(batch)), *input_list), arg="graph")
return cf.call(*input_list, metadata=tuple(m for si in batch for m in si.arg.metadata))
@@ -59,14 +59,22 @@ def graph_split_rewrite(linear:UOp, max_batch_size:int=0) -> UOp:
if current_batch: flush_batch()
return linear.replace(src=tuple(new_src))
def jit_cache_bufs(jit_cache:list[ExecItem]):
for ei in jit_cache:
for b in ei.bufs:
if b is not None: yield b
if isinstance(ei.prg, GraphRunner): yield from jit_cache_bufs(ei.prg.jit_cache)
def _unwrap_beam(ast:UOp) -> UOp: return ast.src[0] if ast.op is Ops.BEAM else ast
def _call_outs_ins(call:UOp) -> tuple[set[int], set[int]]:
non_bind = [s for s in call.src[1:] if s.op is not Ops.BIND]
ast = _unwrap_beam(call.src[0])
if ast.op in (Ops.SINK, Ops.PROGRAM):
prg = get_runner(non_bind[0].device if isinstance(non_bind[0].device, str) else non_bind[0].device[0], call.src[0])
return set(prg.p.outs), set(prg.p.ins)
if ast.op in (Ops.COPY, Ops.BUFFER_VIEW): return {0}, {1}
if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "encdec": return {0}, set(range(1, len(non_bind)))
return set(), set()
def _copy_input(u:UOp) -> UOp:
run_linear(UOp(Ops.LINEAR, src=(u.copy_to_device(u.device).call(new:=UOp.new_buffer(u.device, u.arg, u.dtype), u, metadata=()),)))
return new
@track_rewrites(lambda linear,held_bufs,input_uops,ret=(): f"JIT {pluralize('call', len(linear.src))}")
def jit_lower(linear:UOp, held_bufs:set[UOp], input_uops:list[UOp]) -> UOp:
if VIZ: graph_rewrite(linear, PatternMatcher([]), name="View captured linear")
@@ -101,29 +109,12 @@ def get_input_replace(jit_cache: list[ExecItem], input_buffers:list[Buffer]) ->
if a in input_buffers: input_replace[(j,i)] = input_buffers.index(a)
return input_replace
pm_params = PatternMatcher([(UPat(Ops.PARAM, src=(UPat(), UPat(Ops.DEVICE)), name="p"), lambda ctx,p: ctx[p.arg])])
def linear_to_jit_cache(linear:UOp, input_uops:list[UOp]) -> tuple[list[ExecItem], dict[tuple[int,int],int], list[tuple[int,int,str,int,DType]]]:
# substitute PARAMs with input buffer UOps before lowering
linear = graph_rewrite(linear, pm_params, ctx=input_uops, walk=True, enter_calls=True)
# convert to jit_cache
jit_cache = [ei.lower() for ei in linear_to_schedule(linear)]
for b in jit_cache_bufs(jit_cache): b.ensure_allocated()
# derive input_buffers from input_uops
input_buffers: list[Buffer] = flatten([b.bufs if isinstance(b, MultiBuffer) else [b] for u in input_uops if (b:=buffers[u]) is not None])
# track view buffers whose base is an input buffer
extra_view_inputs: list[tuple[int, int, str, int, DType]] = []
for ei in jit_cache:
for b in ei.bufs:
if b is not None and b._base is not None and b._base in input_buffers and b not in input_buffers:
extra_view_inputs.append((input_buffers.index(b._base), b.offset, b.device, b.size, b.dtype))
input_buffers.append(b)
return jit_cache, get_input_replace(jit_cache, input_buffers), extra_view_inputs
class GraphRunner(Runner):
def __init__(self, linear:UOp, input_buffers:list[Buffer]):
self.jit_cache = [ei.lower() for ei in linear_to_schedule(linear.src[0])]
for b in jit_cache_bufs(self.jit_cache): b.ensure_allocated()
for ei in self.jit_cache:
for b in ei.bufs:
if b is not None: b.ensure_allocated()
self.input_replace = get_input_replace(self.jit_cache, input_buffers) if input_buffers else {}
self.var_vals_replace:dict[int, list[tuple[int, int]]] = {}
@@ -206,16 +197,6 @@ class MultiGraphRunner(GraphRunner):
return _unwrap_beam(new_call.src[0]).op in (Ops.SINK, Ops.PROGRAM, Ops.COPY) \
and len(dedup([type(d) for d in GraphRunner._all_devs(batch_devs, new_call)])) == 1
def get_out_buffers_for_ei(ei:ExecItem) -> list[Buffer]:
if isinstance(ei.prg, CompiledRunner): return [cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs if out not in ei.prg.p.ins]
if isinstance(ei.prg, (BufferCopy, BufferXfer, EncDec)): return [cast(Buffer, ei.bufs[0])]
if isinstance(ei.prg, GraphRunner): return dedup([b for inner in ei.prg.jit_cache for b in get_out_buffers_for_ei(inner)])
return []
def update_depends(depends:set[Buffer|None], jit_cache:list[ExecItem]):
for ei in jit_cache:
if any(b in depends for b in ei.bufs): depends.update(get_out_buffers_for_ei(ei))
ReturnType = TypeVar('ReturnType')
@dataclass
class CapturedJit(Generic[ReturnType]):
@@ -225,47 +206,36 @@ class CapturedJit(Generic[ReturnType]):
expected_input_info: list[tuple[UOp, tuple[Variable, ...], DType, str]] # (view, variables, dtype, device) per input
def __reduce__(self): return self.__class__, (self.ret, self.linear, self.expected_names, self.expected_input_info)
def __post_init__(self): self._jit_cache = None
@property
def jit_cache(self) -> list[ExecItem]: return self._jit_cache if self._jit_cache is not None else []
def _init(self, input_uops:list[UOp]):
self._jit_cache, self._input_replace, self._extra_view_inputs = linear_to_jit_cache(self.linear, input_uops)
self._output_to_writer = {b: j for j, ei in enumerate(self._jit_cache) for b in get_out_buffers_for_ei(ei)}
self._input_to_max_reader: dict[int, int] = {}
for (j, i), idx in self._input_replace.items():
if self._jit_cache[j].bufs[i] not in get_out_buffers_for_ei(self._jit_cache[j]):
self._input_to_max_reader[idx] = max(self._input_to_max_reader.get(idx, -1), j)
for (j,i) in self._input_replace.keys(): self._jit_cache[j].bufs[i] = None
@functools.cached_property
def _written_uops(self) -> set[UOp]:
out: set[UOp] = set()
for call in self.linear.toposort():
if call.op is not Ops.CALL: continue
non_bind = [s for s in call.src[1:] if s.op is not Ops.BIND]
outs, ins = _call_outs_ins(call)
out |= {non_bind[k] for k in outs - ins if non_bind[k].op in (Ops.BUFFER, Ops.BUFFER_VIEW)}
return out
def __call__(self, input_uops:list[UOp], var_vals:dict[str, int]) -> ReturnType:
if self._jit_cache is None: self._init(input_uops)
assert self._jit_cache is not None
# derive input_buffers from input_uops (flatten MultiBuffer)
input_buffers: list[Buffer] = flatten([b.bufs if isinstance(b, MultiBuffer) else [b] for u in input_uops if (b:=buffers[u]) is not None])
# recreate view buffers from input bases
for idx, offset, device, size, dtype in self._extra_view_inputs:
input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated())
# copy aliased inputs to prevent read-after-write hazard
for i, ib in enumerate(input_buffers):
if (writer := self._output_to_writer.get(ib)) is not None and self._input_to_max_reader.get(i, -1) >= writer:
input_buffers[i] = Buffer(ib.device, ib.size, ib.dtype).ensure_allocated().copyin(ib.as_memoryview())
for (j,i),input_idx in self._input_replace.items(): self._jit_cache[j].bufs[i] = input_buffers[input_idx]
if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels")
for ei in self._jit_cache: ei.run(var_vals, jit=True)
for (j,i) in self._input_replace.keys(): self._jit_cache[j].bufs[i] = None
concrete = tuple(_copy_input(u) if u in self._written_uops else u for u in input_uops)
if DEBUG >= 1 and len(self.linear.src) >= 10: print(f"jit execs {len(self.linear.src)} calls")
run_linear(self.linear, var_vals, input_uops=concrete, jit=True)
return self.ret
def free_intermediates(self):
depends: set[Buffer|None] = set([None])
update_depends(depends, self.jit_cache)
arenas = {b._base for b in depends if b is not None and b._base is not None}
to_free = {b for b in depends if b is not None} | {b for b in jit_cache_bufs(self.jit_cache) if b._base in arenas}
for b in to_free:
if hasattr(b, '_buf'): b.deallocate()
for a in arenas:
if a.allocated_views == 0 and a.is_allocated(): a.deallocate()
self.__post_init__()
# drop graph runners
for call in self.linear.src:
if call.src[0].op is Ops.CUSTOM_FUNCTION and call.src[0].arg == "graph": graph_cache.pop(call.src[0], None)
bases: set[Buffer] = set()
for u in self._written_uops:
try: buf = u.buffer
except Exception: continue
for b in (buf.bufs if isinstance(buf, MultiBuffer) else [buf]):
if hasattr(b, '_buf'): b.deallocate()
if b._base is not None: bases.add(b._base)
for a in bases:
if a.is_allocated() and a.allocated_views == 0: a.deallocate()
def _prepare_jit_inputs(args, kwargs):
input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor]
@@ -306,13 +276,6 @@ class TinyJit(Generic[ReturnType]):
assert self.captured is not None, "can't pickle an uncaptured JIT"
return self.__class__, (None, self.captured)
# keep legacy code working
@property
def jit_cache(self) -> list[ExecItem]: return self.captured._jit_cache if self.captured is not None and self.captured._jit_cache is not None else []
@property
def input_replace(self) -> dict[tuple[int, int], int]:
return self.captured._input_replace if self.captured is not None and self.captured._jit_cache is not None else {}
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) # add support for instance methods
def __call__(self, *args, **kwargs) -> ReturnType:
@@ -337,16 +300,14 @@ class TinyJit(Generic[ReturnType]):
_check_no_non_tensor_return(ret)
if DEBUG >= 1: print(f"JIT captured {len(self._linears)} linears with {len(input_buf_uops)} inputs")
# combine all captured linears into one, memory plan, and convert to ExecItems
# combine all captured linears into one, memory plan, and graph split
big_linear = UOp(Ops.LINEAR, src=tuple(flatten([l.src for l in self._linears])))
del self._linears
if self.prune:
big_linear, onetime_linear = prune_linear(big_linear, set(input_buf_uops))
if DEBUG >= 1: print(f"pruned from {len(big_linear.src) + len(onetime_linear.src)} -> {len(big_linear.src)} kernels")
for ei in (si.lower() for si in linear_to_schedule(onetime_linear)):
for b in ei.bufs: cast(Buffer, b).ensure_allocated()
ei.run(var_vals, jit=True)
run_linear(onetime_linear, var_vals)
held_bufs = set(buffers) | {t.uop.buf_uop for t in get_parameters(ret) if t.uop.buf_uop.op is Ops.BUFFER}
linear = jit_lower(big_linear, held_bufs, input_buf_uops)

View File

@@ -1,9 +1,9 @@
from typing import cast, Callable, Iterator
import time, pprint, random, itertools, math, contextlib
import time, pprint, random, itertools, math, contextlib, weakref
from dataclasses import dataclass, replace, field
from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, NOOPT, all_int, Metadata, TRACEMETA, TracingKey
from tinygrad.helpers import BEAM, DEVECTORIZE, size_to_str, time_to_str, VALIDATE_WITH_CPU, cpu_profile, PROFILE, ProfilePointEvent, cpu_events
from tinygrad.helpers import prod, unwrap, EMULATED_DTYPES
from tinygrad.helpers import prod, unwrap, EMULATED_DTYPES, flatten
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, buffers, graph_rewrite
from tinygrad.device import Device, Buffer, MultiBuffer
from tinygrad.renderer import ProgramSpec, Estimates
@@ -222,48 +222,55 @@ def run_schedule(schedule:list[ExecItem], var_vals:dict[str, int]|None=None, do_
@dataclass
class ExecContext:
var_vals: dict[str, int] = field(default_factory=dict)
input_uops: tuple[UOp, ...] = ()
do_update_stats: bool = True
jit: bool = False
def _resolve(b:UOp, inputs:tuple[UOp, ...]) -> UOp:
if b.op is Ops.BUFFER_VIEW and b.src[0].op is Ops.PARAM: return b.replace(src=(inputs[b.src[0].arg], *b.src[1:]))
return inputs[b.arg] if b.op is Ops.PARAM else b
def resolve_params(ctx:ExecContext, call:UOp) -> list[UOp]: return [_resolve(b, ctx.input_uops) for b in call.src[1:] if b.op is not Ops.BIND]
@contextlib.contextmanager
def track_stats(ctx:ExecContext, call:UOp, device:str, display_name:str, estimates:Estimates, bufs:list[Buffer], var_vals:dict[str, int], *,
def track_stats(ctx:ExecContext, call:UOp, device:str, display_name:str, estimates:Estimates, bufs:list[Buffer], var_vals:dict[str, int],
outputs=(0,), inputs=(1,), first_run=False):
if PROFILE: cpu_events.append(ProfilePointEvent(device, "exec", len(cpu_events), {"metadata": call.arg.metadata, "var_vals": var_vals,
"bufs": [b.trace_num for b in bufs], "name": display_name, "outputs": outputs, "inputs": inputs}))
timing: list[float|None] = [None]
st = time.perf_counter()
if DEBUG >= 2: st = time.perf_counter()
yield timing
if not ctx.do_update_stats: return
if timing[0] is None and DEBUG >= 2:
if DEBUG >= 2 and timing[0] is None:
Device[device].synchronize()
timing[0] = time.perf_counter() - st
update_stats(display_name, device, estimates, var_vals, timing[0], len(bufs), jit=False, metadata=call.arg.metadata, first_run=first_run)
update_stats(display_name, device, estimates, var_vals, timing[0], len(bufs), jit=ctx.jit, metadata=call.arg.metadata, first_run=first_run)
def unwrap_multi(call:UOp) -> Iterator[tuple[list[Buffer], dict[str, int]]]:
bufs = [b.buffer for b in call.src[1:] if b.op is not Ops.BIND]
def unwrap_multi(call:UOp, resolved:list[UOp]) -> Iterator[tuple[list[Buffer], dict[str, int]]]:
bufs = [b.buffer for b in resolved]
if not any(isinstance(b, MultiBuffer) for b in bufs): yield cast(list[Buffer], bufs), {}
else:
dnum = next((x.expr for x in call.src[0].variables() if x.expr == '_device_num'), None)
for j, per_dev in enumerate(zip(*[cast(MultiBuffer, b).bufs for b in bufs])): yield list(per_dev), {dnum: j} if dnum else {}
def exec_view(ctx:ExecContext, call, ast):
bufs = [b.buffer for b in call.src[1:] if b.op is not Ops.BIND]
bv = bufs[1].view(call.src[1].arg, ast.dtype, ast.arg[1]*bufs[1].dtype.itemsize)
resolved = resolve_params(ctx, call)
bufs = [cast(Buffer, b.buffer) for b in resolved]
bv = bufs[1].view(resolved[0].arg, ast.dtype, ast.arg[1]*bufs[1].dtype.itemsize)
with track_stats(ctx, call, bv.device, colored(f"view {bv.nbytes:8d} @ {bv.offset:<10d}", "yellow"), Estimates(), [bv, bufs[1]], ctx.var_vals):
buffers[call.src[1]] = bv
buffers[resolved[0]] = bv
def exec_copy(ctx:ExecContext, call, ast):
for bufs, device_vars in unwrap_multi(call):
for bufs, device_vars in unwrap_multi(call, resolve_params(ctx, call)):
dest, src = bufs[0].ensure_allocated(), bufs[1].ensure_allocated()
xfer = hasattr(alc:=Device[dest.device].allocator,'_transfer') and alc.supports_transfer and dest.device.split(":")[0]==src.device.split(":")[0]
prg = (BufferXfer if xfer else BufferCopy)(dest.nbytes, dest.device, src.device)
name = colored(f"{'xfer' if xfer else 'copy'} {size_to_str(dest.nbytes):>8s}, {dest.device[:7]:>7s} <- {src.device[:7]:7s}", "yellow")
with track_stats(ctx, call, dest.device, name, Estimates(lds=dest.nbytes, mem=dest.nbytes), [dest, src], {**ctx.var_vals, **device_vars}):
with track_stats(ctx, call, dest.device, prg.display_name, Estimates(lds=dest.nbytes, mem=dest.nbytes), [dest, src], ctx.var_vals):
prg.copy(dest, src)
def exec_kernel(ctx:ExecContext, call, ast):
sink = ast.src[0] if ast.op is Ops.BEAM else ast
for bufs, device_vars in unwrap_multi(call):
for bufs, device_vars in unwrap_multi(call, resolve_params(ctx, call)):
var_vals = {**ctx.var_vals, **device_vars}
prg = get_runner(bufs[0].device, ast)
prg_bufs = [bufs[i].ensure_allocated() for i in prg.p.globals]
@@ -283,12 +290,22 @@ def exec_kernel(ctx:ExecContext, call, ast):
for i in prg.p.outs: np.testing.assert_allclose(prg_bufs[i].numpy(), cpu_bufs[i].numpy(), rtol=1e-3, atol=1e-3)
def exec_encdec(ctx:ExecContext, call, ast):
bufs = [b.buffer.ensure_allocated() for b in call.src[1:] if b.op is not Ops.BIND]
bufs = [cast(Buffer, b.buffer).ensure_allocated() for b in resolve_params(ctx, call)]
shape, pos_var = tuple(s.arg for s in ast.src if s.op is Ops.CONST), ast.variables()[0].expr
with track_stats(ctx, call, bufs[0].device, colored(f"enc/dec {size_to_str(bufs[0].nbytes)}", "yellow"),
Estimates(lds=bufs[0].nbytes, mem=bufs[0].nbytes), bufs, ctx.var_vals):
bufs[0].allocator._encode_decode(bufs[0]._buf, bufs[1]._buf, bufs[2]._buf, [x._buf for x in bufs[3:]], shape, ctx.var_vals[pos_var])
graph_cache:weakref.WeakKeyDictionary[UOp, Runner] = weakref.WeakKeyDictionary()
def exec_graph(ctx:ExecContext, call, cf):
inputs = resolve_params(ctx, call)
bufs = flatten([b.bufs if isinstance(b, MultiBuffer) else [b] for b in (u.buffer for u in inputs)])
if (runner:=graph_cache.get(cf)) is None:
sub = cf.substitute(dict(zip(cf.src[1:], inputs)))
graph_cache[cf] = runner = Device[cf.device if isinstance(cf.device, str) else cf.device[0]].graph(sub, bufs)
with track_stats(ctx, call, runner.device, runner.display_name, runner.estimates, bufs, ctx.var_vals) as t:
t[0] = runner(bufs, ctx.var_vals, wait=DEBUG >= 2)
pm_beam = PatternMatcher([
(UPat(Ops.CALL, src=(UPat(Ops.SINK, name="sink"),), name="call", allow_any_len=True),
lambda ctx,call,sink: call.replace(src=(UOp(Ops.BEAM, src=(sink,), arg=ctx), *call.src[1:]))),
@@ -299,9 +316,10 @@ pm_exec = PatternMatcher([
(UPat(Ops.CALL, src=(UPat(Ops.COPY, name="ast"),), name="call", allow_any_len=True), exec_copy),
(UPat(Ops.CALL, src=(UPat((Ops.SINK, Ops.PROGRAM, Ops.BEAM), name="ast"),), name="call", allow_any_len=True), exec_kernel),
(UPat(Ops.CALL, src=(UPat(Ops.CUSTOM_FUNCTION, arg="encdec", name="ast"),), name="call", allow_any_len=True), exec_encdec),
(UPat(Ops.CALL, src=(UPat(Ops.CUSTOM_FUNCTION, arg="graph", name="cf"),), name="call", allow_any_len=True), exec_graph),
])
def run_linear(linear:UOp, var_vals:dict[str, int]|None=None, do_update_stats=True):
def run_linear(linear:UOp, var_vals:dict[str, int]|None=None, input_uops:tuple[UOp, ...]=(), do_update_stats=True, jit=False):
if BEAM >= 1: linear = graph_rewrite(linear, pm_beam, ctx=BEAM.value, name="add beam")
ctx = ExecContext(var_vals or {}, do_update_stats)
ctx = ExecContext(var_vals or {}, input_uops, do_update_stats, jit)
for call in linear.src: pm_exec.rewrite(call, ctx)