mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
jit: capturedjit is linear (#15743)
* jit: capturedjit is linear * x * new beam * test * imp * clean * spec * linter
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
from typing import TypeVar, Generic, Callable, cast, Any
|
||||
import functools, collections
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, JIT_BATCH_SIZE, dedup, unwrap, pluralize, VIZ
|
||||
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, JIT_BATCH_SIZE, dedup, pluralize, VIZ
|
||||
from tinygrad.device import Buffer, Compiled, Device, MultiBuffer
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, Variable, sym_infer, Ops, buffers, track_rewrites, graph_rewrite
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Variable, sym_infer, Ops, buffers, track_rewrites, graph_rewrite
|
||||
from tinygrad.engine.realize import ExecItem, capturing, BufferCopy, BufferXfer, EncDec, CompiledRunner, Runner, Estimates
|
||||
from tinygrad.schedule.memory import memory_plan_rewrite, _collect_bufs
|
||||
from tinygrad.schedule import linear_to_schedule
|
||||
@@ -22,14 +22,13 @@ def prune_linear(linear:UOp, needed:set[UOp]) -> tuple[UOp, UOp]:
|
||||
else: onetime.append(si)
|
||||
return linear.replace(src=tuple(kept)), linear.replace(src=tuple(onetime))
|
||||
|
||||
def create_graph_call(batch:list[UOp], input_buffers:set[Buffer]) -> UOp:
|
||||
def bufs_for(b): return b.buffer.bufs if isinstance(b.buffer, MultiBuffer) else [b.buffer]
|
||||
|
||||
input_list = dedup(b for si in batch for b in si.src[1:] if b.op in (Ops.BUFFER, Ops.BUFFER_VIEW) and not input_buffers.isdisjoint(bufs_for(b)))
|
||||
def create_graph_call(batch:list[UOp]) -> UOp:
|
||||
# all external inputs are PARAMs
|
||||
input_list = dedup(b for si in batch for b in si.src[1:] if b.op is Ops.PARAM)
|
||||
cf = UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(UOp(Ops.LINEAR, src=tuple(batch)), *input_list), arg="graph")
|
||||
return cf.call(*input_list, metadata=tuple(m for si in batch for m in si.arg.metadata))
|
||||
|
||||
def graph_split_rewrite(linear:UOp, input_buffers:set[Buffer], max_batch_size:int=0) -> UOp:
|
||||
def graph_split_rewrite(linear:UOp, max_batch_size:int=0) -> UOp:
|
||||
new_src: list[UOp] = []
|
||||
current_batch: list[UOp] = []
|
||||
current_batch_devs: list[Compiled] = []
|
||||
@@ -38,7 +37,7 @@ def graph_split_rewrite(linear:UOp, input_buffers:set[Buffer], max_batch_size:in
|
||||
nonlocal current_batch, current_batch_devs, max_batch_size, new_src
|
||||
if len(current_batch) <= 1 and not getenv("GRAPH_ONE_KERNEL"): new_src.extend(current_batch)
|
||||
else:
|
||||
new_src.append(create_graph_call(current_batch, input_buffers))
|
||||
new_src.append(create_graph_call(current_batch))
|
||||
max_batch_size *= 2
|
||||
if DEBUG >= 2: print(f"JIT GRAPHing batch with {len(current_batch)} kernels")
|
||||
current_batch, current_batch_devs = [], []
|
||||
@@ -66,13 +65,26 @@ def jit_cache_bufs(jit_cache:list[ExecItem]):
|
||||
if b is not None: yield b
|
||||
if isinstance(ei.prg, GraphRunner): yield from jit_cache_bufs(ei.prg.jit_cache)
|
||||
|
||||
@track_rewrites(lambda linear,held_bufs,input_buffers=None,ret=(): f"JIT {pluralize('call', len(linear.src))}")
|
||||
def jit_lower(linear:UOp, held_bufs:set[UOp], input_buffers:list[Buffer]|None=None) -> list[ExecItem]:
|
||||
def _unwrap_beam(ast:UOp) -> UOp: return ast.src[0] if ast.op is Ops.BEAM else ast
|
||||
|
||||
pm_beam = PatternMatcher([
|
||||
(UPat(Ops.CALL, src=(UPat(Ops.SINK, name="sink"),), name="call", allow_any_len=True),
|
||||
lambda ctx,call,sink: call.replace(src=(UOp(Ops.BEAM, src=(sink,), arg=ctx), *call.src[1:])))])
|
||||
|
||||
@track_rewrites(lambda linear,held_bufs,input_uops,ret=(): f"JIT {pluralize('call', len(linear.src))}")
|
||||
def jit_lower(linear:UOp, held_bufs:set[UOp], input_uops:list[UOp]) -> UOp:
|
||||
if VIZ: graph_rewrite(linear, PatternMatcher([]), name="View captured linear")
|
||||
|
||||
# parametrize input buffers: map each input buffer UOp to a PARAM with the correct slot index
|
||||
linear = linear.substitute({u: UOp.param(i, u.dtype, u.shape, u.device) for i,u in enumerate(input_uops)}, walk=True)
|
||||
|
||||
# wrap SINKs with BEAM if jitbeam is set
|
||||
if (jitbeam:=getenv("JITBEAM", BEAM.value)) >= 1: linear = graph_rewrite(linear, pm_beam, ctx=jitbeam, walk=True)
|
||||
|
||||
linear = memory_plan_rewrite(linear, held_bufs)
|
||||
if JIT < 2: linear = graph_split_rewrite(linear, set(input_buffers or []), max_batch_size=JIT_BATCH_SIZE.value)
|
||||
if JIT < 2: linear = graph_split_rewrite(linear, max_batch_size=JIT_BATCH_SIZE.value)
|
||||
if VIZ: graph_rewrite(linear, PatternMatcher([]), name="View graphed linear")
|
||||
return [ei.lower() for ei in linear_to_schedule(linear)]
|
||||
return linear
|
||||
|
||||
class GraphException(Exception): pass
|
||||
class JitError(Exception): pass
|
||||
@@ -86,26 +98,37 @@ def _check_no_non_tensor_return(ret):
|
||||
|
||||
def graph_class(dev): return dev.graph.func if isinstance(dev.graph, functools.partial) else dev.graph
|
||||
|
||||
def get_input_replace(jit_cache: list[ExecItem], input_buffers:list[Buffer],
|
||||
orig_valid_positions: dict[int, set[int]]|None = None) -> dict[tuple[int, int], int]:
|
||||
def get_input_replace(jit_cache: list[ExecItem], input_buffers:list[Buffer]) -> dict[tuple[int, int], int]:
|
||||
input_replace: dict[tuple[int, int], int] = {}
|
||||
for j,ji in enumerate(jit_cache):
|
||||
for i,a in enumerate(ji.bufs):
|
||||
if a in input_buffers:
|
||||
# filter out positions that weren't valid inputs in the original capture (prevents aliasing bugs)
|
||||
if orig_valid_positions is not None and i not in orig_valid_positions.get(id(ji), set()): continue
|
||||
input_replace[(j,i)] = input_buffers.index(a)
|
||||
if a in input_buffers: input_replace[(j,i)] = input_buffers.index(a)
|
||||
return input_replace
|
||||
|
||||
pm_params = PatternMatcher([(UPat(Ops.PARAM, src=(UPat(), UPat(Ops.DEVICE)), name="p"), lambda ctx,p: ctx[p.arg])])
|
||||
|
||||
def linear_to_jit_cache(linear:UOp, input_uops:list[UOp]) -> tuple[list[ExecItem], dict[tuple[int,int],int], list[tuple[int,int,str,int,DType]]]:
|
||||
# substitute PARAMs with input buffer UOps before lowering
|
||||
linear = graph_rewrite(linear, pm_params, ctx=input_uops, walk=True, enter_calls=True)
|
||||
# convert to jit_cache
|
||||
jit_cache = [ei.lower() for ei in linear_to_schedule(linear)]
|
||||
for b in jit_cache_bufs(jit_cache): b.ensure_allocated()
|
||||
# derive input_buffers from input_uops
|
||||
input_buffers: list[Buffer] = flatten([b.bufs if isinstance(b, MultiBuffer) else [b] for u in input_uops if (b:=buffers[u]) is not None])
|
||||
# track view buffers whose base is an input buffer
|
||||
extra_view_inputs: list[tuple[int, int, str, int, DType]] = []
|
||||
for ei in jit_cache:
|
||||
for b in ei.bufs:
|
||||
if b is not None and b._base is not None and b._base in input_buffers and b not in input_buffers:
|
||||
extra_view_inputs.append((input_buffers.index(b._base), b.offset, b.device, b.size, b.dtype))
|
||||
input_buffers.append(b)
|
||||
return jit_cache, get_input_replace(jit_cache, input_buffers), extra_view_inputs
|
||||
|
||||
class GraphRunner(Runner):
|
||||
def __init__(self, linear:UOp|None, input_buffers:list[Buffer]|None,
|
||||
jit_cache:list[ExecItem]|None=None, input_replace:dict[tuple[int,int],int]|None=None):
|
||||
# TODO: captured jit as linear?
|
||||
if linear is not None:
|
||||
jit_cache = [ei.lower() for ei in linear_to_schedule(linear.src[0])]
|
||||
for b in jit_cache_bufs(jit_cache): b.ensure_allocated()
|
||||
input_replace = get_input_replace(jit_cache, input_buffers) if input_buffers else {}
|
||||
self.jit_cache, self.input_replace = unwrap(jit_cache), input_replace or {}
|
||||
def __init__(self, linear:UOp, input_buffers:list[Buffer]):
|
||||
self.jit_cache = [ei.lower() for ei in linear_to_schedule(linear.src[0])]
|
||||
for b in jit_cache_bufs(self.jit_cache): b.ensure_allocated()
|
||||
self.input_replace = get_input_replace(self.jit_cache, input_buffers) if input_buffers else {}
|
||||
|
||||
self.var_vals_replace:dict[int, list[tuple[int, int]]] = {}
|
||||
self.launch_dims_replace:dict[int, tuple[int|None, int|None]] = {}
|
||||
@@ -141,8 +164,6 @@ class GraphRunner(Runner):
|
||||
assert self.jit_cache[0].prg is not None
|
||||
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), self.jit_cache[0].prg.device.split(":")[0], estimates.simplify())
|
||||
|
||||
def __reduce__(self): return self.__class__, (None, None, self.jit_cache, self.input_replace)
|
||||
|
||||
def updated_vars(self, var_vals: dict[str, int]):
|
||||
vals = [var_vals[v] for v in self.vars]
|
||||
for j, vidxs in self.var_vals_replace.items():
|
||||
@@ -179,14 +200,15 @@ class GraphRunner(Runner):
|
||||
|
||||
@staticmethod
|
||||
def supports_exec_item(batch_devs:list[Compiled], new_call:UOp) -> bool:
|
||||
return new_call.src[0].op in (Ops.SINK, Ops.PROGRAM) and len(GraphRunner._all_devs(batch_devs, new_call)) == 1
|
||||
return _unwrap_beam(new_call.src[0]).op in (Ops.SINK, Ops.PROGRAM) and len(GraphRunner._all_devs(batch_devs, new_call)) == 1
|
||||
|
||||
# a marker for your graph supporting multiple devices of the same type
|
||||
class MultiGraphRunner(GraphRunner):
|
||||
@staticmethod
|
||||
def supports_exec_item(batch_devs:list[Compiled], new_call:UOp) -> bool:
|
||||
# Devices must be the same type
|
||||
return new_call.src[0].op in (Ops.SINK, Ops.PROGRAM, Ops.COPY) and len(dedup([type(d) for d in GraphRunner._all_devs(batch_devs, new_call)])) == 1
|
||||
return _unwrap_beam(new_call.src[0]).op in (Ops.SINK, Ops.PROGRAM, Ops.COPY) \
|
||||
and len(dedup([type(d) for d in GraphRunner._all_devs(batch_devs, new_call)])) == 1
|
||||
|
||||
def get_out_buffers_for_ei(ei:ExecItem) -> list[Buffer]:
|
||||
if isinstance(ei.prg, CompiledRunner): return [cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs if out not in ei.prg.p.ins]
|
||||
@@ -202,33 +224,42 @@ ReturnType = TypeVar('ReturnType')
|
||||
@dataclass
|
||||
class CapturedJit(Generic[ReturnType]):
|
||||
ret: Any # includes the Tensors or any other returned object
|
||||
jit_cache: list[ExecItem]
|
||||
input_replace: dict[tuple[int, int], int]
|
||||
extra_view_inputs: list[tuple[int, int, str, int, DType]]
|
||||
linear: UOp
|
||||
expected_names: list[int|str]
|
||||
expected_input_info: list[tuple[UOp, tuple[Variable, ...], DType, str]] # (view, variables, dtype, device) per input
|
||||
|
||||
def __reduce__(self):
|
||||
# TODO: free_intermediates here?
|
||||
return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs, self.expected_names, self.expected_input_info)
|
||||
def __reduce__(self): return self.__class__, (self.ret, self.linear, self.expected_names, self.expected_input_info)
|
||||
def __post_init__(self): self._jit_cache = None
|
||||
@property
|
||||
def jit_cache(self) -> list[ExecItem]: return self._jit_cache if self._jit_cache is not None else []
|
||||
|
||||
def __post_init__(self):
|
||||
self._jit_cache: list[ExecItem] = self.jit_cache
|
||||
self._input_replace: dict[tuple[int, int], int] = self.input_replace
|
||||
self._first_run = True
|
||||
self._needs_rebuild = False
|
||||
# precompute read-after-write hazard detection
|
||||
self._output_to_writer = {b: j for j, ei in enumerate(self.jit_cache) for b in get_out_buffers_for_ei(ei)}
|
||||
def _init(self, input_uops:list[UOp]):
|
||||
self._jit_cache, self._input_replace, self._extra_view_inputs = linear_to_jit_cache(self.linear, input_uops)
|
||||
self._output_to_writer = {b: j for j, ei in enumerate(self._jit_cache) for b in get_out_buffers_for_ei(ei)}
|
||||
self._input_to_max_reader: dict[int, int] = {}
|
||||
for (j, i), idx in self.input_replace.items():
|
||||
# only buffers that were different during capture but alias at jit time (e.g. feeding output back as input) need the copy.
|
||||
if self.jit_cache[j].bufs[i] not in get_out_buffers_for_ei(self.jit_cache[j]):
|
||||
for (j, i), idx in self._input_replace.items():
|
||||
if self._jit_cache[j].bufs[i] not in get_out_buffers_for_ei(self._jit_cache[j]):
|
||||
self._input_to_max_reader[idx] = max(self._input_to_max_reader.get(idx, -1), j)
|
||||
self._clear_inputs()
|
||||
|
||||
def _clear_inputs(self):
|
||||
for (j,i) in self._input_replace.keys(): self._jit_cache[j].bufs[i] = None
|
||||
|
||||
def __call__(self, input_uops:list[UOp], var_vals:dict[str, int]) -> ReturnType:
|
||||
if self._jit_cache is None: self._init(input_uops)
|
||||
assert self._jit_cache is not None
|
||||
# derive input_buffers from input_uops (flatten MultiBuffer)
|
||||
input_buffers: list[Buffer] = flatten([b.bufs if isinstance(b, MultiBuffer) else [b] for u in input_uops if (b:=buffers[u]) is not None])
|
||||
# recreate view buffers from input bases
|
||||
for idx, offset, device, size, dtype in self._extra_view_inputs:
|
||||
input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated())
|
||||
# copy aliased inputs to prevent read-after-write hazard
|
||||
for i, ib in enumerate(input_buffers):
|
||||
if (writer := self._output_to_writer.get(ib)) is not None and self._input_to_max_reader.get(i, -1) >= writer:
|
||||
input_buffers[i] = Buffer(ib.device, ib.size, ib.dtype).ensure_allocated().copyin(ib.as_memoryview())
|
||||
for (j,i),input_idx in self._input_replace.items(): self._jit_cache[j].bufs[i] = input_buffers[input_idx]
|
||||
if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels")
|
||||
for ei in self._jit_cache: ei.run(var_vals, jit=True)
|
||||
for (j,i) in self._input_replace.keys(): self._jit_cache[j].bufs[i] = None
|
||||
return self.ret
|
||||
|
||||
def free_intermediates(self):
|
||||
depends: set[Buffer|None] = set([None])
|
||||
update_depends(depends, self.jit_cache)
|
||||
@@ -239,33 +270,6 @@ class CapturedJit(Generic[ReturnType]):
|
||||
for a in arenas:
|
||||
if a.allocated_views == 0 and a.is_allocated(): a.deallocate()
|
||||
self.__post_init__()
|
||||
self._needs_rebuild = True
|
||||
|
||||
# jit exec
|
||||
def __call__(self, input_buffers:list[Buffer], var_vals:dict[str, int]) -> ReturnType:
|
||||
# assign inputs
|
||||
for idx, offset, device, size, dtype in self.extra_view_inputs:
|
||||
input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated())
|
||||
|
||||
# copy aliased inputs to prevent read-after-write hazard
|
||||
for i, ib in enumerate(input_buffers):
|
||||
if (writer := self._output_to_writer.get(ib)) is not None and self._input_to_max_reader.get(i, -1) >= writer:
|
||||
input_buffers[i] = Buffer(ib.device, ib.size, ib.dtype).ensure_allocated().copyin(ib.as_memoryview())
|
||||
|
||||
for (j,i),input_idx in self._input_replace.items(): self._jit_cache[j].bufs[i] = input_buffers[input_idx]
|
||||
|
||||
# allocate intermediates if freed on first run
|
||||
if self._first_run:
|
||||
for b in jit_cache_bufs(self.jit_cache): b.ensure_allocated()
|
||||
if self._needs_rebuild:
|
||||
for ei in self.jit_cache:
|
||||
if isinstance(ei.prg, GraphRunner): ei.prg = type(ei.prg)(None, None, ei.prg.jit_cache, ei.prg.input_replace)
|
||||
self._first_run = self._needs_rebuild = False
|
||||
|
||||
if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels")
|
||||
for ei in self._jit_cache: ei.run(var_vals, jit=True)
|
||||
self._clear_inputs()
|
||||
return self.ret
|
||||
|
||||
def _prepare_jit_inputs(args, kwargs):
|
||||
input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor]
|
||||
@@ -278,13 +282,14 @@ def _prepare_jit_inputs(args, kwargs):
|
||||
input_uops: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])
|
||||
if any(u.base.op is Ops.CONST for u in input_uops):
|
||||
raise JitError("JIT inputs cannot be const, create a buffer with .contiguous()")
|
||||
input_buffers: list[Buffer] = flatten([b.bufs if isinstance(b, MultiBuffer) else [b] for u in input_uops if (b:=u.base.realized) is not None])
|
||||
if len(set(input_buffers)) != len(input_buffers): raise JitError("duplicate inputs to JIT")
|
||||
# collect buffer UOps (including MultiBuffer)
|
||||
input_buf_uops: list[UOp] = [u.base for u in input_uops if u.base.realized is not None]
|
||||
if len(set(input_buf_uops)) != len(input_buf_uops): raise JitError("duplicate inputs to JIT")
|
||||
inputs = [(*(u.substitute({u.base:UOp(Ops.NOOP)}, extra_pm=mop_cleanup).unbind_all()), u.dtype, u.device) for u in input_uops]
|
||||
_var_vals = merge_dicts([x[1] for x in inputs] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))])
|
||||
var_vals = {k.expr:v for k,v in _var_vals.items()}
|
||||
expected_input_info = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in inputs]
|
||||
return input_buffers, var_vals, names, expected_input_info
|
||||
return input_buf_uops, var_vals, names, expected_input_info
|
||||
|
||||
class TinyJit(Generic[ReturnType]):
|
||||
def __init__(self, fxn:Callable[..., ReturnType]|None, captured:CapturedJit|None=None, prune=False):
|
||||
@@ -307,14 +312,15 @@ class TinyJit(Generic[ReturnType]):
|
||||
|
||||
# keep legacy code working
|
||||
@property
|
||||
def jit_cache(self) -> list[ExecItem]: return self.captured._jit_cache if self.captured is not None else []
|
||||
def jit_cache(self) -> list[ExecItem]: return self.captured._jit_cache if self.captured is not None and self.captured._jit_cache is not None else []
|
||||
@property
|
||||
def input_replace(self) -> dict[tuple[int, int], int]: return self.captured._input_replace if self.captured is not None else {}
|
||||
def input_replace(self) -> dict[tuple[int, int], int]:
|
||||
return self.captured._input_replace if self.captured is not None and self.captured._jit_cache is not None else {}
|
||||
|
||||
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) # add support for instance methods
|
||||
|
||||
def __call__(self, *args, **kwargs) -> ReturnType:
|
||||
input_buffers, var_vals, names, expected_input_info = _prepare_jit_inputs(args, kwargs)
|
||||
input_buf_uops, var_vals, names, expected_input_info = _prepare_jit_inputs(args, kwargs)
|
||||
if not JIT or self.cnt == 0:
|
||||
# jit ignore
|
||||
assert self.fxn is not None
|
||||
@@ -333,46 +339,30 @@ class TinyJit(Generic[ReturnType]):
|
||||
finally: capturing.clear()
|
||||
if not len(self._linears): raise JitError("didn't JIT anything!")
|
||||
_check_no_non_tensor_return(ret)
|
||||
if DEBUG >= 1: print(f"JIT captured {len(self._linears)} linears with {len(input_buffers)} inputs")
|
||||
if DEBUG >= 1: print(f"JIT captured {len(self._linears)} linears with {len(input_buf_uops)} inputs")
|
||||
|
||||
# combine all captured linears into one, memory plan, and convert to ExecItems
|
||||
big_linear = UOp(Ops.LINEAR, src=tuple(flatten([l.src for l in self._linears])))
|
||||
del self._linears
|
||||
|
||||
if self.prune:
|
||||
big_linear, onetime_linear = prune_linear(big_linear, {k for k,v in buffers.items() if isinstance(v, Buffer) and v in set(input_buffers)})
|
||||
big_linear, onetime_linear = prune_linear(big_linear, set(input_buf_uops))
|
||||
if DEBUG >= 1: print(f"pruned from {len(big_linear.src) + len(onetime_linear.src)} -> {len(big_linear.src)} kernels")
|
||||
for ei in (si.lower() for si in linear_to_schedule(onetime_linear)):
|
||||
for b in ei.bufs: cast(Buffer, b).ensure_allocated()
|
||||
ei.run(var_vals, jit=True)
|
||||
|
||||
held_bufs = set(buffers) | {t.uop.buf_uop for t in get_parameters(ret) if t.uop.buf_uop.op is Ops.BUFFER}
|
||||
with Context(BEAM=getenv("JITBEAM", BEAM.value)):
|
||||
jit_cache = jit_lower(big_linear, held_bufs, input_buffers)
|
||||
|
||||
# track inputs that are views of buffers
|
||||
# TODO: eventually expected_buffers should live in ExecItem
|
||||
extra_view_inputs: list[tuple[int, int, str, int, DType]] = []
|
||||
for item in jit_cache:
|
||||
for b in item.bufs:
|
||||
if b is not None and b._base is not None and b._base in input_buffers:
|
||||
input_buffers.append(b)
|
||||
extra_view_inputs.append((input_buffers.index(b.base), b.offset, b.device, b.size, b.dtype))
|
||||
|
||||
input_replace = get_input_replace(jit_cache, input_buffers)
|
||||
if DEBUG >= 1 and len(set(input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found")
|
||||
|
||||
# exec
|
||||
for ei in jit_cache: ei.run(var_vals)
|
||||
|
||||
self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, expected_input_info)
|
||||
linear = jit_lower(big_linear, held_bufs, input_buf_uops)
|
||||
self.captured = CapturedJit(ret, linear, names, expected_input_info)
|
||||
ret = self.captured(input_buf_uops, var_vals)
|
||||
elif self.cnt >= 2:
|
||||
# jit exec
|
||||
assert self.captured is not None
|
||||
if self.captured.expected_names != names: raise JitError(f"args mismatch in JIT: {self.captured.expected_names=} != {names}")
|
||||
if self.captured.expected_input_info != expected_input_info:
|
||||
raise JitError(f"args mismatch in JIT: {self.captured.expected_input_info=} != {expected_input_info=}")
|
||||
ret = self.captured(input_buffers, var_vals)
|
||||
ret = self.captured(input_buf_uops, var_vals)
|
||||
|
||||
self.cnt += 1
|
||||
return ret
|
||||
|
||||
@@ -6,7 +6,7 @@ from tinygrad.device import Buffer, BufferSpec, Compiled, Device, ProfileGraphEn
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import UOp, Ops, Variable
|
||||
from tinygrad.engine.realize import BufferXfer, CompiledRunner, BufferCopy
|
||||
from tinygrad.engine.jit import GraphRunner, MultiGraphRunner
|
||||
from tinygrad.engine.jit import GraphRunner, MultiGraphRunner, _unwrap_beam
|
||||
from tinygrad.runtime.ops_rdma import RDMACopyQueue
|
||||
|
||||
class HCQGraph(MultiGraphRunner):
|
||||
@@ -326,4 +326,4 @@ class HCQGraph(MultiGraphRunner):
|
||||
# MOCKGPU is not supported, since it can't execute commands in parallel
|
||||
is_xfer = len(set(type(d) for d in all_devs)) == 1 and hasattr(alc:=all_devs[0].allocator, '_transfer') and alc.supports_transfer
|
||||
return is_xfer or (all_devs[0].hw_copy_queue_t is not None and not getenv("MOCKGPU"))
|
||||
return new_call.src[0].op in (Ops.SINK, Ops.PROGRAM)
|
||||
return _unwrap_beam(new_call.src[0]).op in (Ops.SINK, Ops.PROGRAM)
|
||||
|
||||
@@ -92,7 +92,7 @@ _tensor_spec = PatternMatcher([
|
||||
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, DType)),
|
||||
|
||||
# BUFFER_VIEW on BUFFER is allowed if BUFFER is
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),)), lambda: True),
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat((Ops.BUFFER, Ops.PARAM)),)), lambda: True),
|
||||
|
||||
# KERNEL can attach to an AFTER to describe the compute required to realize a BUFFER
|
||||
(UPat((Ops.CALL, Ops.FUNCTION), src=UPat((Ops.BUFFER, Ops.AFTER, Ops.MSELECT, Ops.MSTACK, Ops.BIND))), lambda: True),
|
||||
|
||||
Reference in New Issue
Block a user