mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 23:08:06 -05:00
fix sync in hsa graph (#3582)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import ctypes, collections, time
|
||||
import ctypes, collections, time, itertools
|
||||
from typing import List, Any, Dict, cast, Optional, Union
|
||||
from tinygrad.helpers import GraphException, init_c_var
|
||||
from tinygrad.device import Compiled, Buffer, CompiledASTRunner, BufferXfer, MultiDeviceJITGraph, update_stats
|
||||
@@ -69,8 +69,8 @@ class HSAGraph(MultiDeviceJITGraph):
|
||||
self.packets = {}
|
||||
self.transfers = []
|
||||
self.signals_to_reset: List[hsa.hsa_signal_t] = []
|
||||
self.w_dependency_map: Dict[Any, Union[hsa.hsa_signal_t, hsa.hsa_agent_dispatch_packet_t]] = {}
|
||||
self.r_dependency_map: Dict[Any, Union[hsa.hsa_signal_t, hsa.hsa_agent_dispatch_packet_t]] = {}
|
||||
self.w_dependency_map: Dict[Any, Union[hsa.hsa_signal_t, int]] = {}
|
||||
self.r_dependency_map: Dict[Any, List[Union[hsa.hsa_signal_t, int]]] = collections.defaultdict(list)
|
||||
signals_to_devices: Dict[ctypes.c_uint64, List[HSADevice]] = {}
|
||||
|
||||
# Special packet to wait for the world.
|
||||
@@ -98,7 +98,7 @@ class HSAGraph(MultiDeviceJITGraph):
|
||||
|
||||
# Wait for all active signals to finish the graph
|
||||
wait_signals_to_finish: Dict[HSADevice, List[hsa.hsa_signal_t]] = collections.defaultdict(list)
|
||||
for v in dedup_signals([s for s in list(self.w_dependency_map.values())+list(self.r_dependency_map.values()) if isinstance(s, hsa.hsa_signal_t)]):
|
||||
for v in dedup_signals(list(self.w_dependency_map.values()) + list(itertools.chain.from_iterable(self.r_dependency_map.values()))):
|
||||
for dev in signals_to_devices[v.handle]:
|
||||
wait_signals_to_finish[dev].append(v)
|
||||
|
||||
@@ -163,13 +163,26 @@ class HSAGraph(MultiDeviceJITGraph):
|
||||
return None
|
||||
|
||||
def access_resources(self, read, write, new_dependency=None, sync_with_aql_packets=False):
|
||||
wait_signals = []
|
||||
# To synchronize access to resources, we monitor the necessary prerequisites for accessing each resource,
|
||||
# whether for write or read operations. A resource can be accessed by either a single writer or multiple readers.
|
||||
# The tracked dependencies are either hsa signals or ints that reference a specific aql packet.
|
||||
wait_signals: List[Optional[hsa.hsa_signal_t]] = []
|
||||
|
||||
if sync_with_aql_packets: wait_signals += [self.kickoff_signals[rawbuf.d] for rawbuf in read+write]
|
||||
for rawbuf in read:
|
||||
wait_signals.append(self.dependency_as_signal(self.w_dependency_map.get(rawbuf._buf), sync_with_aql_packets=sync_with_aql_packets))
|
||||
if new_dependency: self.r_dependency_map[rawbuf._buf] = new_dependency
|
||||
for rawbuf in write:
|
||||
wait_signals.append(self.dependency_as_signal(self.w_dependency_map.get(rawbuf._buf), sync_with_aql_packets=sync_with_aql_packets))
|
||||
wait_signals.append(self.dependency_as_signal(self.r_dependency_map.get(rawbuf._buf), sync_with_aql_packets=sync_with_aql_packets))
|
||||
if new_dependency: self.w_dependency_map[rawbuf._buf] = new_dependency
|
||||
if sync_with_aql_packets: wait_signals += [self.kickoff_signals[rawbuf.d] for rawbuf in read+write]
|
||||
if rawbuf._buf in self.r_dependency_map:
|
||||
rdeps = self.r_dependency_map.pop(rawbuf._buf)
|
||||
|
||||
# When synchronizing to aql packets, we only need to sync to the latest one, as they are executed in order.
|
||||
signal_deps, aql_deps = [x for x in rdeps if isinstance(x, hsa.hsa_signal_t)], [x for x in rdeps if isinstance(x, int)]
|
||||
deps = signal_deps + [max(aql_deps)] if aql_deps else []
|
||||
for dep in deps: wait_signals.append(self.dependency_as_signal(dep, sync_with_aql_packets=sync_with_aql_packets))
|
||||
|
||||
if new_dependency is not None:
|
||||
for rawbuf in read: self.r_dependency_map[rawbuf._buf].append(new_dependency)
|
||||
for rawbuf in write: self.w_dependency_map[rawbuf._buf] = new_dependency
|
||||
|
||||
return dedup_signals(wait_signals)
|
||||
|
||||
Reference in New Issue
Block a user