diff --git a/tinygrad/runtime/graph/hsa.py b/tinygrad/runtime/graph/hsa.py index ca369456d5..252eb98b97 100644 --- a/tinygrad/runtime/graph/hsa.py +++ b/tinygrad/runtime/graph/hsa.py @@ -63,6 +63,7 @@ class HSAGraph(MultiDeviceJITGraph): self.virt_aql_queues: Dict[Compiled, VirtAQLQueue] = {dev:VirtAQLQueue(dev, 2*len(self.jit_cache)+16) for dev in self.devices} self.packets = {} self.transfers = [] + self.ji_to_transfer: Dict[int, int] = {} # faster to store transfers as list and update using this mapping table. self.signals_to_reset: List[hsa.hsa_signal_t] = [] self.w_dependency_map: Dict[Any, Union[hsa.hsa_signal_t, int]] = {} self.r_dependency_map: Dict[Any, List[Union[hsa.hsa_signal_t, int]]] = collections.defaultdict(list) @@ -93,8 +94,9 @@ class HSAGraph(MultiDeviceJITGraph): signals_to_devices[sync_signal.handle] = [dest_dev, src_dev] wait_signals = self.access_resources(read=[src], write=[dest], new_dependency=sync_signal, sync_with_aql_packets=True) - self.transfers.append((dest._buf, dest_dev.agent, src._buf, src_dev.agent, dest.nbytes, len(wait_signals), - (hsa.hsa_signal_t*len(wait_signals))(*wait_signals), sync_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True)) + self.transfers.append([dest._buf, dest_dev.agent, src._buf, src_dev.agent, dest.nbytes, len(wait_signals), + (hsa.hsa_signal_t*len(wait_signals))(*wait_signals), sync_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True]) + self.ji_to_transfer[j] = len(self.transfers) if PROFILE: self.profile_info[src_dev].append((sync_signal, f"transfer: HSA:{src_dev.device_id} -> HSA:{dest_dev.device_id}", True)) # Wait for all active signals to finish the graph @@ -121,7 +123,11 @@ class HSAGraph(MultiDeviceJITGraph): # Update rawbuffers for (j,i),input_idx in self.input_replace.items(): - self.ji_kargs_structs[j].__setattr__(f'f{i}', input_rawbuffers[input_idx]._buf) + if j in self.ji_kargs_structs: + self.ji_kargs_structs[j].__setattr__(f'f{i}', input_rawbuffers[input_idx]._buf) + else: + if i == 0: self.transfers[self.ji_to_transfer[j]][0] = input_rawbuffers[input_idx]._buf # dest + elif i == 1: self.transfers[self.ji_to_transfer[j]][2] = input_rawbuffers[input_idx]._buf # src # Update var_vals for j in self.jc_idxs_with_updatable_var_vals: