jit: pass init params straight to base (#15496)

* jit: pass init params straight to base

* linter
This commit is contained in:
nimlgen
2026-03-26 21:59:10 +08:00
committed by GitHub
parent ec5b7a249e
commit de24b3fe37
3 changed files with 23 additions and 26 deletions

View File

@@ -5,31 +5,30 @@ from tinygrad.helpers import dedup
from tinygrad.runtime.support.c import init_c_var
from tinygrad.device import Buffer, Device
from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
from tinygrad.engine.realize import BufferXfer, CompiledRunner
from tinygrad.engine.jit import MultiGraphRunner, GraphException
class CUDAGraph(MultiGraphRunner):
def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int],
orig_valid_positions: dict[int, set[int]]|None = None):
super().__init__(jit_cache, input_buffers, var_vals, orig_valid_positions)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Check all jit items are compatible.
if not all(isinstance(ji.prg, (CompiledRunner, BufferXfer)) for ji in jit_cache): raise GraphException
if not all(isinstance(ji.prg, (CompiledRunner, BufferXfer)) for ji in self.jit_cache): raise GraphException
self.jc_idx_with_updatable_bufs = dedup([x[0] for x in self.input_replace.keys()])
self.updatable_nodes: dict[int, tuple[Any, Any, Any, bool]] = {} # dict[jc index] = tuple(graph node, node params, input kernel params, is memcpy)
self.graph = init_c_var(cuda.CUgraph, lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
for j,ji in enumerate(jit_cache):
for j,ji in enumerate(self.jit_cache):
if isinstance(ji.prg, CompiledRunner):
global_size, local_size = ji.prg.p.launch_dims(var_vals)
global_size, local_size = ji.prg.p.launch_dims({v: 0 for v in self.vars})
new_node = cuda.CUgraphNode()
deps = self._access_resources([x.base for x in ji.bufs if x is not None], ji.prg.p.outs, new_dependency=new_node)
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals.get(x.expr, ji.fixedvars.get(x.expr)) for x in ji.prg.p.vars])
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [ji.fixedvars.get(x.expr, 0) for x in ji.prg.p.vars])
kern_params = cuda.CUDA_KERNEL_NODE_PARAMS_v1(ji.prg._prg.prg, *global_size, *local_size, 0, ctypes.cast(0, ctypes.POINTER(ctypes.c_void_p)),
vargs)
check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params)))

View File

@@ -9,16 +9,15 @@ from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner, Buffer
from tinygrad.engine.jit import MultiGraphRunner
class HCQGraph(MultiGraphRunner):
def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int],
orig_valid_positions: dict[int, set[int]]|None = None):
super().__init__(jit_cache, input_buffers, var_vals, orig_valid_positions)
self.devices = list(set(cast(HCQCompiled, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs]))
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.devices = list(set(cast(HCQCompiled, d) for ji in self.jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs]))
# CPU Device is always last
self.devices = sorted(self.devices, key=lambda x: 1 if x._is_cpu() else 0)
# Replace input buffers with variables.
self.hcq_bufs = [[cast(Buffer, x)._buf for x in ji.bufs] for ji in jit_cache]
self.hcq_bufs = [[cast(Buffer, x)._buf for x in ji.bufs] for ji in self.jit_cache]
self.input_replace_to_var: dict[tuple[int, int], Variable] = {}
for (j,i), input_idx in self.input_replace.items():
@@ -27,7 +26,7 @@ class HCQGraph(MultiGraphRunner):
# Allocate kernel args.
kernargs_size: dict[Compiled, int] = collections.defaultdict(int)
for ji in jit_cache:
for ji in self.jit_cache:
if not isinstance(ji.prg, CompiledRunner): continue
kernargs_size[ji.prg.dev] += round_up(ji.prg._prg.kernargs_alloc_size, 16)
self.kernargs_bufs: dict[Compiled, HCQBuffer] = {d:d.allocator._alloc(max(sz, 1), BufferSpec(cpu_access=True)) for d,sz in kernargs_size.items()}
@@ -36,7 +35,7 @@ class HCQGraph(MultiGraphRunner):
self.ji_args: dict[int, HCQArgsState] = {}
kargs_alloc: dict[Compiled, BumpAllocator] = {dev:BumpAllocator(buf.size) for dev,buf in self.kernargs_bufs.items()}
for j,ji in enumerate(jit_cache):
for j,ji in enumerate(self.jit_cache):
if not isinstance(ji.prg, CompiledRunner): continue
argsbuf = self.kernargs_bufs[ji.prg.dev].offset(kargs_alloc[ji.prg.dev].alloc(ji.prg._prg.kernargs_alloc_size, 16))
@@ -73,7 +72,7 @@ class HCQGraph(MultiGraphRunner):
self.input_replace_map: dict[HCQCompiled, set[int]] = collections.defaultdict(set)
self.device_vars: dict[HCQCompiled, dict[str, int]] = {}
for j,ji in enumerate(jit_cache):
for j,ji in enumerate(self.jit_cache):
if is_exec_prg:=isinstance(ji.prg, CompiledRunner): enqueue_dev: HCQCompiled = ji.prg.dev
else:
# For copy ops prioritize enqeueuing on the dest device, so reverse the buffers.
@@ -138,7 +137,7 @@ class HCQGraph(MultiGraphRunner):
last_j[enqueue_queue] = j
# Check which signals are used in the profile graph.
self.prof_signal_is_used = [any(ent.st_id == j or ent.en_id == j for ent in self.prof_graph_entries) for j in range(len(jit_cache) * 2)]
self.prof_signal_is_used = [any(ent.st_id == j or ent.en_id == j for ent in self.prof_graph_entries) for j in range(len(self.jit_cache) * 2)]
# Build hardware queues.
self.copy_to_devs: dict[HCQCompiled, set[HCQCompiled]] = {dev: set() for dev in self.devices}
@@ -152,7 +151,7 @@ class HCQGraph(MultiGraphRunner):
self.comp_queues[dev].memory_barrier().wait(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev]) \
.wait(self.signals['KICK'], self.kickoff_var).signal(self.signals[dev], self.kickoff_var)
for j,ji in enumerate(jit_cache):
for j,ji in enumerate(self.jit_cache):
enqueue_dev, enqueue_queue, sync_signals, deps, signal, signal_val = self.ji_schedule[j]
# Lazy allocate signals

View File

@@ -10,9 +10,8 @@ from tinygrad.runtime.autogen import metal
from tinygrad.runtime.support import objc
class MetalGraph(GraphRunner):
def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int],
orig_valid_positions: dict[int, set[int]]|None = None):
super().__init__(jit_cache, input_buffers, var_vals, orig_valid_positions)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# create metal batch exec
icb_descriptor = metal.MTLIndirectCommandBufferDescriptor.new()
@@ -21,19 +20,19 @@ class MetalGraph(GraphRunner):
icb_descriptor.setInheritPipelineState(False)
icb_descriptor.setMaxKernelBufferBindCount(31)
self.icb = self.dev.sysdevice.newIndirectCommandBufferWithDescriptor_maxCommandCount_options(icb_descriptor, len(jit_cache),
self.icb = self.dev.sysdevice.newIndirectCommandBufferWithDescriptor_maxCommandCount_options(icb_descriptor, len(self.jit_cache),
metal.MTLResourceCPUCacheModeDefaultCache)
if self.icb.value is None: raise GraphException("create indirect command buffer failed, does your system support this?")
# TODO: needs categories
icb_label = bytes(objc.msg("UTF8String", ctypes.c_char_p)(objc.msg("description")(self.icb).retained())).decode()
self.needs_icb_fix = int((m := re.search(r'AGXG(\d+)XFamily', icb_label)) is None or int(m.group(1)) < 15) # not required on M3+
self.fixedvars = merge_dicts([ji.fixedvars for ji in jit_cache])
self.fixedvars = merge_dicts([ji.fixedvars for ji in self.jit_cache])
self.varlist = self.vars + list(self.fixedvars.keys())
if len(self.varlist): self.int_buf = self.dev.allocator.alloc(len(self.varlist)*dtypes.int32.itemsize)
all_pipelines, all_resources = [], [self.int_buf.buf] if len(self.varlist) else []
for j,ji in enumerate(jit_cache):
for j,ji in enumerate(self.jit_cache):
prg: CompiledRunner = cast(CompiledRunner, ji.prg)
icb_command = self.icb.indirectComputeCommandAtIndex(j).retained()
all_pipelines.append(prg._prg.pipeline_state)
@@ -44,7 +43,7 @@ class MetalGraph(GraphRunner):
all_resources.append(b._buf.buf)
for i,v in enumerate(prg.p.vars): icb_command.setKernelBuffer_offset_atIndex(self.int_buf.buf, self.varlist.index(v.expr)*4, len(ji.bufs)+i)
global_size, local_size = prg.p.launch_dims(var_vals)
global_size, local_size = prg.p.launch_dims({v: 0 for v in self.vars})
icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup(metal.MTLSize(*global_size), metal.MTLSize(*local_size))
icb_command.setBarrier()
@@ -53,7 +52,7 @@ class MetalGraph(GraphRunner):
self.command_buffer: Any = None
if len(self.varlist): self.int_buf_view = self.dev.allocator._as_buffer(self.int_buf).cast('i')
for var in self.fixedvars: self.int_buf_view[self.varlist.index(var)] = self.fixedvars[var]
self.range = metal.NSRange(0, len(jit_cache))
self.range = metal.NSRange(0, len(self.jit_cache))
def __call__(self, input_buffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None:
if self.command_buffer is not None and self.command_buffer in self.dev.mtl_buffers_in_flight: wait_check(self.command_buffer)