From ba79a3c09ae09d66b98b28e577ee9f1c2498b570 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Fri, 15 Mar 2024 18:12:18 +0300 Subject: [PATCH] some hsa lines saving + fixes (#3752) * fix write to ring + some lines * hsa driver test --- test/external/external_test_hsa_driver.py | 96 +++++++++++++++++++++++ tinygrad/runtime/driver/hsa.py | 40 +++++----- 2 files changed, 114 insertions(+), 22 deletions(-) create mode 100644 test/external/external_test_hsa_driver.py diff --git a/test/external/external_test_hsa_driver.py b/test/external/external_test_hsa_driver.py new file mode 100644 index 0000000000..d917ec466b --- /dev/null +++ b/test/external/external_test_hsa_driver.py @@ -0,0 +1,96 @@ +import ctypes, unittest +from tinygrad.helpers import init_c_struct_t +from tinygrad.device import Device, Buffer +from tinygrad.dtype import dtypes +from tinygrad.runtime.driver.hsa import AQLQueue +from tinygrad.runtime.graph.hsa import VirtAQLQueue + +def get_hsa_inc_prog(dev, inc=1): + prg = f""" +extern "C" __attribute__((global)) void test_inc(int* data0) {{ + data0[0] = (data0[0]+{inc}); +}} +""" + return dev.runtime("test_inc", dev.compiler.compile(prg)) + +def get_hsa_buffer_and_kernargs(dev): + test_buf = Buffer(Device.DEFAULT, 1, dtypes.int) + test_buf.copyin(memoryview(bytearray(4))) # zero mem + assert test_buf.as_buffer().cast('I')[0] == 0 # check mem is visible + sync to exec + + args_struct_t = init_c_struct_t(tuple([('f0', ctypes.c_void_p)])) + kernargs = dev.alloc_kernargs(8) + args_st = args_struct_t.from_address(kernargs) + args_st.__setattr__('f0', test_buf._buf) + dev.flush_hdp() + return test_buf, kernargs + +@unittest.skipUnless(Device.DEFAULT == "HSA", "only run on HSA") +class TestHSADriver(unittest.TestCase): + def test_hsa_simple_enqueue(self): + dev = Device[Device.DEFAULT] + queue = AQLQueue(dev, sz=256) + + clprg = get_hsa_inc_prog(dev, inc=1) + test_buf, kernargs = get_hsa_buffer_and_kernargs(dev) + + queue.submit_kernel(clprg, [1,1,1], [1,1,1], kernargs) + queue.wait() + + assert test_buf.as_buffer().cast('I')[0] == 1, f"{test_buf.as_buffer().cast('I')[0]} != 1, all packets executed?" + del queue + + def test_hsa_ring_enqueue(self): + dev = Device[Device.DEFAULT] + + queue_size = 256 + exec_cnt = int(queue_size * 1.5) + queue = AQLQueue(dev, sz=queue_size) + + clprg_inc1 = get_hsa_inc_prog(dev, inc=1) + clprg_inc2 = get_hsa_inc_prog(dev, inc=2) + test_buf, kernargs = get_hsa_buffer_and_kernargs(dev) + + for _ in range(exec_cnt): + queue.submit_kernel(clprg_inc1, [1,1,1], [1,1,1], kernargs) + for _ in range(exec_cnt): + queue.submit_kernel(clprg_inc2, [1,1,1], [1,1,1], kernargs) + queue.wait() + + expected = exec_cnt + exec_cnt * 2 + assert test_buf.as_buffer().cast('I')[0] == expected, f"{test_buf.as_buffer().cast('I')[0]} != {expected}, all packets executed?" + del queue + + def test_hsa_blit_enqueue(self): + dev = Device[Device.DEFAULT] + + queue_size = 256 + exec_cnt = 178 + queue = AQLQueue(dev, sz=queue_size) + + test_buf, kernargs = get_hsa_buffer_and_kernargs(dev) + + # Using VirtAQLQueue to blit them + virt_queue_packets_cnt = 31 + virt_queue = VirtAQLQueue(dev, sz=virt_queue_packets_cnt) + + clprogs = [] + sum_per_blit = 0 + for i in range(virt_queue_packets_cnt): + sum_per_blit += i+1 + clprogs.append(get_hsa_inc_prog(dev, inc=i+1)) + + for i in range(virt_queue_packets_cnt): + virt_queue.submit_kernel(clprogs[i], [1,1,1], [1,1,1], kernargs) + + for _ in range(exec_cnt): + queue.blit_packets(virt_queue.queue_base, virt_queue.packets_count) + queue.wait() + + expected = exec_cnt * sum_per_blit + assert test_buf.as_buffer().cast('I')[0] == expected, f"{test_buf.as_buffer().cast('I')[0]} != {expected}, all packets executed?" + del queue, clprogs + + +if __name__ == '__main__': + unittest.main() diff --git a/tinygrad/runtime/driver/hsa.py b/tinygrad/runtime/driver/hsa.py index 8d89f8f6f9..9c61096d8e 100644 --- a/tinygrad/runtime/driver/hsa.py +++ b/tinygrad/runtime/driver/hsa.py @@ -34,10 +34,11 @@ class AQLQueue: hsa.hsa_queue_create(self.device.agent, queue_size, hsa.HSA_QUEUE_TYPE_SINGLE, null_func, None, (1<<32)-1, (1<<32)-1, ctypes.byref(x)))) self.next_doorbell_index = 0 - self.queue_size = self.hw_queue.contents.size - self.write_addr = self.hw_queue.contents.base_address - self.write_addr_end = self.hw_queue.contents.base_address + (AQL_PACKET_SIZE * self.queue_size) - 1 - self.available_packet_slots = self.queue_size + self.queue_base = self.hw_queue.contents.base_address + self.queue_size = self.hw_queue.contents.size * AQL_PACKET_SIZE # in bytes + self.write_addr = self.queue_base + self.write_addr_end = self.queue_base + self.queue_size - 1 # precalc saves some time + self.available_packet_slots = self.hw_queue.contents.size check(hsa.hsa_amd_queue_set_priority(self.hw_queue, hsa.HSA_AMD_QUEUE_PRIORITY_HIGH)) check(hsa.hsa_amd_profiling_set_profiler_enabled(self.hw_queue, 1)) @@ -89,37 +90,32 @@ class AQLQueue: def blit_packets(self, packet_addr, packet_cnt): if self.available_packet_slots < packet_cnt: self._wait_queue(packet_cnt) - tail_blit_packets = min(((self.write_addr_end + 1) - self.write_addr) // AQL_PACKET_SIZE, packet_cnt) + tail_blit_packets = min((self.queue_base + self.queue_size - self.write_addr) // AQL_PACKET_SIZE, packet_cnt) rem_packet_cnt = packet_cnt - tail_blit_packets ctypes.memmove(self.write_addr, packet_addr, AQL_PACKET_SIZE * tail_blit_packets) - self.write_addr += AQL_PACKET_SIZE * tail_blit_packets - if self.write_addr > self.write_addr_end: self.write_addr = self.hw_queue.contents.base_address - if tail_blit_packets > 0: - ctypes.memmove(self.write_addr, packet_addr + AQL_PACKET_SIZE * tail_blit_packets, AQL_PACKET_SIZE * rem_packet_cnt) - self.write_addr += AQL_PACKET_SIZE * rem_packet_cnt + if rem_packet_cnt > 0: ctypes.memmove(self.queue_base, packet_addr + AQL_PACKET_SIZE * tail_blit_packets, AQL_PACKET_SIZE * rem_packet_cnt) - self.next_doorbell_index += packet_cnt - hsa.hsa_queue_store_write_index_screlease(self.hw_queue, self.next_doorbell_index + 1) - hsa.hsa_signal_store_screlease(self.hw_queue.contents.doorbell_signal, self.next_doorbell_index) + self._submit_packet(packet_cnt) def wait(self): signal = self.submit_barrier(need_signal=True) hsa.hsa_signal_wait_scacquire(signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE) - self.available_packet_slots = self.queue_size + self.available_packet_slots = self.queue_size // AQL_PACKET_SIZE def _wait_queue(self, need_packets=1): while self.available_packet_slots < need_packets: rindex = hsa.hsa_queue_load_read_index_relaxed(self.hw_queue) - self.available_packet_slots = self.queue_size - (self.next_doorbell_index - rindex) + self.available_packet_slots = self.queue_size // AQL_PACKET_SIZE - (self.next_doorbell_index - rindex) - def _submit_packet(self): - hsa.hsa_queue_store_write_index_relaxed(self.hw_queue, self.next_doorbell_index + 1) - hsa.hsa_signal_store_screlease(self.hw_queue.contents.doorbell_signal, self.next_doorbell_index) + def _submit_packet(self, cnt=1): + self.available_packet_slots -= cnt + self.next_doorbell_index += cnt + hsa.hsa_queue_store_write_index_relaxed(self.hw_queue, self.next_doorbell_index) + hsa.hsa_signal_store_screlease(self.hw_queue.contents.doorbell_signal, self.next_doorbell_index-1) - self.write_addr += AQL_PACKET_SIZE - if self.write_addr > self.write_addr_end: self.write_addr = self.hw_queue.contents.base_address - self.next_doorbell_index += 1 - self.available_packet_slots -= 1 + self.write_addr += AQL_PACKET_SIZE * cnt + if self.write_addr > self.write_addr_end: + self.write_addr = self.queue_base + (self.write_addr - self.queue_base) % self.queue_size def _alloc_signal(self, reusable=False): return self.device.alloc_signal(reusable=reusable)