mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
remove jit_cache from self in GraphRunner [pr] (#7817)
* remove jit_cache from self in GraphRunner [pr] * add back unused
This commit is contained in:
@@ -71,10 +71,11 @@ def get_input_replace(jit_cache: List[ExecItem], input_rawbuffers:List[Buffer])
|
||||
|
||||
class GraphRunner(Runner):
|
||||
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
self.jit_cache = jit_cache
|
||||
self.jit_cache = jit_cache # NOTE: this is not used, but you have to keep these objects alive for the Graph
|
||||
self.input_replace:Dict[Tuple[int, int], int] = get_input_replace(jit_cache, input_rawbuffers)
|
||||
self.var_vals_replace:Dict[int, List[int]] = {}
|
||||
self.launch_dims_replace:Dict[int, Tuple[Optional[int], Optional[int]]] = {}
|
||||
self.launch_dims_base:Dict[int, Tuple[Tuple[int, ...], Tuple[int, ...]]] = {}
|
||||
|
||||
op_estimate: sint = 0
|
||||
mem_estimate: sint = 0
|
||||
@@ -95,13 +96,16 @@ class GraphRunner(Runner):
|
||||
if ji.prg.p.vars: self.var_vals_replace[j] = [self.vars.index(v) for v in ji.prg.p.vars]
|
||||
|
||||
global_dim_idx, local_dim_idx = find_symbolic_dim(ji.prg.p.global_size), find_symbolic_dim(ji.prg.p.local_size)
|
||||
if global_dim_idx is not None or local_dim_idx is not None: self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx)
|
||||
if global_dim_idx is not None or local_dim_idx is not None:
|
||||
self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx)
|
||||
assert ji.prg.p.global_size is not None and ji.prg.p.local_size is not None
|
||||
self.launch_dims_base[j] = (tuple(ji.prg.p.global_size), tuple(ji.prg.p.local_size))
|
||||
|
||||
# used in MultiGraphRunner. the ints are id() of _bufs
|
||||
self.w_dependency_map: Dict[int, Any] = {}
|
||||
self.r_dependency_map: Dict[int, List[Any]] = collections.defaultdict(list)
|
||||
|
||||
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), jit_cache[0].prg.device.split(":")[0],
|
||||
super().__init__(colored(f"<batched {len(jit_cache)}>", "cyan"), jit_cache[0].prg.device.split(":")[0],
|
||||
ssimplify(op_estimate), ssimplify(mem_estimate), ssimplify(lds_estimate))
|
||||
|
||||
def updated_vars(self, var_vals: Dict[Variable, int]):
|
||||
@@ -111,7 +115,8 @@ class GraphRunner(Runner):
|
||||
|
||||
def updated_launch_dims(self, var_vals: Dict[Variable, int]):
|
||||
dims = [tuple(sym_infer(s, var_vals) for s in dim) for dim in self.symbolic_dims]
|
||||
for j, (gl, lc) in self.launch_dims_replace.items(): yield j, (dims[gl] if gl is not None else None), (dims[lc] if lc is not None else None)
|
||||
for j, (gl, lc) in self.launch_dims_replace.items():
|
||||
yield j, (dims[gl] if gl is not None else self.launch_dims_base[j][0]), (dims[lc] if lc is not None else self.launch_dims_base[j][1])
|
||||
|
||||
def _access_resources(self, rawbufs:List[Buffer], write:List[int], new_dependency:Any):
|
||||
# To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource,
|
||||
|
||||
@@ -20,7 +20,7 @@ class CUDAGraph(MultiGraphRunner):
|
||||
|
||||
self.graph = init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
|
||||
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
for j,ji in enumerate(jit_cache):
|
||||
if isinstance(ji.prg, CompiledRunner):
|
||||
global_size, local_size = ji.prg.p.launch_dims(var_vals)
|
||||
|
||||
@@ -61,9 +61,8 @@ class CUDAGraph(MultiGraphRunner):
|
||||
|
||||
# Update launch dims in the kern_params struct.
|
||||
for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
|
||||
prg = cast(CompiledRunner, self.jit_cache[j].prg)
|
||||
node, global_size, local_size = self.updatable_nodes[j][1], global_dims or prg.p.global_size, local_dims or prg.p.local_size
|
||||
node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_size, *global_size # type: ignore[misc]
|
||||
node = self.updatable_nodes[j][1]
|
||||
node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ = *local_dims, *global_dims # type: ignore[misc]
|
||||
|
||||
# Update graph nodes with the updated structs.
|
||||
for node, c_node_params, c_args, is_copy in self.updatable_nodes.values():
|
||||
|
||||
@@ -14,7 +14,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
|
||||
# Allocate kernel args.
|
||||
kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
|
||||
for ji in self.jit_cache:
|
||||
for ji in jit_cache:
|
||||
if not isinstance(ji.prg, CompiledRunner): continue
|
||||
kernargs_size[ji.prg.dev] += round_up(ji.prg._prg.kernargs_alloc_size, 16)
|
||||
self.kernargs_bufs: Dict[Compiled, HCQBuffer] = {dev:dev.allocator._alloc(sz, BufferSpec(cpu_access=True)) for dev,sz in kernargs_size.items()}
|
||||
@@ -23,7 +23,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
self.ji_args: Dict[int, HCQArgsState] = {}
|
||||
|
||||
kargs_ptrs: Dict[Compiled, int] = {dev:buf.va_addr for dev,buf in self.kernargs_bufs.items()}
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
for j,ji in enumerate(jit_cache):
|
||||
if not isinstance(ji.prg, CompiledRunner): continue
|
||||
kargs_ptrs[ji.prg.dev] = (kargs_ptr:=kargs_ptrs[ji.prg.dev]) + round_up(ji.prg._prg.kernargs_alloc_size, 16)
|
||||
self.ji_args[j] = ji.prg._prg.fill_kernargs([cast(Buffer, b)._buf for b in ji.bufs], [var_vals[v] for v in ji.prg.p.vars], kargs_ptr)
|
||||
@@ -41,7 +41,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
self.signals: Dict[Any, HCQSignal] = {**{dev: dev.signal_t(value=0) for dev in self.devices}, **{"CPU": self.devices[0].signal_t(value=0)}}
|
||||
self.kickoff_value: int = 0
|
||||
|
||||
self.prof_signals: List[HCQSignal] = [self.devices[0].signal_t() for i in range(len(self.jit_cache) * 2)] if PROFILE else []
|
||||
self.prof_signals: List[HCQSignal] = [self.devices[0].signal_t() for i in range(len(jit_cache) * 2)] if PROFILE else []
|
||||
self.prof_records: List[Tuple[Tuple[int, bool], Tuple[int, bool], HCQCompiled, str, bool, List[int], Optional[Dict]]] = []
|
||||
|
||||
last_j: Dict[HWQueue, Optional[int]] = collections.defaultdict(lambda: None)
|
||||
@@ -50,7 +50,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
|
||||
for dev, queue in self.comp_queues.items(): dev_access[queue].add(dev)
|
||||
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
for j,ji in enumerate(jit_cache):
|
||||
enqueue_dev = ji.prg.dev if (is_exec_prg:=isinstance(ji.prg, CompiledRunner)) else Device[ji.bufs[1].device] #type:ignore
|
||||
enqueue_queue = self.comp_queues[enqueue_dev] if is_exec_prg else self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t())
|
||||
out_signal = self.signals.setdefault(enqueue_queue, enqueue_dev.signal_t(value=0))
|
||||
@@ -107,7 +107,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
self.comp_queues[dev].memory_barrier().wait(dev.timeline_signal, dev.timeline_value - 1) \
|
||||
.wait(self.signals['CPU'], self.kickoff_value).signal(self.signals[dev], self.kickoff_value)
|
||||
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
for j,ji in enumerate(jit_cache):
|
||||
enqueue_dev, enqueue_queue, sync_signals, deps, signal, signal_val = self.ji_schedule[j]
|
||||
|
||||
for i in range(len(sync_signals)): self.kickoff_wait_cmds[enqueue_queue].append(len(enqueue_queue) + i)
|
||||
|
||||
@@ -29,7 +29,7 @@ class MetalGraph(GraphRunner):
|
||||
msg(icb_descriptor, "setMaxKernelBufferBindCount:", 31)
|
||||
|
||||
self.icb = msg(self.dev.sysdevice, "newIndirectCommandBufferWithDescriptor:maxCommandCount:options:",
|
||||
icb_descriptor, len(self.jit_cache), MTLResourceOptions.MTLResourceCPUCacheModeDefaultCache, restype=objc_instance)
|
||||
icb_descriptor, len(jit_cache), MTLResourceOptions.MTLResourceCPUCacheModeDefaultCache, restype=objc_instance)
|
||||
if self.icb.value is None: raise GraphException("create indirect command buffer failed, does your system support this?")
|
||||
icb_label = bytes(msg(msg(self.icb, "description", restype=objc_instance), "UTF8String", restype=ctypes.c_char_p)).decode()
|
||||
self.needs_icb_fix = int("AGXG15XFamilyIndirectCommandBuffer" not in icb_label) # not required on M3
|
||||
@@ -37,7 +37,7 @@ class MetalGraph(GraphRunner):
|
||||
if len(self.vars): self.int_buf = self.dev.allocator.alloc(len(self.vars)*dtypes.int32.itemsize)
|
||||
all_resources = [self.int_buf.buf] if len(self.vars) else []
|
||||
all_pipelines = []
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
for j,ji in enumerate(jit_cache):
|
||||
prg: CompiledRunner = cast(CompiledRunner, ji.prg)
|
||||
icb_command = msg(self.icb, "indirectComputeCommandAtIndex:", j, restype=objc_instance)
|
||||
all_pipelines.append(prg._prg.pipeline_state)
|
||||
@@ -56,7 +56,7 @@ class MetalGraph(GraphRunner):
|
||||
self.all_pipelines = dedup(all_pipelines)
|
||||
self.command_buffer: Any = None
|
||||
if len(self.vars): self.int_buf_view = self.dev.allocator._as_buffer(self.int_buf).cast('i')
|
||||
self.range = to_struct(0, len(self.jit_cache))
|
||||
self.range = to_struct(0, len(jit_cache))
|
||||
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
||||
|
||||
@@ -69,11 +69,8 @@ class MetalGraph(GraphRunner):
|
||||
input_rawbuffers[input_idx]._buf.offset, i)
|
||||
|
||||
for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
|
||||
prg = cast(CompiledRunner, self.jit_cache[j].prg)
|
||||
global_size, local_size = global_dims or prg.p.global_size, local_dims or prg.p.local_size
|
||||
computeCommand = msg(self.icb, "indirectComputeCommandAtIndex:", j)
|
||||
msg(computeCommand, "concurrentDispatchThreadgroups:threadsPerThreadgroup:",
|
||||
to_struct(*cast(tuple, global_size)), to_struct(*cast(tuple, local_size)))
|
||||
msg(computeCommand, "concurrentDispatchThreadgroups:threadsPerThreadgroup:", to_struct(*global_dims), to_struct(*local_dims))
|
||||
for j, var in enumerate(self.vars): self.int_buf_view[j] = var_vals[var]
|
||||
|
||||
command_buffer = msg(self.dev.mtl_queue, "commandBuffer", restype=objc_instance)
|
||||
|
||||
Reference in New Issue
Block a user