From b05776ef3e746f9ca17735d55ac4919c27788097 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Thu, 29 Feb 2024 16:43:55 +0300 Subject: [PATCH] fix addresses of dispatch packets (#3534) --- tinygrad/runtime/graph/hsa.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tinygrad/runtime/graph/hsa.py b/tinygrad/runtime/graph/hsa.py index 1924efadd5..9ec8ef4cb5 100644 --- a/tinygrad/runtime/graph/hsa.py +++ b/tinygrad/runtime/graph/hsa.py @@ -80,10 +80,10 @@ class HSAGraph(MultiDeviceJITGraph): for j,ji in enumerate(self.jit_cache): if isinstance(ji.prg, CompiledASTRunner): - self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.device].write_addr) - wait_signals = self.access_resources(read=ji.rawbufs[1:], write=ji.rawbufs[0:1], new_dependency=self.packets[j], sync_with_aql_packets=False) + wait_signals = self.access_resources(read=ji.rawbufs[1:], write=ji.rawbufs[0:1], new_dependency=j, sync_with_aql_packets=False) for i in range(0, len(wait_signals), 5): self.virt_aql_queues[ji.prg.device].submit_barrier(wait_signals=wait_signals[i:i+5]) + self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.device].write_addr) self.virt_aql_queues[ji.prg.device].submit_kernel(ji.prg.clprg, *ji.prg.launch_dims(var_vals), ctypes.addressof(self.ji_kargs_structs[j])) #type:ignore elif isinstance(ji.prg, BufferXfer): dest, src = [cast(Buffer, x) for x in ji.rawbufs[0:2]] @@ -99,6 +99,7 @@ class HSAGraph(MultiDeviceJITGraph): # Make sure the src buffer can be by other devices. c_agents = (hsa.hsa_agent_t * len(HSADevice.devices))(*[dev.agent for dev in HSADevice.devices]) check(hsa.hsa_amd_agents_allow_access(len(HSADevice.devices), c_agents, None, src._buf)) + check(hsa.hsa_amd_agents_allow_access(len(HSADevice.devices), c_agents, None, dest._buf)) # Wait for all active signals to finish the graph wait_signals_to_finish: Dict[HSADevice, List[hsa.hsa_signal_t]] = collections.defaultdict(list) @@ -159,11 +160,11 @@ class HSAGraph(MultiDeviceJITGraph): def dependency_as_signal(self, dep, sync_with_aql_packets) -> Optional[hsa.hsa_signal_t]: if isinstance(dep, hsa.hsa_signal_t): return dep - elif sync_with_aql_packets and isinstance(dep, hsa.hsa_kernel_dispatch_packet_t): - if dep.completion_signal.handle == EMPTY_SIGNAL.handle: - dep.completion_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x)))) - self.signals_to_reset.append(dep.completion_signal) - return dep.completion_signal + elif sync_with_aql_packets and isinstance(packet := self.packets.get(dep), hsa.hsa_kernel_dispatch_packet_t): + if packet.completion_signal.handle == EMPTY_SIGNAL.handle: + packet.completion_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x)))) + self.signals_to_reset.append(packet.completion_signal) + return packet.completion_signal return None def access_resources(self, read, write, new_dependency=None, sync_with_aql_packets=False):