mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-14 16:44:59 -05:00
update inputs for transfers in hsagraph (#3560)
This commit is contained in:
@@ -63,6 +63,7 @@ class HSAGraph(MultiDeviceJITGraph):
|
||||
self.virt_aql_queues: Dict[Compiled, VirtAQLQueue] = {dev:VirtAQLQueue(dev, 2*len(self.jit_cache)+16) for dev in self.devices}
|
||||
self.packets = {}
|
||||
self.transfers = []
|
||||
self.ji_to_transfer: Dict[int, int] = {} # faster to store transfers as list and update using this mapping table.
|
||||
self.signals_to_reset: List[hsa.hsa_signal_t] = []
|
||||
self.w_dependency_map: Dict[Any, Union[hsa.hsa_signal_t, int]] = {}
|
||||
self.r_dependency_map: Dict[Any, List[Union[hsa.hsa_signal_t, int]]] = collections.defaultdict(list)
|
||||
@@ -93,8 +94,9 @@ class HSAGraph(MultiDeviceJITGraph):
|
||||
signals_to_devices[sync_signal.handle] = [dest_dev, src_dev]
|
||||
|
||||
wait_signals = self.access_resources(read=[src], write=[dest], new_dependency=sync_signal, sync_with_aql_packets=True)
|
||||
self.transfers.append((dest._buf, dest_dev.agent, src._buf, src_dev.agent, dest.nbytes, len(wait_signals),
|
||||
(hsa.hsa_signal_t*len(wait_signals))(*wait_signals), sync_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True))
|
||||
self.transfers.append([dest._buf, dest_dev.agent, src._buf, src_dev.agent, dest.nbytes, len(wait_signals),
|
||||
(hsa.hsa_signal_t*len(wait_signals))(*wait_signals), sync_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True])
|
||||
self.ji_to_transfer[j] = len(self.transfers)
|
||||
if PROFILE: self.profile_info[src_dev].append((sync_signal, f"transfer: HSA:{src_dev.device_id} -> HSA:{dest_dev.device_id}", True))
|
||||
|
||||
# Wait for all active signals to finish the graph
|
||||
@@ -121,7 +123,11 @@ class HSAGraph(MultiDeviceJITGraph):
|
||||
|
||||
# Update rawbuffers
|
||||
for (j,i),input_idx in self.input_replace.items():
|
||||
self.ji_kargs_structs[j].__setattr__(f'f{i}', input_rawbuffers[input_idx]._buf)
|
||||
if j in self.ji_kargs_structs:
|
||||
self.ji_kargs_structs[j].__setattr__(f'f{i}', input_rawbuffers[input_idx]._buf)
|
||||
else:
|
||||
if i == 0: self.transfers[self.ji_to_transfer[j]][0] = input_rawbuffers[input_idx]._buf # dest
|
||||
elif i == 1: self.transfers[self.ji_to_transfer[j]][2] = input_rawbuffers[input_idx]._buf # src
|
||||
|
||||
# Update var_vals
|
||||
for j in self.jc_idxs_with_updatable_var_vals:
|
||||
|
||||
Reference in New Issue
Block a user