update variable names around jit [pr] (#14049)

lbs, st_vars_dtype_device and rawbuffers no more
This commit is contained in:
chenyu
2026-01-06 22:32:41 -05:00
committed by GitHub
parent 2833c5a54b
commit 87f4bc5446
6 changed files with 55 additions and 56 deletions

View File

@@ -7,7 +7,7 @@ if __name__ == "__main__":
with open(fetch(sys.argv[1]), "rb") as f:
run_onnx_jit = pickle.load(f)
input_name = run_onnx_jit.captured.expected_names[0]
device = run_onnx_jit.captured.expected_st_vars_dtype_device[0][-1]
device = run_onnx_jit.captured.expected_input_info[0][-1]
print(f"input goes into {input_name=} on {device=}")
hit = 0
for i,(img,y) in enumerate(imagenet_dataloader(cnt=getenv("CNT", 100))):

View File

@@ -24,7 +24,7 @@ def _check_no_non_tensor_return(ret):
def graph_class(dev): return dev.graph.func if isinstance(dev.graph, functools.partial) else dev.graph
def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int], max_batch_size=0) -> list[ExecItem]:
def apply_graph_to_jit(jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int], max_batch_size=0) -> list[ExecItem]:
# Split JIT cache into batches for faster graph execution.
# This allows the accelerator to run some batches while subsequent graphs are still being updated.
graphed_jit_cache: list[ExecItem] = []
@@ -36,10 +36,10 @@ def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer]
try:
if len(current_batch_devs) == 0: raise GraphException("no device for graph")
if len(current_batch) <= 1 and not getenv("GRAPH_ONE_KERNEL"): raise GraphException("only one kernel doesn't graph")
graph_runner = current_batch_devs[0].graph(current_batch, input_rawbuffers, var_vals)
graph_runner = current_batch_devs[0].graph(current_batch, input_buffers, var_vals)
# clear jit inputs to allow their memory to be freed/reused
for (j,i) in graph_runner.input_replace.keys(): graph_runner.jit_cache[j].bufs[i] = None
graphed_jit_cache.append(ExecItem(UOp(Ops.NOOP), cast(list[Buffer|None], input_rawbuffers), prg=graph_runner))
graphed_jit_cache.append(ExecItem(UOp(Ops.NOOP), cast(list[Buffer|None], input_buffers), prg=graph_runner))
max_batch_size *= 2
if DEBUG >= 2: print(f"JIT GRAPHing batch with {len(current_batch)} kernels on device {current_batch_devs[0]}")
except GraphException as e:
@@ -72,18 +72,18 @@ def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer]
if len(current_batch) > 0: flush_batch()
return graphed_jit_cache
def get_input_replace(jit_cache: list[ExecItem], input_rawbuffers:list[Buffer]) -> dict[tuple[int, int], int]:
def get_input_replace(jit_cache: list[ExecItem], input_buffers:list[Buffer]) -> dict[tuple[int, int], int]:
input_replace: dict[tuple[int, int], int] = {}
for j,ji in enumerate(jit_cache):
for i,a in enumerate(ji.bufs):
if a in input_rawbuffers:
input_replace[(j,i)] = input_rawbuffers.index(a)
if a in input_buffers:
input_replace[(j,i)] = input_buffers.index(a)
return input_replace
class GraphRunner(Runner):
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]):
def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int]):
self.jit_cache = jit_cache # NOTE: this is not used, but you have to keep these objects alive for the Graph
self.input_replace:dict[tuple[int, int], int] = get_input_replace(jit_cache, input_rawbuffers)
self.input_replace:dict[tuple[int, int], int] = get_input_replace(jit_cache, input_buffers)
self.var_vals_replace:dict[int, list[tuple[int, int]]] = {}
self.launch_dims_replace:dict[int, tuple[int|None, int|None]] = {}
self.launch_dims_base:dict[int, tuple[tuple[int, ...], tuple[int, ...]]] = {}
@@ -125,19 +125,19 @@ class GraphRunner(Runner):
for j, (gl, lc) in self.launch_dims_replace.items():
yield j, (dims[gl] if gl is not None else self.launch_dims_base[j][0]), (dims[lc] if lc is not None else self.launch_dims_base[j][1])
def _access_resources(self, rawbufs:list[Buffer], write:list[int], new_dependency:Any):
def _access_resources(self, bufs:list[Buffer], write:list[int], new_dependency:Any):
# To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource,
# whether for write or read operations. A resource can be accessed by either a single writer or multiple readers.
wait_nodes = []
for i,rawbuf in enumerate(rawbufs):
if id(rawbuf.base._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(rawbuf.base._buf)])
for i,buf in enumerate(bufs):
if id(buf.base._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(buf.base._buf)])
if i in write:
if id(rawbuf.base._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(rawbuf.base._buf)))
if id(buf.base._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(buf.base._buf)))
for i,rawbuf in enumerate(rawbufs):
if i in write: self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency
else: self.r_dependency_map[id(rawbuf.base._buf)].append(new_dependency)
for i,buf in enumerate(bufs):
if i in write: self.w_dependency_map[id(buf.base._buf)] = new_dependency
else: self.r_dependency_map[id(buf.base._buf)].append(new_dependency)
return list({id(x):x for x in wait_nodes}.values())
@@ -168,12 +168,11 @@ class CapturedJit(Generic[ReturnType]):
input_replace: dict[tuple[int, int], int]
extra_view_inputs: list[tuple[int, int, str, int, DType]]
expected_names: list[int|str]
expected_st_vars_dtype_device: list[tuple[UOp, tuple[Variable, ...], DType, str]]
expected_input_info: list[tuple[UOp, tuple[Variable, ...], DType, str]] # (view, variables, dtype, device) per input
def __reduce__(self):
# TODO: free_intermediates here? replan_buffers_memory_layout here?
return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs,
self.expected_names, self.expected_st_vars_dtype_device)
return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs, self.expected_names, self.expected_input_info)
def __post_init__(self):
self._jit_cache: list[ExecItem] = self.jit_cache
@@ -233,17 +232,17 @@ def _prepare_jit_inputs(args, kwargs):
it = x if isinstance(x, (tuple,list)) else x.values() if isinstance(x, dict) else []
tensors += [t for t in it if t.__class__ is Tensor and not any(t is y for y in tensors)]
if len(unrealized_tensors := [x for x in tensors if not x.uop.is_realized]): Tensor.realize(*unrealized_tensors)
lbs: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])
if any(lb.base.op is Ops.CONST for lb in lbs):
input_uops: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors])
if any(u.base.op is Ops.CONST for u in input_uops):
raise JitError("JIT inputs cannot be const, create a buffer with .contiguous()")
input_buffers: list[Buffer] = flatten([rb.bufs if isinstance(rb:=lb.base.realized, MultiBuffer) else [rb]
for lb in lbs if lb.base.realized is not None])
input_buffers: list[Buffer] = flatten([b.bufs if isinstance(b:=u.base.realized, MultiBuffer) else [b]
for u in input_uops if u.base.realized is not None])
if len(set(input_buffers)) != len(input_buffers): raise JitError("duplicate inputs to JIT")
st_varval_dtype_device = [(*(lb.substitute({lb.base:UOp(Ops.NOOP)}, extra_pm=mop_cleanup).unbind_all()), lb.dtype, lb.device) for lb in lbs]
_var_vals = merge_dicts([x[1] for x in st_varval_dtype_device] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))])
inputs = [(*(u.substitute({u.base:UOp(Ops.NOOP)}, extra_pm=mop_cleanup).unbind_all()), u.dtype, u.device) for u in input_uops]
_var_vals = merge_dicts([x[1] for x in inputs] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))])
var_vals = {k.expr:v for k,v in _var_vals.items()}
st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varval_dtype_device]
return input_buffers, var_vals, names, st_vars_dtype_device
expected_input_info = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in inputs]
return input_buffers, var_vals, names, expected_input_info
class TinyJit(Generic[ReturnType]):
def __init__(self, fxn:Callable[..., ReturnType]|None, captured:CapturedJit|None=None, prune=False, optimize=False):
@@ -284,7 +283,7 @@ class TinyJit(Generic[ReturnType]):
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) # add support for instance methods
def __call__(self, *args, **kwargs) -> ReturnType:
input_buffers, var_vals, names, st_vars_dtype_device = _prepare_jit_inputs(args, kwargs)
input_buffers, var_vals, names, expected_input_info = _prepare_jit_inputs(args, kwargs)
if not JIT or self.cnt == 0:
# jit ignore
assert self.fxn is not None
@@ -342,14 +341,14 @@ class TinyJit(Generic[ReturnType]):
if DEBUG >= 1 and len(set(input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found")
# set this for next run
self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, st_vars_dtype_device)
self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, expected_input_info)
if self.optimize: self.captured.replan_buffers_memory_layout()
elif self.cnt >= 2:
# jit exec
assert self.captured is not None
if self.captured.expected_names != names: raise JitError(f"args mismatch in JIT: {self.captured.expected_names=} != {names}")
if self.captured.expected_st_vars_dtype_device != st_vars_dtype_device:
raise JitError(f"args mismatch in JIT: {self.captured.expected_st_vars_dtype_device=} != {st_vars_dtype_device=}")
if self.captured.expected_input_info != expected_input_info:
raise JitError(f"args mismatch in JIT: {self.captured.expected_input_info=} != {expected_input_info=}")
ret = self.captured(input_buffers, var_vals)
self.cnt += 1

View File

@@ -8,13 +8,13 @@ from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
from tinygrad.engine.jit import MultiGraphRunner, GraphException
class CUDAGraph(MultiGraphRunner):
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]):
super().__init__(jit_cache, input_rawbuffers, var_vals)
def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int]):
super().__init__(jit_cache, input_buffers, var_vals)
# Check all jit items are compatible.
if not all(isinstance(ji.prg, (CompiledRunner, BufferXfer)) for ji in jit_cache): raise GraphException
self.jc_idx_with_updatable_rawbufs = dedup([x[0] for x in self.input_replace.keys()])
self.jc_idx_with_updatable_bufs = dedup([x[0] for x in self.input_replace.keys()])
self.updatable_nodes: dict[int, tuple[Any, Any, Any, bool]] = {} # dict[jc index] = tuple(graph node, node params, input kernel params, is memcpy)
self.graph = init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
@@ -31,29 +31,29 @@ class CUDAGraph(MultiGraphRunner):
kern_params = cuda.CUDA_KERNEL_NODE_PARAMS_v1(ji.prg._prg.prg, *global_size, *local_size, 0, None, vargs)
check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params)))
if j in self.launch_dims_replace or j in self.var_vals_replace or j in self.jc_idx_with_updatable_rawbufs:
if j in self.launch_dims_replace or j in self.var_vals_replace or j in self.jc_idx_with_updatable_bufs:
self.updatable_nodes[j] = (new_node, kern_params, c_args, False)
elif isinstance(ji.prg, BufferXfer):
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
src_dev = cast(CUDADevice, Device[src.device])
node_from = cuda.CUgraphNode()
deps = self._access_resources(rawbufs=[dest.base, src.base], write=[0], new_dependency=node_from)
deps = self._access_resources(bufs=[dest.base, src.base], write=[0], new_dependency=node_from)
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1,
dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1,
WidthInBytes=dest.nbytes, Height=1, Depth=1)
check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_from), self.graph, c_deps, len(deps), ctypes.byref(cp_params), src_dev.context))
if j in self.jc_idx_with_updatable_rawbufs: self.updatable_nodes[j] = (node_from, cp_params, src_dev.context, True)
if j in self.jc_idx_with_updatable_bufs: self.updatable_nodes[j] = (node_from, cp_params, src_dev.context, True)
self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0)))
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
# Update rawbuffers in the c_args struct.
def __call__(self, input_buffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
# Update buffers in the c_args struct.
for (j,i),input_idx in self.input_replace.items():
if not self.updatable_nodes[j][3]: setattr(self.updatable_nodes[j][2], f'f{i}', input_rawbuffers[input_idx]._buf)
if not self.updatable_nodes[j][3]: setattr(self.updatable_nodes[j][2], f'f{i}', input_buffers[input_idx]._buf)
else:
if i == 0: self.updatable_nodes[j][1].destDevice = input_rawbuffers[input_idx]._buf
elif i == 1: self.updatable_nodes[j][1].srcDevice = input_rawbuffers[input_idx]._buf
if i == 0: self.updatable_nodes[j][1].destDevice = input_buffers[input_idx]._buf
elif i == 1: self.updatable_nodes[j][1].srcDevice = input_buffers[input_idx]._buf
# Update var_vals in the c_args struct.
for j, i, v in self.updated_vars(var_vals): setattr(self.updatable_nodes[j][2], f'v{i}', v)

View File

@@ -9,8 +9,8 @@ from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner, Buffer
from tinygrad.engine.jit import MultiGraphRunner
class HCQGraph(MultiGraphRunner):
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]):
super().__init__(jit_cache, input_rawbuffers, var_vals)
def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int]):
super().__init__(jit_cache, input_buffers, var_vals)
self.devices = list(set(cast(HCQCompiled, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs]))
# CPU Device is always last
@@ -189,10 +189,10 @@ class HCQGraph(MultiGraphRunner):
def _dev_copy_queues(self, dev): return [q for (d, _), q in self.copy_queues.items() if d == dev]
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
# Map input rawbuffers
def __call__(self, input_buffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
# Map input buffers
for dev in self.devices:
for idx_to_map in self.input_replace_map[dev]: cast(HCQAllocator, dev.allocator).map(input_rawbuffers[idx_to_map]._buf)
for idx_to_map in self.input_replace_map[dev]: cast(HCQAllocator, dev.allocator).map(input_buffers[idx_to_map]._buf)
# Wait and restore signals
self.kickoff_value += 1
@@ -203,9 +203,9 @@ class HCQGraph(MultiGraphRunner):
**{var.expr: dev.timeline_value - 1 for dev, var in self.virt_timeline_vals.items()},
**{sig.base_buf.va_addr.expr: dev.timeline_signal.base_buf.va_addr for dev, sig in self.virt_timeline_signals.items()}}
# Update rawbuffers
# Update buffers
for (j,i),input_idx in self.input_replace.items():
hcq_var_vals[self.input_replace_to_var[(j,i)].expr] = input_rawbuffers[input_idx]._buf.va_addr
hcq_var_vals[self.input_replace_to_var[(j,i)].expr] = input_buffers[input_idx]._buf.va_addr
for dev in self.devices:
self.comp_queues[dev].submit(dev, hcq_var_vals_local:=hcq_var_vals|self.fixedvars.get(dev, {}))

View File

@@ -10,8 +10,8 @@ from tinygrad.runtime.autogen import metal
from tinygrad.runtime.support import objc
class MetalGraph(GraphRunner):
def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]):
super().__init__(jit_cache, input_rawbuffers, var_vals)
def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int]):
super().__init__(jit_cache, input_buffers, var_vals)
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
# create metal batch exec
@@ -39,7 +39,7 @@ class MetalGraph(GraphRunner):
all_pipelines.append(prg._prg.pipeline_state)
icb_command.setComputePipelineState(prg._prg.pipeline_state)
for i,b in enumerate(ji.bufs):
if b is not None and b not in input_rawbuffers:
if b is not None and b not in input_buffers:
icb_command.setKernelBuffer_offset_atIndex(b._buf.buf, b._buf.offset, i)
all_resources.append(b._buf.buf)
for i,v in enumerate(prg.p.vars): icb_command.setKernelBuffer_offset_atIndex(self.int_buf.buf, self.varlist.index(v.expr)*4, len(ji.bufs)+i)
@@ -55,15 +55,15 @@ class MetalGraph(GraphRunner):
for var in self.fixedvars: self.int_buf_view[self.varlist.index(var)] = self.fixedvars[var]
self.range = metal.NSRange(0, len(jit_cache))
def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
def __call__(self, input_buffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
if self.command_buffer is not None and self.command_buffer in self.dev.mtl_buffers_in_flight: wait_check(self.command_buffer)
# NOTE: old command buffer may not be inflight anymore
if self.command_buffer is not None and PROFILE: self.collect_timestamps()
all_resources = dedup(self.all_resources + [input_rawbuffers[input_idx]._buf.buf for input_idx in self.input_replace.values()])
all_resources = dedup(self.all_resources + [input_buffers[input_idx]._buf.buf for input_idx in self.input_replace.values()])
for (j,i),input_idx in self.input_replace.items():
computeCommand = self.icb.indirectComputeCommandAtIndex(j)
computeCommand.setKernelBuffer_offset_atIndex(input_rawbuffers[input_idx]._buf.buf, input_rawbuffers[input_idx]._buf.offset, i)
computeCommand.setKernelBuffer_offset_atIndex(input_buffers[input_idx]._buf.buf, input_buffers[input_idx]._buf.offset, i)
for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
computeCommand = self.icb.indirectComputeCommandAtIndex(j)

View File

@@ -28,7 +28,7 @@ class NullAllocator(Allocator['NullDevice']):
def _offset(self, buf, offset:int, size:int): pass
class NullGraph(MultiGraphRunner):
def __call__(self, input_rawbuffers, var_vals, wait=False) -> float|None: return 1e-1
def __call__(self, input_buffers, var_vals, wait=False) -> float|None: return 1e-1
class NullDevice(Compiled):
def __init__(self, device:str):