From 654a8b9ef7711317d5ebc6671235899654296dbc Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Sun, 9 Jun 2024 11:33:03 +0300 Subject: [PATCH] retire hsa (#4885) * retire hsa * EMULATE_AMD --- .github/workflows/test.yml | 12 ++++++------ README.md | 3 ++- docs/env_vars.md | 3 ++- .../driver/hsa.py => extra/backends/hsa_driver.py | 0 .../graph/hsa.py => extra/backends/hsa_graph.py | 0 {tinygrad/runtime => extra/backends}/ops_hsa.py | 0 extra/gemm/hip_matmul.py | 10 +++++----- test/helpers.py | 2 +- test/test_dtype_alu.py | 2 +- test/test_jit.py | 2 +- tinygrad/device.py | 2 +- tinygrad/engine/jit.py | 2 +- tinygrad/engine/search.py | 2 +- tinygrad/renderer/cstyle.py | 5 ++--- tinygrad/runtime/ops_amd.py | 14 +++++++++++--- tinygrad/runtime/ops_python.py | 6 +++--- 16 files changed, 37 insertions(+), 28 deletions(-) rename tinygrad/runtime/driver/hsa.py => extra/backends/hsa_driver.py (100%) rename tinygrad/runtime/graph/hsa.py => extra/backends/hsa_graph.py (100%) rename {tinygrad/runtime => extra/backends}/ops_hsa.py (100%) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ec277c7dd4..4161d9dddf 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -41,18 +41,18 @@ jobs: IMAGE=2 PYTHON=1 python3 test/test_ops.py TestOps.test_simple_conv2d - name: Test emulated METAL tensor cores run: DEBUG=2 EMULATE_METAL=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_big_gemm - - name: Test emulated HSA tensor cores + - name: Test emulated AMD tensor cores run: | - PYTHONPATH=. DEBUG=2 EMULATE_HSA=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py - PYTHONPATH=. DEBUG=2 EMULATE_HSA=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py - PYTHONPATH=. DEBUG=2 EMULATE_HSA=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=1 python3 ./extra/gemm/simple_matmul.py - PYTHONPATH=. DEBUG=2 EMULATE_HSA=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=1 python3 ./extra/gemm/simple_matmul.py + PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py + PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py + PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=1 python3 ./extra/gemm/simple_matmul.py + PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=1 python3 ./extra/gemm/simple_matmul.py - name: Test emulated CUDA tensor cores run: DEBUG=2 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm - name: Full test tensor cores run: | PYTHONPATH=. DEBUG=2 EMULATE_METAL=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores - PYTHONPATH=. DEBUG=2 EMULATE_HSA=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores + PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores PYTHONPATH=. DEBUG=2 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores - name: Test dtype with Python emulator run: DEBUG=1 PYTHONPATH=. PYTHON=1 python3 test/test_dtype.py diff --git a/README.md b/README.md index 999c717b4e..1f7f8a4335 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,8 @@ tinygrad already supports numerous accelerators, including: - [x] [LLVM](tinygrad/runtime/ops_llvm.py) - [x] [METAL](tinygrad/runtime/ops_metal.py) - [x] [CUDA](tinygrad/runtime/ops_cuda.py) -- [x] [HSA](tinygrad/runtime/ops_hsa.py) +- [x] [AMD](tinygrad/runtime/ops_amd.py) +- [x] [NV](tinygrad/runtime/ops_nv.py) And it is easy to add more! Your accelerator of choice only needs to support a total of ~25 low level ops. diff --git a/docs/env_vars.md b/docs/env_vars.md index 798db564e7..db5f295032 100644 --- a/docs/env_vars.md +++ b/docs/env_vars.md @@ -33,7 +33,8 @@ Variable | Possible Value(s) | Description DEBUG | [1-6] | enable debugging output, with 4 you get operations, timings, speed, generated code and more GPU | [1] | enable the GPU backend CUDA | [1] | enable CUDA backend -HSA | [1] | enable HSA backend +AMD | [1] | enable AMD backend +NV | [1] | enable NV backend METAL | [1] | enable Metal backend (for Mac M1 and after) METAL_XCODE | [1] | enable Metal using macOS Xcode SDK CLANG | [1] | enable Clang backend diff --git a/tinygrad/runtime/driver/hsa.py b/extra/backends/hsa_driver.py similarity index 100% rename from tinygrad/runtime/driver/hsa.py rename to extra/backends/hsa_driver.py diff --git a/tinygrad/runtime/graph/hsa.py b/extra/backends/hsa_graph.py similarity index 100% rename from tinygrad/runtime/graph/hsa.py rename to extra/backends/hsa_graph.py diff --git a/tinygrad/runtime/ops_hsa.py b/extra/backends/ops_hsa.py similarity index 100% rename from tinygrad/runtime/ops_hsa.py rename to extra/backends/ops_hsa.py diff --git a/extra/gemm/hip_matmul.py b/extra/gemm/hip_matmul.py index c9b881c01c..7fc84809b9 100644 --- a/extra/gemm/hip_matmul.py +++ b/extra/gemm/hip_matmul.py @@ -1,7 +1,7 @@ import time import numpy as np from tinygrad.helpers import getenv, prod, flat_mv -from tinygrad.runtime.ops_hsa import HSAAllocator, HSADevice, HSAProgram +from tinygrad.runtime.ops_amd import AMDAllocator, AMDDevice, AMDProgram # AMD_LOG_LEVEL=3 ./MIOpenDriver gemm --iter 1000 --time 1 --a_w 2048 --a_h 2048 --b_w 2048 # 5.5: Cijk_Ailk_Bljk_HHS_BH_MT128x128x16_MI16x16x16x1_SN_1LDSB0_APM1_ABV0_ACED0_AF0EM1_AF1EM1_AMAS3_ASE_ASGT_ASAE01_ASCE01_ASEM1_AAC0_BL1_BS1_DTL0_DTVA0_DVO0_ETSP_EPS1_FL0_GRVW8_GSU1_GSUASB_GLS0_ISA1100_IU1_K1_KLA_LBSPP128_LPA0_LPB8_LDL1_LRVW16_LWPMn1_LDW0_FMA_MIAV1_MDA2_NTA0_NTB0_NTC0_NTD0_NEPBS0_NLCA1_NLCB1_ONLL1_OPLV0_PK0_PAP0_PGR1_PLR1_RK0_SIA1_SS1_SU32_SUM0_SUS128_SCIUI1_SPO0_SRVW0_SSO0_SVW4_SNLL0_TT4_64_TLDS1_USFGROn1_VAW2_VSn1_VW4_WSGRA1_WSGRB1_WS32_WG32_4_1_WGM4 @@ -31,9 +31,9 @@ local_size = [32, 1, 1] global_size = [N//(KX*16), N//(KY*16), 1] num_threads = prod(local_size) -# Can HSAAllocator initialized as device=0 by default? -device = HSADevice() -hipallocator = HSAAllocator(device) +# Can AMDAllocator initialized as device=0 by default? +device = AMDDevice() +hipallocator = AMDAllocator(device) a = hipallocator.alloc(N*N*4) b = hipallocator.alloc(N*N*2) c = hipallocator.alloc(N*N*2) @@ -115,7 +115,7 @@ extern "C" __attribute__((global))void __attribute__((amdgpu_flat_work_group_siz if DEBUG > 1: print(prog_str) lib = device.compiler.compile(prog_str) -prog = HSAProgram(device, "test", lib) +prog = AMDProgram(device, "test", lib) def timeit(fxn): st = time.perf_counter() diff --git a/test/helpers.py b/test/helpers.py index ca7727c9a1..e2e395b1bc 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -26,7 +26,7 @@ def assert_jit_cache_len(fxn, expected_len): def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT): if dtype == dtypes.bfloat16: # NOTE: this requires bf16 buffer support - return device in {"HSA", "AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX")) + return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX")) if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32] # for CI GPU and OSX, cl_khr_fp16 isn't supported # for CI LLVM, it segfaults because it can't link to the casting function diff --git a/test/test_dtype_alu.py b/test/test_dtype_alu.py index b00e659005..802f40281a 100644 --- a/test/test_dtype_alu.py +++ b/test/test_dtype_alu.py @@ -145,7 +145,7 @@ class TestDTypeALU(unittest.TestCase): def test_int32_midcast_float(self, a, b, c, op1, op2): universal_test_midcast(a, b, c, op1, op2, dtypes.int32, dtypes.float32) # Metal and CUDACPU and HIP behave differently than numpy in CI for overflows - skip_overflow = CI and (Device.DEFAULT in {"HSA", "AMD", "NV"} or getenv("CUDACPU")) + skip_overflow = CI and (Device.DEFAULT in {"AMD", "NV"} or getenv("CUDACPU")) @given(strat.floats(width=32, min_value=0, max_value=10.0) if skip_overflow else ht.float32, strat.floats(width=32, min_value=0, max_value=10.0) if skip_overflow else ht.float32, ht.int32, strat.sampled_from(binary_operations), strat.sampled_from(integer_binary_operations)) diff --git a/test/test_jit.py b/test/test_jit.py index 359019212b..f23c3d9a80 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -288,7 +288,7 @@ class TestJit(unittest.TestCase): for i in range(5): np.testing.assert_equal(g(Tensor([i]*3), Tensor.ones(3), Tensor.zeros(3)).numpy(), np.array([i+1]*3)) - @unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL", "HSA", "NV", "AMD"}, "no GPU CI") + @unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL", "NV", "AMD"}, "no GPU CI") def test_jitted_transfers(self): d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1" diff --git a/tinygrad/device.py b/tinygrad/device.py index 1c7e15358f..f28a324abb 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -29,7 +29,7 @@ class _Device: def DEFAULT(self) -> str: device_from_env: Optional[str] = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._devices, None) # type: ignore if device_from_env: return device_from_env - for device in ["METAL", "HSA", "CUDA", "GPU", "CLANG", "LLVM"]: + for device in ["METAL", "AMD", "CUDA", "GPU", "CLANG", "LLVM"]: try: if self[device]: os.environ[device] = "1" # we set this in environment for spawned children diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index eb4f030e1f..58f95a5d34 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -41,7 +41,7 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer] if ji.prg.__class__ in {EmptyOp, ViewOp}: continue ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None. if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.device - elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"HSA", "CUDA", "NV", "AMD"}: + elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"CUDA", "NV", "AMD"}: ji_graph_dev = Device[ji.bufs[0].device] graph_class = (ji_graph_dev.graph.func if isinstance(ji_graph_dev.graph, functools.partial) else ji_graph_dev.graph) if ji_graph_dev else None #type: ignore diff --git a/tinygrad/engine/search.py b/tinygrad/engine/search.py index b18cd6b61c..2f9641a5c9 100644 --- a/tinygrad/engine/search.py +++ b/tinygrad/engine/search.py @@ -126,7 +126,7 @@ def beam_search(lin:Linearizer, rawbufs:List[Buffer], amt:int, allow_test_size=T beam: List[Tuple[Linearizer, float]] = [(lin, float("inf"))] seen_libs = set() - default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "HSA", "AMD", "NV"} else 0 + default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV"} else 0 if beam_pool is None and (workers := getenv("PARALLEL", default_parallel)): beam_pool = multiprocessing.get_context("spawn").Pool(workers, _init_worker, (), getenv("BEAM_MAX_TASKS_PER_CHILD", 16)) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index bc534d625a..396fe62361 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -312,8 +312,8 @@ def _make_hip_dtype(base_type, name, cnt): return f"typedef {base_type} {name}{cnt} __attribute__((ext_vector_type({cnt})));\n" + \ f"static inline __attribute__((device)) {name}{cnt} make_{name}{cnt}({header}) {{ return {{{elems}}}; }}" -class HIPRenderer(CStyleLanguage): - device = "HSA" +class AMDRenderer(CStyleLanguage): + device = "AMD" shared_max = 65536 tensor_cores = [TensorCore(dims=(16,16,16), threads=[(0,8),(0,2),(1,2)], thread_local_sizes=[[16],[16],[4,2]], thread_local_aliases=[ [[0],[0],[2],[-1],[1]], [[1],[2],[0],[-1],[0]], [[1],[2],[-2],[0],[3,-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]] # noqa: E501 @@ -381,4 +381,3 @@ static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) { return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))" class NVRenderer(CUDARenderer): device = "NV" -class AMDRenderer(HIPRenderer): device = "AMD" diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index 1edf4f38df..c48244351f 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -1,10 +1,10 @@ from __future__ import annotations from typing import Tuple, List, Any, cast import os, fcntl, ctypes, ctypes.util, functools, re, pathlib, mmap, struct, errno, subprocess, time, array -from tinygrad.device import Compiled, BufferOptions, LRUAllocator +from tinygrad.device import Compiled, Compiler, CompileError, BufferOptions, LRUAllocator from tinygrad.helpers import getenv, from_mv, init_c_struct_t, to_mv, round_up, DEBUG from tinygrad.renderer.cstyle import AMDRenderer -from tinygrad.runtime.ops_hsa import HSACompiler +from tinygrad.runtime.driver.hip_comgr import compile_hip import tinygrad.runtime.autogen.kfd as kfd import tinygrad.runtime.autogen.hsa as hsa import tinygrad.runtime.autogen.amd_gpu as amd_gpu @@ -66,6 +66,14 @@ COMPUTE_SHADER_EN, FORCE_START_AT_000, CS_W32_EN = (1 << 0), (1 << 2), (1 << 15) def gfxreg(reg): return reg + 0x00001260 - amd_gpu.PACKET3_SET_SH_REG_START def data64_le(data): return (data & 0xFFFFFFFF, data >> 32) +class AMDCompiler(Compiler): + def __init__(self, arch:str): + self.arch = arch + super().__init__(f"compile_hip_{self.arch}") + def compile(self, src:str) -> bytes: + try: return compile_hip(src, self.arch) + except RuntimeError as e: raise CompileError(e) + class HWPM4Queue: def __init__(self): self.q, self.binded_device, self.ptr_to_dispatch_packet = [], None, {} def __del__(self): @@ -548,7 +556,7 @@ class AMDDevice(Compiled): self.pm4_doorbell = to_mv(self.doorbells + self.pm4_queue.doorbell_offset - self.doorbells_base, 8).cast("Q") from tinygrad.runtime.graph.hcq import HCQGraph - super().__init__(device, AMDAllocator(self), AMDRenderer(), HSACompiler(self.arch), + super().__init__(device, AMDAllocator(self), AMDRenderer(), AMDCompiler(self.arch), functools.partial(AMDProgram, self), functools.partial(HCQGraph, AMDDevice, HWPM4Queue, HWCopyQueue)) def synchronize(self): diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index 05ada20bce..c118348f1e 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -9,7 +9,7 @@ from tinygrad.device import Compiled, Compiler, Allocator from tinygrad.codegen.uops import UOpGraph, UOps from tinygrad.ops import BinaryOps, TernaryOps, exec_alu from tinygrad.renderer import Renderer -from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, HIPRenderer +from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer def _load(m, i): if i < 0 or i >= len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}") @@ -154,7 +154,7 @@ class PythonProgram: # (i, j), C, D (2 elements on 32 threads): row major same as A/B def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4) ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map) - elif arg[5] == "HSA": + elif arg[5] == "AMD": # A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15 def a_elem(x, i, j, goff): assert x[i][goff+j] == x[i][goff+j+16], "warp elements not duplicated properly across lanes" @@ -184,7 +184,7 @@ class PythonRenderer(Renderer): device = "PYTHON" def __init__(self): if getenv("EMULATE_METAL"): self.device, self.tensor_cores = "METAL", MetalRenderer.tensor_cores - if getenv("EMULATE_HSA"): self.device, self.tensor_cores = "HSA", HIPRenderer.tensor_cores + if getenv("EMULATE_AMD"): self.device, self.tensor_cores = "AMD", AMDRenderer.tensor_cores if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tensor_cores def render(self, name:str, uops:UOpGraph) -> str: