From 3b7e3fa2e44482bae233a05ad4cb5e3240c8cb81 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Sat, 2 Mar 2024 18:37:51 +0300 Subject: [PATCH] fix sync in hsa graph (#3582) --- tinygrad/runtime/graph/hsa.py | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/tinygrad/runtime/graph/hsa.py b/tinygrad/runtime/graph/hsa.py index e6189b3e80..b97b2e43e7 100644 --- a/tinygrad/runtime/graph/hsa.py +++ b/tinygrad/runtime/graph/hsa.py @@ -1,4 +1,4 @@ -import ctypes, collections, time +import ctypes, collections, time, itertools from typing import List, Any, Dict, cast, Optional, Union from tinygrad.helpers import GraphException, init_c_var from tinygrad.device import Compiled, Buffer, CompiledASTRunner, BufferXfer, MultiDeviceJITGraph, update_stats @@ -69,8 +69,8 @@ class HSAGraph(MultiDeviceJITGraph): self.packets = {} self.transfers = [] self.signals_to_reset: List[hsa.hsa_signal_t] = [] - self.w_dependency_map: Dict[Any, Union[hsa.hsa_signal_t, hsa.hsa_agent_dispatch_packet_t]] = {} - self.r_dependency_map: Dict[Any, Union[hsa.hsa_signal_t, hsa.hsa_agent_dispatch_packet_t]] = {} + self.w_dependency_map: Dict[Any, Union[hsa.hsa_signal_t, int]] = {} + self.r_dependency_map: Dict[Any, List[Union[hsa.hsa_signal_t, int]]] = collections.defaultdict(list) signals_to_devices: Dict[ctypes.c_uint64, List[HSADevice]] = {} # Special packet to wait for the world. @@ -98,7 +98,7 @@ class HSAGraph(MultiDeviceJITGraph): # Wait for all active signals to finish the graph wait_signals_to_finish: Dict[HSADevice, List[hsa.hsa_signal_t]] = collections.defaultdict(list) - for v in dedup_signals([s for s in list(self.w_dependency_map.values())+list(self.r_dependency_map.values()) if isinstance(s, hsa.hsa_signal_t)]): + for v in dedup_signals(list(self.w_dependency_map.values()) + list(itertools.chain.from_iterable(self.r_dependency_map.values()))): for dev in signals_to_devices[v.handle]: wait_signals_to_finish[dev].append(v) @@ -163,13 +163,26 @@ class HSAGraph(MultiDeviceJITGraph): return None def access_resources(self, read, write, new_dependency=None, sync_with_aql_packets=False): - wait_signals = [] + # To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource, + # whether for write or read operations. A resource can be accessed by either a single writer or multiple readers. + # The tracked dependencies are either hsa signals or ints that reference a specific aql packet. + wait_signals: List[Optional[hsa.hsa_signal_t]] = [] + + if sync_with_aql_packets: wait_signals += [self.kickoff_signals[rawbuf.d] for rawbuf in read+write] for rawbuf in read: wait_signals.append(self.dependency_as_signal(self.w_dependency_map.get(rawbuf._buf), sync_with_aql_packets=sync_with_aql_packets)) - if new_dependency: self.r_dependency_map[rawbuf._buf] = new_dependency for rawbuf in write: wait_signals.append(self.dependency_as_signal(self.w_dependency_map.get(rawbuf._buf), sync_with_aql_packets=sync_with_aql_packets)) - wait_signals.append(self.dependency_as_signal(self.r_dependency_map.get(rawbuf._buf), sync_with_aql_packets=sync_with_aql_packets)) - if new_dependency: self.w_dependency_map[rawbuf._buf] = new_dependency - if sync_with_aql_packets: wait_signals += [self.kickoff_signals[rawbuf.d] for rawbuf in read+write] + if rawbuf._buf in self.r_dependency_map: + rdeps = self.r_dependency_map.pop(rawbuf._buf) + + # When synchronizing to aql packets, we only need to sync to the latest one, as they are executed in order. + signal_deps, aql_deps = [x for x in rdeps if isinstance(x, hsa.hsa_signal_t)], [x for x in rdeps if isinstance(x, int)] + deps = signal_deps + [max(aql_deps)] if aql_deps else [] + for dep in deps: wait_signals.append(self.dependency_as_signal(dep, sync_with_aql_packets=sync_with_aql_packets)) + + if new_dependency is not None: + for rawbuf in read: self.r_dependency_map[rawbuf._buf].append(new_dependency) + for rawbuf in write: self.w_dependency_map[rawbuf._buf] = new_dependency + return dedup_signals(wait_signals)