From aeeb917b6ea7ee7f1e3672e72acce0ac4643be08 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 23 Oct 2024 16:13:50 +0300 Subject: [PATCH] mask out writable bufs in runtime access_resources (#7234) --- extra/backends/hsa_graph.py | 10 +++++----- test/test_graph.py | 3 +-- tinygrad/engine/jit.py | 12 ++++++------ tinygrad/renderer/__init__.py | 3 --- tinygrad/runtime/graph/cuda.py | 5 ++--- tinygrad/runtime/graph/hcq.py | 2 +- 6 files changed, 15 insertions(+), 20 deletions(-) diff --git a/extra/backends/hsa_graph.py b/extra/backends/hsa_graph.py index 22476e9ac0..3603bb9421 100644 --- a/extra/backends/hsa_graph.py +++ b/extra/backends/hsa_graph.py @@ -70,7 +70,7 @@ class HSAGraph(MultiGraphRunner): for j,ji in enumerate(self.jit_cache): if isinstance(ji.prg, CompiledRunner): - wait_signals = self.access_resources(ji.bufs[(outs:=ji.prg.p.outcount):], ji.bufs[:outs], new_dependency=j, sync_with_aql_packets=False) + wait_signals = self.access_resources(ji.bufs, ji.prg.p.outs, new_dependency=j, sync_with_aql_packets=False) for i in range(0, len(wait_signals), 5): self.virt_aql_queues[ji.prg.device].submit_barrier(wait_signals[i:i+5]) self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.device].write_addr) @@ -84,7 +84,7 @@ class HSAGraph(MultiGraphRunner): dest_dev, src_dev = cast(HSADevice, Device[dest.device]), cast(HSADevice, Device[src.device]) sync_signal = self.alloc_signal(reset_on_start=True, wait_on=[dest_dev, src_dev]) - wait_signals = self.access_resources(read=[src], write=[dest], new_dependency=sync_signal, sync_with_aql_packets=True) + wait_signals = self.access_resources([dest, src], write=[0], new_dependency=sync_signal, sync_with_aql_packets=True) self.transfers.append([dest._buf, dest_dev.agent, src._buf, src_dev.agent, dest.nbytes, len(wait_signals), (hsa.hsa_signal_t*len(wait_signals))(*wait_signals), sync_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True]) self.ji_to_transfer[j] = len(self.transfers) - 1 @@ -164,8 +164,8 @@ class HSAGraph(MultiGraphRunner): return packet.completion_signal return None - def access_resources(self, read, write, new_dependency, sync_with_aql_packets=False): - rdeps = self._access_resources(read, write, new_dependency) + def access_resources(self, rawbufs, write, new_dependency, sync_with_aql_packets=False): + rdeps = self._access_resources(rawbufs, write, new_dependency) wait_signals = [self.dependency_as_signal(dep, sync_with_aql_packets=sync_with_aql_packets) for dep in rdeps] - if sync_with_aql_packets: wait_signals += [self.kickoff_signals[cast(HSADevice, Device[rawbuf.device])] for rawbuf in read+write] + if sync_with_aql_packets: wait_signals += [self.kickoff_signals[cast(HSADevice, Device[rawbuf.device])] for rawbuf in rawbufs] return dedup_signals(wait_signals) diff --git a/test/test_graph.py b/test/test_graph.py index 64e553c0ef..7be34ba0f1 100644 --- a/test/test_graph.py +++ b/test/test_graph.py @@ -54,8 +54,7 @@ def helper_test_graphs(graph_impl, graphs, runs=RUN_CNT): out_buffers = set() for graph in graphs: for ji in graph: - writable_buffers = ji.prg.p.outcount if isinstance(ji.prg, CompiledRunner) else 1 - out_buffers.update(ji.bufs[:writable_buffers]) + out_buffers.update([ji.bufs[i] for i in (ji.prg.p.outs if isinstance(ji.prg, CompiledRunner) else [0])]) bufs += ji.bufs reg_ji.append(ji) bufs = dedup(bufs) diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 528ff7f88d..49a4c093ef 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -113,18 +113,18 @@ class GraphRunner(Runner): # pylint: disable=abstract-method dims = [tuple(sym_infer(s, var_vals) for s in dim) for dim in self.symbolic_dims] for j, (gl, lc) in self.launch_dims_replace.items(): yield j, (dims[gl] if gl is not None else None), (dims[lc] if lc is not None else None) - def _access_resources(self, read:List[Buffer], write:List[Buffer], new_dependency:Any): + def _access_resources(self, rawbufs:List[Buffer], write:List[int], new_dependency:Any): # To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource, # whether for write or read operations. A resource can be accessed by either a single writer or multiple readers. wait_nodes = [] - for rawbuf in read + write: + for i,rawbuf in enumerate(rawbufs): if id(rawbuf.base._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(rawbuf.base._buf)]) - for rawbuf in write: - if id(rawbuf.base._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(rawbuf.base._buf))) + if i in write: + if id(rawbuf.base._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(rawbuf.base._buf))) + self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency + else: self.r_dependency_map[id(rawbuf.base._buf)].append(new_dependency) - for rawbuf in read: self.r_dependency_map[id(rawbuf.base._buf)].append(new_dependency) - for rawbuf in write: self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency return list({id(x):x for x in wait_nodes}.values()) # a marker for your graph supporting multiple devices of the same type diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index 5c1f976f5e..a109167c8b 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -62,9 +62,6 @@ class Program: @functools.cached_property def _ops_lds(self) -> Tuple[sint, sint]: return (0,0) if self.uops is None else flops_mem(self.uops, ignore_indexing=True) - @property - def outcount(self) -> int: return len(self.outs) - @functools.cached_property def function_name(self) -> str: return to_function_name(self.name) diff --git a/tinygrad/runtime/graph/cuda.py b/tinygrad/runtime/graph/cuda.py index 03d14325d9..42d1973ddc 100644 --- a/tinygrad/runtime/graph/cuda.py +++ b/tinygrad/runtime/graph/cuda.py @@ -25,8 +25,7 @@ class CUDAGraph(MultiGraphRunner): global_size, local_size = ji.prg.p.launch_dims(var_vals) new_node = cuda.CUgraphNode() - deps = self._access_resources([x.base for x in ji.bufs[ji.prg.p.outcount:] if x is not None], - [x.base for x in ji.bufs[:ji.prg.p.outcount] if x is not None], new_dependency=new_node) + deps = self._access_resources([x.base for x in ji.bufs if x is not None], ji.prg.p.outs, new_dependency=new_node) c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals[x] for x in ji.prg.p.vars]) @@ -39,7 +38,7 @@ class CUDAGraph(MultiGraphRunner): dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]] src_dev = cast(CUDADevice, Device[src.device]) node_from = cuda.CUgraphNode() - deps = self._access_resources(read=[src.base], write=[dest.base], new_dependency=node_from) + deps = self._access_resources(rawbufs=[dest.base, src.base], write=[0], new_dependency=node_from) c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1, dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1, diff --git a/tinygrad/runtime/graph/hcq.py b/tinygrad/runtime/graph/hcq.py index 61f27072c7..5322a40fdf 100644 --- a/tinygrad/runtime/graph/hcq.py +++ b/tinygrad/runtime/graph/hcq.py @@ -56,7 +56,7 @@ class HCQGraph(MultiGraphRunner): out_signal = self.signals.setdefault(enqueue_queue, enqueue_dev.signal_t(value=0)) # Get dependencies based on input and output buffers. - rdeps = self._access_resources(ji.bufs[(wb:=ji.prg.p.outcount if is_exec_prg else 1):], ji.bufs[:wb], (enqueue_queue, j + 1)) #type:ignore + rdeps = self._access_resources(ji.bufs, ji.prg.p.outs if is_exec_prg else [0], (enqueue_queue, j + 1)) #type:ignore # Update dependencies to include previous kernel in queue. This is required for timeline signals. opt_deps, deps = [], rdeps + ([(enqueue_queue, prev_ji + 1)] if (prev_ji:=last_j[enqueue_queue]) is not None else [])