fix sync in hsa graph (#3582)

This commit is contained in:
nimlgen
2024-03-02 18:37:51 +03:00
committed by GitHub
parent 6c36264790
commit 3b7e3fa2e4

View File

@@ -1,4 +1,4 @@
import ctypes, collections, time
import ctypes, collections, time, itertools
from typing import List, Any, Dict, cast, Optional, Union
from tinygrad.helpers import GraphException, init_c_var
from tinygrad.device import Compiled, Buffer, CompiledASTRunner, BufferXfer, MultiDeviceJITGraph, update_stats
@@ -69,8 +69,8 @@ class HSAGraph(MultiDeviceJITGraph):
self.packets = {}
self.transfers = []
self.signals_to_reset: List[hsa.hsa_signal_t] = []
self.w_dependency_map: Dict[Any, Union[hsa.hsa_signal_t, hsa.hsa_agent_dispatch_packet_t]] = {}
self.r_dependency_map: Dict[Any, Union[hsa.hsa_signal_t, hsa.hsa_agent_dispatch_packet_t]] = {}
self.w_dependency_map: Dict[Any, Union[hsa.hsa_signal_t, int]] = {}
self.r_dependency_map: Dict[Any, List[Union[hsa.hsa_signal_t, int]]] = collections.defaultdict(list)
signals_to_devices: Dict[ctypes.c_uint64, List[HSADevice]] = {}
# Special packet to wait for the world.
@@ -98,7 +98,7 @@ class HSAGraph(MultiDeviceJITGraph):
# Wait for all active signals to finish the graph
wait_signals_to_finish: Dict[HSADevice, List[hsa.hsa_signal_t]] = collections.defaultdict(list)
for v in dedup_signals([s for s in list(self.w_dependency_map.values())+list(self.r_dependency_map.values()) if isinstance(s, hsa.hsa_signal_t)]):
for v in dedup_signals(list(self.w_dependency_map.values()) + list(itertools.chain.from_iterable(self.r_dependency_map.values()))):
for dev in signals_to_devices[v.handle]:
wait_signals_to_finish[dev].append(v)
@@ -163,13 +163,26 @@ class HSAGraph(MultiDeviceJITGraph):
return None
def access_resources(self, read, write, new_dependency=None, sync_with_aql_packets=False):
wait_signals = []
# To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource,
# whether for write or read operations. A resource can be accessed by either a single writer or multiple readers.
# The tracked dependencies are either hsa signals or ints that reference a specific aql packet.
wait_signals: List[Optional[hsa.hsa_signal_t]] = []
if sync_with_aql_packets: wait_signals += [self.kickoff_signals[rawbuf.d] for rawbuf in read+write]
for rawbuf in read:
wait_signals.append(self.dependency_as_signal(self.w_dependency_map.get(rawbuf._buf), sync_with_aql_packets=sync_with_aql_packets))
if new_dependency: self.r_dependency_map[rawbuf._buf] = new_dependency
for rawbuf in write:
wait_signals.append(self.dependency_as_signal(self.w_dependency_map.get(rawbuf._buf), sync_with_aql_packets=sync_with_aql_packets))
wait_signals.append(self.dependency_as_signal(self.r_dependency_map.get(rawbuf._buf), sync_with_aql_packets=sync_with_aql_packets))
if new_dependency: self.w_dependency_map[rawbuf._buf] = new_dependency
if sync_with_aql_packets: wait_signals += [self.kickoff_signals[rawbuf.d] for rawbuf in read+write]
if rawbuf._buf in self.r_dependency_map:
rdeps = self.r_dependency_map.pop(rawbuf._buf)
# When synchronizing to aql packets, we only need to sync to the latest one, as they are executed in order.
signal_deps, aql_deps = [x for x in rdeps if isinstance(x, hsa.hsa_signal_t)], [x for x in rdeps if isinstance(x, int)]
deps = signal_deps + [max(aql_deps)] if aql_deps else []
for dep in deps: wait_signals.append(self.dependency_as_signal(dep, sync_with_aql_packets=sync_with_aql_packets))
if new_dependency is not None:
for rawbuf in read: self.r_dependency_map[rawbuf._buf].append(new_dependency)
for rawbuf in write: self.w_dependency_map[rawbuf._buf] = new_dependency
return dedup_signals(wait_signals)