retire hsa (#4885)

* retire hsa

* EMULATE_AMD
This commit is contained in:
nimlgen
2024-06-09 11:33:03 +03:00
committed by GitHub
parent e33efd6a3d
commit 654a8b9ef7
16 changed files with 37 additions and 28 deletions

View File

@@ -41,18 +41,18 @@ jobs:
IMAGE=2 PYTHON=1 python3 test/test_ops.py TestOps.test_simple_conv2d
- name: Test emulated METAL tensor cores
run: DEBUG=2 EMULATE_METAL=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_big_gemm
- name: Test emulated HSA tensor cores
- name: Test emulated AMD tensor cores
run: |
PYTHONPATH=. DEBUG=2 EMULATE_HSA=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
PYTHONPATH=. DEBUG=2 EMULATE_HSA=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
PYTHONPATH=. DEBUG=2 EMULATE_HSA=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=1 python3 ./extra/gemm/simple_matmul.py
PYTHONPATH=. DEBUG=2 EMULATE_HSA=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=1 python3 ./extra/gemm/simple_matmul.py
PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=1 python3 ./extra/gemm/simple_matmul.py
PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=1 python3 ./extra/gemm/simple_matmul.py
- name: Test emulated CUDA tensor cores
run: DEBUG=2 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm
- name: Full test tensor cores
run: |
PYTHONPATH=. DEBUG=2 EMULATE_METAL=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
PYTHONPATH=. DEBUG=2 EMULATE_HSA=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
PYTHONPATH=. DEBUG=2 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
- name: Test dtype with Python emulator
run: DEBUG=1 PYTHONPATH=. PYTHON=1 python3 test/test_dtype.py

View File

@@ -84,7 +84,8 @@ tinygrad already supports numerous accelerators, including:
- [x] [LLVM](tinygrad/runtime/ops_llvm.py)
- [x] [METAL](tinygrad/runtime/ops_metal.py)
- [x] [CUDA](tinygrad/runtime/ops_cuda.py)
- [x] [HSA](tinygrad/runtime/ops_hsa.py)
- [x] [AMD](tinygrad/runtime/ops_amd.py)
- [x] [NV](tinygrad/runtime/ops_nv.py)
And it is easy to add more! Your accelerator of choice only needs to support a total of ~25 low level ops.

View File

@@ -33,7 +33,8 @@ Variable | Possible Value(s) | Description
DEBUG | [1-6] | enable debugging output, with 4 you get operations, timings, speed, generated code and more
GPU | [1] | enable the GPU backend
CUDA | [1] | enable CUDA backend
HSA | [1] | enable HSA backend
AMD | [1] | enable AMD backend
NV | [1] | enable NV backend
METAL | [1] | enable Metal backend (for Mac M1 and after)
METAL_XCODE | [1] | enable Metal using macOS Xcode SDK
CLANG | [1] | enable Clang backend

View File

@@ -1,7 +1,7 @@
import time
import numpy as np
from tinygrad.helpers import getenv, prod, flat_mv
from tinygrad.runtime.ops_hsa import HSAAllocator, HSADevice, HSAProgram
from tinygrad.runtime.ops_amd import AMDAllocator, AMDDevice, AMDProgram
# AMD_LOG_LEVEL=3 ./MIOpenDriver gemm --iter 1000 --time 1 --a_w 2048 --a_h 2048 --b_w 2048
# 5.5: Cijk_Ailk_Bljk_HHS_BH_MT128x128x16_MI16x16x16x1_SN_1LDSB0_APM1_ABV0_ACED0_AF0EM1_AF1EM1_AMAS3_ASE_ASGT_ASAE01_ASCE01_ASEM1_AAC0_BL1_BS1_DTL0_DTVA0_DVO0_ETSP_EPS1_FL0_GRVW8_GSU1_GSUASB_GLS0_ISA1100_IU1_K1_KLA_LBSPP128_LPA0_LPB8_LDL1_LRVW16_LWPMn1_LDW0_FMA_MIAV1_MDA2_NTA0_NTB0_NTC0_NTD0_NEPBS0_NLCA1_NLCB1_ONLL1_OPLV0_PK0_PAP0_PGR1_PLR1_RK0_SIA1_SS1_SU32_SUM0_SUS128_SCIUI1_SPO0_SRVW0_SSO0_SVW4_SNLL0_TT4_64_TLDS1_USFGROn1_VAW2_VSn1_VW4_WSGRA1_WSGRB1_WS32_WG32_4_1_WGM4
@@ -31,9 +31,9 @@ local_size = [32, 1, 1]
global_size = [N//(KX*16), N//(KY*16), 1]
num_threads = prod(local_size)
# Can HSAAllocator initialized as device=0 by default?
device = HSADevice()
hipallocator = HSAAllocator(device)
# Can AMDAllocator initialized as device=0 by default?
device = AMDDevice()
hipallocator = AMDAllocator(device)
a = hipallocator.alloc(N*N*4)
b = hipallocator.alloc(N*N*2)
c = hipallocator.alloc(N*N*2)
@@ -115,7 +115,7 @@ extern "C" __attribute__((global))void __attribute__((amdgpu_flat_work_group_siz
if DEBUG > 1: print(prog_str)
lib = device.compiler.compile(prog_str)
prog = HSAProgram(device, "test", lib)
prog = AMDProgram(device, "test", lib)
def timeit(fxn):
st = time.perf_counter()

View File

@@ -26,7 +26,7 @@ def assert_jit_cache_len(fxn, expected_len):
def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
if dtype == dtypes.bfloat16:
# NOTE: this requires bf16 buffer support
return device in {"HSA", "AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX"))
return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX"))
if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32]
# for CI GPU and OSX, cl_khr_fp16 isn't supported
# for CI LLVM, it segfaults because it can't link to the casting function

View File

@@ -145,7 +145,7 @@ class TestDTypeALU(unittest.TestCase):
def test_int32_midcast_float(self, a, b, c, op1, op2): universal_test_midcast(a, b, c, op1, op2, dtypes.int32, dtypes.float32)
# Metal and CUDACPU and HIP behave differently than numpy in CI for overflows
skip_overflow = CI and (Device.DEFAULT in {"HSA", "AMD", "NV"} or getenv("CUDACPU"))
skip_overflow = CI and (Device.DEFAULT in {"AMD", "NV"} or getenv("CUDACPU"))
@given(strat.floats(width=32, min_value=0, max_value=10.0) if skip_overflow else ht.float32,
strat.floats(width=32, min_value=0, max_value=10.0) if skip_overflow else ht.float32,
ht.int32, strat.sampled_from(binary_operations), strat.sampled_from(integer_binary_operations))

View File

@@ -288,7 +288,7 @@ class TestJit(unittest.TestCase):
for i in range(5):
np.testing.assert_equal(g(Tensor([i]*3), Tensor.ones(3), Tensor.zeros(3)).numpy(), np.array([i+1]*3))
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL", "HSA", "NV", "AMD"}, "no GPU CI")
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL", "NV", "AMD"}, "no GPU CI")
def test_jitted_transfers(self):
d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"

View File

@@ -29,7 +29,7 @@ class _Device:
def DEFAULT(self) -> str:
device_from_env: Optional[str] = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._devices, None) # type: ignore
if device_from_env: return device_from_env
for device in ["METAL", "HSA", "CUDA", "GPU", "CLANG", "LLVM"]:
for device in ["METAL", "AMD", "CUDA", "GPU", "CLANG", "LLVM"]:
try:
if self[device]:
os.environ[device] = "1" # we set this in environment for spawned children

View File

@@ -41,7 +41,7 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
if ji.prg.__class__ in {EmptyOp, ViewOp}: continue
ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None.
if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.device
elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"HSA", "CUDA", "NV", "AMD"}:
elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"CUDA", "NV", "AMD"}:
ji_graph_dev = Device[ji.bufs[0].device]
graph_class = (ji_graph_dev.graph.func if isinstance(ji_graph_dev.graph, functools.partial) else ji_graph_dev.graph) if ji_graph_dev else None #type: ignore

View File

@@ -126,7 +126,7 @@ def beam_search(lin:Linearizer, rawbufs:List[Buffer], amt:int, allow_test_size=T
beam: List[Tuple[Linearizer, float]] = [(lin, float("inf"))]
seen_libs = set()
default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "HSA", "AMD", "NV"} else 0
default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV"} else 0
if beam_pool is None and (workers := getenv("PARALLEL", default_parallel)):
beam_pool = multiprocessing.get_context("spawn").Pool(workers, _init_worker, (), getenv("BEAM_MAX_TASKS_PER_CHILD", 16))

View File

@@ -312,8 +312,8 @@ def _make_hip_dtype(base_type, name, cnt):
return f"typedef {base_type} {name}{cnt} __attribute__((ext_vector_type({cnt})));\n" + \
f"static inline __attribute__((device)) {name}{cnt} make_{name}{cnt}({header}) {{ return {{{elems}}}; }}"
class HIPRenderer(CStyleLanguage):
device = "HSA"
class AMDRenderer(CStyleLanguage):
device = "AMD"
shared_max = 65536
tensor_cores = [TensorCore(dims=(16,16,16), threads=[(0,8),(0,2),(1,2)], thread_local_sizes=[[16],[16],[4,2]], thread_local_aliases=[ [[0],[0],[2],[-1],[1]], [[1],[2],[0],[-1],[0]], [[1],[2],[-2],[0],[3,-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]] # noqa: E501
@@ -381,4 +381,3 @@ static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) {
return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))"
class NVRenderer(CUDARenderer): device = "NV"
class AMDRenderer(HIPRenderer): device = "AMD"

View File

@@ -1,10 +1,10 @@
from __future__ import annotations
from typing import Tuple, List, Any, cast
import os, fcntl, ctypes, ctypes.util, functools, re, pathlib, mmap, struct, errno, subprocess, time, array
from tinygrad.device import Compiled, BufferOptions, LRUAllocator
from tinygrad.device import Compiled, Compiler, CompileError, BufferOptions, LRUAllocator
from tinygrad.helpers import getenv, from_mv, init_c_struct_t, to_mv, round_up, DEBUG
from tinygrad.renderer.cstyle import AMDRenderer
from tinygrad.runtime.ops_hsa import HSACompiler
from tinygrad.runtime.driver.hip_comgr import compile_hip
import tinygrad.runtime.autogen.kfd as kfd
import tinygrad.runtime.autogen.hsa as hsa
import tinygrad.runtime.autogen.amd_gpu as amd_gpu
@@ -66,6 +66,14 @@ COMPUTE_SHADER_EN, FORCE_START_AT_000, CS_W32_EN = (1 << 0), (1 << 2), (1 << 15)
def gfxreg(reg): return reg + 0x00001260 - amd_gpu.PACKET3_SET_SH_REG_START
def data64_le(data): return (data & 0xFFFFFFFF, data >> 32)
class AMDCompiler(Compiler):
def __init__(self, arch:str):
self.arch = arch
super().__init__(f"compile_hip_{self.arch}")
def compile(self, src:str) -> bytes:
try: return compile_hip(src, self.arch)
except RuntimeError as e: raise CompileError(e)
class HWPM4Queue:
def __init__(self): self.q, self.binded_device, self.ptr_to_dispatch_packet = [], None, {}
def __del__(self):
@@ -548,7 +556,7 @@ class AMDDevice(Compiled):
self.pm4_doorbell = to_mv(self.doorbells + self.pm4_queue.doorbell_offset - self.doorbells_base, 8).cast("Q")
from tinygrad.runtime.graph.hcq import HCQGraph
super().__init__(device, AMDAllocator(self), AMDRenderer(), HSACompiler(self.arch),
super().__init__(device, AMDAllocator(self), AMDRenderer(), AMDCompiler(self.arch),
functools.partial(AMDProgram, self), functools.partial(HCQGraph, AMDDevice, HWPM4Queue, HWCopyQueue))
def synchronize(self):

View File

@@ -9,7 +9,7 @@ from tinygrad.device import Compiled, Compiler, Allocator
from tinygrad.codegen.uops import UOpGraph, UOps
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu
from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, HIPRenderer
from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer
def _load(m, i):
if i < 0 or i >= len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}")
@@ -154,7 +154,7 @@ class PythonProgram:
# (i, j), C, D (2 elements on 32 threads): row major same as A/B
def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4)
ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
elif arg[5] == "HSA":
elif arg[5] == "AMD":
# A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
def a_elem(x, i, j, goff):
assert x[i][goff+j] == x[i][goff+j+16], "warp elements not duplicated properly across lanes"
@@ -184,7 +184,7 @@ class PythonRenderer(Renderer):
device = "PYTHON"
def __init__(self):
if getenv("EMULATE_METAL"): self.device, self.tensor_cores = "METAL", MetalRenderer.tensor_cores
if getenv("EMULATE_HSA"): self.device, self.tensor_cores = "HSA", HIPRenderer.tensor_cores
if getenv("EMULATE_AMD"): self.device, self.tensor_cores = "AMD", AMDRenderer.tensor_cores
if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tensor_cores
def render(self, name:str, uops:UOpGraph) -> str: