diff --git a/tinygrad/runtime/graph/cuda.py b/tinygrad/runtime/graph/cuda.py index e6e9afd2b4..48292175de 100644 --- a/tinygrad/runtime/graph/cuda.py +++ b/tinygrad/runtime/graph/cuda.py @@ -5,31 +5,30 @@ from tinygrad.helpers import dedup from tinygrad.runtime.support.c import init_c_var from tinygrad.device import Buffer, Device from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution -from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner +from tinygrad.engine.realize import BufferXfer, CompiledRunner from tinygrad.engine.jit import MultiGraphRunner, GraphException class CUDAGraph(MultiGraphRunner): - def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int], - orig_valid_positions: dict[int, set[int]]|None = None): - super().__init__(jit_cache, input_buffers, var_vals, orig_valid_positions) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) # Check all jit items are compatible. - if not all(isinstance(ji.prg, (CompiledRunner, BufferXfer)) for ji in jit_cache): raise GraphException + if not all(isinstance(ji.prg, (CompiledRunner, BufferXfer)) for ji in self.jit_cache): raise GraphException self.jc_idx_with_updatable_bufs = dedup([x[0] for x in self.input_replace.keys()]) self.updatable_nodes: dict[int, tuple[Any, Any, Any, bool]] = {} # dict[jc index] = tuple(graph node, node params, input kernel params, is memcpy) self.graph = init_c_var(cuda.CUgraph, lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0))) - for j,ji in enumerate(jit_cache): + for j,ji in enumerate(self.jit_cache): if isinstance(ji.prg, CompiledRunner): - global_size, local_size = ji.prg.p.launch_dims(var_vals) + global_size, local_size = ji.prg.p.launch_dims({v: 0 for v in self.vars}) new_node = cuda.CUgraphNode() deps = self._access_resources([x.base for x in ji.bufs if x is not None], ji.prg.p.outs, new_dependency=new_node) c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None - c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals.get(x.expr, ji.fixedvars.get(x.expr)) for x in ji.prg.p.vars]) + c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [ji.fixedvars.get(x.expr, 0) for x in ji.prg.p.vars]) kern_params = cuda.CUDA_KERNEL_NODE_PARAMS_v1(ji.prg._prg.prg, *global_size, *local_size, 0, ctypes.cast(0, ctypes.POINTER(ctypes.c_void_p)), vargs) check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params))) diff --git a/tinygrad/runtime/graph/hcq.py b/tinygrad/runtime/graph/hcq.py index 5d88111424..ae91f3ce30 100644 --- a/tinygrad/runtime/graph/hcq.py +++ b/tinygrad/runtime/graph/hcq.py @@ -9,16 +9,15 @@ from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner, Buffer from tinygrad.engine.jit import MultiGraphRunner class HCQGraph(MultiGraphRunner): - def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int], - orig_valid_positions: dict[int, set[int]]|None = None): - super().__init__(jit_cache, input_buffers, var_vals, orig_valid_positions) - self.devices = list(set(cast(HCQCompiled, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs])) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.devices = list(set(cast(HCQCompiled, d) for ji in self.jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs])) # CPU Device is always last self.devices = sorted(self.devices, key=lambda x: 1 if x._is_cpu() else 0) # Replace input buffers with variables. - self.hcq_bufs = [[cast(Buffer, x)._buf for x in ji.bufs] for ji in jit_cache] + self.hcq_bufs = [[cast(Buffer, x)._buf for x in ji.bufs] for ji in self.jit_cache] self.input_replace_to_var: dict[tuple[int, int], Variable] = {} for (j,i), input_idx in self.input_replace.items(): @@ -27,7 +26,7 @@ class HCQGraph(MultiGraphRunner): # Allocate kernel args. kernargs_size: dict[Compiled, int] = collections.defaultdict(int) - for ji in jit_cache: + for ji in self.jit_cache: if not isinstance(ji.prg, CompiledRunner): continue kernargs_size[ji.prg.dev] += round_up(ji.prg._prg.kernargs_alloc_size, 16) self.kernargs_bufs: dict[Compiled, HCQBuffer] = {d:d.allocator._alloc(max(sz, 1), BufferSpec(cpu_access=True)) for d,sz in kernargs_size.items()} @@ -36,7 +35,7 @@ class HCQGraph(MultiGraphRunner): self.ji_args: dict[int, HCQArgsState] = {} kargs_alloc: dict[Compiled, BumpAllocator] = {dev:BumpAllocator(buf.size) for dev,buf in self.kernargs_bufs.items()} - for j,ji in enumerate(jit_cache): + for j,ji in enumerate(self.jit_cache): if not isinstance(ji.prg, CompiledRunner): continue argsbuf = self.kernargs_bufs[ji.prg.dev].offset(kargs_alloc[ji.prg.dev].alloc(ji.prg._prg.kernargs_alloc_size, 16)) @@ -73,7 +72,7 @@ class HCQGraph(MultiGraphRunner): self.input_replace_map: dict[HCQCompiled, set[int]] = collections.defaultdict(set) self.device_vars: dict[HCQCompiled, dict[str, int]] = {} - for j,ji in enumerate(jit_cache): + for j,ji in enumerate(self.jit_cache): if is_exec_prg:=isinstance(ji.prg, CompiledRunner): enqueue_dev: HCQCompiled = ji.prg.dev else: # For copy ops prioritize enqeueuing on the dest device, so reverse the buffers. @@ -138,7 +137,7 @@ class HCQGraph(MultiGraphRunner): last_j[enqueue_queue] = j # Check which signals are used in the profile graph. - self.prof_signal_is_used = [any(ent.st_id == j or ent.en_id == j for ent in self.prof_graph_entries) for j in range(len(jit_cache) * 2)] + self.prof_signal_is_used = [any(ent.st_id == j or ent.en_id == j for ent in self.prof_graph_entries) for j in range(len(self.jit_cache) * 2)] # Build hardware queues. self.copy_to_devs: dict[HCQCompiled, set[HCQCompiled]] = {dev: set() for dev in self.devices} @@ -152,7 +151,7 @@ class HCQGraph(MultiGraphRunner): self.comp_queues[dev].memory_barrier().wait(self.virt_timeline_signals[dev], self.virt_timeline_vals[dev]) \ .wait(self.signals['KICK'], self.kickoff_var).signal(self.signals[dev], self.kickoff_var) - for j,ji in enumerate(jit_cache): + for j,ji in enumerate(self.jit_cache): enqueue_dev, enqueue_queue, sync_signals, deps, signal, signal_val = self.ji_schedule[j] # Lazy allocate signals diff --git a/tinygrad/runtime/graph/metal.py b/tinygrad/runtime/graph/metal.py index 23ada58668..815338b4ad 100644 --- a/tinygrad/runtime/graph/metal.py +++ b/tinygrad/runtime/graph/metal.py @@ -10,9 +10,8 @@ from tinygrad.runtime.autogen import metal from tinygrad.runtime.support import objc class MetalGraph(GraphRunner): - def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int], - orig_valid_positions: dict[int, set[int]]|None = None): - super().__init__(jit_cache, input_buffers, var_vals, orig_valid_positions) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) # create metal batch exec icb_descriptor = metal.MTLIndirectCommandBufferDescriptor.new() @@ -21,19 +20,19 @@ class MetalGraph(GraphRunner): icb_descriptor.setInheritPipelineState(False) icb_descriptor.setMaxKernelBufferBindCount(31) - self.icb = self.dev.sysdevice.newIndirectCommandBufferWithDescriptor_maxCommandCount_options(icb_descriptor, len(jit_cache), + self.icb = self.dev.sysdevice.newIndirectCommandBufferWithDescriptor_maxCommandCount_options(icb_descriptor, len(self.jit_cache), metal.MTLResourceCPUCacheModeDefaultCache) if self.icb.value is None: raise GraphException("create indirect command buffer failed, does your system support this?") # TODO: needs categories icb_label = bytes(objc.msg("UTF8String", ctypes.c_char_p)(objc.msg("description")(self.icb).retained())).decode() self.needs_icb_fix = int((m := re.search(r'AGXG(\d+)XFamily', icb_label)) is None or int(m.group(1)) < 15) # not required on M3+ - self.fixedvars = merge_dicts([ji.fixedvars for ji in jit_cache]) + self.fixedvars = merge_dicts([ji.fixedvars for ji in self.jit_cache]) self.varlist = self.vars + list(self.fixedvars.keys()) if len(self.varlist): self.int_buf = self.dev.allocator.alloc(len(self.varlist)*dtypes.int32.itemsize) all_pipelines, all_resources = [], [self.int_buf.buf] if len(self.varlist) else [] - for j,ji in enumerate(jit_cache): + for j,ji in enumerate(self.jit_cache): prg: CompiledRunner = cast(CompiledRunner, ji.prg) icb_command = self.icb.indirectComputeCommandAtIndex(j).retained() all_pipelines.append(prg._prg.pipeline_state) @@ -44,7 +43,7 @@ class MetalGraph(GraphRunner): all_resources.append(b._buf.buf) for i,v in enumerate(prg.p.vars): icb_command.setKernelBuffer_offset_atIndex(self.int_buf.buf, self.varlist.index(v.expr)*4, len(ji.bufs)+i) - global_size, local_size = prg.p.launch_dims(var_vals) + global_size, local_size = prg.p.launch_dims({v: 0 for v in self.vars}) icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup(metal.MTLSize(*global_size), metal.MTLSize(*local_size)) icb_command.setBarrier() @@ -53,7 +52,7 @@ class MetalGraph(GraphRunner): self.command_buffer: Any = None if len(self.varlist): self.int_buf_view = self.dev.allocator._as_buffer(self.int_buf).cast('i') for var in self.fixedvars: self.int_buf_view[self.varlist.index(var)] = self.fixedvars[var] - self.range = metal.NSRange(0, len(jit_cache)) + self.range = metal.NSRange(0, len(self.jit_cache)) def __call__(self, input_buffers: list[Buffer], var_vals: dict[str, int], wait=False) -> float|None: if self.command_buffer is not None and self.command_buffer in self.dev.mtl_buffers_in_flight: wait_check(self.command_buffer)