diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index c797654eac..15f1611f9c 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -71,10 +71,11 @@ def get_input_replace(jit_cache: List[ExecItem], input_rawbuffers:List[Buffer]) class GraphRunner(Runner): def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]): - self.jit_cache = jit_cache + self.jit_cache = jit_cache # NOTE: this is not used, but you have to keep these objects alive for the Graph self.input_replace:Dict[Tuple[int, int], int] = get_input_replace(jit_cache, input_rawbuffers) self.var_vals_replace:Dict[int, List[int]] = {} self.launch_dims_replace:Dict[int, Tuple[Optional[int], Optional[int]]] = {} + self.launch_dims_base:Dict[int, Tuple[Tuple[int, ...], Tuple[int, ...]]] = {} op_estimate: sint = 0 mem_estimate: sint = 0 @@ -95,13 +96,16 @@ class GraphRunner(Runner): if ji.prg.p.vars: self.var_vals_replace[j] = [self.vars.index(v) for v in ji.prg.p.vars] global_dim_idx, local_dim_idx = find_symbolic_dim(ji.prg.p.global_size), find_symbolic_dim(ji.prg.p.local_size) - if global_dim_idx is not None or local_dim_idx is not None: self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx) + if global_dim_idx is not None or local_dim_idx is not None: + self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx) + assert ji.prg.p.global_size is not None and ji.prg.p.local_size is not None + self.launch_dims_base[j] = (tuple(ji.prg.p.global_size), tuple(ji.prg.p.local_size)) # used in MultiGraphRunner. the ints are id() of _bufs self.w_dependency_map: Dict[int, Any] = {} self.r_dependency_map: Dict[int, List[Any]] = collections.defaultdict(list) - super().__init__(colored(f"", "cyan"), jit_cache[0].prg.device.split(":")[0], + super().__init__(colored(f"", "cyan"), jit_cache[0].prg.device.split(":")[0], ssimplify(op_estimate), ssimplify(mem_estimate), ssimplify(lds_estimate)) def updated_vars(self, var_vals: Dict[Variable, int]): @@ -111,7 +115,8 @@ class GraphRunner(Runner): def updated_launch_dims(self, var_vals: Dict[Variable, int]): dims = [tuple(sym_infer(s, var_vals) for s in dim) for dim in self.symbolic_dims] - for j, (gl, lc) in self.launch_dims_replace.items(): yield j, (dims[gl] if gl is not None else None), (dims[lc] if lc is not None else None) + for j, (gl, lc) in self.launch_dims_replace.items(): + yield j, (dims[gl] if gl is not None else self.launch_dims_base[j][0]), (dims[lc] if lc is not None else self.launch_dims_base[j][1]) def _access_resources(self, rawbufs:List[Buffer], write:List[int], new_dependency:Any): # To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource, diff --git a/tinygrad/runtime/graph/cuda.py b/tinygrad/runtime/graph/cuda.py index b5cb96b4cd..c0c54ae167 100644 --- a/tinygrad/runtime/graph/cuda.py +++ b/tinygrad/runtime/graph/cuda.py @@ -20,7 +20,7 @@ class CUDAGraph(MultiGraphRunner): self.graph = init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0))) - for j,ji in enumerate(self.jit_cache): + for j,ji in enumerate(jit_cache): if isinstance(ji.prg, CompiledRunner): global_size, local_size = ji.prg.p.launch_dims(var_vals) @@ -61,9 +61,8 @@ class CUDAGraph(MultiGraphRunner): # Update launch dims in the kern_params struct. for j, global_dims, local_dims in self.updated_launch_dims(var_vals): - prg = cast(CompiledRunner, self.jit_cache[j].prg) - node, global_size, local_size = self.updatable_nodes[j][1], global_dims or prg.p.global_size, local_dims or prg.p.local_size - node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_size, *global_size # type: ignore[misc] + node = self.updatable_nodes[j][1] + node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_dims, *global_dims # type: ignore[misc] # Update graph nodes with the updated structs. for node, c_node_params, c_args, is_copy in self.updatable_nodes.values(): diff --git a/tinygrad/runtime/graph/hcq.py b/tinygrad/runtime/graph/hcq.py index f727d8b466..4dc3b75e33 100644 --- a/tinygrad/runtime/graph/hcq.py +++ b/tinygrad/runtime/graph/hcq.py @@ -14,7 +14,7 @@ class HCQGraph(MultiGraphRunner): # Allocate kernel args. kernargs_size: Dict[Compiled, int] = collections.defaultdict(int) - for ji in self.jit_cache: + for ji in jit_cache: if not isinstance(ji.prg, CompiledRunner): continue kernargs_size[ji.prg.dev] += round_up(ji.prg._prg.kernargs_alloc_size, 16) self.kernargs_bufs: Dict[Compiled, HCQBuffer] = {dev:dev.allocator._alloc(sz, BufferSpec(cpu_access=True)) for dev,sz in kernargs_size.items()} @@ -23,7 +23,7 @@ class HCQGraph(MultiGraphRunner): self.ji_args: Dict[int, HCQArgsState] = {} kargs_ptrs: Dict[Compiled, int] = {dev:buf.va_addr for dev,buf in self.kernargs_bufs.items()} - for j,ji in enumerate(self.jit_cache): + for j,ji in enumerate(jit_cache): if not isinstance(ji.prg, CompiledRunner): continue kargs_ptrs[ji.prg.dev] = (kargs_ptr:=kargs_ptrs[ji.prg.dev]) + round_up(ji.prg._prg.kernargs_alloc_size, 16) self.ji_args[j] = ji.prg._prg.fill_kernargs([cast(Buffer, b)._buf for b in ji.bufs], [var_vals[v] for v in ji.prg.p.vars], kargs_ptr) @@ -41,7 +41,7 @@ class HCQGraph(MultiGraphRunner): self.signals: Dict[Any, HCQSignal] = {**{dev: dev.signal_t(value=0) for dev in self.devices}, **{"CPU": self.devices[0].signal_t(value=0)}} self.kickoff_value: int = 0 - self.prof_signals: List[HCQSignal] = [self.devices[0].signal_t() for i in range(len(self.jit_cache) * 2)] if PROFILE else [] + self.prof_signals: List[HCQSignal] = [self.devices[0].signal_t() for i in range(len(jit_cache) * 2)] if PROFILE else [] self.prof_records: List[Tuple[Tuple[int, bool], Tuple[int, bool], HCQCompiled, str, bool, List[int], Optional[Dict]]] = [] last_j: Dict[HWQueue, Optional[int]] = collections.defaultdict(lambda: None) @@ -50,7 +50,7 @@ class HCQGraph(MultiGraphRunner): for dev, queue in self.comp_queues.items(): dev_access[queue].add(dev) - for j,ji in enumerate(self.jit_cache): + for j,ji in enumerate(jit_cache): enqueue_dev = ji.prg.dev if (is_exec_prg:=isinstance(ji.prg, CompiledRunner)) else Device[ji.bufs[1].device] #type:ignore enqueue_queue = self.comp_queues[enqueue_dev] if is_exec_prg else self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t()) out_signal = self.signals.setdefault(enqueue_queue, enqueue_dev.signal_t(value=0)) @@ -107,7 +107,7 @@ class HCQGraph(MultiGraphRunner): self.comp_queues[dev].memory_barrier().wait(dev.timeline_signal, dev.timeline_value - 1) \ .wait(self.signals['CPU'], self.kickoff_value).signal(self.signals[dev], self.kickoff_value) - for j,ji in enumerate(self.jit_cache): + for j,ji in enumerate(jit_cache): enqueue_dev, enqueue_queue, sync_signals, deps, signal, signal_val = self.ji_schedule[j] for i in range(len(sync_signals)): self.kickoff_wait_cmds[enqueue_queue].append(len(enqueue_queue) + i) diff --git a/tinygrad/runtime/graph/metal.py b/tinygrad/runtime/graph/metal.py index c0a1779715..6016399b1e 100644 --- a/tinygrad/runtime/graph/metal.py +++ b/tinygrad/runtime/graph/metal.py @@ -29,7 +29,7 @@ class MetalGraph(GraphRunner): msg(icb_descriptor, "setMaxKernelBufferBindCount:", 31) self.icb = msg(self.dev.sysdevice, "newIndirectCommandBufferWithDescriptor:maxCommandCount:options:", - icb_descriptor, len(self.jit_cache), MTLResourceOptions.MTLResourceCPUCacheModeDefaultCache, restype=objc_instance) + icb_descriptor, len(jit_cache), MTLResourceOptions.MTLResourceCPUCacheModeDefaultCache, restype=objc_instance) if self.icb.value is None: raise GraphException("create indirect command buffer failed, does your system support this?") icb_label = bytes(msg(msg(self.icb, "description", restype=objc_instance), "UTF8String", restype=ctypes.c_char_p)).decode() self.needs_icb_fix = int("AGXG15XFamilyIndirectCommandBuffer" not in icb_label) # not required on M3 @@ -37,7 +37,7 @@ class MetalGraph(GraphRunner): if len(self.vars): self.int_buf = self.dev.allocator.alloc(len(self.vars)*dtypes.int32.itemsize) all_resources = [self.int_buf.buf] if len(self.vars) else [] all_pipelines = [] - for j,ji in enumerate(self.jit_cache): + for j,ji in enumerate(jit_cache): prg: CompiledRunner = cast(CompiledRunner, ji.prg) icb_command = msg(self.icb, "indirectComputeCommandAtIndex:", j, restype=objc_instance) all_pipelines.append(prg._prg.pipeline_state) @@ -56,7 +56,7 @@ class MetalGraph(GraphRunner): self.all_pipelines = dedup(all_pipelines) self.command_buffer: Any = None if len(self.vars): self.int_buf_view = self.dev.allocator._as_buffer(self.int_buf).cast('i') - self.range = to_struct(0, len(self.jit_cache)) + self.range = to_struct(0, len(jit_cache)) def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]: @@ -69,11 +69,8 @@ class MetalGraph(GraphRunner): input_rawbuffers[input_idx]._buf.offset, i) for j, global_dims, local_dims in self.updated_launch_dims(var_vals): - prg = cast(CompiledRunner, self.jit_cache[j].prg) - global_size, local_size = global_dims or prg.p.global_size, local_dims or prg.p.local_size computeCommand = msg(self.icb, "indirectComputeCommandAtIndex:", j) - msg(computeCommand, "concurrentDispatchThreadgroups:threadsPerThreadgroup:", - to_struct(*cast(tuple, global_size)), to_struct(*cast(tuple, local_size))) + msg(computeCommand, "concurrentDispatchThreadgroups:threadsPerThreadgroup:", to_struct(*global_dims), to_struct(*local_dims)) for j, var in enumerate(self.vars): self.int_buf_view[j] = var_vals[var] command_buffer = msg(self.dev.mtl_queue, "commandBuffer", restype=objc_instance)