remove jit_cache from self in GraphRunner [pr] (#7817)

* remove jit_cache from self in GraphRunner [pr]

* add back unused
This commit is contained in:
George Hotz
2024-11-21 13:26:37 +08:00
committed by GitHub
parent e9ae2ccd09
commit df6f1815ad
4 changed files with 21 additions and 20 deletions

View File

@@ -71,10 +71,11 @@ def get_input_replace(jit_cache: List[ExecItem], input_rawbuffers:List[Buffer])
class GraphRunner(Runner):
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
self.jit_cache = jit_cache
self.jit_cache = jit_cache # NOTE: this is not used, but you have to keep these objects alive for the Graph
self.input_replace:Dict[Tuple[int, int], int] = get_input_replace(jit_cache, input_rawbuffers)
self.var_vals_replace:Dict[int, List[int]] = {}
self.launch_dims_replace:Dict[int, Tuple[Optional[int], Optional[int]]] = {}
self.launch_dims_base:Dict[int, Tuple[Tuple[int, ...], Tuple[int, ...]]] = {}
op_estimate: sint = 0
mem_estimate: sint = 0
@@ -95,13 +96,16 @@ class GraphRunner(Runner):
if ji.prg.p.vars: self.var_vals_replace[j] = [self.vars.index(v) for v in ji.prg.p.vars]
global_dim_idx, local_dim_idx = find_symbolic_dim(ji.prg.p.global_size), find_symbolic_dim(ji.prg.p.local_size)
if global_dim_idx is not None or local_dim_idx is not None: self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx)
if global_dim_idx is not None or local_dim_idx is not None:
self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx)
assert ji.prg.p.global_size is not None and ji.prg.p.local_size is not None
self.launch_dims_base[j] = (tuple(ji.prg.p.global_size), tuple(ji.prg.p.local_size))
# used in MultiGraphRunner. the ints are id() of _bufs
self.w_dependency_map: Dict[int, Any] = {}
self.r_dependency_map: Dict[int, List[Any]] = collections.defaultdict(list)
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), jit_cache[0].prg.device.split(":")[0],
super().__init__(colored(f"<batched {len(jit_cache)}>", "cyan"), jit_cache[0].prg.device.split(":")[0],
ssimplify(op_estimate), ssimplify(mem_estimate), ssimplify(lds_estimate))
def updated_vars(self, var_vals: Dict[Variable, int]):
@@ -111,7 +115,8 @@ class GraphRunner(Runner):
def updated_launch_dims(self, var_vals: Dict[Variable, int]):
dims = [tuple(sym_infer(s, var_vals) for s in dim) for dim in self.symbolic_dims]
for j, (gl, lc) in self.launch_dims_replace.items(): yield j, (dims[gl] if gl is not None else None), (dims[lc] if lc is not None else None)
for j, (gl, lc) in self.launch_dims_replace.items():
yield j, (dims[gl] if gl is not None else self.launch_dims_base[j][0]), (dims[lc] if lc is not None else self.launch_dims_base[j][1])
def _access_resources(self, rawbufs:List[Buffer], write:List[int], new_dependency:Any):
# To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource,

View File

@@ -20,7 +20,7 @@ class CUDAGraph(MultiGraphRunner):
self.graph = init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
for j,ji in enumerate(self.jit_cache):
for j,ji in enumerate(jit_cache):
if isinstance(ji.prg, CompiledRunner):
global_size, local_size = ji.prg.p.launch_dims(var_vals)
@@ -61,9 +61,8 @@ class CUDAGraph(MultiGraphRunner):
# Update launch dims in the kern_params struct.
for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
prg = cast(CompiledRunner, self.jit_cache[j].prg)
node, global_size, local_size = self.updatable_nodes[j][1], global_dims or prg.p.global_size, local_dims or prg.p.local_size
node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_size, *global_size # type: ignore[misc]
node = self.updatable_nodes[j][1]
node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_dims, *global_dims # type: ignore[misc]
# Update graph nodes with the updated structs.
for node, c_node_params, c_args, is_copy in self.updatable_nodes.values():

View File

@@ -14,7 +14,7 @@ class HCQGraph(MultiGraphRunner):
# Allocate kernel args.
kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
for ji in self.jit_cache:
for ji in jit_cache:
if not isinstance(ji.prg, CompiledRunner): continue
kernargs_size[ji.prg.dev] += round_up(ji.prg._prg.kernargs_alloc_size, 16)
self.kernargs_bufs: Dict[Compiled, HCQBuffer] = {dev:dev.allocator._alloc(sz, BufferSpec(cpu_access=True)) for dev,sz in kernargs_size.items()}
@@ -23,7 +23,7 @@ class HCQGraph(MultiGraphRunner):
self.ji_args: Dict[int, HCQArgsState] = {}
kargs_ptrs: Dict[Compiled, int] = {dev:buf.va_addr for dev,buf in self.kernargs_bufs.items()}
for j,ji in enumerate(self.jit_cache):
for j,ji in enumerate(jit_cache):
if not isinstance(ji.prg, CompiledRunner): continue
kargs_ptrs[ji.prg.dev] = (kargs_ptr:=kargs_ptrs[ji.prg.dev]) + round_up(ji.prg._prg.kernargs_alloc_size, 16)
self.ji_args[j] = ji.prg._prg.fill_kernargs([cast(Buffer, b)._buf for b in ji.bufs], [var_vals[v] for v in ji.prg.p.vars], kargs_ptr)
@@ -41,7 +41,7 @@ class HCQGraph(MultiGraphRunner):
self.signals: Dict[Any, HCQSignal] = {**{dev: dev.signal_t(value=0) for dev in self.devices}, **{"CPU": self.devices[0].signal_t(value=0)}}
self.kickoff_value: int = 0
self.prof_signals: List[HCQSignal] = [self.devices[0].signal_t() for i in range(len(self.jit_cache) * 2)] if PROFILE else []
self.prof_signals: List[HCQSignal] = [self.devices[0].signal_t() for i in range(len(jit_cache) * 2)] if PROFILE else []
self.prof_records: List[Tuple[Tuple[int, bool], Tuple[int, bool], HCQCompiled, str, bool, List[int], Optional[Dict]]] = []
last_j: Dict[HWQueue, Optional[int]] = collections.defaultdict(lambda: None)
@@ -50,7 +50,7 @@ class HCQGraph(MultiGraphRunner):
for dev, queue in self.comp_queues.items(): dev_access[queue].add(dev)
for j,ji in enumerate(self.jit_cache):
for j,ji in enumerate(jit_cache):
enqueue_dev = ji.prg.dev if (is_exec_prg:=isinstance(ji.prg, CompiledRunner)) else Device[ji.bufs[1].device] #type:ignore
enqueue_queue = self.comp_queues[enqueue_dev] if is_exec_prg else self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t())
out_signal = self.signals.setdefault(enqueue_queue, enqueue_dev.signal_t(value=0))
@@ -107,7 +107,7 @@ class HCQGraph(MultiGraphRunner):
self.comp_queues[dev].memory_barrier().wait(dev.timeline_signal, dev.timeline_value - 1) \
.wait(self.signals['CPU'], self.kickoff_value).signal(self.signals[dev], self.kickoff_value)
for j,ji in enumerate(self.jit_cache):
for j,ji in enumerate(jit_cache):
enqueue_dev, enqueue_queue, sync_signals, deps, signal, signal_val = self.ji_schedule[j]
for i in range(len(sync_signals)): self.kickoff_wait_cmds[enqueue_queue].append(len(enqueue_queue) + i)

View File

@@ -29,7 +29,7 @@ class MetalGraph(GraphRunner):
msg(icb_descriptor, "setMaxKernelBufferBindCount:", 31)
self.icb = msg(self.dev.sysdevice, "newIndirectCommandBufferWithDescriptor:maxCommandCount:options:",
icb_descriptor, len(self.jit_cache), MTLResourceOptions.MTLResourceCPUCacheModeDefaultCache, restype=objc_instance)
icb_descriptor, len(jit_cache), MTLResourceOptions.MTLResourceCPUCacheModeDefaultCache, restype=objc_instance)
if self.icb.value is None: raise GraphException("create indirect command buffer failed, does your system support this?")
icb_label = bytes(msg(msg(self.icb, "description", restype=objc_instance), "UTF8String", restype=ctypes.c_char_p)).decode()
self.needs_icb_fix = int("AGXG15XFamilyIndirectCommandBuffer" not in icb_label) # not required on M3
@@ -37,7 +37,7 @@ class MetalGraph(GraphRunner):
if len(self.vars): self.int_buf = self.dev.allocator.alloc(len(self.vars)*dtypes.int32.itemsize)
all_resources = [self.int_buf.buf] if len(self.vars) else []
all_pipelines = []
for j,ji in enumerate(self.jit_cache):
for j,ji in enumerate(jit_cache):
prg: CompiledRunner = cast(CompiledRunner, ji.prg)
icb_command = msg(self.icb, "indirectComputeCommandAtIndex:", j, restype=objc_instance)
all_pipelines.append(prg._prg.pipeline_state)
@@ -56,7 +56,7 @@ class MetalGraph(GraphRunner):
self.all_pipelines = dedup(all_pipelines)
self.command_buffer: Any = None
if len(self.vars): self.int_buf_view = self.dev.allocator._as_buffer(self.int_buf).cast('i')
self.range = to_struct(0, len(self.jit_cache))
self.range = to_struct(0, len(jit_cache))
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
@@ -69,11 +69,8 @@ class MetalGraph(GraphRunner):
input_rawbuffers[input_idx]._buf.offset, i)
for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
prg = cast(CompiledRunner, self.jit_cache[j].prg)
global_size, local_size = global_dims or prg.p.global_size, local_dims or prg.p.local_size
computeCommand = msg(self.icb, "indirectComputeCommandAtIndex:", j)
msg(computeCommand, "concurrentDispatchThreadgroups:threadsPerThreadgroup:",
to_struct(*cast(tuple, global_size)), to_struct(*cast(tuple, local_size)))
msg(computeCommand, "concurrentDispatchThreadgroups:threadsPerThreadgroup:", to_struct(*global_dims), to_struct(*local_dims))
for j, var in enumerate(self.vars): self.int_buf_view[j] = var_vals[var]
command_buffer = msg(self.dev.mtl_queue, "commandBuffer", restype=objc_instance)