* cuda fp8

* tensor core

* tc test

* clean

* clean pm
This commit is contained in:
b1tg
2025-10-22 03:05:25 +08:00
committed by GitHub
parent 587ccc0e5c
commit 60d7e232f2
7 changed files with 43 additions and 15 deletions

View File

@@ -204,7 +204,7 @@ jobs:
DEBUG=2 EMULATE=CUDA FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm_fp16
DEBUG=2 EMULATE=CUDA ALLOW_TF32=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm
DEBUG=2 EMULATE=CUDA_SM75 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm_fp16
DEBUG=2 EMULATE=CUDA ALLOW_TF32=1 FORWARD_ONLY=1 PYTHON=1 python3 test/opt/test_tensor_cores.py
DEBUG=2 EMULATE=CUDA_SM89 ALLOW_TF32=1 FORWARD_ONLY=1 PYTHON=1 python3 test/opt/test_tensor_cores.py
- name: Test emulated INTEL OpenCL tensor cores
run: DEBUG=2 EMULATE=INTEL FORWARD_ONLY=1 PYTHON=1 HALF=1 N=64 python3 ./extra/gemm/simple_matmul.py
- name: Test emulated AMX tensor cores

View File

@@ -5,8 +5,10 @@ from tinygrad.dtype import _to_np_dtype
from tinygrad.codegen.opt import OptOps
from tinygrad.engine.realize import lower_schedule
dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float
acc_dtype = dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else None
dtype_in = (dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else
dtypes.fp8e4m3 if getenv("FP8E4M3") else dtypes.fp8e5m2 if getenv("FP8E5M2") else dtypes.float)
acc_dtype = (dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else
dtypes.fp8e4m3 if getenv("ACC_FP8E4M3") else dtypes.fp8e5m2 if getenv("ACC_FP8E5M2") else None)
if getenv("INT"): dtype_in, acc_dtype = dtypes.int8, dtypes.int32
if getenv("UINT"): dtype_in, acc_dtype = dtypes.uint8, dtypes.int32
@@ -14,8 +16,10 @@ N = getenv("N", 4096)
M = getenv("M", N)
K = getenv("K", N)
CNT = getenv("CNT", 10)
ATOL = getenv("ATOL", 1e-4)
RTOL = getenv("RTOL", 3e-2)
atol, rtol = {dtypes.bfloat16:(1e-3, 1e-2), dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2:(1.0, 5e-1)}.get(dtype_in, (1e-4, 3e-2))
ATOL, RTOL = getenv("ATOL", atol), getenv("RTOL", rtol)
INT_LOW = getenv("INT_LOW", 0)
INT_HIGH = getenv("INT_HIGH", 10)

View File

@@ -75,7 +75,10 @@ def universal_test_unary(a, dtype, op):
out: Tensor = op[0](ta)
tensor_value = out.numpy()
numpy_value = op[1](ta.numpy())
if dtype in dtypes.fp8s: numpy_value = truncate[dtype](numpy_value)
if dtype in dtypes.fp8s:
# cuda cast f32 inf to f8 MAX, amd cast it to nan(E4M3)/inf(E5M2)
if math.isinf(numpy_value): return
numpy_value = truncate[dtype](numpy_value)
if dtype in dtypes.floats:
atol, rtol = { dtypes.float16:(1e-3, 1e-2), dtypes.bfloat16:(1e-3, 2e-2),
dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2: (1.0, 5e-1)}.get(dtype, (1e-6, 1e-5))

View File

@@ -80,6 +80,10 @@ cuda_81616 = [TensorCore(dims=(8,16,16), threads=32, elements_per_thread=(8,4,4)
swizzle=((('r1', 'r2', 'l2', 'l3', 'l4'), ('u1', 'r3'), ('l0', 'l1', 'u0', 'r0')),
(('r1', 'r2', 'u0', 'l0', 'l1'), ('r0', 'r3'), ('l2', 'l3', 'l4', 'u1'))))
for di,do in [(dtypes.half,dtypes.float), (dtypes.bfloat16,dtypes.float), (dtypes.half,dtypes.half)]]
cuda_81632_f8 = [TensorCore(dims=(8,16,32), threads=32, elements_per_thread=(16,8,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
swizzle=((('r2', 'r3', 'l2', 'l3', 'l4'), ('u1', 'r4'), ('l0', 'l1', 'u0', 'r0', 'r1')),
(('r2', 'r3', 'u0', 'l0', 'l1'), ('r1', 'r4'), ('l2', 'l3', 'l4', 'u1', 'r0'))))
for di,do in [(dtypes.fp8e4m3,dtypes.float),(dtypes.fp8e5m2,dtypes.float)]]
cuda_8168_f16 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
swizzle=((('r1', 'r2', 'l2', 'l3', 'l4'), ('r0', 'u1'), ('l0', 'l1', 'u0')),
(('r1', 'r2', 'u0', 'l0', 'l1'), ('u1', 'r0'), ('l2', 'l3', 'l4'))))
@@ -87,9 +91,10 @@ cuda_8168_f16 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,
cuda_8168_tf32 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=dtypes.float, dtype_out=dtypes.float, opts=cuda_tc_opts,
swizzle=((('r0', 'r1', 'l2', 'l3', 'l4'), ('u1', 'r2'), ('l0', 'l1', 'u0')),
(('r0', 'r1', 'u0', 'l0', 'l1'), ('u1', 'r2'), ('l2', 'l3', 'l4'))))]
cuda_sm75: list[TensorCore] = cuda_8168_f16
cuda_sm80: list[TensorCore] = cuda_81616 + cuda_8168_f16
if getenv("ALLOW_TF32", 0): cuda_sm80 += cuda_8168_tf32
cuda_sm75: list[TensorCore] = cuda_8168_f16
cuda_sm89: list[TensorCore] = cuda_sm80 + cuda_81632_f8
# ***** AMD *****

View File

@@ -331,7 +331,9 @@ def is_dtype_supported(dtype:DType, device:str|None=None) -> bool:
if device in {"CUDA", "NV"}: return not CI and not getenv(f"{device}_PTX") and not getenv("NV_NAK")
if device in {"CPU"}: return not CI and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"} and not getenv("CPU_LVP")
return device in {"AMD", "PYTHON", "NULL"}
if dtype in dtypes.fp8s: return device in {"PYTHON", "NULL"}
if dtype in dtypes.fp8s:
if device in {"CUDA", "NV"}: return not CI and not getenv(f"{device}_PTX") and not getenv("NV_NAK")
return device in {"PYTHON", "NULL"}
if device == "WEBGPU": return dtype in [dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short,
dtypes.ushort, dtypes.float, dtypes.int32, dtypes.uint32, dtypes.half]
# for CI GPU and OSX, cl_khr_fp16 isn't supported

View File

@@ -37,7 +37,7 @@ base_rewrite = PatternMatcher([
(UPat(Ops.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{truncate[x.dtype](x.arg)}u"),
(UPat(Ops.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "1" if x.arg else "0"),
# consts are rendered to larger type and casted
(UPat(Ops.CONST, (dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, f'{x.arg}f')})"),
(UPat(Ops.CONST, (*dtypes.fp8s, dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, f'{x.arg}f')})"),
(UPat(Ops.CONST, (dtypes.uint8, dtypes.uint16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, f'{x.arg}u')})"),
(UPat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, str(x.arg))})"),
# default const render
@@ -345,7 +345,8 @@ class CUDARenderer(CStyleLanguage):
shared_max = 49152
def __init__(self, arch:str):
self.tensor_cores, self.arch = tc.cuda_sm80 if int(arch[3:]) >= 80 else tc.cuda_sm75 if int(arch[3:]) >= 75 else [], arch
self.arch = arch
self.tensor_cores = tc.cuda_sm89 if int(arch[3:]) >= 89 else tc.cuda_sm80 if int(arch[3:]) >= 80 else tc.cuda_sm75 if int(arch[3:]) >= 75 else []
def __reduce__(self): return self.__class__, (self.arch,)
# language options
@@ -364,8 +365,14 @@ class CUDARenderer(CStyleLanguage):
Ops.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",
Ops.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
Ops.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"(1/{x})" }
type_map = {dtypes.bfloat16: "nv_bfloat16"}
type_map = {dtypes.bfloat16: "nv_bfloat16", dtypes.fp8e4m3: "__nv_fp8_e4m3", dtypes.fp8e5m2: "__nv_fp8_e5m2"}
extra_matcher = PatternMatcher([
(UPat(Ops.CAST, dtypes.fp8s, UPat.var("x", dtypes.fp8s), name='y'), lambda x,y: x.cast(dtypes.float).cast(y.dtype) if x.dtype!=y.dtype else None),
(UPat(GroupOp.ALU, dtype=dtypes.fp8s, name="x"),
lambda x: UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(x.dtype)),
(UPat(GroupOp.ALU, dtypes.bool, name="alu", src=(UPat.var("x", dtype=dtypes.fp8s), UPat.var("y", dtype=dtypes.fp8s))),
lambda alu,x,y: UOp(alu.op, dtypes.bool, (x.cast(dtypes.float), y.cast(dtypes.float)), alu.arg)),
]) + extra_pm
def render_vector_prefix(self, dt:DType) -> str:
vec, scal = self.render_dtype(dt), self.render_dtype(dt.scalar()),
elems, header = ', '.join(_nms[:dt.count]), ', '.join([f"{scal} {x}" for x in _nms[:dt.count]])
@@ -376,11 +383,12 @@ class CUDARenderer(CStyleLanguage):
prefix = ["#define INFINITY (__int_as_float(0x7f800000))","#define NAN (__int_as_float(0x7fffffff))"]
used_dtypes = uops_to_dtypes(uops)
if any(dt.scalar() in dtypes.fp8s for dt in used_dtypes): prefix.append("#include <cuda_fp8.h>")
if any(dt.scalar() == dtypes.half for dt in used_dtypes): prefix.append("#include <cuda_fp16.h>")
if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("#include <cuda_bf16.h>")
prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count in (4,8) and dt.scalar() in {dtypes.half, dtypes.bfloat16}]
dt_map_in = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16" }
prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if (dt.count in (4,8) and dt.scalar() in {dtypes.half, dtypes.bfloat16})
or (dt.count in (8,16) and dt.scalar() in dtypes.fp8s)]
dt_map_in = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16", dtypes.fp8e4m3: "e4m3", dtypes.fp8e5m2: "e5m2" }
dt_map_out = { dtypes.float: "f32", dtypes.half: "f16" }
for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in wmma_args(uops):
upcast_sizes = [prod(size for _, size in upcast) for upcast in upcast_axes]

View File

@@ -177,6 +177,11 @@ class PythonProgram:
def b_elem(x, col, k, goff): return x[k%2 + (k//8)*2][goff + (k//2)%4 + col*4]
ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
elif dims == (8,16,32):
def a_elem(x, k, row, goff): return x[k%4 + (row//8)*4 + (k//16)*8][goff + (k//4)%4 + (row%8)*4]
def b_elem(x, col, k, goff): return x[k%4 + (k//16)*4][goff + (k//4)%4 + col*4]
ul[i] = wmma_helper(32, 32, 16, 8, 4, a_elem, b_elem, c_map)
elif dims == (8,16,8) and dtype_in == dtypes.half:
def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2][goff + k//2 + (row%8)*4]
def b_elem(x, col, k, goff): return x[k%2][goff + k//2 + col*4]
@@ -220,6 +225,7 @@ class PythonRenderer(Renderer):
case "AMD_RDNA4": self.device, self.tensor_cores = "AMD", tc.amd_rdna4
case "CUDA": self.device, self.tensor_cores = "CUDA", tc.cuda_sm80
case "CUDA_SM75": self.device, self.tensor_cores = "CUDA", tc.cuda_sm75
case "CUDA_SM89": self.device, self.tensor_cores = "CUDA", tc.cuda_sm89
case "INTEL": self.device, self.suffix, self.tensor_cores = "INTEL", "INTEL", tc.intel
case "AMX": self.device, self.tensor_cores = "CPU", tc.amx
case "": pass