From 76ade20b89698a6230e62b83a72ac544fb674f08 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Mon, 11 Mar 2024 22:32:43 +0300 Subject: [PATCH] hsa driver tiny cleanups (#3684) --- tinygrad/runtime/driver/hsa.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tinygrad/runtime/driver/hsa.py b/tinygrad/runtime/driver/hsa.py index 14a3d87489..8d89f8f6f9 100644 --- a/tinygrad/runtime/driver/hsa.py +++ b/tinygrad/runtime/driver/hsa.py @@ -25,7 +25,6 @@ BARRIER_HEADER |= hsa.HSA_PACKET_TYPE_BARRIER_AND << hsa.HSA_PACKET_HEADER_TYPE class AQLQueue: def __init__(self, device, sz=-1): self.device = device - self.wait_signals = [] check(hsa.hsa_agent_get_info(self.device.agent, hsa.HSA_AGENT_INFO_QUEUE_MAX_SIZE, ctypes.byref(max_queue_size := ctypes.c_uint32()))) queue_size = min(max_queue_size.value, sz) if sz != -1 else max_queue_size.value @@ -90,7 +89,7 @@ class AQLQueue: def blit_packets(self, packet_addr, packet_cnt): if self.available_packet_slots < packet_cnt: self._wait_queue(packet_cnt) - tail_blit_packets = min(((self.write_addr_end + 1) - self.write_addr) // 64, packet_cnt) + tail_blit_packets = min(((self.write_addr_end + 1) - self.write_addr) // AQL_PACKET_SIZE, packet_cnt) rem_packet_cnt = packet_cnt - tail_blit_packets ctypes.memmove(self.write_addr, packet_addr, AQL_PACKET_SIZE * tail_blit_packets) self.write_addr += AQL_PACKET_SIZE * tail_blit_packets