mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
run_linear in jit (#15827)
* run_linear in jit * x * x * f * casts * ugh * f * x * x * simple
This commit is contained in:
4
.github/workflows/benchmark.yml
vendored
4
.github/workflows/benchmark.yml
vendored
@@ -626,7 +626,7 @@ jobs:
|
||||
- name: IR3 openpilot compile3 0.11.0 driving_vision
|
||||
run: BENCHMARK_LOG=ir3_openpilot_0_11_0_vision PYTHONPATH="." ASSERT_MIN_STEP_TIME=17 DEV=QCOM:IR3 FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.11.0/selfdrive/modeld/models/driving_vision.onnx
|
||||
- name: openpilot compile3 0.11.0 driving_policy
|
||||
run: BENCHMARK_LOG=openpilot_0_11_0_policy PYTHONPATH="." ASSERT_MIN_STEP_TIME=3 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.11.0/selfdrive/modeld/models/driving_policy.onnx
|
||||
run: BENCHMARK_LOG=openpilot_0_11_0_policy PYTHONPATH="." ASSERT_MIN_STEP_TIME=4 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.11.0/selfdrive/modeld/models/driving_policy.onnx
|
||||
- name: openpilot compile3 0.11.0 dmonitoring
|
||||
run: BENCHMARK_LOG=openpilot_0_11_0_dmonitoring PYTHONPATH="." ASSERT_MIN_STEP_TIME=11 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.11.0/selfdrive/modeld/models/dmonitoring_model.onnx
|
||||
- name: DEBUG=2 openpilot compile3 0.10.1 driving_vision
|
||||
@@ -634,7 +634,7 @@ jobs:
|
||||
- name: openpilot compile3 0.10.1 driving_vision
|
||||
run: BENCHMARK_LOG=openpilot_0_10_1_vision PYTHONPATH="." ASSERT_MIN_STEP_TIME=17 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_vision.onnx
|
||||
- name: openpilot compile3 0.10.1 driving_policy
|
||||
run: BENCHMARK_LOG=openpilot_0_10_1_policy PYTHONPATH="." ASSERT_MIN_STEP_TIME=3 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_policy.onnx
|
||||
run: BENCHMARK_LOG=openpilot_0_10_1_policy PYTHONPATH="." ASSERT_MIN_STEP_TIME=4 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_policy.onnx
|
||||
- name: openpilot compile3 0.10.1 dmonitoring
|
||||
run: BENCHMARK_LOG=openpilot_0_10_1_dmonitoring PYTHONPATH="." ASSERT_MIN_STEP_TIME=11 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/dmonitoring_model.onnx
|
||||
- name: benchmark MobileNetV2 on DSP
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from typing import TypeVar, Generic, Callable, cast, Any
|
||||
from typing import TypeVar, Generic, Callable, Any
|
||||
import functools, collections
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, JIT_BATCH_SIZE, dedup, pluralize, VIZ
|
||||
from tinygrad.device import Buffer, Compiled, Device, MultiBuffer
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Variable, sym_infer, Ops, buffers, track_rewrites, graph_rewrite
|
||||
from tinygrad.engine.realize import ExecItem, capturing, BufferCopy, BufferXfer, EncDec, CompiledRunner, Runner, Estimates, pm_beam
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, Variable, sym_infer, Ops, buffers, track_rewrites, graph_rewrite
|
||||
from tinygrad.engine.realize import ExecItem, capturing, CompiledRunner, Runner, Estimates, pm_beam, run_linear, get_runner, graph_cache
|
||||
from tinygrad.schedule.memory import memory_plan_rewrite, _collect_bufs
|
||||
from tinygrad.schedule import linear_to_schedule
|
||||
from tinygrad.nn.state import get_parameters
|
||||
@@ -24,7 +24,7 @@ def prune_linear(linear:UOp, needed:set[UOp]) -> tuple[UOp, UOp]:
|
||||
|
||||
def create_graph_call(batch:list[UOp]) -> UOp:
|
||||
# all external inputs are PARAMs
|
||||
input_list = dedup(b for si in batch for b in si.src[1:] if b.op is Ops.PARAM)
|
||||
input_list = dedup(u for si in batch for b in si.src[1:] for u in b.toposort() if u.op is Ops.PARAM)
|
||||
cf = UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(UOp(Ops.LINEAR, src=tuple(batch)), *input_list), arg="graph")
|
||||
return cf.call(*input_list, metadata=tuple(m for si in batch for m in si.arg.metadata))
|
||||
|
||||
@@ -59,14 +59,22 @@ def graph_split_rewrite(linear:UOp, max_batch_size:int=0) -> UOp:
|
||||
if current_batch: flush_batch()
|
||||
return linear.replace(src=tuple(new_src))
|
||||
|
||||
def jit_cache_bufs(jit_cache:list[ExecItem]):
|
||||
for ei in jit_cache:
|
||||
for b in ei.bufs:
|
||||
if b is not None: yield b
|
||||
if isinstance(ei.prg, GraphRunner): yield from jit_cache_bufs(ei.prg.jit_cache)
|
||||
|
||||
def _unwrap_beam(ast:UOp) -> UOp: return ast.src[0] if ast.op is Ops.BEAM else ast
|
||||
|
||||
def _call_outs_ins(call:UOp) -> tuple[set[int], set[int]]:
|
||||
non_bind = [s for s in call.src[1:] if s.op is not Ops.BIND]
|
||||
ast = _unwrap_beam(call.src[0])
|
||||
if ast.op in (Ops.SINK, Ops.PROGRAM):
|
||||
prg = get_runner(non_bind[0].device if isinstance(non_bind[0].device, str) else non_bind[0].device[0], call.src[0])
|
||||
return set(prg.p.outs), set(prg.p.ins)
|
||||
if ast.op in (Ops.COPY, Ops.BUFFER_VIEW): return {0}, {1}
|
||||
if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "encdec": return {0}, set(range(1, len(non_bind)))
|
||||
return set(), set()
|
||||
|
||||
def _copy_input(u:UOp) -> UOp:
|
||||
run_linear(UOp(Ops.LINEAR, src=(u.copy_to_device(u.device).call(new:=UOp.new_buffer(u.device, u.arg, u.dtype), u, metadata=()),)))
|
||||
return new
|
||||
|
||||
@track_rewrites(lambda linear,held_bufs,input_uops,ret=(): f"JIT {pluralize('call', len(linear.src))}")
|
||||
def jit_lower(linear:UOp, held_bufs:set[UOp], input_uops:list[UOp]) -> UOp:
|
||||
if VIZ: graph_rewrite(linear, PatternMatcher([]), name="View captured linear")
|
||||
@@ -101,29 +109,12 @@ def get_input_replace(jit_cache: list[ExecItem], input_buffers:list[Buffer]) ->
|
||||
if a in input_buffers: input_replace[(j,i)] = input_buffers.index(a)
|
||||
return input_replace
|
||||
|
||||
pm_params = PatternMatcher([(UPat(Ops.PARAM, src=(UPat(), UPat(Ops.DEVICE)), name="p"), lambda ctx,p: ctx[p.arg])])
|
||||
|
||||
def linear_to_jit_cache(linear:UOp, input_uops:list[UOp]) -> tuple[list[ExecItem], dict[tuple[int,int],int], list[tuple[int,int,str,int,DType]]]:
|
||||
# substitute PARAMs with input buffer UOps before lowering
|
||||
linear = graph_rewrite(linear, pm_params, ctx=input_uops, walk=True, enter_calls=True)
|
||||
# convert to jit_cache
|
||||
jit_cache = [ei.lower() for ei in linear_to_schedule(linear)]
|
||||
for b in jit_cache_bufs(jit_cache): b.ensure_allocated()
|
||||
# derive input_buffers from input_uops
|
||||
input_buffers: list[Buffer] = flatten([b.bufs if isinstance(b, MultiBuffer) else [b] for u in input_uops if (b:=buffers[u]) is not None])
|
||||
# track view buffers whose base is an input buffer
|
||||
extra_view_inputs: list[tuple[int, int, str, int, DType]] = []
|
||||
for ei in jit_cache:
|
||||
for b in ei.bufs:
|
||||
if b is not None and b._base is not None and b._base in input_buffers and b not in input_buffers:
|
||||
extra_view_inputs.append((input_buffers.index(b._base), b.offset, b.device, b.size, b.dtype))
|
||||
input_buffers.append(b)
|
||||
return jit_cache, get_input_replace(jit_cache, input_buffers), extra_view_inputs
|
||||
|
||||
class GraphRunner(Runner):
|
||||
def __init__(self, linear:UOp, input_buffers:list[Buffer]):
|
||||
self.jit_cache = [ei.lower() for ei in linear_to_schedule(linear.src[0])]
|
||||
for b in jit_cache_bufs(self.jit_cache): b.ensure_allocated()
|
||||
for ei in self.jit_cache:
|
||||
for b in ei.bufs:
|
||||
if b is not None: b.ensure_allocated()
|
||||
self.input_replace = get_input_replace(self.jit_cache, input_buffers) if input_buffers else {}
|
||||
|
||||
self.var_vals_replace:dict[int, list[tuple[int, int]]] = {}
|
||||
@@ -206,16 +197,6 @@ class MultiGraphRunner(GraphRunner):
|
||||
return _unwrap_beam(new_call.src[0]).op in (Ops.SINK, Ops.PROGRAM, Ops.COPY) \
|
||||
and len(dedup([type(d) for d in GraphRunner._all_devs(batch_devs, new_call)])) == 1
|
||||
|
||||
def get_out_buffers_for_ei(ei:ExecItem) -> list[Buffer]:
|
||||
if isinstance(ei.prg, CompiledRunner): return [cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs if out not in ei.prg.p.ins]
|
||||
if isinstance(ei.prg, (BufferCopy, BufferXfer, EncDec)): return [cast(Buffer, ei.bufs[0])]
|
||||
if isinstance(ei.prg, GraphRunner): return dedup([b for inner in ei.prg.jit_cache for b in get_out_buffers_for_ei(inner)])
|
||||
return []
|
||||
|
||||
def update_depends(depends:set[Buffer|None], jit_cache:list[ExecItem]):
|
||||
for ei in jit_cache:
|
||||
if any(b in depends for b in ei.bufs): depends.update(get_out_buffers_for_ei(ei))
|
||||
|
||||
ReturnType = TypeVar('ReturnType')
|
||||
@dataclass
|
||||
class CapturedJit(Generic[ReturnType]):
|
||||
@@ -225,47 +206,36 @@ class CapturedJit(Generic[ReturnType]):
|
||||
expected_input_info: list[tuple[UOp, tuple[Variable, ...], DType, str]] # (view, variables, dtype, device) per input
|
||||
|
||||
def __reduce__(self): return self.__class__, (self.ret, self.linear, self.expected_names, self.expected_input_info)
|
||||
def __post_init__(self): self._jit_cache = None
|
||||
@property
|
||||
def jit_cache(self) -> list[ExecItem]: return self._jit_cache if self._jit_cache is not None else []
|
||||
|
||||
def _init(self, input_uops:list[UOp]):
|
||||
self._jit_cache, self._input_replace, self._extra_view_inputs = linear_to_jit_cache(self.linear, input_uops)
|
||||
self._output_to_writer = {b: j for j, ei in enumerate(self._jit_cache) for b in get_out_buffers_for_ei(ei)}
|
||||
self._input_to_max_reader: dict[int, int] = {}
|
||||
for (j, i), idx in self._input_replace.items():
|
||||
if self._jit_cache[j].bufs[i] not in get_out_buffers_for_ei(self._jit_cache[j]):
|
||||
self._input_to_max_reader[idx] = max(self._input_to_max_reader.get(idx, -1), j)
|
||||
for (j,i) in self._input_replace.keys(): self._jit_cache[j].bufs[i] = None
|
||||
@functools.cached_property
|
||||
def _written_uops(self) -> set[UOp]:
|
||||
out: set[UOp] = set()
|
||||
for call in self.linear.toposort():
|
||||
if call.op is not Ops.CALL: continue
|
||||
non_bind = [s for s in call.src[1:] if s.op is not Ops.BIND]
|
||||
outs, ins = _call_outs_ins(call)
|
||||
out |= {non_bind[k] for k in outs - ins if non_bind[k].op in (Ops.BUFFER, Ops.BUFFER_VIEW)}
|
||||
return out
|
||||
|
||||
def __call__(self, input_uops:list[UOp], var_vals:dict[str, int]) -> ReturnType:
|
||||
if self._jit_cache is None: self._init(input_uops)
|
||||
assert self._jit_cache is not None
|
||||
# derive input_buffers from input_uops (flatten MultiBuffer)
|
||||
input_buffers: list[Buffer] = flatten([b.bufs if isinstance(b, MultiBuffer) else [b] for u in input_uops if (b:=buffers[u]) is not None])
|
||||
# recreate view buffers from input bases
|
||||
for idx, offset, device, size, dtype in self._extra_view_inputs:
|
||||
input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated())
|
||||
# copy aliased inputs to prevent read-after-write hazard
|
||||
for i, ib in enumerate(input_buffers):
|
||||
if (writer := self._output_to_writer.get(ib)) is not None and self._input_to_max_reader.get(i, -1) >= writer:
|
||||
input_buffers[i] = Buffer(ib.device, ib.size, ib.dtype).ensure_allocated().copyin(ib.as_memoryview())
|
||||
for (j,i),input_idx in self._input_replace.items(): self._jit_cache[j].bufs[i] = input_buffers[input_idx]
|
||||
if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels")
|
||||
for ei in self._jit_cache: ei.run(var_vals, jit=True)
|
||||
for (j,i) in self._input_replace.keys(): self._jit_cache[j].bufs[i] = None
|
||||
concrete = tuple(_copy_input(u) if u in self._written_uops else u for u in input_uops)
|
||||
if DEBUG >= 1 and len(self.linear.src) >= 10: print(f"jit execs {len(self.linear.src)} calls")
|
||||
run_linear(self.linear, var_vals, input_uops=concrete, jit=True)
|
||||
return self.ret
|
||||
|
||||
def free_intermediates(self):
|
||||
depends: set[Buffer|None] = set([None])
|
||||
update_depends(depends, self.jit_cache)
|
||||
arenas = {b._base for b in depends if b is not None and b._base is not None}
|
||||
to_free = {b for b in depends if b is not None} | {b for b in jit_cache_bufs(self.jit_cache) if b._base in arenas}
|
||||
for b in to_free:
|
||||
if hasattr(b, '_buf'): b.deallocate()
|
||||
for a in arenas:
|
||||
if a.allocated_views == 0 and a.is_allocated(): a.deallocate()
|
||||
self.__post_init__()
|
||||
# drop graph runners
|
||||
for call in self.linear.src:
|
||||
if call.src[0].op is Ops.CUSTOM_FUNCTION and call.src[0].arg == "graph": graph_cache.pop(call.src[0], None)
|
||||
bases: set[Buffer] = set()
|
||||
for u in self._written_uops:
|
||||
try: buf = u.buffer
|
||||
except Exception: continue
|
||||
for b in (buf.bufs if isinstance(buf, MultiBuffer) else [buf]):
|
||||
if hasattr(b, '_buf'): b.deallocate()
|
||||
if b._base is not None: bases.add(b._base)
|
||||
for a in bases:
|
||||
if a.is_allocated() and a.allocated_views == 0: a.deallocate()
|
||||
|
||||
def _prepare_jit_inputs(args, kwargs):
|
||||
input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor]
|
||||
@@ -306,13 +276,6 @@ class TinyJit(Generic[ReturnType]):
|
||||
assert self.captured is not None, "can't pickle an uncaptured JIT"
|
||||
return self.__class__, (None, self.captured)
|
||||
|
||||
# keep legacy code working
|
||||
@property
|
||||
def jit_cache(self) -> list[ExecItem]: return self.captured._jit_cache if self.captured is not None and self.captured._jit_cache is not None else []
|
||||
@property
|
||||
def input_replace(self) -> dict[tuple[int, int], int]:
|
||||
return self.captured._input_replace if self.captured is not None and self.captured._jit_cache is not None else {}
|
||||
|
||||
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) # add support for instance methods
|
||||
|
||||
def __call__(self, *args, **kwargs) -> ReturnType:
|
||||
@@ -337,16 +300,14 @@ class TinyJit(Generic[ReturnType]):
|
||||
_check_no_non_tensor_return(ret)
|
||||
if DEBUG >= 1: print(f"JIT captured {len(self._linears)} linears with {len(input_buf_uops)} inputs")
|
||||
|
||||
# combine all captured linears into one, memory plan, and convert to ExecItems
|
||||
# combine all captured linears into one, memory plan, and graph split
|
||||
big_linear = UOp(Ops.LINEAR, src=tuple(flatten([l.src for l in self._linears])))
|
||||
del self._linears
|
||||
|
||||
if self.prune:
|
||||
big_linear, onetime_linear = prune_linear(big_linear, set(input_buf_uops))
|
||||
if DEBUG >= 1: print(f"pruned from {len(big_linear.src) + len(onetime_linear.src)} -> {len(big_linear.src)} kernels")
|
||||
for ei in (si.lower() for si in linear_to_schedule(onetime_linear)):
|
||||
for b in ei.bufs: cast(Buffer, b).ensure_allocated()
|
||||
ei.run(var_vals, jit=True)
|
||||
run_linear(onetime_linear, var_vals)
|
||||
|
||||
held_bufs = set(buffers) | {t.uop.buf_uop for t in get_parameters(ret) if t.uop.buf_uop.op is Ops.BUFFER}
|
||||
linear = jit_lower(big_linear, held_bufs, input_buf_uops)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from typing import cast, Callable, Iterator
|
||||
import time, pprint, random, itertools, math, contextlib
|
||||
import time, pprint, random, itertools, math, contextlib, weakref
|
||||
from dataclasses import dataclass, replace, field
|
||||
from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, NOOPT, all_int, Metadata, TRACEMETA, TracingKey
|
||||
from tinygrad.helpers import BEAM, DEVECTORIZE, size_to_str, time_to_str, VALIDATE_WITH_CPU, cpu_profile, PROFILE, ProfilePointEvent, cpu_events
|
||||
from tinygrad.helpers import prod, unwrap, EMULATED_DTYPES
|
||||
from tinygrad.helpers import prod, unwrap, EMULATED_DTYPES, flatten
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, buffers, graph_rewrite
|
||||
from tinygrad.device import Device, Buffer, MultiBuffer
|
||||
from tinygrad.renderer import ProgramSpec, Estimates
|
||||
@@ -222,48 +222,55 @@ def run_schedule(schedule:list[ExecItem], var_vals:dict[str, int]|None=None, do_
|
||||
@dataclass
|
||||
class ExecContext:
|
||||
var_vals: dict[str, int] = field(default_factory=dict)
|
||||
input_uops: tuple[UOp, ...] = ()
|
||||
do_update_stats: bool = True
|
||||
jit: bool = False
|
||||
|
||||
def _resolve(b:UOp, inputs:tuple[UOp, ...]) -> UOp:
|
||||
if b.op is Ops.BUFFER_VIEW and b.src[0].op is Ops.PARAM: return b.replace(src=(inputs[b.src[0].arg], *b.src[1:]))
|
||||
return inputs[b.arg] if b.op is Ops.PARAM else b
|
||||
def resolve_params(ctx:ExecContext, call:UOp) -> list[UOp]: return [_resolve(b, ctx.input_uops) for b in call.src[1:] if b.op is not Ops.BIND]
|
||||
|
||||
@contextlib.contextmanager
|
||||
def track_stats(ctx:ExecContext, call:UOp, device:str, display_name:str, estimates:Estimates, bufs:list[Buffer], var_vals:dict[str, int], *,
|
||||
def track_stats(ctx:ExecContext, call:UOp, device:str, display_name:str, estimates:Estimates, bufs:list[Buffer], var_vals:dict[str, int],
|
||||
outputs=(0,), inputs=(1,), first_run=False):
|
||||
if PROFILE: cpu_events.append(ProfilePointEvent(device, "exec", len(cpu_events), {"metadata": call.arg.metadata, "var_vals": var_vals,
|
||||
"bufs": [b.trace_num for b in bufs], "name": display_name, "outputs": outputs, "inputs": inputs}))
|
||||
timing: list[float|None] = [None]
|
||||
st = time.perf_counter()
|
||||
if DEBUG >= 2: st = time.perf_counter()
|
||||
yield timing
|
||||
if not ctx.do_update_stats: return
|
||||
if timing[0] is None and DEBUG >= 2:
|
||||
if DEBUG >= 2 and timing[0] is None:
|
||||
Device[device].synchronize()
|
||||
timing[0] = time.perf_counter() - st
|
||||
update_stats(display_name, device, estimates, var_vals, timing[0], len(bufs), jit=False, metadata=call.arg.metadata, first_run=first_run)
|
||||
update_stats(display_name, device, estimates, var_vals, timing[0], len(bufs), jit=ctx.jit, metadata=call.arg.metadata, first_run=first_run)
|
||||
|
||||
def unwrap_multi(call:UOp) -> Iterator[tuple[list[Buffer], dict[str, int]]]:
|
||||
bufs = [b.buffer for b in call.src[1:] if b.op is not Ops.BIND]
|
||||
def unwrap_multi(call:UOp, resolved:list[UOp]) -> Iterator[tuple[list[Buffer], dict[str, int]]]:
|
||||
bufs = [b.buffer for b in resolved]
|
||||
if not any(isinstance(b, MultiBuffer) for b in bufs): yield cast(list[Buffer], bufs), {}
|
||||
else:
|
||||
dnum = next((x.expr for x in call.src[0].variables() if x.expr == '_device_num'), None)
|
||||
for j, per_dev in enumerate(zip(*[cast(MultiBuffer, b).bufs for b in bufs])): yield list(per_dev), {dnum: j} if dnum else {}
|
||||
|
||||
def exec_view(ctx:ExecContext, call, ast):
|
||||
bufs = [b.buffer for b in call.src[1:] if b.op is not Ops.BIND]
|
||||
bv = bufs[1].view(call.src[1].arg, ast.dtype, ast.arg[1]*bufs[1].dtype.itemsize)
|
||||
resolved = resolve_params(ctx, call)
|
||||
bufs = [cast(Buffer, b.buffer) for b in resolved]
|
||||
bv = bufs[1].view(resolved[0].arg, ast.dtype, ast.arg[1]*bufs[1].dtype.itemsize)
|
||||
with track_stats(ctx, call, bv.device, colored(f"view {bv.nbytes:8d} @ {bv.offset:<10d}", "yellow"), Estimates(), [bv, bufs[1]], ctx.var_vals):
|
||||
buffers[call.src[1]] = bv
|
||||
buffers[resolved[0]] = bv
|
||||
|
||||
def exec_copy(ctx:ExecContext, call, ast):
|
||||
for bufs, device_vars in unwrap_multi(call):
|
||||
for bufs, device_vars in unwrap_multi(call, resolve_params(ctx, call)):
|
||||
dest, src = bufs[0].ensure_allocated(), bufs[1].ensure_allocated()
|
||||
xfer = hasattr(alc:=Device[dest.device].allocator,'_transfer') and alc.supports_transfer and dest.device.split(":")[0]==src.device.split(":")[0]
|
||||
prg = (BufferXfer if xfer else BufferCopy)(dest.nbytes, dest.device, src.device)
|
||||
name = colored(f"{'xfer' if xfer else 'copy'} {size_to_str(dest.nbytes):>8s}, {dest.device[:7]:>7s} <- {src.device[:7]:7s}", "yellow")
|
||||
with track_stats(ctx, call, dest.device, name, Estimates(lds=dest.nbytes, mem=dest.nbytes), [dest, src], {**ctx.var_vals, **device_vars}):
|
||||
with track_stats(ctx, call, dest.device, prg.display_name, Estimates(lds=dest.nbytes, mem=dest.nbytes), [dest, src], ctx.var_vals):
|
||||
prg.copy(dest, src)
|
||||
|
||||
def exec_kernel(ctx:ExecContext, call, ast):
|
||||
sink = ast.src[0] if ast.op is Ops.BEAM else ast
|
||||
|
||||
for bufs, device_vars in unwrap_multi(call):
|
||||
for bufs, device_vars in unwrap_multi(call, resolve_params(ctx, call)):
|
||||
var_vals = {**ctx.var_vals, **device_vars}
|
||||
prg = get_runner(bufs[0].device, ast)
|
||||
prg_bufs = [bufs[i].ensure_allocated() for i in prg.p.globals]
|
||||
@@ -283,12 +290,22 @@ def exec_kernel(ctx:ExecContext, call, ast):
|
||||
for i in prg.p.outs: np.testing.assert_allclose(prg_bufs[i].numpy(), cpu_bufs[i].numpy(), rtol=1e-3, atol=1e-3)
|
||||
|
||||
def exec_encdec(ctx:ExecContext, call, ast):
|
||||
bufs = [b.buffer.ensure_allocated() for b in call.src[1:] if b.op is not Ops.BIND]
|
||||
bufs = [cast(Buffer, b.buffer).ensure_allocated() for b in resolve_params(ctx, call)]
|
||||
shape, pos_var = tuple(s.arg for s in ast.src if s.op is Ops.CONST), ast.variables()[0].expr
|
||||
with track_stats(ctx, call, bufs[0].device, colored(f"enc/dec {size_to_str(bufs[0].nbytes)}", "yellow"),
|
||||
Estimates(lds=bufs[0].nbytes, mem=bufs[0].nbytes), bufs, ctx.var_vals):
|
||||
bufs[0].allocator._encode_decode(bufs[0]._buf, bufs[1]._buf, bufs[2]._buf, [x._buf for x in bufs[3:]], shape, ctx.var_vals[pos_var])
|
||||
|
||||
graph_cache:weakref.WeakKeyDictionary[UOp, Runner] = weakref.WeakKeyDictionary()
|
||||
def exec_graph(ctx:ExecContext, call, cf):
|
||||
inputs = resolve_params(ctx, call)
|
||||
bufs = flatten([b.bufs if isinstance(b, MultiBuffer) else [b] for b in (u.buffer for u in inputs)])
|
||||
if (runner:=graph_cache.get(cf)) is None:
|
||||
sub = cf.substitute(dict(zip(cf.src[1:], inputs)))
|
||||
graph_cache[cf] = runner = Device[cf.device if isinstance(cf.device, str) else cf.device[0]].graph(sub, bufs)
|
||||
with track_stats(ctx, call, runner.device, runner.display_name, runner.estimates, bufs, ctx.var_vals) as t:
|
||||
t[0] = runner(bufs, ctx.var_vals, wait=DEBUG >= 2)
|
||||
|
||||
pm_beam = PatternMatcher([
|
||||
(UPat(Ops.CALL, src=(UPat(Ops.SINK, name="sink"),), name="call", allow_any_len=True),
|
||||
lambda ctx,call,sink: call.replace(src=(UOp(Ops.BEAM, src=(sink,), arg=ctx), *call.src[1:]))),
|
||||
@@ -299,9 +316,10 @@ pm_exec = PatternMatcher([
|
||||
(UPat(Ops.CALL, src=(UPat(Ops.COPY, name="ast"),), name="call", allow_any_len=True), exec_copy),
|
||||
(UPat(Ops.CALL, src=(UPat((Ops.SINK, Ops.PROGRAM, Ops.BEAM), name="ast"),), name="call", allow_any_len=True), exec_kernel),
|
||||
(UPat(Ops.CALL, src=(UPat(Ops.CUSTOM_FUNCTION, arg="encdec", name="ast"),), name="call", allow_any_len=True), exec_encdec),
|
||||
(UPat(Ops.CALL, src=(UPat(Ops.CUSTOM_FUNCTION, arg="graph", name="cf"),), name="call", allow_any_len=True), exec_graph),
|
||||
])
|
||||
|
||||
def run_linear(linear:UOp, var_vals:dict[str, int]|None=None, do_update_stats=True):
|
||||
def run_linear(linear:UOp, var_vals:dict[str, int]|None=None, input_uops:tuple[UOp, ...]=(), do_update_stats=True, jit=False):
|
||||
if BEAM >= 1: linear = graph_rewrite(linear, pm_beam, ctx=BEAM.value, name="add beam")
|
||||
ctx = ExecContext(var_vals or {}, do_update_stats)
|
||||
ctx = ExecContext(var_vals or {}, input_uops, do_update_stats, jit)
|
||||
for call in linear.src: pm_exec.rewrite(call, ctx)
|
||||
|
||||
Reference in New Issue
Block a user