Files
tinygrad/test/external/external_test_hsa_driver.py
George Hotz 744af193f0 remove ScheduleItem and merge it with ExecItem (#13759)
* remove ExecItem and merge it with ScheduleItem

* less diff

* fix issues

* min diff

* don't change bufs in _lower

* min diff

* update

* revert

* fixes

* diff
2025-12-19 17:04:24 -04:00

120 lines
4.2 KiB
Python

import ctypes, unittest
from tinygrad.helpers import init_c_struct_t
from tinygrad.device import Device, Buffer
from tinygrad.dtype import dtypes
from tinygrad.runtime.support.hsa import AQLQueue
from tinygrad.runtime.graph.hsa import VirtAQLQueue, HSAGraph
from tinygrad.engine.schedule import ExecItem
from tinygrad.engine.realize import BufferXfer
from tinygrad.uop.ops import UOp, Ops
def get_hsa_inc_prog(dev, inc=1):
prg = f"""
extern "C" __attribute__((global)) void test_inc(int* data0) {{
data0[0] = (data0[0]+{inc});
}}
"""
return dev.runtime("test_inc", dev.compiler.compile(prg))
def get_hsa_buffer_and_kernargs(dev):
test_buf = Buffer(Device.DEFAULT, 1, dtypes.int)
test_buf.copyin(memoryview(bytearray(4))) # zero mem
assert test_buf.as_buffer().cast('I')[0] == 0 # check mem is visible + sync to exec
args_struct_t = init_c_struct_t(tuple([('f0', ctypes.c_void_p)]))
kernargs = dev.alloc_kernargs(8)
args_st = args_struct_t.from_address(kernargs)
args_st.__setattr__('f0', test_buf._buf)
dev.flush_hdp()
return test_buf, kernargs
@unittest.skipUnless(Device.DEFAULT == "HSA", "only run on HSA")
class TestHSADriver(unittest.TestCase):
def test_hsa_simple_enqueue(self):
dev = Device[Device.DEFAULT]
queue = AQLQueue(dev, sz=256)
clprg = get_hsa_inc_prog(dev, inc=1)
test_buf, kernargs = get_hsa_buffer_and_kernargs(dev)
queue.submit_kernel(clprg, [1,1,1], [1,1,1], kernargs)
queue.wait()
assert test_buf.as_buffer().cast('I')[0] == 1, f"{test_buf.as_buffer().cast('I')[0]} != 1, all packets executed?"
del queue
def test_hsa_ring_enqueue(self):
dev = Device[Device.DEFAULT]
queue_size = 256
exec_cnt = int(queue_size * 1.5)
queue = AQLQueue(dev, sz=queue_size)
clprg_inc1 = get_hsa_inc_prog(dev, inc=1)
clprg_inc2 = get_hsa_inc_prog(dev, inc=2)
test_buf, kernargs = get_hsa_buffer_and_kernargs(dev)
for _ in range(exec_cnt):
queue.submit_kernel(clprg_inc1, [1,1,1], [1,1,1], kernargs)
for _ in range(exec_cnt):
queue.submit_kernel(clprg_inc2, [1,1,1], [1,1,1], kernargs)
queue.wait()
expected = exec_cnt + exec_cnt * 2
assert test_buf.as_buffer().cast('I')[0] == expected, f"{test_buf.as_buffer().cast('I')[0]} != {expected}, all packets executed?"
del queue
def test_hsa_blit_enqueue(self):
dev = Device[Device.DEFAULT]
queue_size = 256
exec_cnt = 178
queue = AQLQueue(dev, sz=queue_size)
test_buf, kernargs = get_hsa_buffer_and_kernargs(dev)
# Using VirtAQLQueue to blit them
virt_queue_packets_cnt = 31
virt_queue = VirtAQLQueue(dev, sz=virt_queue_packets_cnt)
clprogs = []
sum_per_blit = 0
for i in range(virt_queue_packets_cnt):
sum_per_blit += i+1
clprogs.append(get_hsa_inc_prog(dev, inc=i+1))
for i in range(virt_queue_packets_cnt):
virt_queue.submit_kernel(clprogs[i], [1,1,1], [1,1,1], kernargs)
for _ in range(exec_cnt):
queue.blit_packets(virt_queue.queue_base, virt_queue.packets_count)
queue.wait()
expected = exec_cnt * sum_per_blit
assert test_buf.as_buffer().cast('I')[0] == expected, f"{test_buf.as_buffer().cast('I')[0]} != {expected}, all packets executed?"
del queue, clprogs
def test_hsa_copies_sync(self):
d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
test_buf0 = Buffer(d0, 1, dtypes.int)
test_buf1 = Buffer(d0, 1, dtypes.int)
test_buf2 = Buffer(d1, 1, dtypes.int)
test_buf0.copyin(memoryview(bytearray(1*4)))
test_buf1.copyin(memoryview(bytearray(1*4)))
test_buf2.copyin(memoryview(bytearray(1*4)))
jit_cache = [ExecItem(UOp(Ops.NOOP), [test_buf0, test_buf2], prg=BufferXfer(test_buf0.nbytes, test_buf0.device, test_buf2.device)),
ExecItem(UOp(Ops.NOOP), [test_buf2, test_buf1], prg=BufferXfer(test_buf2.nbytes, test_buf2.device, test_buf1.device))]
graph = HSAGraph(jit_cache, [], {})
for i in range(10000):
test_buf0.copyin(memoryview(bytearray(1*4)))
test_buf2.copyin(memoryview(bytearray(int.to_bytes(4, length=1*4, byteorder='little'))))
graph([], {})
assert test_buf0.as_buffer().cast('I')[0] == 4
assert test_buf2.as_buffer().cast('I')[0] == 0
if __name__ == '__main__':
unittest.main()