mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
some hsa lines saving + fixes (#3752)
* fix write to ring + some lines * hsa driver test
This commit is contained in:
96
test/external/external_test_hsa_driver.py
vendored
Normal file
96
test/external/external_test_hsa_driver.py
vendored
Normal file
@@ -0,0 +1,96 @@
|
||||
import ctypes, unittest
|
||||
from tinygrad.helpers import init_c_struct_t
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.runtime.driver.hsa import AQLQueue
|
||||
from tinygrad.runtime.graph.hsa import VirtAQLQueue
|
||||
|
||||
def get_hsa_inc_prog(dev, inc=1):
|
||||
prg = f"""
|
||||
extern "C" __attribute__((global)) void test_inc(int* data0) {{
|
||||
data0[0] = (data0[0]+{inc});
|
||||
}}
|
||||
"""
|
||||
return dev.runtime("test_inc", dev.compiler.compile(prg))
|
||||
|
||||
def get_hsa_buffer_and_kernargs(dev):
|
||||
test_buf = Buffer(Device.DEFAULT, 1, dtypes.int)
|
||||
test_buf.copyin(memoryview(bytearray(4))) # zero mem
|
||||
assert test_buf.as_buffer().cast('I')[0] == 0 # check mem is visible + sync to exec
|
||||
|
||||
args_struct_t = init_c_struct_t(tuple([('f0', ctypes.c_void_p)]))
|
||||
kernargs = dev.alloc_kernargs(8)
|
||||
args_st = args_struct_t.from_address(kernargs)
|
||||
args_st.__setattr__('f0', test_buf._buf)
|
||||
dev.flush_hdp()
|
||||
return test_buf, kernargs
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "HSA", "only run on HSA")
|
||||
class TestHSADriver(unittest.TestCase):
|
||||
def test_hsa_simple_enqueue(self):
|
||||
dev = Device[Device.DEFAULT]
|
||||
queue = AQLQueue(dev, sz=256)
|
||||
|
||||
clprg = get_hsa_inc_prog(dev, inc=1)
|
||||
test_buf, kernargs = get_hsa_buffer_and_kernargs(dev)
|
||||
|
||||
queue.submit_kernel(clprg, [1,1,1], [1,1,1], kernargs)
|
||||
queue.wait()
|
||||
|
||||
assert test_buf.as_buffer().cast('I')[0] == 1, f"{test_buf.as_buffer().cast('I')[0]} != 1, all packets executed?"
|
||||
del queue
|
||||
|
||||
def test_hsa_ring_enqueue(self):
|
||||
dev = Device[Device.DEFAULT]
|
||||
|
||||
queue_size = 256
|
||||
exec_cnt = int(queue_size * 1.5)
|
||||
queue = AQLQueue(dev, sz=queue_size)
|
||||
|
||||
clprg_inc1 = get_hsa_inc_prog(dev, inc=1)
|
||||
clprg_inc2 = get_hsa_inc_prog(dev, inc=2)
|
||||
test_buf, kernargs = get_hsa_buffer_and_kernargs(dev)
|
||||
|
||||
for _ in range(exec_cnt):
|
||||
queue.submit_kernel(clprg_inc1, [1,1,1], [1,1,1], kernargs)
|
||||
for _ in range(exec_cnt):
|
||||
queue.submit_kernel(clprg_inc2, [1,1,1], [1,1,1], kernargs)
|
||||
queue.wait()
|
||||
|
||||
expected = exec_cnt + exec_cnt * 2
|
||||
assert test_buf.as_buffer().cast('I')[0] == expected, f"{test_buf.as_buffer().cast('I')[0]} != {expected}, all packets executed?"
|
||||
del queue
|
||||
|
||||
def test_hsa_blit_enqueue(self):
|
||||
dev = Device[Device.DEFAULT]
|
||||
|
||||
queue_size = 256
|
||||
exec_cnt = 178
|
||||
queue = AQLQueue(dev, sz=queue_size)
|
||||
|
||||
test_buf, kernargs = get_hsa_buffer_and_kernargs(dev)
|
||||
|
||||
# Using VirtAQLQueue to blit them
|
||||
virt_queue_packets_cnt = 31
|
||||
virt_queue = VirtAQLQueue(dev, sz=virt_queue_packets_cnt)
|
||||
|
||||
clprogs = []
|
||||
sum_per_blit = 0
|
||||
for i in range(virt_queue_packets_cnt):
|
||||
sum_per_blit += i+1
|
||||
clprogs.append(get_hsa_inc_prog(dev, inc=i+1))
|
||||
|
||||
for i in range(virt_queue_packets_cnt):
|
||||
virt_queue.submit_kernel(clprogs[i], [1,1,1], [1,1,1], kernargs)
|
||||
|
||||
for _ in range(exec_cnt):
|
||||
queue.blit_packets(virt_queue.queue_base, virt_queue.packets_count)
|
||||
queue.wait()
|
||||
|
||||
expected = exec_cnt * sum_per_blit
|
||||
assert test_buf.as_buffer().cast('I')[0] == expected, f"{test_buf.as_buffer().cast('I')[0]} != {expected}, all packets executed?"
|
||||
del queue, clprogs
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -34,10 +34,11 @@ class AQLQueue:
|
||||
hsa.hsa_queue_create(self.device.agent, queue_size, hsa.HSA_QUEUE_TYPE_SINGLE, null_func, None, (1<<32)-1, (1<<32)-1, ctypes.byref(x))))
|
||||
|
||||
self.next_doorbell_index = 0
|
||||
self.queue_size = self.hw_queue.contents.size
|
||||
self.write_addr = self.hw_queue.contents.base_address
|
||||
self.write_addr_end = self.hw_queue.contents.base_address + (AQL_PACKET_SIZE * self.queue_size) - 1
|
||||
self.available_packet_slots = self.queue_size
|
||||
self.queue_base = self.hw_queue.contents.base_address
|
||||
self.queue_size = self.hw_queue.contents.size * AQL_PACKET_SIZE # in bytes
|
||||
self.write_addr = self.queue_base
|
||||
self.write_addr_end = self.queue_base + self.queue_size - 1 # precalc saves some time
|
||||
self.available_packet_slots = self.hw_queue.contents.size
|
||||
|
||||
check(hsa.hsa_amd_queue_set_priority(self.hw_queue, hsa.HSA_AMD_QUEUE_PRIORITY_HIGH))
|
||||
check(hsa.hsa_amd_profiling_set_profiler_enabled(self.hw_queue, 1))
|
||||
@@ -89,37 +90,32 @@ class AQLQueue:
|
||||
def blit_packets(self, packet_addr, packet_cnt):
|
||||
if self.available_packet_slots < packet_cnt: self._wait_queue(packet_cnt)
|
||||
|
||||
tail_blit_packets = min(((self.write_addr_end + 1) - self.write_addr) // AQL_PACKET_SIZE, packet_cnt)
|
||||
tail_blit_packets = min((self.queue_base + self.queue_size - self.write_addr) // AQL_PACKET_SIZE, packet_cnt)
|
||||
rem_packet_cnt = packet_cnt - tail_blit_packets
|
||||
ctypes.memmove(self.write_addr, packet_addr, AQL_PACKET_SIZE * tail_blit_packets)
|
||||
self.write_addr += AQL_PACKET_SIZE * tail_blit_packets
|
||||
if self.write_addr > self.write_addr_end: self.write_addr = self.hw_queue.contents.base_address
|
||||
if tail_blit_packets > 0:
|
||||
ctypes.memmove(self.write_addr, packet_addr + AQL_PACKET_SIZE * tail_blit_packets, AQL_PACKET_SIZE * rem_packet_cnt)
|
||||
self.write_addr += AQL_PACKET_SIZE * rem_packet_cnt
|
||||
if rem_packet_cnt > 0: ctypes.memmove(self.queue_base, packet_addr + AQL_PACKET_SIZE * tail_blit_packets, AQL_PACKET_SIZE * rem_packet_cnt)
|
||||
|
||||
self.next_doorbell_index += packet_cnt
|
||||
hsa.hsa_queue_store_write_index_screlease(self.hw_queue, self.next_doorbell_index + 1)
|
||||
hsa.hsa_signal_store_screlease(self.hw_queue.contents.doorbell_signal, self.next_doorbell_index)
|
||||
self._submit_packet(packet_cnt)
|
||||
|
||||
def wait(self):
|
||||
signal = self.submit_barrier(need_signal=True)
|
||||
hsa.hsa_signal_wait_scacquire(signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
||||
self.available_packet_slots = self.queue_size
|
||||
self.available_packet_slots = self.queue_size // AQL_PACKET_SIZE
|
||||
|
||||
def _wait_queue(self, need_packets=1):
|
||||
while self.available_packet_slots < need_packets:
|
||||
rindex = hsa.hsa_queue_load_read_index_relaxed(self.hw_queue)
|
||||
self.available_packet_slots = self.queue_size - (self.next_doorbell_index - rindex)
|
||||
self.available_packet_slots = self.queue_size // AQL_PACKET_SIZE - (self.next_doorbell_index - rindex)
|
||||
|
||||
def _submit_packet(self):
|
||||
hsa.hsa_queue_store_write_index_relaxed(self.hw_queue, self.next_doorbell_index + 1)
|
||||
hsa.hsa_signal_store_screlease(self.hw_queue.contents.doorbell_signal, self.next_doorbell_index)
|
||||
def _submit_packet(self, cnt=1):
|
||||
self.available_packet_slots -= cnt
|
||||
self.next_doorbell_index += cnt
|
||||
hsa.hsa_queue_store_write_index_relaxed(self.hw_queue, self.next_doorbell_index)
|
||||
hsa.hsa_signal_store_screlease(self.hw_queue.contents.doorbell_signal, self.next_doorbell_index-1)
|
||||
|
||||
self.write_addr += AQL_PACKET_SIZE
|
||||
if self.write_addr > self.write_addr_end: self.write_addr = self.hw_queue.contents.base_address
|
||||
self.next_doorbell_index += 1
|
||||
self.available_packet_slots -= 1
|
||||
self.write_addr += AQL_PACKET_SIZE * cnt
|
||||
if self.write_addr > self.write_addr_end:
|
||||
self.write_addr = self.queue_base + (self.write_addr - self.queue_base) % self.queue_size
|
||||
|
||||
def _alloc_signal(self, reusable=False): return self.device.alloc_signal(reusable=reusable)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user