From 6fc7013463fb343fb356cee103eb5263c4572229 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 22 Nov 2024 11:22:59 +0800 Subject: [PATCH] put all DSP in dsp file [pr] (#7833) --- extra/backends/ops_hip.py | 197 ------------------------------------ extra/backends/ops_rhip.py | 18 ---- tinygrad/renderer/cstyle.py | 30 ------ tinygrad/runtime/ops_dsp.py | 36 ++++++- 4 files changed, 34 insertions(+), 247 deletions(-) delete mode 100644 extra/backends/ops_hip.py delete mode 100644 extra/backends/ops_rhip.py diff --git a/extra/backends/ops_hip.py b/extra/backends/ops_hip.py deleted file mode 100644 index d0c4851027..0000000000 --- a/extra/backends/ops_hip.py +++ /dev/null @@ -1,197 +0,0 @@ -from __future__ import annotations -import ctypes, functools, subprocess, io -from typing import Tuple, TypeVar, List, Any, cast, Set -import tinygrad.runtime.autogen.hip as hip -from tinygrad.helpers import DEBUG, getenv, init_c_var -from tinygrad.helpers import from_mv, round_up, to_mv, colored, init_c_struct_t -from tinygrad.device import Compiled, LRUAllocator, BufferSpec, Runner, Device, Buffer, MallocAllocator, update_stats, Compiler, CompilerOptions -from tinygrad.renderer.cstyle import HIPRenderer -from tinygrad.runtime.support.hip_comgr import compile_hip -from tinygrad.renderer.rdna import uops_to_rdna - -class RDNACompiler(Compiler): - linearizer_opts = LinearizerOptions("HIP", has_tensor_cores=True) - def __init__(self, arch:str): - self.arch = arch - super().__init__(f"compile_rdna_{self.arch}") - def render(self, name:str, uops) -> str: return uops_to_rdna(name, uops) - def compile(self, src:str) -> bytes: - ret = compile_hip(src, self.arch, True) - #with open("/tmp/out.so", "wb") as f: f.write(ret) - return ret - -class HIPCompiler(Compiler): - compiler_opts = CompilerOptions("HIP", has_tensor_cores=True, shared_max=65536) - def __init__(self, arch:str): - self.arch = arch - super().__init__(f"compile_hip_{self.arch}") - def render(self, name:str, uops) -> str: return HIPRenderer(name, uops) - def compile(self, src:str) -> bytes: return compile_hip(src, self.arch) - -hip_current_device = None -def hip_set_device(d:int): - global hip_current_device - if d == hip_current_device: return - check(hip.hipSetDevice(d)) - hip_current_device = d - -def check(status): - if status != 0: raise RuntimeError(f"HIP Error {status}, {ctypes.string_at(hip.hipGetErrorString(status)).decode()}") - -class HIPProgram: - def __init__(self, device:int, name:str, lib:bytes): - self.device, self.name, self.lib = device, name, lib - - if DEBUG >= 6: - asm = subprocess.check_output(["/opt/rocm/llvm/bin/llvm-objdump", '-d', '-'], input=lib) - print('\n'.join([x for x in asm.decode('utf-8').split("\n") if 's_code_end' not in x])) - - hip_set_device(self.device) - self.module = init_c_var(hip.hipModule_t(), lambda x: check(hip.hipModuleLoadData(ctypes.byref(x), lib))) - self.prg = init_c_var(hip.hipFunction_t(), lambda x: check(hip.hipModuleGetFunction(ctypes.byref(x), self.module, name.encode("utf-8")))) - - def __del__(self): - if hasattr(self, 'module'): check(hip.hipModuleUnload(self.module)) - - def __call__(self, *args, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False): - hip_set_device(self.device) - if not hasattr(self, "vargs"): - self.c_args = init_c_struct_t(tuple([(f'f{i}', hip.hipDeviceptr_t) for i in range(len(args))] + - [(f'v{i}', ctypes.c_int) for i in range(len(vals))]))(*args, *vals) - self.vargs = (ctypes.c_void_p * 5)(ctypes.c_void_p(1), ctypes.cast(ctypes.byref(self.c_args), ctypes.c_void_p), - ctypes.c_void_p(2), ctypes.cast(ctypes.byref(ctypes.c_size_t(ctypes.sizeof(self.c_args))), ctypes.c_void_p), - ctypes.c_void_p(3)) - else: - for i in range(len(args)): self.c_args.__setattr__(f'f{i}', args[i]) - for i in range(len(vals)): self.c_args.__setattr__(f'v{i}', vals[i]) - if wait: - evs = [init_c_var(hip.hipEvent_t(), lambda x: hip.hipEventCreate(ctypes.byref(x), 0)) for _ in range(2)] - check(hip.hipEventRecord(evs[0], None)) - check(hip.hipModuleLaunchKernel(self.prg, *global_size, *local_size, 0, None, None, self.vargs)) - if wait: - check(hip.hipEventRecord(evs[1], None)) - check(hip.hipEventSynchronize(evs[1])) - check(hip.hipEventElapsedTime(ctypes.byref(ret := ctypes.c_float()), evs[0], evs[1])) - for ev in evs: check(hip.hipEventDestroy(ev)) - return ret.value * 1e-3 - return None - -T = TypeVar("T") -CHUNK_SIZE, PAGE_SIZE = 256*1024*1024, 0x1000 -class HIPAllocator(LRUAllocator): - def __init__(self, device:HIPDevice): - self.device = device - self.track_cross_device: Set[HIPDevice] = set() - super().__init__() - def full_synchronize(self): - self.device.synchronize() - for x in self.track_cross_device: x.synchronize() - self.track_cross_device.clear() - def free_cache(self): - self.full_synchronize() - return super().free_cache() - def _alloc(self, size:int): - hip_set_device(self.device.device) - return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipMalloc(ctypes.byref(x), size))) - def _alloc_with_options(self, size:int, options:BufferSpec): - hip_set_device(self.device.device) - if options.uncached: - return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipExtMallocWithFlags(ctypes.byref(x), size, 3))) # hipDeviceMallocUncached = 3 - elif options.host: - return init_c_var(hip.hipDeviceptr_t(), lambda x: check(hip.hipHostMalloc(ctypes.byref(x), size, 2 if options.signal else 0))) - else: - raise Exception("no options") - def _free(self, opaque:T): check(hip.hipFree(opaque)) - def copy_from_fd(self, dest, fd, offset, size): - hip_set_device(self.device.device) - if not hasattr(self, 'hb'): - self.hb = [self._alloc_with_options(CHUNK_SIZE, BufferSpec(host=True)) for _ in range(2)] - self.hb_events = [None, None] - self.hb_polarity = 0 - fo = io.FileIO(fd, "a+b", closefd=False) - fo.seek(offset - (minor_offset:=offset % PAGE_SIZE)) - copied_in = 0 - for local_offset in range(0, size+minor_offset, CHUNK_SIZE): - local_size = min(round_up(size+minor_offset, PAGE_SIZE)-local_offset, CHUNK_SIZE) - if self.hb_events[self.hb_polarity] is not None: - # NOTE: block doesn't work here because we modify the CPU memory - check(hip.hipEventSynchronize(self.hb_events[self.hb_polarity])) - check(hip.hipEventDestroy(self.hb_events[self.hb_polarity])) - self.hb_events[self.hb_polarity] = None - fo.readinto(to_mv(self.hb[self.hb_polarity], local_size)) - check(hip.hipMemcpyAsync(ctypes.c_void_p(dest.value + copied_in), ctypes.c_void_p(self.hb[self.hb_polarity].value + minor_offset), - copy_size:=min(local_size-minor_offset, size-copied_in), hip.hipMemcpyHostToDevice, None)) - self.hb_events[self.hb_polarity] = init_c_var(hip.hipEvent_t(), lambda x: check(hip.hipEventCreate(ctypes.byref(x)))) - check(hip.hipEventRecord(self.hb_events[self.hb_polarity], None)) - copied_in += copy_size - self.hb_polarity = (self.hb_polarity+1) % len(self.hb) - minor_offset = 0 # only on the first - def _copyin(self, dest:T, src: memoryview): - hip_set_device(self.device.device) - host_mem = self._alloc_with_options(len(src), BufferSpec(host=True)) - self.device.pending_copyin.append(host_mem) - ctypes.memmove(host_mem, from_mv(src), len(src)) - check(hip.hipMemcpyAsync(dest, host_mem, len(src), hip.hipMemcpyHostToDevice, None)) - def _copyout(self, dest:memoryview, src:T): - self.full_synchronize() - hip_set_device(self.device.device) - check(hip.hipMemcpy(from_mv(dest), src, len(dest), hip.hipMemcpyDeviceToHost)) - def transfer(self, dest:T, src:T, sz:int, **kwargs): - hip_set_device(self.device.device) - check(hip.hipMemcpyAsync(dest, src, sz, hip.hipMemcpyDeviceToDevice, None)) - -class HIPSyncEvent(Runner): - def __init__(self, lb): - self.lb, self.device, self.device = lb, cast(HIPDevice, Device[lb.device]), lb.device - super().__init__() - def __call__(self, rawbufs:List[Buffer], var_vals, wait=False, jit=False): - to_mv(rawbufs[0]._buf, 4).cast("I")[0] = 0 - hip_set_device(self.device.device) - check(hip.hipStreamWriteValue32(None, rawbufs[0]._buf, 1, 0)) - update_stats(colored("sync", "red"), 0, 0, {}, None, 1, jit, device=self.device) - -class HIPWaitEvent(Runner): - def __init__(self, device): - self.device, self.device = cast(HIPDevice, Device[device]), device - super().__init__() - def __call__(self, rawbufs:List[Buffer], var_vals, wait=False, jit=False): - hip_set_device(self.device.device) - check(hip.hipStreamWaitValue32(None, rawbufs[0]._buf, 1, 1, 0xFFFFFFFF)) - update_stats(colored("wait", "RED"), 0, 0, {}, None, 1, jit, device=self.device) - -if getenv("HIPCPU"): - rhip = ctypes.CDLL("/usr/local/lib/libremu.so") - class RHIPProgram: - def __init__(self, name:str, lib:bytes): - self.name, self.lib = name, lib - def __call__(self, *args, global_size, local_size, vals=(), wait=False): - args = (*args, *vals) - rhip.hipModuleLaunchKernel(self.lib, len(self.lib), *global_size, *local_size, 0, None, None, - len(args), (ctypes.c_void_p * len(args))(*[ctypes.cast(x, ctypes.c_void_p) for x in args])) - -class HIPDevice(Compiled): - def __init__(self, device:str=""): - self.device = int(device.split(":")[1]) if ":" in device else 0 - self.pending_copyin: List[ctypes.c_void_p] = [] - self.track_cross_buffer: List[Any] = [] - self.peers: Set[int] = set() - - if getenv("HIPCPU"): - super().__init__(device, MallocAllocator, HIPCompiler("gfx1100"), RHIPProgram) - else: - self.arch = init_c_var(hip.hipDeviceProp_t(), lambda x: check(hip.hipGetDeviceProperties(x, self.device))).gcnArchName.decode() - from tinygrad.runtime.graph.hip import HIPGraph - super().__init__(device, HIPAllocator(self), RDNACompiler(self.arch) if getenv("RDNA") else HIPCompiler(self.arch), - functools.partial(HIPProgram, self.device), HIPGraph) - def synchronize(self): - if getenv("HIPCPU"): return - hip_set_device(self.device) - check(hip.hipDeviceSynchronize()) - for opaque in self.pending_copyin: check(hip.hipFree(opaque)) - self.track_cross_buffer.clear() - self.pending_copyin.clear() - def enable_peer(self, dnum): - if self.device == dnum or dnum in self.peers: return - hip_set_device(self.device) - check(hip.hipDeviceEnablePeerAccess(dnum, 0)) - self.peers.add(dnum) diff --git a/extra/backends/ops_rhip.py b/extra/backends/ops_rhip.py deleted file mode 100644 index d127c98446..0000000000 --- a/extra/backends/ops_rhip.py +++ /dev/null @@ -1,18 +0,0 @@ -import ctypes -from tinygrad.device import Compiled, MallocAllocator -from tinygrad.renderer.cstyle import HIPRenderer -from tinygrad.runtime.ops_hsa import HSACompiler - -rhip = ctypes.CDLL("/usr/local/lib/libremu.so") -class RHIPProgram: - def __init__(self, name:str, lib:bytes): - self.name, self.lib = name, lib - def __call__(self, *args, global_size, local_size, vals=(), wait=False): - args = (*args, *vals) - rhip.hipModuleLaunchKernel(self.lib, len(self.lib), *global_size, *local_size, 0, None, None, - len(args), (ctypes.c_void_p * len(args))(*[ctypes.cast(x, ctypes.c_void_p) for x in args])) - -class RHIPDevice(Compiled): - def __init__(self, device:str=""): - self.device = int(device.split(":")[1]) if ":" in device else 0 - super().__init__(device, MallocAllocator, HIPRenderer(), HSACompiler("gfx1100"), RHIPProgram) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 3510f11446..007190a6b4 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -429,36 +429,6 @@ class AMDRenderer(CStyleLanguage): # NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))" -class DSPRenderer(ClangRenderer): - device = "DSP" - supports_float4 = False - buffer_suffix = " restrict __attribute__((align_value(128)))" - kernel_prefix = "__attribute__((noinline)) " - type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" } - code_for_op = {**ClangRenderer.code_for_op, Ops.SIN: lambda x,dtype: f"__builtin_sin({x})", - Ops.LOG2: lambda x,dtype: f"__builtin_log2l({x})" if dtype == dtypes.float64 else f"__builtin_log2f({x})", - Ops.EXP2: lambda x,dtype: f"__builtin_exp2l({x})" if dtype == dtypes.float64 else f"__builtin_exp2f({x})"} - - def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str: - ret = super().render_kernel(function_name, kernel, bufs, uops, prefix) - msrc = ['''struct dcvs_v2_req { int type; int _pad; _Bool dcvs_enable; char dcvs_option; _Bool set_latency; int latency; _Bool set_dcvs_params; - short _pad2; char target_corner; char min_corner; char max_corner; int _pad3[3]; };''', 'int HAP_power_set(void*, void*);', - 'typedef union { struct { void *pv; unsigned int len; } buf; struct { int fd; unsigned int offset; } dma; } remote_arg;', - 'void* HAP_mmap(void *addr, int len, int prot, int flags, int fd, long offset);', 'int HAP_munmap(void *addr, int len);', - 'unsigned long long HAP_perf_get_time_us(void);', 'int entry(unsigned long long handle, unsigned int sc, remote_arg* pra) {', - 'struct dcvs_v2_req req = {.type=7, .dcvs_enable=0, .set_latency=1, .latency=100, .set_dcvs_params=1, .target_corner = 6 /* TURBO */};', - 'HAP_power_set((void*)handle, (void*)&req);'] - msrc += ['if ((sc>>24) != 2) return 0;'] - msrc += [f'int sz_or_val_{i} = ((int*)pra[0].buf.pv)[{i}];' for i,b in enumerate(bufs)] - msrc += [f'int off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)] - msrc += [f'void *buf_{i} = HAP_mmap(0,sz_or_val_{i},3,0,pra[{i+3}].dma.fd,0)+off{i};' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)] - msrc += ["unsigned long long start = HAP_perf_get_time_us();"] - msrc += [f"{function_name}({', '.join([(f'buf_{i}' if isinstance(b[1][0], PtrDType) else f'sz_or_val_{i}') for i,b in enumerate(bufs)])});"] - msrc += ["*(unsigned long long *)(pra[2].buf.pv) = HAP_perf_get_time_us() - start;"] - msrc += [f'HAP_munmap(buf_{i}, sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)] - msrc += ["return 0; }"] - return ret + '\n' + '\n'.join(msrc) - class NVRenderer(CUDARenderer): device = "NV" class HIPRenderer(AMDRenderer): device = "HIP" class QCOMRenderer(OpenCLRenderer): device = "QCOM" diff --git a/tinygrad/runtime/ops_dsp.py b/tinygrad/runtime/ops_dsp.py index b56bdb3660..7cac17c6c1 100644 --- a/tinygrad/runtime/ops_dsp.py +++ b/tinygrad/runtime/ops_dsp.py @@ -1,14 +1,46 @@ from __future__ import annotations -from typing import Tuple, Any +from typing import Tuple, Any, List import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, contextlib, sys assert sys.platform != 'win32' from tinygrad.device import BufferSpec, Compiled, Allocator +from tinygrad.dtype import dtypes, DType, PtrDType +from tinygrad.ops import Ops, UOp from tinygrad.helpers import from_mv, getenv, round_up, mv_address, to_mv from tinygrad.runtime.ops_clang import ClangCompiler -from tinygrad.renderer.cstyle import DSPRenderer +from tinygrad.renderer.cstyle import ClangRenderer from tinygrad.runtime.autogen import libc, qcom_dsp if getenv("IOCTL"): import extra.dsp.run # noqa: F401 # pylint: disable=unused-import +class DSPRenderer(ClangRenderer): + device = "DSP" + supports_float4 = False + buffer_suffix = " restrict __attribute__((align_value(128)))" + kernel_prefix = "__attribute__((noinline)) " + type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" } + code_for_op = {**ClangRenderer.code_for_op, Ops.SIN: lambda x,dtype: f"__builtin_sin({x})", + Ops.LOG2: lambda x,dtype: f"__builtin_log2l({x})" if dtype == dtypes.float64 else f"__builtin_log2f({x})", + Ops.EXP2: lambda x,dtype: f"__builtin_exp2l({x})" if dtype == dtypes.float64 else f"__builtin_exp2f({x})"} + + def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str: + ret = super().render_kernel(function_name, kernel, bufs, uops, prefix) + msrc = ['''struct dcvs_v2_req { int type; int _pad; _Bool dcvs_enable; char dcvs_option; _Bool set_latency; int latency; _Bool set_dcvs_params; + short _pad2; char target_corner; char min_corner; char max_corner; int _pad3[3]; };''', 'int HAP_power_set(void*, void*);', + 'typedef union { struct { void *pv; unsigned int len; } buf; struct { int fd; unsigned int offset; } dma; } remote_arg;', + 'void* HAP_mmap(void *addr, int len, int prot, int flags, int fd, long offset);', 'int HAP_munmap(void *addr, int len);', + 'unsigned long long HAP_perf_get_time_us(void);', 'int entry(unsigned long long handle, unsigned int sc, remote_arg* pra) {', + 'struct dcvs_v2_req req = {.type=7, .dcvs_enable=0, .set_latency=1, .latency=100, .set_dcvs_params=1, .target_corner = 6 /* TURBO */};', + 'HAP_power_set((void*)handle, (void*)&req);'] + msrc += ['if ((sc>>24) != 2) return 0;'] + msrc += [f'int sz_or_val_{i} = ((int*)pra[0].buf.pv)[{i}];' for i,b in enumerate(bufs)] + msrc += [f'int off{i} = ((int*)pra[1].buf.pv)[{i}];' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)] + msrc += [f'void *buf_{i} = HAP_mmap(0,sz_or_val_{i},3,0,pra[{i+3}].dma.fd,0)+off{i};' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)] + msrc += ["unsigned long long start = HAP_perf_get_time_us();"] + msrc += [f"{function_name}({', '.join([(f'buf_{i}' if isinstance(b[1][0], PtrDType) else f'sz_or_val_{i}') for i,b in enumerate(bufs)])});"] + msrc += ["*(unsigned long long *)(pra[2].buf.pv) = HAP_perf_get_time_us() - start;"] + msrc += [f'HAP_munmap(buf_{i}, sz_or_val_{i});' for i,b in enumerate(bufs) if isinstance(b[1][0], PtrDType)] + msrc += ["return 0; }"] + return ret + '\n' + '\n'.join(msrc) + def rpc_sc(method=0, ins=0, outs=0, fds=0): return (method << 24) | (ins << 16) | (outs << 8) | fds def rpc_prep_args(ins=None, outs=None, in_fds=None): ins, outs, in_fds = ins or list(), outs or list(), in_fds or list()