update inputs for transfers in hsagraph (#3560)

This commit is contained in:
nimlgen
2024-03-18 18:01:04 +03:00
committed by GitHub
parent 086291e8c6
commit e78df485c7

View File

@@ -63,6 +63,7 @@ class HSAGraph(MultiDeviceJITGraph):
self.virt_aql_queues: Dict[Compiled, VirtAQLQueue] = {dev:VirtAQLQueue(dev, 2*len(self.jit_cache)+16) for dev in self.devices}
self.packets = {}
self.transfers = []
self.ji_to_transfer: Dict[int, int] = {} # faster to store transfers as list and update using this mapping table.
self.signals_to_reset: List[hsa.hsa_signal_t] = []
self.w_dependency_map: Dict[Any, Union[hsa.hsa_signal_t, int]] = {}
self.r_dependency_map: Dict[Any, List[Union[hsa.hsa_signal_t, int]]] = collections.defaultdict(list)
@@ -93,8 +94,9 @@ class HSAGraph(MultiDeviceJITGraph):
signals_to_devices[sync_signal.handle] = [dest_dev, src_dev]
wait_signals = self.access_resources(read=[src], write=[dest], new_dependency=sync_signal, sync_with_aql_packets=True)
self.transfers.append((dest._buf, dest_dev.agent, src._buf, src_dev.agent, dest.nbytes, len(wait_signals),
(hsa.hsa_signal_t*len(wait_signals))(*wait_signals), sync_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True))
self.transfers.append([dest._buf, dest_dev.agent, src._buf, src_dev.agent, dest.nbytes, len(wait_signals),
(hsa.hsa_signal_t*len(wait_signals))(*wait_signals), sync_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True])
self.ji_to_transfer[j] = len(self.transfers)
if PROFILE: self.profile_info[src_dev].append((sync_signal, f"transfer: HSA:{src_dev.device_id} -> HSA:{dest_dev.device_id}", True))
# Wait for all active signals to finish the graph
@@ -121,7 +123,11 @@ class HSAGraph(MultiDeviceJITGraph):
# Update rawbuffers
for (j,i),input_idx in self.input_replace.items():
self.ji_kargs_structs[j].__setattr__(f'f{i}', input_rawbuffers[input_idx]._buf)
if j in self.ji_kargs_structs:
self.ji_kargs_structs[j].__setattr__(f'f{i}', input_rawbuffers[input_idx]._buf)
else:
if i == 0: self.transfers[self.ji_to_transfer[j]][0] = input_rawbuffers[input_idx]._buf # dest
elif i == 1: self.transfers[self.ji_to_transfer[j]][2] = input_rawbuffers[input_idx]._buf # src
# Update var_vals
for j in self.jc_idxs_with_updatable_var_vals: