mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-15 17:15:48 -05:00
hsa runtime (#3382)
* hsa init * handles transfer * linter * clean up hwqueue * fix sync freezes * print errors
This commit is contained in:
4229
tinygrad/runtime/autogen/hsa.py
Normal file
4229
tinygrad/runtime/autogen/hsa.py
Normal file
File diff suppressed because it is too large
Load Diff
145
tinygrad/runtime/driver/hsa.py
Normal file
145
tinygrad/runtime/driver/hsa.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import ctypes
|
||||
import tinygrad.runtime.autogen.hsa as hsa
|
||||
from tinygrad.helpers import init_c_var
|
||||
|
||||
def check(status):
|
||||
if status != 0:
|
||||
hsa.hsa_status_string(status, ctypes.byref(status_str := ctypes.POINTER(ctypes.c_char)()))
|
||||
raise RuntimeError(f"HSA Error {status}: {ctypes.string_at(status_str).decode()}")
|
||||
|
||||
# Precalulated AQL info
|
||||
AQL_PACKET_SIZE = ctypes.sizeof(hsa.hsa_kernel_dispatch_packet_t)
|
||||
EMPTY_SIGNAL = hsa.hsa_signal_t()
|
||||
|
||||
DISPATCH_KERNEL_SETUP = 3 << hsa.HSA_KERNEL_DISPATCH_PACKET_SETUP_DIMENSIONS
|
||||
DISPATCH_KERNEL_HEADER = 1 << hsa.HSA_PACKET_HEADER_BARRIER
|
||||
DISPATCH_KERNEL_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE
|
||||
DISPATCH_KERNEL_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE
|
||||
DISPATCH_KERNEL_HEADER |= hsa.HSA_PACKET_TYPE_KERNEL_DISPATCH << hsa.HSA_PACKET_HEADER_TYPE
|
||||
|
||||
BARRIER_HEADER = 1 << hsa.HSA_PACKET_HEADER_BARRIER
|
||||
BARRIER_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE
|
||||
BARRIER_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE
|
||||
BARRIER_HEADER |= hsa.HSA_PACKET_TYPE_BARRIER_AND << hsa.HSA_PACKET_HEADER_TYPE
|
||||
|
||||
class HWQueue:
|
||||
def __init__(self, device, sz=-1):
|
||||
self.device = device
|
||||
self.wait_signals = []
|
||||
|
||||
check(hsa.hsa_agent_get_info(self.device.agent, hsa.HSA_AGENT_INFO_QUEUE_MAX_SIZE, ctypes.byref(max_queue_size := ctypes.c_uint32())))
|
||||
queue_size = min(max_queue_size.value, sz) if sz != -1 else max_queue_size.value
|
||||
|
||||
null_func = ctypes.CFUNCTYPE(None, hsa.hsa_status_t, ctypes.POINTER(hsa.struct_hsa_queue_s), ctypes.c_void_p)()
|
||||
self.hw_queue = init_c_var(ctypes.POINTER(hsa.hsa_queue_t)(), lambda x: check(
|
||||
hsa.hsa_queue_create(self.device.agent, queue_size, hsa.HSA_QUEUE_TYPE_SINGLE, null_func, None, (1<<32)-1, (1<<32)-1, ctypes.byref(x))))
|
||||
|
||||
self.next_doorbell_index = 0
|
||||
self.queue_size = self.hw_queue.contents.size
|
||||
self.write_addr = self.hw_queue.contents.base_address
|
||||
self.write_addr_end = self.hw_queue.contents.base_address + (AQL_PACKET_SIZE * self.queue_size) - 1
|
||||
self.available_packet_slots = self.queue_size
|
||||
|
||||
check(hsa.hsa_amd_queue_set_priority(self.hw_queue, hsa.HSA_AMD_QUEUE_PRIORITY_HIGH))
|
||||
check(hsa.hsa_amd_profiling_set_profiler_enabled(self.hw_queue, 1))
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, 'hw_queue'): check(hsa.hsa_queue_destroy(self.hw_queue))
|
||||
|
||||
def submit_kernel(self, prg, global_size, local_size, kernargs, need_signal=False):
|
||||
if self.available_packet_slots == 0: self._wait_queue()
|
||||
signal = self.device.alloc_signal(reusable=True) if need_signal else EMPTY_SIGNAL
|
||||
|
||||
packet = hsa.hsa_kernel_dispatch_packet_t.from_address(self.write_addr)
|
||||
packet.workgroup_size_x = local_size[0]
|
||||
packet.workgroup_size_y = local_size[1]
|
||||
packet.workgroup_size_z = local_size[2]
|
||||
packet.reserved0 = 0
|
||||
packet.grid_size_x = global_size[0] * local_size[0]
|
||||
packet.grid_size_y = global_size[1] * local_size[1]
|
||||
packet.grid_size_z = global_size[2] * local_size[2]
|
||||
packet.private_segment_size = prg.private_segment_size
|
||||
packet.group_segment_size = prg.group_segment_size
|
||||
packet.kernel_object = prg.handle
|
||||
packet.kernarg_address = kernargs
|
||||
packet.reserved2 = 0
|
||||
packet.completion_signal = signal
|
||||
packet.setup = DISPATCH_KERNEL_SETUP
|
||||
packet.header = DISPATCH_KERNEL_HEADER
|
||||
self._submit_packet()
|
||||
|
||||
return signal
|
||||
|
||||
def submit_barrier(self, wait_signals=None, need_signal=False):
|
||||
assert wait_signals is None or len(wait_signals) < 5
|
||||
if self.available_packet_slots == 0: self._wait_queue()
|
||||
signal = self.device.alloc_signal(reusable=True) if need_signal else EMPTY_SIGNAL
|
||||
|
||||
packet = hsa.hsa_barrier_and_packet_t.from_address(self.write_addr)
|
||||
packet.reserved0 = 0
|
||||
packet.reserved1 = 0
|
||||
for i in range(5):
|
||||
packet.dep_signal[i] = wait_signals[i] if wait_signals and len(wait_signals) > i else EMPTY_SIGNAL
|
||||
packet.reserved2 = 0
|
||||
packet.completion_signal = signal
|
||||
packet.header = BARRIER_HEADER
|
||||
self._submit_packet()
|
||||
|
||||
return signal
|
||||
|
||||
def wait(self):
|
||||
signal = self.submit_barrier(need_signal=True)
|
||||
hsa.hsa_signal_wait_scacquire(signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
||||
self.available_packet_slots = self.queue_size
|
||||
|
||||
def _wait_queue(self):
|
||||
while self.available_packet_slots == 0:
|
||||
rindex = hsa.hsa_queue_load_read_index_relaxed(self.hw_queue)
|
||||
self.available_packet_slots = self.queue_size - (self.next_doorbell_index - rindex)
|
||||
|
||||
def _submit_packet(self):
|
||||
hsa.hsa_queue_store_write_index_relaxed(self.hw_queue, self.next_doorbell_index + 1)
|
||||
hsa.hsa_signal_store_screlease(self.hw_queue.contents.doorbell_signal, self.next_doorbell_index)
|
||||
|
||||
self.write_addr += AQL_PACKET_SIZE
|
||||
if self.write_addr > self.write_addr_end: self.write_addr = self.hw_queue.contents.base_address
|
||||
self.next_doorbell_index += 1
|
||||
self.available_packet_slots -= 1
|
||||
|
||||
|
||||
def find_agent(typ, device_id):
|
||||
@ctypes.CFUNCTYPE(hsa.hsa_status_t, hsa.hsa_agent_t, ctypes.c_void_p)
|
||||
def __filter_agents(agent, data):
|
||||
status = hsa.hsa_agent_get_info(agent, hsa.HSA_AGENT_INFO_DEVICE, ctypes.byref(device_type := hsa.hsa_device_type_t()))
|
||||
if status == 0 and device_type.value == typ:
|
||||
ret = ctypes.cast(data, ctypes.POINTER(hsa.hsa_agent_t))
|
||||
if ret[0].handle == device_id:
|
||||
ret[0] = agent
|
||||
return hsa.HSA_STATUS_INFO_BREAK
|
||||
ret[0].handle = ret[0].handle + 1
|
||||
return hsa.HSA_STATUS_SUCCESS
|
||||
|
||||
hsa.hsa_iterate_agents(__filter_agents, ctypes.byref(agent := hsa.hsa_agent_t()))
|
||||
return agent
|
||||
|
||||
def find_memory_pool(agent, segtyp=-1, flags=-1, location=-1):
|
||||
@ctypes.CFUNCTYPE(hsa.hsa_status_t, hsa.hsa_amd_memory_pool_t, ctypes.c_void_p)
|
||||
def __filter_amd_memory_pools(mem_pool, data):
|
||||
check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_SEGMENT, ctypes.byref(segment := hsa.hsa_amd_segment_t())))
|
||||
if segtyp >= 0 and segment.value != segtyp: return hsa.HSA_STATUS_SUCCESS
|
||||
|
||||
check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_GLOBAL_FLAGS, ctypes.byref(fgs := hsa.hsa_amd_memory_pool_global_flag_t()))) # noqa: E501
|
||||
if flags >= 0 and (fgs.value & flags) == flags: return hsa.HSA_STATUS_SUCCESS
|
||||
|
||||
check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_LOCATION, ctypes.byref(loc:=hsa.hsa_amd_memory_pool_location_t())))
|
||||
if location >= 0 and loc.value != location: return hsa.HSA_STATUS_SUCCESS
|
||||
|
||||
check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_SIZE, ctypes.byref(sz := ctypes.c_size_t())))
|
||||
if sz.value == 0: return hsa.HSA_STATUS_SUCCESS
|
||||
|
||||
ret = ctypes.cast(data, ctypes.POINTER(hsa.hsa_amd_memory_pool_t))
|
||||
ret[0] = mem_pool
|
||||
return hsa.HSA_STATUS_INFO_BREAK
|
||||
|
||||
hsa.hsa_amd_agent_iterate_memory_pools(agent, __filter_amd_memory_pools, ctypes.byref(region := hsa.hsa_amd_memory_pool_t()))
|
||||
return region
|
||||
@@ -126,7 +126,7 @@ class HIPAllocator(LRUAllocator):
|
||||
self.full_synchronize()
|
||||
hip_set_device(self.device.device)
|
||||
check(hip.hipMemcpy(from_mv(dest), src, len(dest), hip.hipMemcpyDeviceToHost))
|
||||
def transfer(self, dest:T, src:T, sz:int):
|
||||
def transfer(self, dest:T, src:T, sz:int, **kwargs):
|
||||
hip_set_device(self.device.device)
|
||||
check(hip.hipMemcpyAsync(dest, src, sz, hip.hipMemcpyDeviceToDevice, None))
|
||||
|
||||
|
||||
211
tinygrad/runtime/ops_hsa.py
Normal file
211
tinygrad/runtime/ops_hsa.py
Normal file
@@ -0,0 +1,211 @@
|
||||
from __future__ import annotations
|
||||
import ctypes, functools, subprocess, io, atexit
|
||||
from typing import Tuple, TypeVar, List
|
||||
import tinygrad.runtime.autogen.hsa as hsa
|
||||
from tinygrad.helpers import DEBUG, init_c_var, from_mv, round_up, to_mv, init_c_struct_t
|
||||
from tinygrad.device import Compiled, LRUAllocator
|
||||
from tinygrad.runtime.ops_hip import HIPCompiler
|
||||
from tinygrad.runtime.driver.hsa import check, find_agent, find_memory_pool, HWQueue
|
||||
|
||||
HSACompiler = HIPCompiler
|
||||
|
||||
class HSAProgram:
|
||||
def __init__(self, device:HSADevice, name:str, lib:bytes):
|
||||
self.device, self.name, self.lib = device, name, lib
|
||||
|
||||
if DEBUG >= 6:
|
||||
asm = subprocess.check_output(["/opt/rocm/llvm/bin/llvm-objdump", '-d', '-'], input=lib)
|
||||
print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x]))
|
||||
|
||||
self.exec = init_c_var(hsa.hsa_executable_t(), lambda x: check(hsa.hsa_executable_create_alt(hsa.HSA_PROFILE_FULL, hsa.HSA_DEFAULT_FLOAT_ROUNDING_MODE_DEFAULT, None, ctypes.byref(x)))) # noqa: E501
|
||||
check(hsa.hsa_code_object_reader_create_from_memory(lib, len(lib), ctypes.byref(code_reader := hsa.hsa_code_object_reader_t())))
|
||||
check(hsa.hsa_executable_load_agent_code_object(self.exec, self.device.agent, code_reader, None, None))
|
||||
check(hsa.hsa_executable_freeze(self.exec, None))
|
||||
|
||||
self.kernel = init_c_var(hsa.hsa_executable_symbol_t(), lambda x: check(hsa.hsa_executable_get_symbol_by_name(self.exec, (name+".kd").encode("utf-8"), ctypes.byref(self.device.agent), ctypes.byref(x)))) # noqa: E501
|
||||
self.handle = init_c_var(ctypes.c_uint64(), lambda x: check(hsa.hsa_executable_symbol_get_info(self.kernel, hsa.HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_OBJECT, ctypes.byref(x)))) # noqa: E501
|
||||
self.kernargs_segment_size = init_c_var(ctypes.c_uint32(), lambda x: check(hsa.hsa_executable_symbol_get_info(self.kernel, hsa.HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_KERNARG_SEGMENT_SIZE, ctypes.byref(x)))).value # noqa: E501
|
||||
self.group_segment_size = init_c_var(ctypes.c_uint32(), lambda x: check(hsa.hsa_executable_symbol_get_info(self.kernel, hsa.HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_GROUP_SEGMENT_SIZE, ctypes.byref(x)))).value # noqa: E501
|
||||
self.private_segment_size = init_c_var(ctypes.c_uint32(), lambda x: check(hsa.hsa_executable_symbol_get_info(self.kernel, hsa.HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_PRIVATE_SEGMENT_SIZE, ctypes.byref(x)))).value # noqa: E501
|
||||
|
||||
check(hsa.hsa_code_object_reader_destroy(code_reader))
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, 'exec'): check(hsa.hsa_executable_destroy(self.exec))
|
||||
|
||||
def __call__(self, *args, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
|
||||
if not hasattr(self, "args_struct_t"):
|
||||
self.args_struct_t = init_c_struct_t(tuple([(f'f{i}', ctypes.c_void_p) for i in range(len(args))] +
|
||||
[(f'v{i}', ctypes.c_int) for i in range(len(vals))]))
|
||||
assert ctypes.sizeof(self.args_struct_t) == self.kernargs_segment_size, f"{ctypes.sizeof(self.args_struct_t)} != {self.kernargs_segment_size}"
|
||||
|
||||
kernargs = None
|
||||
if self.kernargs_segment_size > 0:
|
||||
kernargs = self.device.alloc_kernargs(self.kernargs_segment_size)
|
||||
args_st = self.args_struct_t.from_address(kernargs)
|
||||
for i in range(len(args)): args_st.__setattr__(f'f{i}', args[i])
|
||||
for i in range(len(vals)): args_st.__setattr__(f'v{i}', vals[i])
|
||||
|
||||
signal = self.device.hw_queue.submit_kernel(self, global_size, local_size, kernargs, need_signal=wait)
|
||||
if wait:
|
||||
hsa.hsa_signal_wait_scacquire(signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
||||
check(hsa.hsa_amd_profiling_get_dispatch_time(self.device.agent, signal, ctypes.byref(timings := hsa.hsa_amd_profiling_dispatch_time_t())))
|
||||
return (timings.end - timings.start) * self.device.clocks_to_time
|
||||
|
||||
T = TypeVar("T")
|
||||
CHUNK_SIZE, PAGE_SIZE = 256*1024*1024, 0x1000
|
||||
class HSAAllocator(LRUAllocator):
|
||||
def __init__(self, device:HSADevice):
|
||||
self.device = device
|
||||
super().__init__()
|
||||
|
||||
def _alloc(self, size:int):
|
||||
c_agents = (hsa.hsa_agent_t * len(HSADevice.devices))(*[dev.agent for dev in HSADevice.devices])
|
||||
check(hsa.hsa_amd_memory_pool_allocate(self.device.gpu_mempool, size, 0, ctypes.byref(buf := ctypes.c_void_p())))
|
||||
check(hsa.hsa_amd_agents_allow_access(len(HSADevice.devices), c_agents, None, buf))
|
||||
return buf.value
|
||||
|
||||
def _free(self, opaque:T):
|
||||
self.device.synchronize()
|
||||
check(hsa.hsa_amd_memory_pool_free(opaque))
|
||||
|
||||
def copyin(self, dest:T, src: memoryview):
|
||||
# Async copyin sync model uses barriers on the main hw queue, since barriers are guaranteed to execute in order with all other packets.
|
||||
copy_signal = self.device.alloc_signal(reusable=True)
|
||||
sync_signal = self.device.hw_queue.submit_barrier(need_signal=True)
|
||||
c_agents = (hsa.hsa_agent_t*2)(*[HSADevice.cpu_agent, self.device.agent])
|
||||
check(hsa.hsa_amd_memory_pool_allocate(HSADevice.cpu_mempool, src.nbytes, 0, ctypes.byref(mem := ctypes.c_void_p())))
|
||||
check(hsa.hsa_amd_agents_allow_access(2, c_agents, None, mem))
|
||||
ctypes.memmove(mem, from_mv(src), src.nbytes)
|
||||
check(hsa.hsa_amd_memory_async_copy_on_engine(dest, self.device.agent, mem, HSADevice.cpu_agent, src.nbytes,
|
||||
1, ctypes.byref(sync_signal), copy_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True))
|
||||
self.device.hw_queue.submit_barrier(wait_signals=[copy_signal])
|
||||
self.device.delayed_free.append(mem)
|
||||
|
||||
def copy_from_fd(self, dest, fd, offset, size):
|
||||
sync_signal = self.device.hw_queue.submit_barrier(need_signal=True)
|
||||
|
||||
if not hasattr(self, 'hb'):
|
||||
c_agents = (hsa.hsa_agent_t*2)(*[HSADevice.cpu_agent, self.device.agent])
|
||||
self.hb = []
|
||||
for _ in range(2):
|
||||
check(hsa.hsa_amd_memory_pool_allocate(HSADevice.cpu_mempool, CHUNK_SIZE, 0, ctypes.byref(mem := ctypes.c_void_p())))
|
||||
check(hsa.hsa_amd_agents_allow_access(2, c_agents, None, mem))
|
||||
self.hb.append(mem.value)
|
||||
self.hb_signals = [self.device.alloc_signal(reusable=False) for _ in range(2)]
|
||||
self.hb_polarity = 0
|
||||
self.sdma = [hsa.HSA_AMD_SDMA_ENGINE_0, hsa.HSA_AMD_SDMA_ENGINE_1]
|
||||
for sig in self.hb_signals: hsa.hsa_signal_store_relaxed(sig, 0)
|
||||
|
||||
fo = io.FileIO(fd, "a+b", closefd=False)
|
||||
fo.seek(offset - (minor_offset:=offset % PAGE_SIZE))
|
||||
|
||||
copies_called = 0
|
||||
copied_in = 0
|
||||
for local_offset in range(0, size+minor_offset, CHUNK_SIZE):
|
||||
local_size = min(round_up(size+minor_offset, PAGE_SIZE)-local_offset, CHUNK_SIZE)
|
||||
copy_size = min(local_size-minor_offset, size-copied_in)
|
||||
if copy_size == 0: break
|
||||
|
||||
hsa.hsa_signal_wait_scacquire(self.hb_signals[self.hb_polarity], hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
||||
self.device.reusable_signals.append(self.hb_signals[self.hb_polarity]) # it's free now and can be reused
|
||||
self.hb_signals[self.hb_polarity] = self.device.alloc_signal(reusable=False)
|
||||
|
||||
fo.readinto(to_mv(self.hb[self.hb_polarity], local_size))
|
||||
check(hsa.hsa_amd_memory_async_copy_on_engine(dest+copied_in, self.device.agent, self.hb[self.hb_polarity]+minor_offset, HSADevice.cpu_agent,
|
||||
copy_size, 1, ctypes.byref(sync_signal), self.hb_signals[self.hb_polarity],
|
||||
self.sdma[self.hb_polarity], True))
|
||||
copied_in += copy_size
|
||||
self.hb_polarity = (self.hb_polarity + 1) % len(self.hb)
|
||||
minor_offset = 0 # only on the first
|
||||
copies_called += 1
|
||||
|
||||
wait_signals = [self.hb_signals[self.hb_polarity - 1]]
|
||||
if copies_called > 1: wait_signals.append(self.hb_signals[self.hb_polarity])
|
||||
self.device.hw_queue.submit_barrier(wait_signals=wait_signals)
|
||||
|
||||
def copyout(self, dest:memoryview, src:T):
|
||||
self.device.synchronize()
|
||||
copy_signal = self.device.alloc_signal(reusable=True)
|
||||
c_agents = (hsa.hsa_agent_t*2)(*[HSADevice.cpu_agent, self.device.agent])
|
||||
check(hsa.hsa_amd_memory_lock_to_pool(from_mv(dest), dest.nbytes, c_agents, 2, HSADevice.cpu_mempool, 0, ctypes.byref(addr:=ctypes.c_void_p())))
|
||||
check(hsa.hsa_amd_memory_async_copy(addr, HSADevice.cpu_agent, src, self.device.agent, dest.nbytes, 0, None, copy_signal))
|
||||
hsa.hsa_signal_wait_scacquire(copy_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
||||
check(hsa.hsa_amd_memory_unlock(from_mv(dest)))
|
||||
|
||||
def transfer(self, dest:T, src:T, sz:int, src_dev=None, dest_dev=None):
|
||||
copy_signal = dest_dev.alloc_signal(reusable=False)
|
||||
sync_signal_1 = src_dev.hw_queue.submit_barrier(need_signal=True)
|
||||
sync_signal_2 = dest_dev.hw_queue.submit_barrier(need_signal=True)
|
||||
c_wait_signal = (hsa.hsa_signal_t*2)(sync_signal_1, sync_signal_2)
|
||||
check(hsa.hsa_amd_memory_async_copy_on_engine(dest, dest_dev.agent, src, src_dev.agent, sz, 2, c_wait_signal, copy_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True)) # noqa: E501
|
||||
src_dev.hw_queue.submit_barrier(wait_signals=[copy_signal])
|
||||
dest_dev.hw_queue.submit_barrier(wait_signals=[copy_signal])
|
||||
|
||||
class HSADevice(Compiled):
|
||||
cpu_agent = None
|
||||
cpu_mempool = None
|
||||
devices: List[HSADevice] = []
|
||||
def __init__(self, device:str=""):
|
||||
if not HSADevice.cpu_agent:
|
||||
check(hsa.hsa_init())
|
||||
atexit.register(lambda: hsa.hsa_shut_down())
|
||||
HSADevice.cpu_agent = find_agent(hsa.HSA_DEVICE_TYPE_CPU, device_id=0)
|
||||
HSADevice.cpu_mempool = find_memory_pool(HSADevice.cpu_agent, segtyp=hsa.HSA_AMD_SEGMENT_GLOBAL, location=hsa.HSA_AMD_MEMORY_POOL_LOCATION_CPU)
|
||||
|
||||
self.device_id = int(device.split(":")[1]) if ":" in device else 0
|
||||
self.agent = find_agent(hsa.HSA_DEVICE_TYPE_GPU, device_id=self.device_id)
|
||||
self.gpu_mempool = find_memory_pool(self.agent, segtyp=hsa.HSA_AMD_SEGMENT_GLOBAL, location=hsa.HSA_AMD_MEMORY_POOL_LOCATION_GPU)
|
||||
self.kernargs_pool = find_memory_pool(self.agent, segtyp=hsa.HSA_AMD_SEGMENT_GLOBAL, flags=hsa.HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_KERNARG_INIT)
|
||||
self.hw_queue = HWQueue(self)
|
||||
HSADevice.devices.append(self)
|
||||
|
||||
check(hsa.hsa_agent_get_info(self.agent, hsa.HSA_AGENT_INFO_NAME, ctypes.byref(agent_name_buf := ctypes.create_string_buffer(256))))
|
||||
self.arch = ctypes.string_at(agent_name_buf).decode()
|
||||
|
||||
check(hsa.hsa_system_get_info(hsa.HSA_SYSTEM_INFO_TIMESTAMP_FREQUENCY, ctypes.byref(gpu_freq := ctypes.c_uint64())))
|
||||
self.clocks_to_time: float = 1 / gpu_freq.value
|
||||
|
||||
self.kernarg_pool_sz = 16 << 20
|
||||
self.kernarg_start_addr = init_c_var(ctypes.c_void_p(), lambda x: check(hsa.hsa_amd_memory_pool_allocate(self.kernargs_pool, self.kernarg_pool_sz, 0, ctypes.byref(x)))).value # noqa: E501
|
||||
self.kernarg_next_addr = self.kernarg_start_addr
|
||||
|
||||
self.delayed_free: List[ctypes.c_void_p] = []
|
||||
self.signal_pool: List[hsa.hsa_signal_t] = []
|
||||
self.reusable_signals: List[hsa.hsa_signal_t] = []
|
||||
for _ in range(4096):
|
||||
check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(signal := hsa.hsa_signal_t())))
|
||||
self.signal_pool.append(signal)
|
||||
|
||||
super().__init__(device, HSAAllocator(self), HSACompiler(self.arch), functools.partial(HSAProgram, self), None)
|
||||
|
||||
def synchronize(self):
|
||||
self.hw_queue.wait()
|
||||
|
||||
for sig in self.reusable_signals: hsa.hsa_signal_silent_store_relaxed(sig, 1)
|
||||
self.signal_pool.extend(self.reusable_signals)
|
||||
self.reusable_signals.clear()
|
||||
|
||||
for opaque_to_free in self.delayed_free: check(hsa.hsa_amd_memory_pool_free(opaque_to_free))
|
||||
self.delayed_free.clear()
|
||||
|
||||
self.kernarg_next_addr = self.kernarg_start_addr
|
||||
|
||||
def alloc_signal(self, reusable=False):
|
||||
if len(self.signal_pool): signal = self.signal_pool.pop()
|
||||
else: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(signal := hsa.hsa_signal_t())))
|
||||
|
||||
# reusable means a signal could be reused after synchronize for the device it's allocated from is called.
|
||||
if reusable: self.reusable_signals.append(signal)
|
||||
return signal
|
||||
|
||||
def alloc_kernargs(self, sz):
|
||||
if self.kernarg_next_addr + sz >= self.kernarg_start_addr + self.kernarg_pool_sz:
|
||||
self.delayed_free.append(self.kernarg_start_addr)
|
||||
self.kernarg_pool_sz = int(self.kernarg_pool_sz * 2)
|
||||
self.kernarg_start_addr = init_c_var(ctypes.c_void_p(), lambda x: check(hsa.hsa_amd_memory_pool_allocate(self.kernargs_pool, self.kernarg_pool_sz, 0, ctypes.byref(x)))).value # noqa: E501
|
||||
self.kernarg_next_addr = self.kernarg_start_addr
|
||||
|
||||
result = self.kernarg_next_addr
|
||||
self.kernarg_next_addr = (self.kernarg_next_addr + sz + 15) & (~15) # align to 16 bytes
|
||||
return result
|
||||
@@ -60,7 +60,7 @@ class MetalAllocator(LRUAllocator):
|
||||
ret = self.device.device.newBufferWithLength_options_(size, Metal.MTLResourceStorageModeShared)
|
||||
if ret is None: raise MemoryError(f"Metal OOM while allocating {size=}")
|
||||
return ret
|
||||
def transfer(self, dest:Any, src:Any, sz:int):
|
||||
def transfer(self, dest:Any, src:Any, sz:int, **kwargs):
|
||||
command_buffer = self.device.mtl_queue.commandBuffer()
|
||||
encoder = command_buffer.blitCommandEncoder()
|
||||
encoder.copyFromBuffer_sourceOffset_toBuffer_destinationOffset_size_(src, 0, dest, 0, sz)
|
||||
|
||||
Reference in New Issue
Block a user