From 2d54e4d747cab150e3dd93e7332861cd088363f3 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Wed, 20 Mar 2024 00:17:41 +0300 Subject: [PATCH] clean up hsa driver (#3818) * clean up driver * remove returns --- tinygrad/runtime/driver/hsa.py | 20 ++++++------------ tinygrad/runtime/graph/hsa.py | 38 +++++++++++++++++----------------- tinygrad/runtime/ops_hsa.py | 28 ++++++++++++------------- 3 files changed, 39 insertions(+), 47 deletions(-) diff --git a/tinygrad/runtime/driver/hsa.py b/tinygrad/runtime/driver/hsa.py index 9c61096d8e..3091e0c0e2 100644 --- a/tinygrad/runtime/driver/hsa.py +++ b/tinygrad/runtime/driver/hsa.py @@ -46,9 +46,8 @@ class AQLQueue: def __del__(self): if hasattr(self, 'hw_queue'): check(hsa.hsa_queue_destroy(self.hw_queue)) - def submit_kernel(self, prg, global_size, local_size, kernargs, need_signal=False): + def submit_kernel(self, prg, global_size, local_size, kernargs, completion_signal=None): if self.available_packet_slots == 0: self._wait_queue() - signal = self._alloc_signal(reusable=True) if need_signal else EMPTY_SIGNAL packet = hsa.hsa_kernel_dispatch_packet_t.from_address(self.write_addr) packet.workgroup_size_x = local_size[0] @@ -63,17 +62,14 @@ class AQLQueue: packet.kernel_object = prg.handle packet.kernarg_address = kernargs packet.reserved2 = 0 - packet.completion_signal = signal + packet.completion_signal = completion_signal if completion_signal else EMPTY_SIGNAL packet.setup = DISPATCH_KERNEL_SETUP packet.header = DISPATCH_KERNEL_HEADER self._submit_packet() - return signal - - def submit_barrier(self, wait_signals=None, need_signal=False, completion_signal=None): + def submit_barrier(self, wait_signals=None, completion_signal=None): assert wait_signals is None or len(wait_signals) <= 5 if self.available_packet_slots == 0: self._wait_queue() - signal = (completion_signal or self._alloc_signal(reusable=True)) if need_signal else EMPTY_SIGNAL packet = hsa.hsa_barrier_and_packet_t.from_address(self.write_addr) packet.reserved0 = 0 @@ -81,12 +77,10 @@ class AQLQueue: for i in range(5): packet.dep_signal[i] = wait_signals[i] if wait_signals and len(wait_signals) > i else EMPTY_SIGNAL packet.reserved2 = 0 - packet.completion_signal = signal + packet.completion_signal = completion_signal if completion_signal else EMPTY_SIGNAL packet.header = BARRIER_HEADER self._submit_packet() - return signal - def blit_packets(self, packet_addr, packet_cnt): if self.available_packet_slots < packet_cnt: self._wait_queue(packet_cnt) @@ -98,8 +92,8 @@ class AQLQueue: self._submit_packet(packet_cnt) def wait(self): - signal = self.submit_barrier(need_signal=True) - hsa.hsa_signal_wait_scacquire(signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE) + self.submit_barrier([], finish_signal := self.device.alloc_signal(reusable=True)) + hsa.hsa_signal_wait_scacquire(finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE) self.available_packet_slots = self.queue_size // AQL_PACKET_SIZE def _wait_queue(self, need_packets=1): @@ -117,8 +111,6 @@ class AQLQueue: if self.write_addr > self.write_addr_end: self.write_addr = self.queue_base + (self.write_addr - self.queue_base) % self.queue_size - def _alloc_signal(self, reusable=False): return self.device.alloc_signal(reusable=reusable) - def scan_agents(): agents = collections.defaultdict(list) diff --git a/tinygrad/runtime/graph/hsa.py b/tinygrad/runtime/graph/hsa.py index fc2b39cb76..876767fe5c 100644 --- a/tinygrad/runtime/graph/hsa.py +++ b/tinygrad/runtime/graph/hsa.py @@ -23,7 +23,6 @@ class VirtAQLQueue(AQLQueue): self.write_addr += AQL_PACKET_SIZE self.packets_count += 1 self.available_packet_slots -= 1 - def _alloc_signal(self, reusable=False): return init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_signal_create(1, 0, None, ctypes.byref(x)))) class HSAGraph(MultiDeviceJITGraph): def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]): @@ -67,31 +66,28 @@ class HSAGraph(MultiDeviceJITGraph): self.signals_to_reset: List[hsa.hsa_signal_t] = [] self.w_dependency_map: Dict[Any, Union[hsa.hsa_signal_t, int]] = {} self.r_dependency_map: Dict[Any, List[Union[hsa.hsa_signal_t, int]]] = collections.defaultdict(list) - signals_to_devices: Dict[ctypes.c_uint64, List[HSADevice]] = {} + self.signals_to_devices: Dict[ctypes.c_uint64, List[HSADevice]] = {} self.profile_info: Dict[Compiled, List[Tuple[Any, ...]]] = collections.defaultdict(list) # Special packet to wait for the world. - self.kickoff_signals: Dict[HSADevice, hsa.hsa_signal_t] = {} - for dev in self.devices: self.kickoff_signals[dev] = self.virt_aql_queues[dev].submit_barrier(need_signal=True) - self.signals_to_reset += list(self.kickoff_signals.values()) + self.kickoff_signals: Dict[HSADevice, hsa.hsa_signal_t] = {dev:self.alloc_signal(reset_on_start=True) for dev in self.devices} + for dev in self.devices: self.virt_aql_queues[dev].submit_barrier([], self.kickoff_signals[dev]) for j,ji in enumerate(self.jit_cache): if isinstance(ji.prg, CompiledASTRunner): wait_signals = self.access_resources(ji.rawbufs[(outs:=ji.prg.outcount):], ji.rawbufs[:outs], new_dependency=j, sync_with_aql_packets=False) for i in range(0, len(wait_signals), 5): - self.virt_aql_queues[ji.prg.device].submit_barrier(wait_signals=wait_signals[i:i+5]) + self.virt_aql_queues[ji.prg.device].submit_barrier(wait_signals[i:i+5]) self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.device].write_addr) - sync_signal = self.virt_aql_queues[ji.prg.device].submit_kernel(ji.prg.clprg, *ji.prg.launch_dims(var_vals), #type:ignore - ctypes.addressof(self.ji_kargs_structs[j]), need_signal=PROFILE) - if PROFILE: - self.profile_info[ji.prg.device].append((sync_signal, ji.prg.clprg.name, False)) - self.signals_to_reset.append(sync_signal) + + sync_signal = self.alloc_signal(reset_on_start=True) if PROFILE else None + self.virt_aql_queues[ji.prg.device].submit_kernel(ji.prg.clprg, *ji.prg.launch_dims(var_vals), #type:ignore + ctypes.addressof(self.ji_kargs_structs[j]), completion_signal=sync_signal) + if PROFILE: self.profile_info[ji.prg.device].append((sync_signal, ji.prg.clprg.name, False)) elif isinstance(ji.prg, BufferXfer): dest, src = [cast(Buffer, x) for x in ji.rawbufs[0:2]] dest_dev, src_dev = cast(HSADevice, dest.d), cast(HSADevice, src.d) - sync_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x)))) - self.signals_to_reset.append(sync_signal) - signals_to_devices[sync_signal.handle] = [dest_dev, src_dev] + sync_signal = self.alloc_signal(reset_on_start=True, wait_on=[dest_dev, src_dev]) wait_signals = self.access_resources(read=[src], write=[dest], new_dependency=sync_signal, sync_with_aql_packets=True) self.transfers.append([dest._buf, dest_dev.agent, src._buf, src_dev.agent, dest.nbytes, len(wait_signals), @@ -102,14 +98,14 @@ class HSAGraph(MultiDeviceJITGraph): # Wait for all active signals to finish the graph wait_signals_to_finish: Dict[HSADevice, List[hsa.hsa_signal_t]] = collections.defaultdict(list) for v in dedup_signals(list(self.w_dependency_map.values()) + list(itertools.chain.from_iterable(self.r_dependency_map.values()))): - for dev in signals_to_devices[v.handle]: + for dev in self.signals_to_devices[v.handle]: wait_signals_to_finish[dev].append(v) self.finish_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x)))) for dev in self.devices: wait_signals = wait_signals_to_finish[dev] for i in range(0, max(1, len(wait_signals)), 5): - self.virt_aql_queues[dev].submit_barrier(wait_signals[i:i+5], need_signal=(i+5>=len(wait_signals)), completion_signal=self.finish_signal) + self.virt_aql_queues[dev].submit_barrier(wait_signals[i:i+5], completion_signal=self.finish_signal if i+5>=len(wait_signals) else None) # Zero signals to allow graph to start and execute. for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 0) @@ -162,12 +158,16 @@ class HSAGraph(MultiDeviceJITGraph): jit=jit, num_kernels=len(self.jit_cache), device="HSA") return et + def alloc_signal(self, reset_on_start=False, wait_on=None): + sync_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x)))) + if reset_on_start: self.signals_to_reset.append(sync_signal) + if wait_on is not None: self.signals_to_devices[sync_signal.handle] = wait_on + return sync_signal + def dependency_as_signal(self, dep, sync_with_aql_packets) -> Optional[hsa.hsa_signal_t]: if isinstance(dep, hsa.hsa_signal_t): return dep elif sync_with_aql_packets and isinstance(packet := self.packets.get(dep), hsa.hsa_kernel_dispatch_packet_t): - if packet.completion_signal.handle == EMPTY_SIGNAL.handle: - packet.completion_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x)))) - self.signals_to_reset.append(packet.completion_signal) + if packet.completion_signal.handle == EMPTY_SIGNAL.handle: packet.completion_signal = self.alloc_signal(reset_on_start=True) return packet.completion_signal return None diff --git a/tinygrad/runtime/ops_hsa.py b/tinygrad/runtime/ops_hsa.py index dcef69a1ee..a8f3bdf355 100644 --- a/tinygrad/runtime/ops_hsa.py +++ b/tinygrad/runtime/ops_hsa.py @@ -89,7 +89,8 @@ class HSAProgram: for i in range(len(vals)): args_st.__setattr__(f'v{i}', vals[i]) self.device.flush_hdp() - signal = self.device.hw_queue.submit_kernel(self, global_size, local_size, kernargs, need_signal=(wait or PROFILE)) + signal = self.device.alloc_signal(reusable=True) if wait or PROFILE else None + self.device.hw_queue.submit_kernel(self, global_size, local_size, kernargs, completion_signal=signal) if PROFILE: Profiler.track(signal, self.device, self.name) if wait: hsa.hsa_signal_wait_scacquire(signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE) @@ -122,18 +123,17 @@ class HSAAllocator(LRUAllocator): def copyin(self, dest:T, src: memoryview): # Async copyin sync model uses barriers on the main hw queue, since barriers are guaranteed to execute in order with all other packets. - copy_signal = self.device.alloc_signal(reusable=True) - sync_signal = self.device.hw_queue.submit_barrier(need_signal=True) + self.device.hw_queue.submit_barrier([], sync_signal := self.device.alloc_signal(reusable=True)) mem = self._alloc_with_options(src.nbytes, BufferOptions(host=True)) ctypes.memmove(mem, from_mv(src), src.nbytes) - check(hsa.hsa_amd_memory_async_copy_on_engine(dest, self.device.agent, mem, HSADevice.cpu_agent, src.nbytes, - 1, ctypes.byref(sync_signal), copy_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True)) - self.device.hw_queue.submit_barrier(wait_signals=[copy_signal]) + check(hsa.hsa_amd_memory_async_copy_on_engine(dest, self.device.agent, mem, HSADevice.cpu_agent, src.nbytes, 1, ctypes.byref(sync_signal), + copy_signal := self.device.alloc_signal(reusable=True), hsa.HSA_AMD_SDMA_ENGINE_0, True)) + self.device.hw_queue.submit_barrier([copy_signal]) self.device.delayed_free.append(mem) if PROFILE: Profiler.track(copy_signal, self.device, f"copyin: CPU -> HSA:{self.device.device_id}", is_copy=True) def copy_from_fd(self, dest, fd, offset, size): - sync_signal = self.device.hw_queue.submit_barrier(need_signal=True) + self.device.hw_queue.submit_barrier([], sync_signal := self.device.alloc_signal(reusable=True)) if not hasattr(self, 'hb'): self.hb = [self._alloc_with_options(CHUNK_SIZE, BufferOptions(host=True)) for _ in range(2)] @@ -167,7 +167,7 @@ class HSAAllocator(LRUAllocator): wait_signals = [self.hb_signals[self.hb_polarity - 1]] if copies_called > 1: wait_signals.append(self.hb_signals[self.hb_polarity]) - self.device.hw_queue.submit_barrier(wait_signals=wait_signals) + self.device.hw_queue.submit_barrier(wait_signals) def copyout(self, dest:memoryview, src:T): HSADevice.synchronize_system() @@ -180,13 +180,13 @@ class HSAAllocator(LRUAllocator): if PROFILE: Profiler.track(copy_signal, self.device, f"copyout: HSA:{self.device.device_id} -> CPU", is_copy=True) def transfer(self, dest:T, src:T, sz:int, src_dev=None, dest_dev=None): - copy_signal = dest_dev.alloc_signal(reusable=False) - sync_signal_1 = src_dev.hw_queue.submit_barrier(need_signal=True) - sync_signal_2 = dest_dev.hw_queue.submit_barrier(need_signal=True) + src_dev.hw_queue.submit_barrier([], sync_signal_1 := src_dev.alloc_signal(reusable=True)) + dest_dev.hw_queue.submit_barrier([], sync_signal_2 := dest_dev.alloc_signal(reusable=True)) c_wait_signal = (hsa.hsa_signal_t*2)(sync_signal_1, sync_signal_2) - check(hsa.hsa_amd_memory_async_copy_on_engine(dest, dest_dev.agent, src, src_dev.agent, sz, 2, c_wait_signal, copy_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True)) # noqa: E501 - src_dev.hw_queue.submit_barrier(wait_signals=[copy_signal]) - dest_dev.hw_queue.submit_barrier(wait_signals=[copy_signal]) + check(hsa.hsa_amd_memory_async_copy_on_engine(dest, dest_dev.agent, src, src_dev.agent, sz, 2, c_wait_signal, + copy_signal := dest_dev.alloc_signal(reusable=False), hsa.HSA_AMD_SDMA_ENGINE_0, True)) + src_dev.hw_queue.submit_barrier([copy_signal]) + dest_dev.hw_queue.submit_barrier([copy_signal]) if PROFILE: Profiler.track(copy_signal, src_dev, f"transfer: HSA:{src_dev.device_id} -> HSA:{dest_dev.device_id}", is_copy=True) class HSADevice(Compiled):