mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-15 09:05:40 -05:00
hsa multigpu graph (#3403)
* init hsa multigraph * better handling of accesses to buffers * revert sdma0 only when copies from fd
This commit is contained in:
@@ -22,7 +22,7 @@ BARRIER_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCACQUIRE_
|
||||
BARRIER_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE
|
||||
BARRIER_HEADER |= hsa.HSA_PACKET_TYPE_BARRIER_AND << hsa.HSA_PACKET_HEADER_TYPE
|
||||
|
||||
class HWQueue:
|
||||
class AQLQueue:
|
||||
def __init__(self, device, sz=-1):
|
||||
self.device = device
|
||||
self.wait_signals = []
|
||||
@@ -48,7 +48,7 @@ class HWQueue:
|
||||
|
||||
def submit_kernel(self, prg, global_size, local_size, kernargs, need_signal=False):
|
||||
if self.available_packet_slots == 0: self._wait_queue()
|
||||
signal = self.device.alloc_signal(reusable=True) if need_signal else EMPTY_SIGNAL
|
||||
signal = self._alloc_signal(reusable=True) if need_signal else EMPTY_SIGNAL
|
||||
|
||||
packet = hsa.hsa_kernel_dispatch_packet_t.from_address(self.write_addr)
|
||||
packet.workgroup_size_x = local_size[0]
|
||||
@@ -70,10 +70,10 @@ class HWQueue:
|
||||
|
||||
return signal
|
||||
|
||||
def submit_barrier(self, wait_signals=None, need_signal=False):
|
||||
assert wait_signals is None or len(wait_signals) < 5
|
||||
def submit_barrier(self, wait_signals=None, need_signal=False, completion_signal=None):
|
||||
assert wait_signals is None or len(wait_signals) <= 5
|
||||
if self.available_packet_slots == 0: self._wait_queue()
|
||||
signal = self.device.alloc_signal(reusable=True) if need_signal else EMPTY_SIGNAL
|
||||
signal = (completion_signal or self._alloc_signal(reusable=True)) if need_signal else EMPTY_SIGNAL
|
||||
|
||||
packet = hsa.hsa_barrier_and_packet_t.from_address(self.write_addr)
|
||||
packet.reserved0 = 0
|
||||
@@ -87,13 +87,29 @@ class HWQueue:
|
||||
|
||||
return signal
|
||||
|
||||
def blit_packets(self, packet_addr, packet_cnt):
|
||||
if self.available_packet_slots < packet_cnt: self._wait_queue(packet_cnt)
|
||||
|
||||
tail_blit_packets = min(((self.write_addr_end + 1) - self.write_addr) // 64, packet_cnt)
|
||||
rem_packet_cnt = packet_cnt - tail_blit_packets
|
||||
ctypes.memmove(self.write_addr, packet_addr, AQL_PACKET_SIZE * tail_blit_packets)
|
||||
self.write_addr += AQL_PACKET_SIZE * tail_blit_packets
|
||||
if self.write_addr > self.write_addr_end: self.write_addr = self.hw_queue.contents.base_address
|
||||
if tail_blit_packets > 0:
|
||||
ctypes.memmove(self.write_addr, packet_addr + AQL_PACKET_SIZE * tail_blit_packets, AQL_PACKET_SIZE * rem_packet_cnt)
|
||||
self.write_addr += AQL_PACKET_SIZE * rem_packet_cnt
|
||||
|
||||
self.next_doorbell_index += packet_cnt
|
||||
hsa.hsa_queue_store_write_index_screlease(self.hw_queue, self.next_doorbell_index + 1)
|
||||
hsa.hsa_signal_store_screlease(self.hw_queue.contents.doorbell_signal, self.next_doorbell_index)
|
||||
|
||||
def wait(self):
|
||||
signal = self.submit_barrier(need_signal=True)
|
||||
hsa.hsa_signal_wait_scacquire(signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
||||
self.available_packet_slots = self.queue_size
|
||||
|
||||
def _wait_queue(self):
|
||||
while self.available_packet_slots == 0:
|
||||
def _wait_queue(self, need_packets=1):
|
||||
while self.available_packet_slots < need_packets:
|
||||
rindex = hsa.hsa_queue_load_read_index_relaxed(self.hw_queue)
|
||||
self.available_packet_slots = self.queue_size - (self.next_doorbell_index - rindex)
|
||||
|
||||
@@ -106,6 +122,7 @@ class HWQueue:
|
||||
self.next_doorbell_index += 1
|
||||
self.available_packet_slots -= 1
|
||||
|
||||
def _alloc_signal(self, reusable=False): return self.device.alloc_signal(reusable=reusable)
|
||||
|
||||
def find_agent(typ, device_id):
|
||||
@ctypes.CFUNCTYPE(hsa.hsa_status_t, hsa.hsa_agent_t, ctypes.c_void_p)
|
||||
|
||||
179
tinygrad/runtime/graph/hsa.py
Normal file
179
tinygrad/runtime/graph/hsa.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import ctypes, collections, time
|
||||
from typing import List, Any, Dict, cast, Optional, Union
|
||||
from tinygrad.helpers import GraphException, init_c_var
|
||||
from tinygrad.device import Compiled, Buffer, CompiledASTRunner, BufferXfer, MultiDeviceJITGraph, update_stats
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.runtime.ops_hsa import HSADevice
|
||||
from tinygrad.features.jit import JitItem, get_input_replace, get_jit_stats, \
|
||||
get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
|
||||
import tinygrad.runtime.autogen.hsa as hsa
|
||||
from tinygrad.runtime.driver.hsa import check, AQLQueue, AQL_PACKET_SIZE, EMPTY_SIGNAL
|
||||
|
||||
def dedup_signals(signals): return [hsa.hsa_signal_t(hndl) for hndl in set([x.handle for x in signals if isinstance(x, hsa.hsa_signal_t)])]
|
||||
|
||||
class VirtAQLQueue(AQLQueue):
|
||||
def __init__(self, device, sz):
|
||||
self.device = device
|
||||
self.virt_queue = (hsa.hsa_kernel_dispatch_packet_t * sz)()
|
||||
self.queue_base = self.write_addr = ctypes.addressof(self.virt_queue)
|
||||
self.packets_count = 0
|
||||
self.available_packet_slots = sz
|
||||
def _wait_queue(self, need_packets=1): assert False, f"VirtQueue is too small to handle {self.packets_count+need_packets} packets!"
|
||||
def _submit_packet(self):
|
||||
self.write_addr += AQL_PACKET_SIZE
|
||||
self.packets_count += 1
|
||||
self.available_packet_slots -= 1
|
||||
def _alloc_signal(self, reusable=False): return init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_signal_create(1, 0, None, ctypes.byref(x))))
|
||||
|
||||
class HSAGraph(MultiDeviceJITGraph):
|
||||
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
self.jit_cache = jit_cache
|
||||
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
|
||||
self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache) #type:ignore
|
||||
self.jc_idxs_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache)
|
||||
self.jc_idxs_with_updatable_var_vals = get_jc_idxs_with_updatable_var_vals(jit_cache)
|
||||
|
||||
# Check all jit items are compatible.
|
||||
compiled_devices = set()
|
||||
for ji in self.jit_cache:
|
||||
if isinstance(ji.prg, CompiledASTRunner): compiled_devices.add(ji.prg.device)
|
||||
elif isinstance(ji.prg, BufferXfer):
|
||||
for x in ji.rawbufs[0:2]: compiled_devices.add(cast(Buffer, x).d)
|
||||
else: raise GraphException
|
||||
if any(not isinstance(d, HSADevice) for d in compiled_devices): raise GraphException
|
||||
|
||||
self.devices: List[HSADevice] = list(compiled_devices) #type:ignore
|
||||
|
||||
# Allocate kernel args.
|
||||
kernargs_size: Dict[HSADevice, int] = collections.defaultdict(int)
|
||||
for ji in self.jit_cache:
|
||||
if isinstance(ji.prg, CompiledASTRunner): kernargs_size[cast(HSADevice, ji.prg.device)] += (ctypes.sizeof(ji.prg.clprg.args_struct_t)+15) & ~15
|
||||
|
||||
kernargs_ptrs: Dict[Compiled, int] = {}
|
||||
for dev,sz in kernargs_size.items():
|
||||
kernargs_ptrs[dev] = init_c_var(ctypes.c_void_p(),
|
||||
lambda x: check(hsa.hsa_amd_memory_pool_allocate(dev.kernargs_pool, sz, 0, ctypes.byref(x)))).value
|
||||
check(hsa.hsa_amd_agents_allow_access(1, ctypes.byref(dev.agent), None, kernargs_ptrs[dev]))
|
||||
|
||||
# Fill initial arguments.
|
||||
self.ji_kargs_structs: Dict[int, ctypes.Structure] = {}
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
if not isinstance(ji.prg, CompiledASTRunner): continue
|
||||
self.ji_kargs_structs[j] = ji.prg.clprg.args_struct_t.from_address(kernargs_ptrs[ji.prg.device])
|
||||
kernargs_ptrs[ji.prg.device] += (ctypes.sizeof(ji.prg.clprg.args_struct_t) + 15) & ~15
|
||||
for i in range(len(ji.rawbufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.rawbufs[i])._buf)
|
||||
for i in range(len(ji.prg.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.vars[i]])
|
||||
|
||||
# Build queues.
|
||||
self.virt_aql_queues: Dict[Compiled, VirtAQLQueue] = {dev:VirtAQLQueue(dev, 2*len(self.jit_cache)+16) for dev in self.devices}
|
||||
self.packets = {}
|
||||
self.transfers = []
|
||||
self.signals_to_reset: List[hsa.hsa_signal_t] = []
|
||||
self.w_dependency_map: Dict[Any, Union[hsa.hsa_signal_t, hsa.hsa_agent_dispatch_packet_t]] = {}
|
||||
self.r_dependency_map: Dict[Any, Union[hsa.hsa_signal_t, hsa.hsa_agent_dispatch_packet_t]] = {}
|
||||
signals_to_devices: Dict[ctypes.c_uint64, List[HSADevice]] = {}
|
||||
|
||||
# Special packet to wait for the world.
|
||||
self.kickoff_signals: Dict[HSADevice, hsa.hsa_signal_t] = {}
|
||||
for dev in self.devices: self.kickoff_signals[dev] = self.virt_aql_queues[dev].submit_barrier(need_signal=True)
|
||||
self.signals_to_reset += list(self.kickoff_signals.values())
|
||||
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
if isinstance(ji.prg, CompiledASTRunner):
|
||||
self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.device].write_addr)
|
||||
wait_signals = self.access_resources(read=ji.rawbufs[1:], write=ji.rawbufs[0:1], new_dependency=self.packets[j], sync_with_aql_packets=False)
|
||||
for i in range(0, len(wait_signals), 5):
|
||||
self.virt_aql_queues[ji.prg.device].submit_barrier(wait_signals=wait_signals[i:i+5])
|
||||
self.virt_aql_queues[ji.prg.device].submit_kernel(ji.prg.clprg, *ji.prg.launch_dims(var_vals), ctypes.addressof(self.ji_kargs_structs[j])) #type:ignore
|
||||
elif isinstance(ji.prg, BufferXfer):
|
||||
dest, src = [cast(Buffer, x) for x in ji.rawbufs[0:2]]
|
||||
dest_dev, src_dev = cast(HSADevice, dest.d), cast(HSADevice, src.d)
|
||||
sync_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
|
||||
self.signals_to_reset.append(sync_signal)
|
||||
signals_to_devices[sync_signal.handle] = [dest_dev, src_dev]
|
||||
|
||||
wait_signals = self.access_resources(read=[src], write=[dest], new_dependency=sync_signal, sync_with_aql_packets=True)
|
||||
self.transfers.append((dest._buf, dest_dev.agent, src._buf, src_dev.agent, dest.nbytes, len(wait_signals),
|
||||
(hsa.hsa_signal_t*len(wait_signals))(*wait_signals), sync_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True))
|
||||
|
||||
# Make sure the src buffer can be by other devices.
|
||||
c_agents = (hsa.hsa_agent_t * len(HSADevice.devices))(*[dev.agent for dev in HSADevice.devices])
|
||||
check(hsa.hsa_amd_agents_allow_access(len(HSADevice.devices), c_agents, None, src._buf))
|
||||
|
||||
# Wait for all active signals to finish the graph
|
||||
wait_signals_to_finish: Dict[HSADevice, List[hsa.hsa_signal_t]] = collections.defaultdict(list)
|
||||
for v in dedup_signals([s for s in list(self.w_dependency_map.values())+list(self.r_dependency_map.values()) if isinstance(s, hsa.hsa_signal_t)]):
|
||||
for dev in signals_to_devices[v.handle]:
|
||||
wait_signals_to_finish[dev].append(v)
|
||||
|
||||
self.finish_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
|
||||
for dev in self.devices:
|
||||
wait_signals = wait_signals_to_finish[dev]
|
||||
for i in range(0, max(1, len(wait_signals)), 5):
|
||||
self.virt_aql_queues[dev].submit_barrier(wait_signals[i:i+5], need_signal=(i+5>=len(wait_signals)), completion_signal=self.finish_signal)
|
||||
|
||||
# Zero signals to allow graph to start and execute.
|
||||
for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 0)
|
||||
hsa.hsa_signal_silent_store_relaxed(self.finish_signal, 0)
|
||||
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
|
||||
# Wait and restore signals
|
||||
hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
||||
for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 1)
|
||||
hsa.hsa_signal_silent_store_relaxed(self.finish_signal, len(self.devices))
|
||||
|
||||
# Update rawbuffers
|
||||
for (j,i),input_idx in self.input_replace.items():
|
||||
self.ji_kargs_structs[j].__setattr__(f'f{i}', input_rawbuffers[input_idx]._buf)
|
||||
|
||||
# Update var_vals
|
||||
for j in self.jc_idxs_with_updatable_var_vals:
|
||||
for i,v in enumerate(cast(CompiledASTRunner, self.jit_cache[j].prg).vars):
|
||||
self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v])
|
||||
|
||||
# Update launch dims
|
||||
for j in self.jc_idxs_with_updatable_launch_dims:
|
||||
gl, lc = cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals)
|
||||
self.packets[j].workgroup_size_x = lc[0]
|
||||
self.packets[j].workgroup_size_y = lc[1]
|
||||
self.packets[j].workgroup_size_z = lc[2]
|
||||
self.packets[j].grid_size_x = gl[0] * lc[0]
|
||||
self.packets[j].grid_size_y = gl[1] * lc[1]
|
||||
self.packets[j].grid_size_z = gl[2] * lc[2]
|
||||
|
||||
for dev in self.devices:
|
||||
dev.hw_queue.blit_packets(self.virt_aql_queues[dev].queue_base, self.virt_aql_queues[dev].packets_count)
|
||||
|
||||
for transfer_data in self.transfers:
|
||||
check(hsa.hsa_amd_memory_async_copy_on_engine(*transfer_data))
|
||||
|
||||
et = None
|
||||
if wait:
|
||||
st = time.perf_counter()
|
||||
hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
||||
et = time.perf_counter() - st
|
||||
|
||||
update_stats(f"<batched {len(self.jit_cache)}>", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers),
|
||||
jit=jit, num_kernels=len(self.jit_cache), device="HSA")
|
||||
return et
|
||||
|
||||
def dependency_as_signal(self, dep, sync_with_aql_packets) -> Optional[hsa.hsa_signal_t]:
|
||||
if isinstance(dep, hsa.hsa_signal_t): return dep
|
||||
elif sync_with_aql_packets and isinstance(dep, hsa.hsa_kernel_dispatch_packet_t):
|
||||
if dep.completion_signal.handle == EMPTY_SIGNAL.handle:
|
||||
dep.completion_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
|
||||
self.signals_to_reset.append(dep.completion_signal)
|
||||
return dep.completion_signal
|
||||
return None
|
||||
|
||||
def access_resources(self, read, write, new_dependency=None, sync_with_aql_packets=False):
|
||||
wait_signals = []
|
||||
for rawbuf in read:
|
||||
wait_signals.append(self.dependency_as_signal(self.w_dependency_map.get(rawbuf._buf), sync_with_aql_packets=sync_with_aql_packets))
|
||||
if new_dependency: self.r_dependency_map[rawbuf._buf] = new_dependency
|
||||
for rawbuf in write:
|
||||
wait_signals.append(self.dependency_as_signal(self.w_dependency_map.get(rawbuf._buf), sync_with_aql_packets=sync_with_aql_packets))
|
||||
wait_signals.append(self.dependency_as_signal(self.r_dependency_map.get(rawbuf._buf), sync_with_aql_packets=sync_with_aql_packets))
|
||||
if new_dependency: self.w_dependency_map[rawbuf._buf] = new_dependency
|
||||
if sync_with_aql_packets: wait_signals += [self.kickoff_signals[rawbuf.d] for rawbuf in read+write]
|
||||
return dedup_signals(wait_signals)
|
||||
@@ -5,7 +5,7 @@ import tinygrad.runtime.autogen.hsa as hsa
|
||||
from tinygrad.helpers import DEBUG, init_c_var, from_mv, round_up, to_mv, init_c_struct_t
|
||||
from tinygrad.device import Compiled, LRUAllocator
|
||||
from tinygrad.runtime.ops_hip import HIPCompiler
|
||||
from tinygrad.runtime.driver.hsa import check, find_agent, find_memory_pool, HWQueue
|
||||
from tinygrad.runtime.driver.hsa import check, find_agent, find_memory_pool, AQLQueue
|
||||
|
||||
HSACompiler = HIPCompiler
|
||||
|
||||
@@ -67,7 +67,7 @@ class HSAAllocator(LRUAllocator):
|
||||
return buf.value
|
||||
|
||||
def _free(self, opaque:T):
|
||||
self.device.synchronize()
|
||||
HSADevice.synchronize_system()
|
||||
check(hsa.hsa_amd_memory_pool_free(opaque))
|
||||
|
||||
def copyin(self, dest:T, src: memoryview):
|
||||
@@ -126,7 +126,7 @@ class HSAAllocator(LRUAllocator):
|
||||
self.device.hw_queue.submit_barrier(wait_signals=wait_signals)
|
||||
|
||||
def copyout(self, dest:memoryview, src:T):
|
||||
self.device.synchronize()
|
||||
HSADevice.synchronize_system()
|
||||
copy_signal = self.device.alloc_signal(reusable=True)
|
||||
c_agents = (hsa.hsa_agent_t*2)(*[HSADevice.cpu_agent, self.device.agent])
|
||||
check(hsa.hsa_amd_memory_lock_to_pool(from_mv(dest), dest.nbytes, c_agents, 2, HSADevice.cpu_mempool, 0, ctypes.byref(addr:=ctypes.c_void_p())))
|
||||
@@ -158,7 +158,7 @@ class HSADevice(Compiled):
|
||||
self.agent = find_agent(hsa.HSA_DEVICE_TYPE_GPU, device_id=self.device_id)
|
||||
self.gpu_mempool = find_memory_pool(self.agent, segtyp=hsa.HSA_AMD_SEGMENT_GLOBAL, location=hsa.HSA_AMD_MEMORY_POOL_LOCATION_GPU)
|
||||
self.kernargs_pool = find_memory_pool(self.agent, segtyp=hsa.HSA_AMD_SEGMENT_GLOBAL, flags=hsa.HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_KERNARG_INIT)
|
||||
self.hw_queue = HWQueue(self)
|
||||
self.hw_queue = AQLQueue(self)
|
||||
HSADevice.devices.append(self)
|
||||
|
||||
check(hsa.hsa_agent_get_info(self.agent, hsa.HSA_AGENT_INFO_NAME, ctypes.byref(agent_name_buf := ctypes.create_string_buffer(256))))
|
||||
@@ -178,7 +178,8 @@ class HSADevice(Compiled):
|
||||
check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(signal := hsa.hsa_signal_t())))
|
||||
self.signal_pool.append(signal)
|
||||
|
||||
super().__init__(device, HSAAllocator(self), HSACompiler(self.arch), functools.partial(HSAProgram, self), None)
|
||||
from tinygrad.runtime.graph.hsa import HSAGraph
|
||||
super().__init__(device, HSAAllocator(self), HSACompiler(self.arch), functools.partial(HSAProgram, self), HSAGraph)
|
||||
|
||||
def synchronize(self):
|
||||
self.hw_queue.wait()
|
||||
@@ -192,6 +193,10 @@ class HSADevice(Compiled):
|
||||
|
||||
self.kernarg_next_addr = self.kernarg_start_addr
|
||||
|
||||
@staticmethod
|
||||
def synchronize_system():
|
||||
for d in HSADevice.devices: d.synchronize()
|
||||
|
||||
def alloc_signal(self, reusable=False):
|
||||
if len(self.signal_pool): signal = self.signal_pool.pop()
|
||||
else: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(signal := hsa.hsa_signal_t())))
|
||||
|
||||
Reference in New Issue
Block a user