From 08ef77c72160280d74f7aa005f03a41b6c9ef382 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Wed, 28 Feb 2024 20:40:53 +0300 Subject: [PATCH] hsa multigpu graph (#3403) * init hsa multigraph * better handling of accesses to buffers * revert sdma0 only when copies from fd --- tinygrad/device.py | 4 + tinygrad/features/jit.py | 6 +- tinygrad/runtime/driver/hsa.py | 31 ++++-- tinygrad/runtime/graph/hsa.py | 179 +++++++++++++++++++++++++++++++++ tinygrad/runtime/ops_hsa.py | 15 ++- 5 files changed, 221 insertions(+), 14 deletions(-) create mode 100644 tinygrad/runtime/graph/hsa.py diff --git a/tinygrad/device.py b/tinygrad/device.py index e92029a746..9ab3885986 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -223,6 +223,10 @@ class CompiledASTRunner(JITRunner): self.first_run = False return et +class MultiDeviceJITGraph(JITRunner): + def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]: + raise NotImplementedError("override this") + class Compiled: def __init__(self, device:str, allocator:Allocator, compiler:Optional[Compiler], runtime, graph=None): self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler, runtime, graph diff --git a/tinygrad/features/jit.py b/tinygrad/features/jit.py index e303ca6766..6dbc260f8f 100644 --- a/tinygrad/features/jit.py +++ b/tinygrad/features/jit.py @@ -4,7 +4,7 @@ import functools, itertools, operator from tinygrad.nn.state import get_parameters from tinygrad.dtype import DType from tinygrad.helpers import DEBUG, merge_dicts, getenv, all_int, Context, GRAPH, flatten, GraphException -from tinygrad.device import Compiled, JITRunner, CompiledASTRunner, Buffer +from tinygrad.device import Compiled, JITRunner, CompiledASTRunner, Buffer, BufferXfer, MultiDeviceJITGraph from tinygrad.tensor import Tensor from tinygrad.lazy import LazyBuffer from tinygrad.features.multi import MultiLazyBuffer @@ -55,9 +55,11 @@ def apply_graph_to_jit(jit_cache: List[JitItem], input_rawbuffers: List[Buffer], for ji in jit_cache: ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None. if isinstance(ji.prg, CompiledASTRunner): ji_graph_dev = ji.prg.device + elif isinstance(ji.prg, BufferXfer) and ji.rawbufs[0] and ji.rawbufs[0].d.dname.startswith("HSA"): ji_graph_dev = ji.rawbufs[0].d can_be_graphed = ji_graph_dev and ji_graph_dev.graph - can_extend_graph_batch = can_be_graphed and len(current_batch) < getenv("JIT_BATCH_SIZE", 64) and ji_graph_dev == current_device + can_extend_graph_batch = can_be_graphed and len(current_batch) < getenv("JIT_BATCH_SIZE", 64) and (ji_graph_dev == current_device or + (isinstance(ji_graph_dev.graph, type) and issubclass(ji_graph_dev.graph, MultiDeviceJITGraph) and type(ji_graph_dev) == type(current_device))) #type:ignore if not can_extend_graph_batch and len(current_batch) > 0: flush_batch() if can_be_graphed: current_batch.append(ji) diff --git a/tinygrad/runtime/driver/hsa.py b/tinygrad/runtime/driver/hsa.py index 4b3fe11e83..5aaeaf483f 100644 --- a/tinygrad/runtime/driver/hsa.py +++ b/tinygrad/runtime/driver/hsa.py @@ -22,7 +22,7 @@ BARRIER_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCACQUIRE_ BARRIER_HEADER |= hsa.HSA_FENCE_SCOPE_SYSTEM << hsa.HSA_PACKET_HEADER_SCRELEASE_FENCE_SCOPE BARRIER_HEADER |= hsa.HSA_PACKET_TYPE_BARRIER_AND << hsa.HSA_PACKET_HEADER_TYPE -class HWQueue: +class AQLQueue: def __init__(self, device, sz=-1): self.device = device self.wait_signals = [] @@ -48,7 +48,7 @@ class HWQueue: def submit_kernel(self, prg, global_size, local_size, kernargs, need_signal=False): if self.available_packet_slots == 0: self._wait_queue() - signal = self.device.alloc_signal(reusable=True) if need_signal else EMPTY_SIGNAL + signal = self._alloc_signal(reusable=True) if need_signal else EMPTY_SIGNAL packet = hsa.hsa_kernel_dispatch_packet_t.from_address(self.write_addr) packet.workgroup_size_x = local_size[0] @@ -70,10 +70,10 @@ class HWQueue: return signal - def submit_barrier(self, wait_signals=None, need_signal=False): - assert wait_signals is None or len(wait_signals) < 5 + def submit_barrier(self, wait_signals=None, need_signal=False, completion_signal=None): + assert wait_signals is None or len(wait_signals) <= 5 if self.available_packet_slots == 0: self._wait_queue() - signal = self.device.alloc_signal(reusable=True) if need_signal else EMPTY_SIGNAL + signal = (completion_signal or self._alloc_signal(reusable=True)) if need_signal else EMPTY_SIGNAL packet = hsa.hsa_barrier_and_packet_t.from_address(self.write_addr) packet.reserved0 = 0 @@ -87,13 +87,29 @@ class HWQueue: return signal + def blit_packets(self, packet_addr, packet_cnt): + if self.available_packet_slots < packet_cnt: self._wait_queue(packet_cnt) + + tail_blit_packets = min(((self.write_addr_end + 1) - self.write_addr) // 64, packet_cnt) + rem_packet_cnt = packet_cnt - tail_blit_packets + ctypes.memmove(self.write_addr, packet_addr, AQL_PACKET_SIZE * tail_blit_packets) + self.write_addr += AQL_PACKET_SIZE * tail_blit_packets + if self.write_addr > self.write_addr_end: self.write_addr = self.hw_queue.contents.base_address + if tail_blit_packets > 0: + ctypes.memmove(self.write_addr, packet_addr + AQL_PACKET_SIZE * tail_blit_packets, AQL_PACKET_SIZE * rem_packet_cnt) + self.write_addr += AQL_PACKET_SIZE * rem_packet_cnt + + self.next_doorbell_index += packet_cnt + hsa.hsa_queue_store_write_index_screlease(self.hw_queue, self.next_doorbell_index + 1) + hsa.hsa_signal_store_screlease(self.hw_queue.contents.doorbell_signal, self.next_doorbell_index) + def wait(self): signal = self.submit_barrier(need_signal=True) hsa.hsa_signal_wait_scacquire(signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE) self.available_packet_slots = self.queue_size - def _wait_queue(self): - while self.available_packet_slots == 0: + def _wait_queue(self, need_packets=1): + while self.available_packet_slots < need_packets: rindex = hsa.hsa_queue_load_read_index_relaxed(self.hw_queue) self.available_packet_slots = self.queue_size - (self.next_doorbell_index - rindex) @@ -106,6 +122,7 @@ class HWQueue: self.next_doorbell_index += 1 self.available_packet_slots -= 1 + def _alloc_signal(self, reusable=False): return self.device.alloc_signal(reusable=reusable) def find_agent(typ, device_id): @ctypes.CFUNCTYPE(hsa.hsa_status_t, hsa.hsa_agent_t, ctypes.c_void_p) diff --git a/tinygrad/runtime/graph/hsa.py b/tinygrad/runtime/graph/hsa.py new file mode 100644 index 0000000000..1924efadd5 --- /dev/null +++ b/tinygrad/runtime/graph/hsa.py @@ -0,0 +1,179 @@ +import ctypes, collections, time +from typing import List, Any, Dict, cast, Optional, Union +from tinygrad.helpers import GraphException, init_c_var +from tinygrad.device import Compiled, Buffer, CompiledASTRunner, BufferXfer, MultiDeviceJITGraph, update_stats +from tinygrad.shape.symbolic import Variable +from tinygrad.runtime.ops_hsa import HSADevice +from tinygrad.features.jit import JitItem, get_input_replace, get_jit_stats, \ + get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals +import tinygrad.runtime.autogen.hsa as hsa +from tinygrad.runtime.driver.hsa import check, AQLQueue, AQL_PACKET_SIZE, EMPTY_SIGNAL + +def dedup_signals(signals): return [hsa.hsa_signal_t(hndl) for hndl in set([x.handle for x in signals if isinstance(x, hsa.hsa_signal_t)])] + +class VirtAQLQueue(AQLQueue): + def __init__(self, device, sz): + self.device = device + self.virt_queue = (hsa.hsa_kernel_dispatch_packet_t * sz)() + self.queue_base = self.write_addr = ctypes.addressof(self.virt_queue) + self.packets_count = 0 + self.available_packet_slots = sz + def _wait_queue(self, need_packets=1): assert False, f"VirtQueue is too small to handle {self.packets_count+need_packets} packets!" + def _submit_packet(self): + self.write_addr += AQL_PACKET_SIZE + self.packets_count += 1 + self.available_packet_slots -= 1 + def _alloc_signal(self, reusable=False): return init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_signal_create(1, 0, None, ctypes.byref(x)))) + +class HSAGraph(MultiDeviceJITGraph): + def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]): + self.jit_cache = jit_cache + self.input_replace = get_input_replace(jit_cache, input_rawbuffers) + self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache) #type:ignore + self.jc_idxs_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache) + self.jc_idxs_with_updatable_var_vals = get_jc_idxs_with_updatable_var_vals(jit_cache) + + # Check all jit items are compatible. + compiled_devices = set() + for ji in self.jit_cache: + if isinstance(ji.prg, CompiledASTRunner): compiled_devices.add(ji.prg.device) + elif isinstance(ji.prg, BufferXfer): + for x in ji.rawbufs[0:2]: compiled_devices.add(cast(Buffer, x).d) + else: raise GraphException + if any(not isinstance(d, HSADevice) for d in compiled_devices): raise GraphException + + self.devices: List[HSADevice] = list(compiled_devices) #type:ignore + + # Allocate kernel args. + kernargs_size: Dict[HSADevice, int] = collections.defaultdict(int) + for ji in self.jit_cache: + if isinstance(ji.prg, CompiledASTRunner): kernargs_size[cast(HSADevice, ji.prg.device)] += (ctypes.sizeof(ji.prg.clprg.args_struct_t)+15) & ~15 + + kernargs_ptrs: Dict[Compiled, int] = {} + for dev,sz in kernargs_size.items(): + kernargs_ptrs[dev] = init_c_var(ctypes.c_void_p(), + lambda x: check(hsa.hsa_amd_memory_pool_allocate(dev.kernargs_pool, sz, 0, ctypes.byref(x)))).value + check(hsa.hsa_amd_agents_allow_access(1, ctypes.byref(dev.agent), None, kernargs_ptrs[dev])) + + # Fill initial arguments. + self.ji_kargs_structs: Dict[int, ctypes.Structure] = {} + for j,ji in enumerate(self.jit_cache): + if not isinstance(ji.prg, CompiledASTRunner): continue + self.ji_kargs_structs[j] = ji.prg.clprg.args_struct_t.from_address(kernargs_ptrs[ji.prg.device]) + kernargs_ptrs[ji.prg.device] += (ctypes.sizeof(ji.prg.clprg.args_struct_t) + 15) & ~15 + for i in range(len(ji.rawbufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.rawbufs[i])._buf) + for i in range(len(ji.prg.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.vars[i]]) + + # Build queues. + self.virt_aql_queues: Dict[Compiled, VirtAQLQueue] = {dev:VirtAQLQueue(dev, 2*len(self.jit_cache)+16) for dev in self.devices} + self.packets = {} + self.transfers = [] + self.signals_to_reset: List[hsa.hsa_signal_t] = [] + self.w_dependency_map: Dict[Any, Union[hsa.hsa_signal_t, hsa.hsa_agent_dispatch_packet_t]] = {} + self.r_dependency_map: Dict[Any, Union[hsa.hsa_signal_t, hsa.hsa_agent_dispatch_packet_t]] = {} + signals_to_devices: Dict[ctypes.c_uint64, List[HSADevice]] = {} + + # Special packet to wait for the world. + self.kickoff_signals: Dict[HSADevice, hsa.hsa_signal_t] = {} + for dev in self.devices: self.kickoff_signals[dev] = self.virt_aql_queues[dev].submit_barrier(need_signal=True) + self.signals_to_reset += list(self.kickoff_signals.values()) + + for j,ji in enumerate(self.jit_cache): + if isinstance(ji.prg, CompiledASTRunner): + self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.device].write_addr) + wait_signals = self.access_resources(read=ji.rawbufs[1:], write=ji.rawbufs[0:1], new_dependency=self.packets[j], sync_with_aql_packets=False) + for i in range(0, len(wait_signals), 5): + self.virt_aql_queues[ji.prg.device].submit_barrier(wait_signals=wait_signals[i:i+5]) + self.virt_aql_queues[ji.prg.device].submit_kernel(ji.prg.clprg, *ji.prg.launch_dims(var_vals), ctypes.addressof(self.ji_kargs_structs[j])) #type:ignore + elif isinstance(ji.prg, BufferXfer): + dest, src = [cast(Buffer, x) for x in ji.rawbufs[0:2]] + dest_dev, src_dev = cast(HSADevice, dest.d), cast(HSADevice, src.d) + sync_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x)))) + self.signals_to_reset.append(sync_signal) + signals_to_devices[sync_signal.handle] = [dest_dev, src_dev] + + wait_signals = self.access_resources(read=[src], write=[dest], new_dependency=sync_signal, sync_with_aql_packets=True) + self.transfers.append((dest._buf, dest_dev.agent, src._buf, src_dev.agent, dest.nbytes, len(wait_signals), + (hsa.hsa_signal_t*len(wait_signals))(*wait_signals), sync_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True)) + + # Make sure the src buffer can be by other devices. + c_agents = (hsa.hsa_agent_t * len(HSADevice.devices))(*[dev.agent for dev in HSADevice.devices]) + check(hsa.hsa_amd_agents_allow_access(len(HSADevice.devices), c_agents, None, src._buf)) + + # Wait for all active signals to finish the graph + wait_signals_to_finish: Dict[HSADevice, List[hsa.hsa_signal_t]] = collections.defaultdict(list) + for v in dedup_signals([s for s in list(self.w_dependency_map.values())+list(self.r_dependency_map.values()) if isinstance(s, hsa.hsa_signal_t)]): + for dev in signals_to_devices[v.handle]: + wait_signals_to_finish[dev].append(v) + + self.finish_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x)))) + for dev in self.devices: + wait_signals = wait_signals_to_finish[dev] + for i in range(0, max(1, len(wait_signals)), 5): + self.virt_aql_queues[dev].submit_barrier(wait_signals[i:i+5], need_signal=(i+5>=len(wait_signals)), completion_signal=self.finish_signal) + + # Zero signals to allow graph to start and execute. + for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 0) + hsa.hsa_signal_silent_store_relaxed(self.finish_signal, 0) + + def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]: + # Wait and restore signals + hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE) + for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 1) + hsa.hsa_signal_silent_store_relaxed(self.finish_signal, len(self.devices)) + + # Update rawbuffers + for (j,i),input_idx in self.input_replace.items(): + self.ji_kargs_structs[j].__setattr__(f'f{i}', input_rawbuffers[input_idx]._buf) + + # Update var_vals + for j in self.jc_idxs_with_updatable_var_vals: + for i,v in enumerate(cast(CompiledASTRunner, self.jit_cache[j].prg).vars): + self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v]) + + # Update launch dims + for j in self.jc_idxs_with_updatable_launch_dims: + gl, lc = cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals) + self.packets[j].workgroup_size_x = lc[0] + self.packets[j].workgroup_size_y = lc[1] + self.packets[j].workgroup_size_z = lc[2] + self.packets[j].grid_size_x = gl[0] * lc[0] + self.packets[j].grid_size_y = gl[1] * lc[1] + self.packets[j].grid_size_z = gl[2] * lc[2] + + for dev in self.devices: + dev.hw_queue.blit_packets(self.virt_aql_queues[dev].queue_base, self.virt_aql_queues[dev].packets_count) + + for transfer_data in self.transfers: + check(hsa.hsa_amd_memory_async_copy_on_engine(*transfer_data)) + + et = None + if wait: + st = time.perf_counter() + hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE) + et = time.perf_counter() - st + + update_stats(f"", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), + jit=jit, num_kernels=len(self.jit_cache), device="HSA") + return et + + def dependency_as_signal(self, dep, sync_with_aql_packets) -> Optional[hsa.hsa_signal_t]: + if isinstance(dep, hsa.hsa_signal_t): return dep + elif sync_with_aql_packets and isinstance(dep, hsa.hsa_kernel_dispatch_packet_t): + if dep.completion_signal.handle == EMPTY_SIGNAL.handle: + dep.completion_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x)))) + self.signals_to_reset.append(dep.completion_signal) + return dep.completion_signal + return None + + def access_resources(self, read, write, new_dependency=None, sync_with_aql_packets=False): + wait_signals = [] + for rawbuf in read: + wait_signals.append(self.dependency_as_signal(self.w_dependency_map.get(rawbuf._buf), sync_with_aql_packets=sync_with_aql_packets)) + if new_dependency: self.r_dependency_map[rawbuf._buf] = new_dependency + for rawbuf in write: + wait_signals.append(self.dependency_as_signal(self.w_dependency_map.get(rawbuf._buf), sync_with_aql_packets=sync_with_aql_packets)) + wait_signals.append(self.dependency_as_signal(self.r_dependency_map.get(rawbuf._buf), sync_with_aql_packets=sync_with_aql_packets)) + if new_dependency: self.w_dependency_map[rawbuf._buf] = new_dependency + if sync_with_aql_packets: wait_signals += [self.kickoff_signals[rawbuf.d] for rawbuf in read+write] + return dedup_signals(wait_signals) diff --git a/tinygrad/runtime/ops_hsa.py b/tinygrad/runtime/ops_hsa.py index 95adc10d98..d5e17bd56c 100644 --- a/tinygrad/runtime/ops_hsa.py +++ b/tinygrad/runtime/ops_hsa.py @@ -5,7 +5,7 @@ import tinygrad.runtime.autogen.hsa as hsa from tinygrad.helpers import DEBUG, init_c_var, from_mv, round_up, to_mv, init_c_struct_t from tinygrad.device import Compiled, LRUAllocator from tinygrad.runtime.ops_hip import HIPCompiler -from tinygrad.runtime.driver.hsa import check, find_agent, find_memory_pool, HWQueue +from tinygrad.runtime.driver.hsa import check, find_agent, find_memory_pool, AQLQueue HSACompiler = HIPCompiler @@ -67,7 +67,7 @@ class HSAAllocator(LRUAllocator): return buf.value def _free(self, opaque:T): - self.device.synchronize() + HSADevice.synchronize_system() check(hsa.hsa_amd_memory_pool_free(opaque)) def copyin(self, dest:T, src: memoryview): @@ -126,7 +126,7 @@ class HSAAllocator(LRUAllocator): self.device.hw_queue.submit_barrier(wait_signals=wait_signals) def copyout(self, dest:memoryview, src:T): - self.device.synchronize() + HSADevice.synchronize_system() copy_signal = self.device.alloc_signal(reusable=True) c_agents = (hsa.hsa_agent_t*2)(*[HSADevice.cpu_agent, self.device.agent]) check(hsa.hsa_amd_memory_lock_to_pool(from_mv(dest), dest.nbytes, c_agents, 2, HSADevice.cpu_mempool, 0, ctypes.byref(addr:=ctypes.c_void_p()))) @@ -158,7 +158,7 @@ class HSADevice(Compiled): self.agent = find_agent(hsa.HSA_DEVICE_TYPE_GPU, device_id=self.device_id) self.gpu_mempool = find_memory_pool(self.agent, segtyp=hsa.HSA_AMD_SEGMENT_GLOBAL, location=hsa.HSA_AMD_MEMORY_POOL_LOCATION_GPU) self.kernargs_pool = find_memory_pool(self.agent, segtyp=hsa.HSA_AMD_SEGMENT_GLOBAL, flags=hsa.HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_KERNARG_INIT) - self.hw_queue = HWQueue(self) + self.hw_queue = AQLQueue(self) HSADevice.devices.append(self) check(hsa.hsa_agent_get_info(self.agent, hsa.HSA_AGENT_INFO_NAME, ctypes.byref(agent_name_buf := ctypes.create_string_buffer(256)))) @@ -178,7 +178,8 @@ class HSADevice(Compiled): check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(signal := hsa.hsa_signal_t()))) self.signal_pool.append(signal) - super().__init__(device, HSAAllocator(self), HSACompiler(self.arch), functools.partial(HSAProgram, self), None) + from tinygrad.runtime.graph.hsa import HSAGraph + super().__init__(device, HSAAllocator(self), HSACompiler(self.arch), functools.partial(HSAProgram, self), HSAGraph) def synchronize(self): self.hw_queue.wait() @@ -192,6 +193,10 @@ class HSADevice(Compiled): self.kernarg_next_addr = self.kernarg_start_addr + @staticmethod + def synchronize_system(): + for d in HSADevice.devices: d.synchronize() + def alloc_signal(self, reusable=False): if len(self.signal_pool): signal = self.signal_pool.pop() else: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(signal := hsa.hsa_signal_t())))