some hsa lines saving + fixes (#3752)

* fix write to ring + some lines

* hsa driver test
This commit is contained in:
nimlgen
2024-03-15 18:12:18 +03:00
committed by GitHub
parent ca19eb3e82
commit ba79a3c09a
2 changed files with 114 additions and 22 deletions

View File

@@ -34,10 +34,11 @@ class AQLQueue:
hsa.hsa_queue_create(self.device.agent, queue_size, hsa.HSA_QUEUE_TYPE_SINGLE, null_func, None, (1<<32)-1, (1<<32)-1, ctypes.byref(x))))
self.next_doorbell_index = 0
self.queue_size = self.hw_queue.contents.size
self.write_addr = self.hw_queue.contents.base_address
self.write_addr_end = self.hw_queue.contents.base_address + (AQL_PACKET_SIZE * self.queue_size) - 1
self.available_packet_slots = self.queue_size
self.queue_base = self.hw_queue.contents.base_address
self.queue_size = self.hw_queue.contents.size * AQL_PACKET_SIZE # in bytes
self.write_addr = self.queue_base
self.write_addr_end = self.queue_base + self.queue_size - 1 # precalc saves some time
self.available_packet_slots = self.hw_queue.contents.size
check(hsa.hsa_amd_queue_set_priority(self.hw_queue, hsa.HSA_AMD_QUEUE_PRIORITY_HIGH))
check(hsa.hsa_amd_profiling_set_profiler_enabled(self.hw_queue, 1))
@@ -89,37 +90,32 @@ class AQLQueue:
def blit_packets(self, packet_addr, packet_cnt):
if self.available_packet_slots < packet_cnt: self._wait_queue(packet_cnt)
tail_blit_packets = min(((self.write_addr_end + 1) - self.write_addr) // AQL_PACKET_SIZE, packet_cnt)
tail_blit_packets = min((self.queue_base + self.queue_size - self.write_addr) // AQL_PACKET_SIZE, packet_cnt)
rem_packet_cnt = packet_cnt - tail_blit_packets
ctypes.memmove(self.write_addr, packet_addr, AQL_PACKET_SIZE * tail_blit_packets)
self.write_addr += AQL_PACKET_SIZE * tail_blit_packets
if self.write_addr > self.write_addr_end: self.write_addr = self.hw_queue.contents.base_address
if tail_blit_packets > 0:
ctypes.memmove(self.write_addr, packet_addr + AQL_PACKET_SIZE * tail_blit_packets, AQL_PACKET_SIZE * rem_packet_cnt)
self.write_addr += AQL_PACKET_SIZE * rem_packet_cnt
if rem_packet_cnt > 0: ctypes.memmove(self.queue_base, packet_addr + AQL_PACKET_SIZE * tail_blit_packets, AQL_PACKET_SIZE * rem_packet_cnt)
self.next_doorbell_index += packet_cnt
hsa.hsa_queue_store_write_index_screlease(self.hw_queue, self.next_doorbell_index + 1)
hsa.hsa_signal_store_screlease(self.hw_queue.contents.doorbell_signal, self.next_doorbell_index)
self._submit_packet(packet_cnt)
def wait(self):
signal = self.submit_barrier(need_signal=True)
hsa.hsa_signal_wait_scacquire(signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
self.available_packet_slots = self.queue_size
self.available_packet_slots = self.queue_size // AQL_PACKET_SIZE
def _wait_queue(self, need_packets=1):
while self.available_packet_slots < need_packets:
rindex = hsa.hsa_queue_load_read_index_relaxed(self.hw_queue)
self.available_packet_slots = self.queue_size - (self.next_doorbell_index - rindex)
self.available_packet_slots = self.queue_size // AQL_PACKET_SIZE - (self.next_doorbell_index - rindex)
def _submit_packet(self):
hsa.hsa_queue_store_write_index_relaxed(self.hw_queue, self.next_doorbell_index + 1)
hsa.hsa_signal_store_screlease(self.hw_queue.contents.doorbell_signal, self.next_doorbell_index)
def _submit_packet(self, cnt=1):
self.available_packet_slots -= cnt
self.next_doorbell_index += cnt
hsa.hsa_queue_store_write_index_relaxed(self.hw_queue, self.next_doorbell_index)
hsa.hsa_signal_store_screlease(self.hw_queue.contents.doorbell_signal, self.next_doorbell_index-1)
self.write_addr += AQL_PACKET_SIZE
if self.write_addr > self.write_addr_end: self.write_addr = self.hw_queue.contents.base_address
self.next_doorbell_index += 1
self.available_packet_slots -= 1
self.write_addr += AQL_PACKET_SIZE * cnt
if self.write_addr > self.write_addr_end:
self.write_addr = self.queue_base + (self.write_addr - self.queue_base) % self.queue_size
def _alloc_signal(self, reusable=False): return self.device.alloc_signal(reusable=reusable)