From 87f4bc544638e143442e3a7d71c9ec72476250f0 Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 6 Jan 2026 22:32:41 -0500 Subject: [PATCH] update variable names around jit [pr] (#14049) lbs, st_vars_dtype_device and rawbuffers no more --- examples/test_pkl_imagenet.py | 2 +- tinygrad/engine/jit.py | 59 ++++++++++++++++----------------- tinygrad/runtime/graph/cuda.py | 22 ++++++------ tinygrad/runtime/graph/hcq.py | 14 ++++---- tinygrad/runtime/graph/metal.py | 12 +++---- tinygrad/runtime/ops_null.py | 2 +- 6 files changed, 55 insertions(+), 56 deletions(-) diff --git a/examples/test_pkl_imagenet.py b/examples/test_pkl_imagenet.py index 8110abf309..f714307cc2 100644 --- a/examples/test_pkl_imagenet.py +++ b/examples/test_pkl_imagenet.py @@ -7,7 +7,7 @@ if __name__ == "__main__": with open(fetch(sys.argv[1]), "rb") as f: run_onnx_jit = pickle.load(f) input_name = run_onnx_jit.captured.expected_names[0] - device = run_onnx_jit.captured.expected_st_vars_dtype_device[0][-1] + device = run_onnx_jit.captured.expected_input_info[0][-1] print(f"input goes into {input_name=} on {device=}") hit = 0 for i,(img,y) in enumerate(imagenet_dataloader(cnt=getenv("CNT", 100))): diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 252117e0d9..1d5d9edb73 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -24,7 +24,7 @@ def _check_no_non_tensor_return(ret): def graph_class(dev): return dev.graph.func if isinstance(dev.graph, functools.partial) else dev.graph -def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int], max_batch_size=0) -> list[ExecItem]: +def apply_graph_to_jit(jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int], max_batch_size=0) -> list[ExecItem]: # Split JIT cache into batches for faster graph execution. # This allows the accelerator to run some batches while subsequent graphs are still being updated. graphed_jit_cache: list[ExecItem] = [] @@ -36,10 +36,10 @@ def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer] try: if len(current_batch_devs) == 0: raise GraphException("no device for graph") if len(current_batch) <= 1 and not getenv("GRAPH_ONE_KERNEL"): raise GraphException("only one kernel doesn't graph") - graph_runner = current_batch_devs[0].graph(current_batch, input_rawbuffers, var_vals) + graph_runner = current_batch_devs[0].graph(current_batch, input_buffers, var_vals) # clear jit inputs to allow their memory to be freed/reused for (j,i) in graph_runner.input_replace.keys(): graph_runner.jit_cache[j].bufs[i] = None - graphed_jit_cache.append(ExecItem(UOp(Ops.NOOP), cast(list[Buffer|None], input_rawbuffers), prg=graph_runner)) + graphed_jit_cache.append(ExecItem(UOp(Ops.NOOP), cast(list[Buffer|None], input_buffers), prg=graph_runner)) max_batch_size *= 2 if DEBUG >= 2: print(f"JIT GRAPHing batch with {len(current_batch)} kernels on device {current_batch_devs[0]}") except GraphException as e: @@ -72,18 +72,18 @@ def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer] if len(current_batch) > 0: flush_batch() return graphed_jit_cache -def get_input_replace(jit_cache: list[ExecItem], input_rawbuffers:list[Buffer]) -> dict[tuple[int, int], int]: +def get_input_replace(jit_cache: list[ExecItem], input_buffers:list[Buffer]) -> dict[tuple[int, int], int]: input_replace: dict[tuple[int, int], int] = {} for j,ji in enumerate(jit_cache): for i,a in enumerate(ji.bufs): - if a in input_rawbuffers: - input_replace[(j,i)] = input_rawbuffers.index(a) + if a in input_buffers: + input_replace[(j,i)] = input_buffers.index(a) return input_replace class GraphRunner(Runner): - def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]): + def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int]): self.jit_cache = jit_cache # NOTE: this is not used, but you have to keep these objects alive for the Graph - self.input_replace:dict[tuple[int, int], int] = get_input_replace(jit_cache, input_rawbuffers) + self.input_replace:dict[tuple[int, int], int] = get_input_replace(jit_cache, input_buffers) self.var_vals_replace:dict[int, list[tuple[int, int]]] = {} self.launch_dims_replace:dict[int, tuple[int|None, int|None]] = {} self.launch_dims_base:dict[int, tuple[tuple[int, ...], tuple[int, ...]]] = {} @@ -125,19 +125,19 @@ class GraphRunner(Runner): for j, (gl, lc) in self.launch_dims_replace.items(): yield j, (dims[gl] if gl is not None else self.launch_dims_base[j][0]), (dims[lc] if lc is not None else self.launch_dims_base[j][1]) - def _access_resources(self, rawbufs:list[Buffer], write:list[int], new_dependency:Any): + def _access_resources(self, bufs:list[Buffer], write:list[int], new_dependency:Any): # To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource, # whether for write or read operations. A resource can be accessed by either a single writer or multiple readers. wait_nodes = [] - for i,rawbuf in enumerate(rawbufs): - if id(rawbuf.base._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(rawbuf.base._buf)]) + for i,buf in enumerate(bufs): + if id(buf.base._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(buf.base._buf)]) if i in write: - if id(rawbuf.base._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(rawbuf.base._buf))) + if id(buf.base._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(buf.base._buf))) - for i,rawbuf in enumerate(rawbufs): - if i in write: self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency - else: self.r_dependency_map[id(rawbuf.base._buf)].append(new_dependency) + for i,buf in enumerate(bufs): + if i in write: self.w_dependency_map[id(buf.base._buf)] = new_dependency + else: self.r_dependency_map[id(buf.base._buf)].append(new_dependency) return list({id(x):x for x in wait_nodes}.values()) @@ -168,12 +168,11 @@ class CapturedJit(Generic[ReturnType]): input_replace: dict[tuple[int, int], int] extra_view_inputs: list[tuple[int, int, str, int, DType]] expected_names: list[int|str] - expected_st_vars_dtype_device: list[tuple[UOp, tuple[Variable, ...], DType, str]] + expected_input_info: list[tuple[UOp, tuple[Variable, ...], DType, str]] # (view, variables, dtype, device) per input def __reduce__(self): # TODO: free_intermediates here? replan_buffers_memory_layout here? - return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs, - self.expected_names, self.expected_st_vars_dtype_device) + return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs, self.expected_names, self.expected_input_info) def __post_init__(self): self._jit_cache: list[ExecItem] = self.jit_cache @@ -233,17 +232,17 @@ def _prepare_jit_inputs(args, kwargs): it = x if isinstance(x, (tuple,list)) else x.values() if isinstance(x, dict) else [] tensors += [t for t in it if t.__class__ is Tensor and not any(t is y for y in tensors)] if len(unrealized_tensors := [x for x in tensors if not x.uop.is_realized]): Tensor.realize(*unrealized_tensors) - lbs: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors]) - if any(lb.base.op is Ops.CONST for lb in lbs): + input_uops: list[UOp] = flatten([t.uop.src if t.uop.op is Ops.MULTI else [t.uop] for t in tensors]) + if any(u.base.op is Ops.CONST for u in input_uops): raise JitError("JIT inputs cannot be const, create a buffer with .contiguous()") - input_buffers: list[Buffer] = flatten([rb.bufs if isinstance(rb:=lb.base.realized, MultiBuffer) else [rb] - for lb in lbs if lb.base.realized is not None]) + input_buffers: list[Buffer] = flatten([b.bufs if isinstance(b:=u.base.realized, MultiBuffer) else [b] + for u in input_uops if u.base.realized is not None]) if len(set(input_buffers)) != len(input_buffers): raise JitError("duplicate inputs to JIT") - st_varval_dtype_device = [(*(lb.substitute({lb.base:UOp(Ops.NOOP)}, extra_pm=mop_cleanup).unbind_all()), lb.dtype, lb.device) for lb in lbs] - _var_vals = merge_dicts([x[1] for x in st_varval_dtype_device] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))]) + inputs = [(*(u.substitute({u.base:UOp(Ops.NOOP)}, extra_pm=mop_cleanup).unbind_all()), u.dtype, u.device) for u in input_uops] + _var_vals = merge_dicts([x[1] for x in inputs] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))]) var_vals = {k.expr:v for k,v in _var_vals.items()} - st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varval_dtype_device] - return input_buffers, var_vals, names, st_vars_dtype_device + expected_input_info = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in inputs] + return input_buffers, var_vals, names, expected_input_info class TinyJit(Generic[ReturnType]): def __init__(self, fxn:Callable[..., ReturnType]|None, captured:CapturedJit|None=None, prune=False, optimize=False): @@ -284,7 +283,7 @@ class TinyJit(Generic[ReturnType]): def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) # add support for instance methods def __call__(self, *args, **kwargs) -> ReturnType: - input_buffers, var_vals, names, st_vars_dtype_device = _prepare_jit_inputs(args, kwargs) + input_buffers, var_vals, names, expected_input_info = _prepare_jit_inputs(args, kwargs) if not JIT or self.cnt == 0: # jit ignore assert self.fxn is not None @@ -342,14 +341,14 @@ class TinyJit(Generic[ReturnType]): if DEBUG >= 1 and len(set(input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found") # set this for next run - self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, st_vars_dtype_device) + self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, expected_input_info) if self.optimize: self.captured.replan_buffers_memory_layout() elif self.cnt >= 2: # jit exec assert self.captured is not None if self.captured.expected_names != names: raise JitError(f"args mismatch in JIT: {self.captured.expected_names=} != {names}") - if self.captured.expected_st_vars_dtype_device != st_vars_dtype_device: - raise JitError(f"args mismatch in JIT: {self.captured.expected_st_vars_dtype_device=} != {st_vars_dtype_device=}") + if self.captured.expected_input_info != expected_input_info: + raise JitError(f"args mismatch in JIT: {self.captured.expected_input_info=} != {expected_input_info=}") ret = self.captured(input_buffers, var_vals) self.cnt += 1 diff --git a/tinygrad/runtime/graph/cuda.py b/tinygrad/runtime/graph/cuda.py index 056b0b6766..e6bef5aee8 100644 --- a/tinygrad/runtime/graph/cuda.py +++ b/tinygrad/runtime/graph/cuda.py @@ -8,13 +8,13 @@ from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner from tinygrad.engine.jit import MultiGraphRunner, GraphException class CUDAGraph(MultiGraphRunner): - def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]): - super().__init__(jit_cache, input_rawbuffers, var_vals) + def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int]): + super().__init__(jit_cache, input_buffers, var_vals) # Check all jit items are compatible. if not all(isinstance(ji.prg, (CompiledRunner, BufferXfer)) for ji in jit_cache): raise GraphException - self.jc_idx_with_updatable_rawbufs = dedup([x[0] for x in self.input_replace.keys()]) + self.jc_idx_with_updatable_bufs = dedup([x[0] for x in self.input_replace.keys()]) self.updatable_nodes: dict[int, tuple[Any, Any, Any, bool]] = {} # dict[jc index] = tuple(graph node, node params, input kernel params, is memcpy) self.graph = init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0))) @@ -31,29 +31,29 @@ class CUDAGraph(MultiGraphRunner): kern_params = cuda.CUDA_KERNEL_NODE_PARAMS_v1(ji.prg._prg.prg, *global_size, *local_size, 0, None, vargs) check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params))) - if j in self.launch_dims_replace or j in self.var_vals_replace or j in self.jc_idx_with_updatable_rawbufs: + if j in self.launch_dims_replace or j in self.var_vals_replace or j in self.jc_idx_with_updatable_bufs: self.updatable_nodes[j] = (new_node, kern_params, c_args, False) elif isinstance(ji.prg, BufferXfer): dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]] src_dev = cast(CUDADevice, Device[src.device]) node_from = cuda.CUgraphNode() - deps = self._access_resources(rawbufs=[dest.base, src.base], write=[0], new_dependency=node_from) + deps = self._access_resources(bufs=[dest.base, src.base], write=[0], new_dependency=node_from) c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1, dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1, WidthInBytes=dest.nbytes, Height=1, Depth=1) check(cuda.cuGraphAddMemcpyNode(ctypes.byref(node_from), self.graph, c_deps, len(deps), ctypes.byref(cp_params), src_dev.context)) - if j in self.jc_idx_with_updatable_rawbufs: self.updatable_nodes[j] = (node_from, cp_params, src_dev.context, True) + if j in self.jc_idx_with_updatable_bufs: self.updatable_nodes[j] = (node_from, cp_params, src_dev.context, True) self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0))) - def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None: - # Update rawbuffers in the c_args struct. + def __call__(self, input_buffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None: + # Update buffers in the c_args struct. for (j,i),input_idx in self.input_replace.items(): - if not self.updatable_nodes[j][3]: setattr(self.updatable_nodes[j][2], f'f{i}', input_rawbuffers[input_idx]._buf) + if not self.updatable_nodes[j][3]: setattr(self.updatable_nodes[j][2], f'f{i}', input_buffers[input_idx]._buf) else: - if i == 0: self.updatable_nodes[j][1].destDevice = input_rawbuffers[input_idx]._buf - elif i == 1: self.updatable_nodes[j][1].srcDevice = input_rawbuffers[input_idx]._buf + if i == 0: self.updatable_nodes[j][1].destDevice = input_buffers[input_idx]._buf + elif i == 1: self.updatable_nodes[j][1].srcDevice = input_buffers[input_idx]._buf # Update var_vals in the c_args struct. for j, i, v in self.updated_vars(var_vals): setattr(self.updatable_nodes[j][2], f'v{i}', v) diff --git a/tinygrad/runtime/graph/hcq.py b/tinygrad/runtime/graph/hcq.py index 9220f4a9c8..2cfa438392 100644 --- a/tinygrad/runtime/graph/hcq.py +++ b/tinygrad/runtime/graph/hcq.py @@ -9,8 +9,8 @@ from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner, Buffer from tinygrad.engine.jit import MultiGraphRunner class HCQGraph(MultiGraphRunner): - def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]): - super().__init__(jit_cache, input_rawbuffers, var_vals) + def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int]): + super().__init__(jit_cache, input_buffers, var_vals) self.devices = list(set(cast(HCQCompiled, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs])) # CPU Device is always last @@ -189,10 +189,10 @@ class HCQGraph(MultiGraphRunner): def _dev_copy_queues(self, dev): return [q for (d, _), q in self.copy_queues.items() if d == dev] - def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None: - # Map input rawbuffers + def __call__(self, input_buffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None: + # Map input buffers for dev in self.devices: - for idx_to_map in self.input_replace_map[dev]: cast(HCQAllocator, dev.allocator).map(input_rawbuffers[idx_to_map]._buf) + for idx_to_map in self.input_replace_map[dev]: cast(HCQAllocator, dev.allocator).map(input_buffers[idx_to_map]._buf) # Wait and restore signals self.kickoff_value += 1 @@ -203,9 +203,9 @@ class HCQGraph(MultiGraphRunner): **{var.expr: dev.timeline_value - 1 for dev, var in self.virt_timeline_vals.items()}, **{sig.base_buf.va_addr.expr: dev.timeline_signal.base_buf.va_addr for dev, sig in self.virt_timeline_signals.items()}} - # Update rawbuffers + # Update buffers for (j,i),input_idx in self.input_replace.items(): - hcq_var_vals[self.input_replace_to_var[(j,i)].expr] = input_rawbuffers[input_idx]._buf.va_addr + hcq_var_vals[self.input_replace_to_var[(j,i)].expr] = input_buffers[input_idx]._buf.va_addr for dev in self.devices: self.comp_queues[dev].submit(dev, hcq_var_vals_local:=hcq_var_vals|self.fixedvars.get(dev, {})) diff --git a/tinygrad/runtime/graph/metal.py b/tinygrad/runtime/graph/metal.py index 928a7bea4c..e7e767b3ef 100644 --- a/tinygrad/runtime/graph/metal.py +++ b/tinygrad/runtime/graph/metal.py @@ -10,8 +10,8 @@ from tinygrad.runtime.autogen import metal from tinygrad.runtime.support import objc class MetalGraph(GraphRunner): - def __init__(self, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[str, int]): - super().__init__(jit_cache, input_rawbuffers, var_vals) + def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int]): + super().__init__(jit_cache, input_buffers, var_vals) if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException # create metal batch exec @@ -39,7 +39,7 @@ class MetalGraph(GraphRunner): all_pipelines.append(prg._prg.pipeline_state) icb_command.setComputePipelineState(prg._prg.pipeline_state) for i,b in enumerate(ji.bufs): - if b is not None and b not in input_rawbuffers: + if b is not None and b not in input_buffers: icb_command.setKernelBuffer_offset_atIndex(b._buf.buf, b._buf.offset, i) all_resources.append(b._buf.buf) for i,v in enumerate(prg.p.vars): icb_command.setKernelBuffer_offset_atIndex(self.int_buf.buf, self.varlist.index(v.expr)*4, len(ji.bufs)+i) @@ -55,15 +55,15 @@ class MetalGraph(GraphRunner): for var in self.fixedvars: self.int_buf_view[self.varlist.index(var)] = self.fixedvars[var] self.range = metal.NSRange(0, len(jit_cache)) - def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None: + def __call__(self, input_buffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None: if self.command_buffer is not None and self.command_buffer in self.dev.mtl_buffers_in_flight: wait_check(self.command_buffer) # NOTE: old command buffer may not be inflight anymore if self.command_buffer is not None and PROFILE: self.collect_timestamps() - all_resources = dedup(self.all_resources + [input_rawbuffers[input_idx]._buf.buf for input_idx in self.input_replace.values()]) + all_resources = dedup(self.all_resources + [input_buffers[input_idx]._buf.buf for input_idx in self.input_replace.values()]) for (j,i),input_idx in self.input_replace.items(): computeCommand = self.icb.indirectComputeCommandAtIndex(j) - computeCommand.setKernelBuffer_offset_atIndex(input_rawbuffers[input_idx]._buf.buf, input_rawbuffers[input_idx]._buf.offset, i) + computeCommand.setKernelBuffer_offset_atIndex(input_buffers[input_idx]._buf.buf, input_buffers[input_idx]._buf.offset, i) for j, global_dims, local_dims in self.updated_launch_dims(var_vals): computeCommand = self.icb.indirectComputeCommandAtIndex(j) diff --git a/tinygrad/runtime/ops_null.py b/tinygrad/runtime/ops_null.py index d77773654c..0fb11ad572 100644 --- a/tinygrad/runtime/ops_null.py +++ b/tinygrad/runtime/ops_null.py @@ -28,7 +28,7 @@ class NullAllocator(Allocator['NullDevice']): def _offset(self, buf, offset:int, size:int): pass class NullGraph(MultiGraphRunner): - def __call__(self, input_rawbuffers, var_vals, wait=False) -> float|None: return 1e-1 + def __call__(self, input_buffers, var_vals, wait=False) -> float|None: return 1e-1 class NullDevice(Compiled): def __init__(self, device:str):