mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
12
.github/workflows/test.yml
vendored
12
.github/workflows/test.yml
vendored
@@ -41,18 +41,18 @@ jobs:
|
||||
IMAGE=2 PYTHON=1 python3 test/test_ops.py TestOps.test_simple_conv2d
|
||||
- name: Test emulated METAL tensor cores
|
||||
run: DEBUG=2 EMULATE_METAL=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_big_gemm
|
||||
- name: Test emulated HSA tensor cores
|
||||
- name: Test emulated AMD tensor cores
|
||||
run: |
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_HSA=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_HSA=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_HSA=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=1 python3 ./extra/gemm/simple_matmul.py
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_HSA=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=1 python3 ./extra/gemm/simple_matmul.py
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=1 python3 ./extra/gemm/simple_matmul.py
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=1 python3 ./extra/gemm/simple_matmul.py
|
||||
- name: Test emulated CUDA tensor cores
|
||||
run: DEBUG=2 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm
|
||||
- name: Full test tensor cores
|
||||
run: |
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_METAL=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_HSA=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
|
||||
- name: Test dtype with Python emulator
|
||||
run: DEBUG=1 PYTHONPATH=. PYTHON=1 python3 test/test_dtype.py
|
||||
|
||||
@@ -84,7 +84,8 @@ tinygrad already supports numerous accelerators, including:
|
||||
- [x] [LLVM](tinygrad/runtime/ops_llvm.py)
|
||||
- [x] [METAL](tinygrad/runtime/ops_metal.py)
|
||||
- [x] [CUDA](tinygrad/runtime/ops_cuda.py)
|
||||
- [x] [HSA](tinygrad/runtime/ops_hsa.py)
|
||||
- [x] [AMD](tinygrad/runtime/ops_amd.py)
|
||||
- [x] [NV](tinygrad/runtime/ops_nv.py)
|
||||
|
||||
And it is easy to add more! Your accelerator of choice only needs to support a total of ~25 low level ops.
|
||||
|
||||
|
||||
@@ -33,7 +33,8 @@ Variable | Possible Value(s) | Description
|
||||
DEBUG | [1-6] | enable debugging output, with 4 you get operations, timings, speed, generated code and more
|
||||
GPU | [1] | enable the GPU backend
|
||||
CUDA | [1] | enable CUDA backend
|
||||
HSA | [1] | enable HSA backend
|
||||
AMD | [1] | enable AMD backend
|
||||
NV | [1] | enable NV backend
|
||||
METAL | [1] | enable Metal backend (for Mac M1 and after)
|
||||
METAL_XCODE | [1] | enable Metal using macOS Xcode SDK
|
||||
CLANG | [1] | enable Clang backend
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import time
|
||||
import numpy as np
|
||||
from tinygrad.helpers import getenv, prod, flat_mv
|
||||
from tinygrad.runtime.ops_hsa import HSAAllocator, HSADevice, HSAProgram
|
||||
from tinygrad.runtime.ops_amd import AMDAllocator, AMDDevice, AMDProgram
|
||||
|
||||
# AMD_LOG_LEVEL=3 ./MIOpenDriver gemm --iter 1000 --time 1 --a_w 2048 --a_h 2048 --b_w 2048
|
||||
# 5.5: Cijk_Ailk_Bljk_HHS_BH_MT128x128x16_MI16x16x16x1_SN_1LDSB0_APM1_ABV0_ACED0_AF0EM1_AF1EM1_AMAS3_ASE_ASGT_ASAE01_ASCE01_ASEM1_AAC0_BL1_BS1_DTL0_DTVA0_DVO0_ETSP_EPS1_FL0_GRVW8_GSU1_GSUASB_GLS0_ISA1100_IU1_K1_KLA_LBSPP128_LPA0_LPB8_LDL1_LRVW16_LWPMn1_LDW0_FMA_MIAV1_MDA2_NTA0_NTB0_NTC0_NTD0_NEPBS0_NLCA1_NLCB1_ONLL1_OPLV0_PK0_PAP0_PGR1_PLR1_RK0_SIA1_SS1_SU32_SUM0_SUS128_SCIUI1_SPO0_SRVW0_SSO0_SVW4_SNLL0_TT4_64_TLDS1_USFGROn1_VAW2_VSn1_VW4_WSGRA1_WSGRB1_WS32_WG32_4_1_WGM4
|
||||
@@ -31,9 +31,9 @@ local_size = [32, 1, 1]
|
||||
global_size = [N//(KX*16), N//(KY*16), 1]
|
||||
num_threads = prod(local_size)
|
||||
|
||||
# Can HSAAllocator initialized as device=0 by default?
|
||||
device = HSADevice()
|
||||
hipallocator = HSAAllocator(device)
|
||||
# Can AMDAllocator initialized as device=0 by default?
|
||||
device = AMDDevice()
|
||||
hipallocator = AMDAllocator(device)
|
||||
a = hipallocator.alloc(N*N*4)
|
||||
b = hipallocator.alloc(N*N*2)
|
||||
c = hipallocator.alloc(N*N*2)
|
||||
@@ -115,7 +115,7 @@ extern "C" __attribute__((global))void __attribute__((amdgpu_flat_work_group_siz
|
||||
|
||||
if DEBUG > 1: print(prog_str)
|
||||
lib = device.compiler.compile(prog_str)
|
||||
prog = HSAProgram(device, "test", lib)
|
||||
prog = AMDProgram(device, "test", lib)
|
||||
|
||||
def timeit(fxn):
|
||||
st = time.perf_counter()
|
||||
|
||||
@@ -26,7 +26,7 @@ def assert_jit_cache_len(fxn, expected_len):
|
||||
def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
|
||||
if dtype == dtypes.bfloat16:
|
||||
# NOTE: this requires bf16 buffer support
|
||||
return device in {"HSA", "AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX"))
|
||||
return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX"))
|
||||
if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32]
|
||||
# for CI GPU and OSX, cl_khr_fp16 isn't supported
|
||||
# for CI LLVM, it segfaults because it can't link to the casting function
|
||||
|
||||
@@ -145,7 +145,7 @@ class TestDTypeALU(unittest.TestCase):
|
||||
def test_int32_midcast_float(self, a, b, c, op1, op2): universal_test_midcast(a, b, c, op1, op2, dtypes.int32, dtypes.float32)
|
||||
|
||||
# Metal and CUDACPU and HIP behave differently than numpy in CI for overflows
|
||||
skip_overflow = CI and (Device.DEFAULT in {"HSA", "AMD", "NV"} or getenv("CUDACPU"))
|
||||
skip_overflow = CI and (Device.DEFAULT in {"AMD", "NV"} or getenv("CUDACPU"))
|
||||
@given(strat.floats(width=32, min_value=0, max_value=10.0) if skip_overflow else ht.float32,
|
||||
strat.floats(width=32, min_value=0, max_value=10.0) if skip_overflow else ht.float32,
|
||||
ht.int32, strat.sampled_from(binary_operations), strat.sampled_from(integer_binary_operations))
|
||||
|
||||
@@ -288,7 +288,7 @@ class TestJit(unittest.TestCase):
|
||||
for i in range(5):
|
||||
np.testing.assert_equal(g(Tensor([i]*3), Tensor.ones(3), Tensor.zeros(3)).numpy(), np.array([i+1]*3))
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL", "HSA", "NV", "AMD"}, "no GPU CI")
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL", "NV", "AMD"}, "no GPU CI")
|
||||
def test_jitted_transfers(self):
|
||||
d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ class _Device:
|
||||
def DEFAULT(self) -> str:
|
||||
device_from_env: Optional[str] = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._devices, None) # type: ignore
|
||||
if device_from_env: return device_from_env
|
||||
for device in ["METAL", "HSA", "CUDA", "GPU", "CLANG", "LLVM"]:
|
||||
for device in ["METAL", "AMD", "CUDA", "GPU", "CLANG", "LLVM"]:
|
||||
try:
|
||||
if self[device]:
|
||||
os.environ[device] = "1" # we set this in environment for spawned children
|
||||
|
||||
@@ -41,7 +41,7 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
|
||||
if ji.prg.__class__ in {EmptyOp, ViewOp}: continue
|
||||
ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None.
|
||||
if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.device
|
||||
elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"HSA", "CUDA", "NV", "AMD"}:
|
||||
elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"CUDA", "NV", "AMD"}:
|
||||
ji_graph_dev = Device[ji.bufs[0].device]
|
||||
|
||||
graph_class = (ji_graph_dev.graph.func if isinstance(ji_graph_dev.graph, functools.partial) else ji_graph_dev.graph) if ji_graph_dev else None #type: ignore
|
||||
|
||||
@@ -126,7 +126,7 @@ def beam_search(lin:Linearizer, rawbufs:List[Buffer], amt:int, allow_test_size=T
|
||||
beam: List[Tuple[Linearizer, float]] = [(lin, float("inf"))]
|
||||
seen_libs = set()
|
||||
|
||||
default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "HSA", "AMD", "NV"} else 0
|
||||
default_parallel = multiprocessing.cpu_count() if lin.opts.device in {"CUDA", "AMD", "NV"} else 0
|
||||
if beam_pool is None and (workers := getenv("PARALLEL", default_parallel)):
|
||||
beam_pool = multiprocessing.get_context("spawn").Pool(workers, _init_worker, (), getenv("BEAM_MAX_TASKS_PER_CHILD", 16))
|
||||
|
||||
|
||||
@@ -312,8 +312,8 @@ def _make_hip_dtype(base_type, name, cnt):
|
||||
return f"typedef {base_type} {name}{cnt} __attribute__((ext_vector_type({cnt})));\n" + \
|
||||
f"static inline __attribute__((device)) {name}{cnt} make_{name}{cnt}({header}) {{ return {{{elems}}}; }}"
|
||||
|
||||
class HIPRenderer(CStyleLanguage):
|
||||
device = "HSA"
|
||||
class AMDRenderer(CStyleLanguage):
|
||||
device = "AMD"
|
||||
shared_max = 65536
|
||||
tensor_cores = [TensorCore(dims=(16,16,16), threads=[(0,8),(0,2),(1,2)], thread_local_sizes=[[16],[16],[4,2]], thread_local_aliases=[ [[0],[0],[2],[-1],[1]], [[1],[2],[0],[-1],[0]], [[1],[2],[-2],[0],[3,-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]] # noqa: E501
|
||||
|
||||
@@ -381,4 +381,3 @@ static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) {
|
||||
return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))"
|
||||
|
||||
class NVRenderer(CUDARenderer): device = "NV"
|
||||
class AMDRenderer(HIPRenderer): device = "AMD"
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from __future__ import annotations
|
||||
from typing import Tuple, List, Any, cast
|
||||
import os, fcntl, ctypes, ctypes.util, functools, re, pathlib, mmap, struct, errno, subprocess, time, array
|
||||
from tinygrad.device import Compiled, BufferOptions, LRUAllocator
|
||||
from tinygrad.device import Compiled, Compiler, CompileError, BufferOptions, LRUAllocator
|
||||
from tinygrad.helpers import getenv, from_mv, init_c_struct_t, to_mv, round_up, DEBUG
|
||||
from tinygrad.renderer.cstyle import AMDRenderer
|
||||
from tinygrad.runtime.ops_hsa import HSACompiler
|
||||
from tinygrad.runtime.driver.hip_comgr import compile_hip
|
||||
import tinygrad.runtime.autogen.kfd as kfd
|
||||
import tinygrad.runtime.autogen.hsa as hsa
|
||||
import tinygrad.runtime.autogen.amd_gpu as amd_gpu
|
||||
@@ -66,6 +66,14 @@ COMPUTE_SHADER_EN, FORCE_START_AT_000, CS_W32_EN = (1 << 0), (1 << 2), (1 << 15)
|
||||
def gfxreg(reg): return reg + 0x00001260 - amd_gpu.PACKET3_SET_SH_REG_START
|
||||
def data64_le(data): return (data & 0xFFFFFFFF, data >> 32)
|
||||
|
||||
class AMDCompiler(Compiler):
|
||||
def __init__(self, arch:str):
|
||||
self.arch = arch
|
||||
super().__init__(f"compile_hip_{self.arch}")
|
||||
def compile(self, src:str) -> bytes:
|
||||
try: return compile_hip(src, self.arch)
|
||||
except RuntimeError as e: raise CompileError(e)
|
||||
|
||||
class HWPM4Queue:
|
||||
def __init__(self): self.q, self.binded_device, self.ptr_to_dispatch_packet = [], None, {}
|
||||
def __del__(self):
|
||||
@@ -548,7 +556,7 @@ class AMDDevice(Compiled):
|
||||
self.pm4_doorbell = to_mv(self.doorbells + self.pm4_queue.doorbell_offset - self.doorbells_base, 8).cast("Q")
|
||||
|
||||
from tinygrad.runtime.graph.hcq import HCQGraph
|
||||
super().__init__(device, AMDAllocator(self), AMDRenderer(), HSACompiler(self.arch),
|
||||
super().__init__(device, AMDAllocator(self), AMDRenderer(), AMDCompiler(self.arch),
|
||||
functools.partial(AMDProgram, self), functools.partial(HCQGraph, AMDDevice, HWPM4Queue, HWCopyQueue))
|
||||
|
||||
def synchronize(self):
|
||||
|
||||
@@ -9,7 +9,7 @@ from tinygrad.device import Compiled, Compiler, Allocator
|
||||
from tinygrad.codegen.uops import UOpGraph, UOps
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, HIPRenderer
|
||||
from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer
|
||||
|
||||
def _load(m, i):
|
||||
if i < 0 or i >= len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}")
|
||||
@@ -154,7 +154,7 @@ class PythonProgram:
|
||||
# (i, j), C, D (2 elements on 32 threads): row major same as A/B
|
||||
def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4)
|
||||
ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
|
||||
elif arg[5] == "HSA":
|
||||
elif arg[5] == "AMD":
|
||||
# A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
|
||||
def a_elem(x, i, j, goff):
|
||||
assert x[i][goff+j] == x[i][goff+j+16], "warp elements not duplicated properly across lanes"
|
||||
@@ -184,7 +184,7 @@ class PythonRenderer(Renderer):
|
||||
device = "PYTHON"
|
||||
def __init__(self):
|
||||
if getenv("EMULATE_METAL"): self.device, self.tensor_cores = "METAL", MetalRenderer.tensor_cores
|
||||
if getenv("EMULATE_HSA"): self.device, self.tensor_cores = "HSA", HIPRenderer.tensor_cores
|
||||
if getenv("EMULATE_AMD"): self.device, self.tensor_cores = "AMD", AMDRenderer.tensor_cores
|
||||
if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tensor_cores
|
||||
|
||||
def render(self, name:str, uops:UOpGraph) -> str:
|
||||
|
||||
Reference in New Issue
Block a user