hsa runtime (#3382)

* hsa init

* handles transfer

* linter

* clean up hwqueue

* fix sync freezes

* print errors
This commit is contained in:
nimlgen
2024-02-15 16:14:34 +03:00
committed by GitHub
parent 93eceef727
commit 002bf380b0
8 changed files with 4604 additions and 4 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,145 @@
import ctypes
import tinygrad.runtime.autogen.hsa as hsa
from tinygrad.helpers import init_c_var
def check(status):
if status != 0:
hsa.hsa_status_string(status, ctypes.byref(status_str := ctypes.POINTER(ctypes.c_char)()))
raise RuntimeError(f"HSA Error {status}: {ctypes.string_at(status_str).decode()}")
# Precalulated AQL info
AQL_PACKET_SIZE = ctypes.sizeof(hsa.hsa_kernel_dispatch_packet_t)
EMPTY_SIGNAL = hsa.hsa_signal_t()
DISPATCH_KERNEL_SETUP = 3 << hsa.HSA_KERNEL_DISPATCH_PACKET_SETUP_DIMENSIONS
DISPATCH_KERNEL_HEADER = 1 << hsa.HSA_PACKET_HEADER_BARRIER
DISPATCH_KERNEL_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE
DISPATCH_KERNEL_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE
DISPATCH_KERNEL_HEADER |= hsa.HSA_PACKET_TYPE_KERNEL_DISPATCH << hsa.HSA_PACKET_HEADER_TYPE
BARRIER_HEADER = 1 << hsa.HSA_PACKET_HEADER_BARRIER
BARRIER_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCACQUIRE_FENCE_SCOPE
BARRIER_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE
BARRIER_HEADER |= hsa.HSA_PACKET_TYPE_BARRIER_AND << hsa.HSA_PACKET_HEADER_TYPE
class HWQueue:
def __init__(self, device, sz=-1):
self.device = device
self.wait_signals = []
check(hsa.hsa_agent_get_info(self.device.agent, hsa.HSA_AGENT_INFO_QUEUE_MAX_SIZE, ctypes.byref(max_queue_size := ctypes.c_uint32())))
queue_size = min(max_queue_size.value, sz) if sz != -1 else max_queue_size.value
null_func = ctypes.CFUNCTYPE(None, hsa.hsa_status_t, ctypes.POINTER(hsa.struct_hsa_queue_s), ctypes.c_void_p)()
self.hw_queue = init_c_var(ctypes.POINTER(hsa.hsa_queue_t)(), lambda x: check(
hsa.hsa_queue_create(self.device.agent, queue_size, hsa.HSA_QUEUE_TYPE_SINGLE, null_func, None, (1<<32)-1, (1<<32)-1, ctypes.byref(x))))
self.next_doorbell_index = 0
self.queue_size = self.hw_queue.contents.size
self.write_addr = self.hw_queue.contents.base_address
self.write_addr_end = self.hw_queue.contents.base_address + (AQL_PACKET_SIZE * self.queue_size) - 1
self.available_packet_slots = self.queue_size
check(hsa.hsa_amd_queue_set_priority(self.hw_queue, hsa.HSA_AMD_QUEUE_PRIORITY_HIGH))
check(hsa.hsa_amd_profiling_set_profiler_enabled(self.hw_queue, 1))
def __del__(self):
if hasattr(self, 'hw_queue'): check(hsa.hsa_queue_destroy(self.hw_queue))
def submit_kernel(self, prg, global_size, local_size, kernargs, need_signal=False):
if self.available_packet_slots == 0: self._wait_queue()
signal = self.device.alloc_signal(reusable=True) if need_signal else EMPTY_SIGNAL
packet = hsa.hsa_kernel_dispatch_packet_t.from_address(self.write_addr)
packet.workgroup_size_x = local_size[0]
packet.workgroup_size_y = local_size[1]
packet.workgroup_size_z = local_size[2]
packet.reserved0 = 0
packet.grid_size_x = global_size[0] * local_size[0]
packet.grid_size_y = global_size[1] * local_size[1]
packet.grid_size_z = global_size[2] * local_size[2]
packet.private_segment_size = prg.private_segment_size
packet.group_segment_size = prg.group_segment_size
packet.kernel_object = prg.handle
packet.kernarg_address = kernargs
packet.reserved2 = 0
packet.completion_signal = signal
packet.setup = DISPATCH_KERNEL_SETUP
packet.header = DISPATCH_KERNEL_HEADER
self._submit_packet()
return signal
def submit_barrier(self, wait_signals=None, need_signal=False):
assert wait_signals is None or len(wait_signals) < 5
if self.available_packet_slots == 0: self._wait_queue()
signal = self.device.alloc_signal(reusable=True) if need_signal else EMPTY_SIGNAL
packet = hsa.hsa_barrier_and_packet_t.from_address(self.write_addr)
packet.reserved0 = 0
packet.reserved1 = 0
for i in range(5):
packet.dep_signal[i] = wait_signals[i] if wait_signals and len(wait_signals) > i else EMPTY_SIGNAL
packet.reserved2 = 0
packet.completion_signal = signal
packet.header = BARRIER_HEADER
self._submit_packet()
return signal
def wait(self):
signal = self.submit_barrier(need_signal=True)
hsa.hsa_signal_wait_scacquire(signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
self.available_packet_slots = self.queue_size
def _wait_queue(self):
while self.available_packet_slots == 0:
rindex = hsa.hsa_queue_load_read_index_relaxed(self.hw_queue)
self.available_packet_slots = self.queue_size - (self.next_doorbell_index - rindex)
def _submit_packet(self):
hsa.hsa_queue_store_write_index_relaxed(self.hw_queue, self.next_doorbell_index + 1)
hsa.hsa_signal_store_screlease(self.hw_queue.contents.doorbell_signal, self.next_doorbell_index)
self.write_addr += AQL_PACKET_SIZE
if self.write_addr > self.write_addr_end: self.write_addr = self.hw_queue.contents.base_address
self.next_doorbell_index += 1
self.available_packet_slots -= 1
def find_agent(typ, device_id):
@ctypes.CFUNCTYPE(hsa.hsa_status_t, hsa.hsa_agent_t, ctypes.c_void_p)
def __filter_agents(agent, data):
status = hsa.hsa_agent_get_info(agent, hsa.HSA_AGENT_INFO_DEVICE, ctypes.byref(device_type := hsa.hsa_device_type_t()))
if status == 0 and device_type.value == typ:
ret = ctypes.cast(data, ctypes.POINTER(hsa.hsa_agent_t))
if ret[0].handle == device_id:
ret[0] = agent
return hsa.HSA_STATUS_INFO_BREAK
ret[0].handle = ret[0].handle + 1
return hsa.HSA_STATUS_SUCCESS
hsa.hsa_iterate_agents(__filter_agents, ctypes.byref(agent := hsa.hsa_agent_t()))
return agent
def find_memory_pool(agent, segtyp=-1, flags=-1, location=-1):
@ctypes.CFUNCTYPE(hsa.hsa_status_t, hsa.hsa_amd_memory_pool_t, ctypes.c_void_p)
def __filter_amd_memory_pools(mem_pool, data):
check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_SEGMENT, ctypes.byref(segment := hsa.hsa_amd_segment_t())))
if segtyp >= 0 and segment.value != segtyp: return hsa.HSA_STATUS_SUCCESS
check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_GLOBAL_FLAGS, ctypes.byref(fgs := hsa.hsa_amd_memory_pool_global_flag_t()))) # noqa: E501
if flags >= 0 and (fgs.value & flags) == flags: return hsa.HSA_STATUS_SUCCESS
check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_LOCATION, ctypes.byref(loc:=hsa.hsa_amd_memory_pool_location_t())))
if location >= 0 and loc.value != location: return hsa.HSA_STATUS_SUCCESS
check(hsa.hsa_amd_memory_pool_get_info(mem_pool, hsa.HSA_AMD_MEMORY_POOL_INFO_SIZE, ctypes.byref(sz := ctypes.c_size_t())))
if sz.value == 0: return hsa.HSA_STATUS_SUCCESS
ret = ctypes.cast(data, ctypes.POINTER(hsa.hsa_amd_memory_pool_t))
ret[0] = mem_pool
return hsa.HSA_STATUS_INFO_BREAK
hsa.hsa_amd_agent_iterate_memory_pools(agent, __filter_amd_memory_pools, ctypes.byref(region := hsa.hsa_amd_memory_pool_t()))
return region

View File

@@ -126,7 +126,7 @@ class HIPAllocator(LRUAllocator):
self.full_synchronize()
hip_set_device(self.device.device)
check(hip.hipMemcpy(from_mv(dest), src, len(dest), hip.hipMemcpyDeviceToHost))
def transfer(self, dest:T, src:T, sz:int):
def transfer(self, dest:T, src:T, sz:int, **kwargs):
hip_set_device(self.device.device)
check(hip.hipMemcpyAsync(dest, src, sz, hip.hipMemcpyDeviceToDevice, None))

211
tinygrad/runtime/ops_hsa.py Normal file
View File

@@ -0,0 +1,211 @@
from __future__ import annotations
import ctypes, functools, subprocess, io, atexit
from typing import Tuple, TypeVar, List
import tinygrad.runtime.autogen.hsa as hsa
from tinygrad.helpers import DEBUG, init_c_var, from_mv, round_up, to_mv, init_c_struct_t
from tinygrad.device import Compiled, LRUAllocator
from tinygrad.runtime.ops_hip import HIPCompiler
from tinygrad.runtime.driver.hsa import check, find_agent, find_memory_pool, HWQueue
HSACompiler = HIPCompiler
class HSAProgram:
def __init__(self, device:HSADevice, name:str, lib:bytes):
self.device, self.name, self.lib = device, name, lib
if DEBUG >= 6:
asm = subprocess.check_output(["/opt/rocm/llvm/bin/llvm-objdump", '-d', '-'], input=lib)
print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x]))
self.exec = init_c_var(hsa.hsa_executable_t(), lambda x: check(hsa.hsa_executable_create_alt(hsa.HSA_PROFILE_FULL, hsa.HSA_DEFAULT_FLOAT_ROUNDING_MODE_DEFAULT, None, ctypes.byref(x)))) # noqa: E501
check(hsa.hsa_code_object_reader_create_from_memory(lib, len(lib), ctypes.byref(code_reader := hsa.hsa_code_object_reader_t())))
check(hsa.hsa_executable_load_agent_code_object(self.exec, self.device.agent, code_reader, None, None))
check(hsa.hsa_executable_freeze(self.exec, None))
self.kernel = init_c_var(hsa.hsa_executable_symbol_t(), lambda x: check(hsa.hsa_executable_get_symbol_by_name(self.exec, (name+".kd").encode("utf-8"), ctypes.byref(self.device.agent), ctypes.byref(x)))) # noqa: E501
self.handle = init_c_var(ctypes.c_uint64(), lambda x: check(hsa.hsa_executable_symbol_get_info(self.kernel, hsa.HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_OBJECT, ctypes.byref(x)))) # noqa: E501
self.kernargs_segment_size = init_c_var(ctypes.c_uint32(), lambda x: check(hsa.hsa_executable_symbol_get_info(self.kernel, hsa.HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_KERNARG_SEGMENT_SIZE, ctypes.byref(x)))).value # noqa: E501
self.group_segment_size = init_c_var(ctypes.c_uint32(), lambda x: check(hsa.hsa_executable_symbol_get_info(self.kernel, hsa.HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_GROUP_SEGMENT_SIZE, ctypes.byref(x)))).value # noqa: E501
self.private_segment_size = init_c_var(ctypes.c_uint32(), lambda x: check(hsa.hsa_executable_symbol_get_info(self.kernel, hsa.HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_PRIVATE_SEGMENT_SIZE, ctypes.byref(x)))).value # noqa: E501
check(hsa.hsa_code_object_reader_destroy(code_reader))
def __del__(self):
if hasattr(self, 'exec'): check(hsa.hsa_executable_destroy(self.exec))
def __call__(self, *args, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
if not hasattr(self, "args_struct_t"):
self.args_struct_t = init_c_struct_t(tuple([(f'f{i}', ctypes.c_void_p) for i in range(len(args))] +
[(f'v{i}', ctypes.c_int) for i in range(len(vals))]))
assert ctypes.sizeof(self.args_struct_t) == self.kernargs_segment_size, f"{ctypes.sizeof(self.args_struct_t)} != {self.kernargs_segment_size}"
kernargs = None
if self.kernargs_segment_size > 0:
kernargs = self.device.alloc_kernargs(self.kernargs_segment_size)
args_st = self.args_struct_t.from_address(kernargs)
for i in range(len(args)): args_st.__setattr__(f'f{i}', args[i])
for i in range(len(vals)): args_st.__setattr__(f'v{i}', vals[i])
signal = self.device.hw_queue.submit_kernel(self, global_size, local_size, kernargs, need_signal=wait)
if wait:
hsa.hsa_signal_wait_scacquire(signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
check(hsa.hsa_amd_profiling_get_dispatch_time(self.device.agent, signal, ctypes.byref(timings := hsa.hsa_amd_profiling_dispatch_time_t())))
return (timings.end - timings.start) * self.device.clocks_to_time
T = TypeVar("T")
CHUNK_SIZE, PAGE_SIZE = 256*1024*1024, 0x1000
class HSAAllocator(LRUAllocator):
def __init__(self, device:HSADevice):
self.device = device
super().__init__()
def _alloc(self, size:int):
c_agents = (hsa.hsa_agent_t * len(HSADevice.devices))(*[dev.agent for dev in HSADevice.devices])
check(hsa.hsa_amd_memory_pool_allocate(self.device.gpu_mempool, size, 0, ctypes.byref(buf := ctypes.c_void_p())))
check(hsa.hsa_amd_agents_allow_access(len(HSADevice.devices), c_agents, None, buf))
return buf.value
def _free(self, opaque:T):
self.device.synchronize()
check(hsa.hsa_amd_memory_pool_free(opaque))
def copyin(self, dest:T, src: memoryview):
# Async copyin sync model uses barriers on the main hw queue, since barriers are guaranteed to execute in order with all other packets.
copy_signal = self.device.alloc_signal(reusable=True)
sync_signal = self.device.hw_queue.submit_barrier(need_signal=True)
c_agents = (hsa.hsa_agent_t*2)(*[HSADevice.cpu_agent, self.device.agent])
check(hsa.hsa_amd_memory_pool_allocate(HSADevice.cpu_mempool, src.nbytes, 0, ctypes.byref(mem := ctypes.c_void_p())))
check(hsa.hsa_amd_agents_allow_access(2, c_agents, None, mem))
ctypes.memmove(mem, from_mv(src), src.nbytes)
check(hsa.hsa_amd_memory_async_copy_on_engine(dest, self.device.agent, mem, HSADevice.cpu_agent, src.nbytes,
1, ctypes.byref(sync_signal), copy_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True))
self.device.hw_queue.submit_barrier(wait_signals=[copy_signal])
self.device.delayed_free.append(mem)
def copy_from_fd(self, dest, fd, offset, size):
sync_signal = self.device.hw_queue.submit_barrier(need_signal=True)
if not hasattr(self, 'hb'):
c_agents = (hsa.hsa_agent_t*2)(*[HSADevice.cpu_agent, self.device.agent])
self.hb = []
for _ in range(2):
check(hsa.hsa_amd_memory_pool_allocate(HSADevice.cpu_mempool, CHUNK_SIZE, 0, ctypes.byref(mem := ctypes.c_void_p())))
check(hsa.hsa_amd_agents_allow_access(2, c_agents, None, mem))
self.hb.append(mem.value)
self.hb_signals = [self.device.alloc_signal(reusable=False) for _ in range(2)]
self.hb_polarity = 0
self.sdma = [hsa.HSA_AMD_SDMA_ENGINE_0, hsa.HSA_AMD_SDMA_ENGINE_1]
for sig in self.hb_signals: hsa.hsa_signal_store_relaxed(sig, 0)
fo = io.FileIO(fd, "a+b", closefd=False)
fo.seek(offset - (minor_offset:=offset % PAGE_SIZE))
copies_called = 0
copied_in = 0
for local_offset in range(0, size+minor_offset, CHUNK_SIZE):
local_size = min(round_up(size+minor_offset, PAGE_SIZE)-local_offset, CHUNK_SIZE)
copy_size = min(local_size-minor_offset, size-copied_in)
if copy_size == 0: break
hsa.hsa_signal_wait_scacquire(self.hb_signals[self.hb_polarity], hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
self.device.reusable_signals.append(self.hb_signals[self.hb_polarity]) # it's free now and can be reused
self.hb_signals[self.hb_polarity] = self.device.alloc_signal(reusable=False)
fo.readinto(to_mv(self.hb[self.hb_polarity], local_size))
check(hsa.hsa_amd_memory_async_copy_on_engine(dest+copied_in, self.device.agent, self.hb[self.hb_polarity]+minor_offset, HSADevice.cpu_agent,
copy_size, 1, ctypes.byref(sync_signal), self.hb_signals[self.hb_polarity],
self.sdma[self.hb_polarity], True))
copied_in += copy_size
self.hb_polarity = (self.hb_polarity + 1) % len(self.hb)
minor_offset = 0 # only on the first
copies_called += 1
wait_signals = [self.hb_signals[self.hb_polarity - 1]]
if copies_called > 1: wait_signals.append(self.hb_signals[self.hb_polarity])
self.device.hw_queue.submit_barrier(wait_signals=wait_signals)
def copyout(self, dest:memoryview, src:T):
self.device.synchronize()
copy_signal = self.device.alloc_signal(reusable=True)
c_agents = (hsa.hsa_agent_t*2)(*[HSADevice.cpu_agent, self.device.agent])
check(hsa.hsa_amd_memory_lock_to_pool(from_mv(dest), dest.nbytes, c_agents, 2, HSADevice.cpu_mempool, 0, ctypes.byref(addr:=ctypes.c_void_p())))
check(hsa.hsa_amd_memory_async_copy(addr, HSADevice.cpu_agent, src, self.device.agent, dest.nbytes, 0, None, copy_signal))
hsa.hsa_signal_wait_scacquire(copy_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
check(hsa.hsa_amd_memory_unlock(from_mv(dest)))
def transfer(self, dest:T, src:T, sz:int, src_dev=None, dest_dev=None):
copy_signal = dest_dev.alloc_signal(reusable=False)
sync_signal_1 = src_dev.hw_queue.submit_barrier(need_signal=True)
sync_signal_2 = dest_dev.hw_queue.submit_barrier(need_signal=True)
c_wait_signal = (hsa.hsa_signal_t*2)(sync_signal_1, sync_signal_2)
check(hsa.hsa_amd_memory_async_copy_on_engine(dest, dest_dev.agent, src, src_dev.agent, sz, 2, c_wait_signal, copy_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True)) # noqa: E501
src_dev.hw_queue.submit_barrier(wait_signals=[copy_signal])
dest_dev.hw_queue.submit_barrier(wait_signals=[copy_signal])
class HSADevice(Compiled):
cpu_agent = None
cpu_mempool = None
devices: List[HSADevice] = []
def __init__(self, device:str=""):
if not HSADevice.cpu_agent:
check(hsa.hsa_init())
atexit.register(lambda: hsa.hsa_shut_down())
HSADevice.cpu_agent = find_agent(hsa.HSA_DEVICE_TYPE_CPU, device_id=0)
HSADevice.cpu_mempool = find_memory_pool(HSADevice.cpu_agent, segtyp=hsa.HSA_AMD_SEGMENT_GLOBAL, location=hsa.HSA_AMD_MEMORY_POOL_LOCATION_CPU)
self.device_id = int(device.split(":")[1]) if ":" in device else 0
self.agent = find_agent(hsa.HSA_DEVICE_TYPE_GPU, device_id=self.device_id)
self.gpu_mempool = find_memory_pool(self.agent, segtyp=hsa.HSA_AMD_SEGMENT_GLOBAL, location=hsa.HSA_AMD_MEMORY_POOL_LOCATION_GPU)
self.kernargs_pool = find_memory_pool(self.agent, segtyp=hsa.HSA_AMD_SEGMENT_GLOBAL, flags=hsa.HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_KERNARG_INIT)
self.hw_queue = HWQueue(self)
HSADevice.devices.append(self)
check(hsa.hsa_agent_get_info(self.agent, hsa.HSA_AGENT_INFO_NAME, ctypes.byref(agent_name_buf := ctypes.create_string_buffer(256))))
self.arch = ctypes.string_at(agent_name_buf).decode()
check(hsa.hsa_system_get_info(hsa.HSA_SYSTEM_INFO_TIMESTAMP_FREQUENCY, ctypes.byref(gpu_freq := ctypes.c_uint64())))
self.clocks_to_time: float = 1 / gpu_freq.value
self.kernarg_pool_sz = 16 << 20
self.kernarg_start_addr = init_c_var(ctypes.c_void_p(), lambda x: check(hsa.hsa_amd_memory_pool_allocate(self.kernargs_pool, self.kernarg_pool_sz, 0, ctypes.byref(x)))).value # noqa: E501
self.kernarg_next_addr = self.kernarg_start_addr
self.delayed_free: List[ctypes.c_void_p] = []
self.signal_pool: List[hsa.hsa_signal_t] = []
self.reusable_signals: List[hsa.hsa_signal_t] = []
for _ in range(4096):
check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(signal := hsa.hsa_signal_t())))
self.signal_pool.append(signal)
super().__init__(device, HSAAllocator(self), HSACompiler(self.arch), functools.partial(HSAProgram, self), None)
def synchronize(self):
self.hw_queue.wait()
for sig in self.reusable_signals: hsa.hsa_signal_silent_store_relaxed(sig, 1)
self.signal_pool.extend(self.reusable_signals)
self.reusable_signals.clear()
for opaque_to_free in self.delayed_free: check(hsa.hsa_amd_memory_pool_free(opaque_to_free))
self.delayed_free.clear()
self.kernarg_next_addr = self.kernarg_start_addr
def alloc_signal(self, reusable=False):
if len(self.signal_pool): signal = self.signal_pool.pop()
else: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(signal := hsa.hsa_signal_t())))
# reusable means a signal could be reused after synchronize for the device it's allocated from is called.
if reusable: self.reusable_signals.append(signal)
return signal
def alloc_kernargs(self, sz):
if self.kernarg_next_addr + sz >= self.kernarg_start_addr + self.kernarg_pool_sz:
self.delayed_free.append(self.kernarg_start_addr)
self.kernarg_pool_sz = int(self.kernarg_pool_sz * 2)
self.kernarg_start_addr = init_c_var(ctypes.c_void_p(), lambda x: check(hsa.hsa_amd_memory_pool_allocate(self.kernargs_pool, self.kernarg_pool_sz, 0, ctypes.byref(x)))).value # noqa: E501
self.kernarg_next_addr = self.kernarg_start_addr
result = self.kernarg_next_addr
self.kernarg_next_addr = (self.kernarg_next_addr + sz + 15) & (~15) # align to 16 bytes
return result

View File

@@ -60,7 +60,7 @@ class MetalAllocator(LRUAllocator):
ret = self.device.device.newBufferWithLength_options_(size, Metal.MTLResourceStorageModeShared)
if ret is None: raise MemoryError(f"Metal OOM while allocating {size=}")
return ret
def transfer(self, dest:Any, src:Any, sz:int):
def transfer(self, dest:Any, src:Any, sz:int, **kwargs):
command_buffer = self.device.mtl_queue.commandBuffer()
encoder = command_buffer.blitCommandEncoder()
encoder.copyFromBuffer_sourceOffset_toBuffer_destinationOffset_size_(src, 0, dest, 0, sz)