mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
cuda fp8 (#12782)
* cuda fp8 * tensor core * tc test * clean * clean pm
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -204,7 +204,7 @@ jobs:
|
||||
DEBUG=2 EMULATE=CUDA FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm_fp16
|
||||
DEBUG=2 EMULATE=CUDA ALLOW_TF32=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm
|
||||
DEBUG=2 EMULATE=CUDA_SM75 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm_fp16
|
||||
DEBUG=2 EMULATE=CUDA ALLOW_TF32=1 FORWARD_ONLY=1 PYTHON=1 python3 test/opt/test_tensor_cores.py
|
||||
DEBUG=2 EMULATE=CUDA_SM89 ALLOW_TF32=1 FORWARD_ONLY=1 PYTHON=1 python3 test/opt/test_tensor_cores.py
|
||||
- name: Test emulated INTEL OpenCL tensor cores
|
||||
run: DEBUG=2 EMULATE=INTEL FORWARD_ONLY=1 PYTHON=1 HALF=1 N=64 python3 ./extra/gemm/simple_matmul.py
|
||||
- name: Test emulated AMX tensor cores
|
||||
|
||||
@@ -5,8 +5,10 @@ from tinygrad.dtype import _to_np_dtype
|
||||
from tinygrad.codegen.opt import OptOps
|
||||
from tinygrad.engine.realize import lower_schedule
|
||||
|
||||
dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float
|
||||
acc_dtype = dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else None
|
||||
dtype_in = (dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else
|
||||
dtypes.fp8e4m3 if getenv("FP8E4M3") else dtypes.fp8e5m2 if getenv("FP8E5M2") else dtypes.float)
|
||||
acc_dtype = (dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else
|
||||
dtypes.fp8e4m3 if getenv("ACC_FP8E4M3") else dtypes.fp8e5m2 if getenv("ACC_FP8E5M2") else None)
|
||||
if getenv("INT"): dtype_in, acc_dtype = dtypes.int8, dtypes.int32
|
||||
if getenv("UINT"): dtype_in, acc_dtype = dtypes.uint8, dtypes.int32
|
||||
|
||||
@@ -14,8 +16,10 @@ N = getenv("N", 4096)
|
||||
M = getenv("M", N)
|
||||
K = getenv("K", N)
|
||||
CNT = getenv("CNT", 10)
|
||||
ATOL = getenv("ATOL", 1e-4)
|
||||
RTOL = getenv("RTOL", 3e-2)
|
||||
|
||||
atol, rtol = {dtypes.bfloat16:(1e-3, 1e-2), dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2:(1.0, 5e-1)}.get(dtype_in, (1e-4, 3e-2))
|
||||
ATOL, RTOL = getenv("ATOL", atol), getenv("RTOL", rtol)
|
||||
|
||||
INT_LOW = getenv("INT_LOW", 0)
|
||||
INT_HIGH = getenv("INT_HIGH", 10)
|
||||
|
||||
|
||||
@@ -75,7 +75,10 @@ def universal_test_unary(a, dtype, op):
|
||||
out: Tensor = op[0](ta)
|
||||
tensor_value = out.numpy()
|
||||
numpy_value = op[1](ta.numpy())
|
||||
if dtype in dtypes.fp8s: numpy_value = truncate[dtype](numpy_value)
|
||||
if dtype in dtypes.fp8s:
|
||||
# cuda cast f32 inf to f8 MAX, amd cast it to nan(E4M3)/inf(E5M2)
|
||||
if math.isinf(numpy_value): return
|
||||
numpy_value = truncate[dtype](numpy_value)
|
||||
if dtype in dtypes.floats:
|
||||
atol, rtol = { dtypes.float16:(1e-3, 1e-2), dtypes.bfloat16:(1e-3, 2e-2),
|
||||
dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2: (1.0, 5e-1)}.get(dtype, (1e-6, 1e-5))
|
||||
|
||||
@@ -80,6 +80,10 @@ cuda_81616 = [TensorCore(dims=(8,16,16), threads=32, elements_per_thread=(8,4,4)
|
||||
swizzle=((('r1', 'r2', 'l2', 'l3', 'l4'), ('u1', 'r3'), ('l0', 'l1', 'u0', 'r0')),
|
||||
(('r1', 'r2', 'u0', 'l0', 'l1'), ('r0', 'r3'), ('l2', 'l3', 'l4', 'u1'))))
|
||||
for di,do in [(dtypes.half,dtypes.float), (dtypes.bfloat16,dtypes.float), (dtypes.half,dtypes.half)]]
|
||||
cuda_81632_f8 = [TensorCore(dims=(8,16,32), threads=32, elements_per_thread=(16,8,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
|
||||
swizzle=((('r2', 'r3', 'l2', 'l3', 'l4'), ('u1', 'r4'), ('l0', 'l1', 'u0', 'r0', 'r1')),
|
||||
(('r2', 'r3', 'u0', 'l0', 'l1'), ('r1', 'r4'), ('l2', 'l3', 'l4', 'u1', 'r0'))))
|
||||
for di,do in [(dtypes.fp8e4m3,dtypes.float),(dtypes.fp8e5m2,dtypes.float)]]
|
||||
cuda_8168_f16 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
|
||||
swizzle=((('r1', 'r2', 'l2', 'l3', 'l4'), ('r0', 'u1'), ('l0', 'l1', 'u0')),
|
||||
(('r1', 'r2', 'u0', 'l0', 'l1'), ('u1', 'r0'), ('l2', 'l3', 'l4'))))
|
||||
@@ -87,9 +91,10 @@ cuda_8168_f16 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,
|
||||
cuda_8168_tf32 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=dtypes.float, dtype_out=dtypes.float, opts=cuda_tc_opts,
|
||||
swizzle=((('r0', 'r1', 'l2', 'l3', 'l4'), ('u1', 'r2'), ('l0', 'l1', 'u0')),
|
||||
(('r0', 'r1', 'u0', 'l0', 'l1'), ('u1', 'r2'), ('l2', 'l3', 'l4'))))]
|
||||
cuda_sm75: list[TensorCore] = cuda_8168_f16
|
||||
cuda_sm80: list[TensorCore] = cuda_81616 + cuda_8168_f16
|
||||
if getenv("ALLOW_TF32", 0): cuda_sm80 += cuda_8168_tf32
|
||||
cuda_sm75: list[TensorCore] = cuda_8168_f16
|
||||
cuda_sm89: list[TensorCore] = cuda_sm80 + cuda_81632_f8
|
||||
|
||||
# ***** AMD *****
|
||||
|
||||
|
||||
@@ -331,7 +331,9 @@ def is_dtype_supported(dtype:DType, device:str|None=None) -> bool:
|
||||
if device in {"CUDA", "NV"}: return not CI and not getenv(f"{device}_PTX") and not getenv("NV_NAK")
|
||||
if device in {"CPU"}: return not CI and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"} and not getenv("CPU_LVP")
|
||||
return device in {"AMD", "PYTHON", "NULL"}
|
||||
if dtype in dtypes.fp8s: return device in {"PYTHON", "NULL"}
|
||||
if dtype in dtypes.fp8s:
|
||||
if device in {"CUDA", "NV"}: return not CI and not getenv(f"{device}_PTX") and not getenv("NV_NAK")
|
||||
return device in {"PYTHON", "NULL"}
|
||||
if device == "WEBGPU": return dtype in [dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short,
|
||||
dtypes.ushort, dtypes.float, dtypes.int32, dtypes.uint32, dtypes.half]
|
||||
# for CI GPU and OSX, cl_khr_fp16 isn't supported
|
||||
|
||||
@@ -37,7 +37,7 @@ base_rewrite = PatternMatcher([
|
||||
(UPat(Ops.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{truncate[x.dtype](x.arg)}u"),
|
||||
(UPat(Ops.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "1" if x.arg else "0"),
|
||||
# consts are rendered to larger type and casted
|
||||
(UPat(Ops.CONST, (dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, f'{x.arg}f')})"),
|
||||
(UPat(Ops.CONST, (*dtypes.fp8s, dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, f'{x.arg}f')})"),
|
||||
(UPat(Ops.CONST, (dtypes.uint8, dtypes.uint16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, f'{x.arg}u')})"),
|
||||
(UPat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, str(x.arg))})"),
|
||||
# default const render
|
||||
@@ -345,7 +345,8 @@ class CUDARenderer(CStyleLanguage):
|
||||
shared_max = 49152
|
||||
|
||||
def __init__(self, arch:str):
|
||||
self.tensor_cores, self.arch = tc.cuda_sm80 if int(arch[3:]) >= 80 else tc.cuda_sm75 if int(arch[3:]) >= 75 else [], arch
|
||||
self.arch = arch
|
||||
self.tensor_cores = tc.cuda_sm89 if int(arch[3:]) >= 89 else tc.cuda_sm80 if int(arch[3:]) >= 80 else tc.cuda_sm75 if int(arch[3:]) >= 75 else []
|
||||
def __reduce__(self): return self.__class__, (self.arch,)
|
||||
|
||||
# language options
|
||||
@@ -364,8 +365,14 @@ class CUDARenderer(CStyleLanguage):
|
||||
Ops.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",
|
||||
Ops.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
|
||||
Ops.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"(1/{x})" }
|
||||
type_map = {dtypes.bfloat16: "nv_bfloat16"}
|
||||
|
||||
type_map = {dtypes.bfloat16: "nv_bfloat16", dtypes.fp8e4m3: "__nv_fp8_e4m3", dtypes.fp8e5m2: "__nv_fp8_e5m2"}
|
||||
extra_matcher = PatternMatcher([
|
||||
(UPat(Ops.CAST, dtypes.fp8s, UPat.var("x", dtypes.fp8s), name='y'), lambda x,y: x.cast(dtypes.float).cast(y.dtype) if x.dtype!=y.dtype else None),
|
||||
(UPat(GroupOp.ALU, dtype=dtypes.fp8s, name="x"),
|
||||
lambda x: UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(x.dtype)),
|
||||
(UPat(GroupOp.ALU, dtypes.bool, name="alu", src=(UPat.var("x", dtype=dtypes.fp8s), UPat.var("y", dtype=dtypes.fp8s))),
|
||||
lambda alu,x,y: UOp(alu.op, dtypes.bool, (x.cast(dtypes.float), y.cast(dtypes.float)), alu.arg)),
|
||||
]) + extra_pm
|
||||
def render_vector_prefix(self, dt:DType) -> str:
|
||||
vec, scal = self.render_dtype(dt), self.render_dtype(dt.scalar()),
|
||||
elems, header = ', '.join(_nms[:dt.count]), ', '.join([f"{scal} {x}" for x in _nms[:dt.count]])
|
||||
@@ -376,11 +383,12 @@ class CUDARenderer(CStyleLanguage):
|
||||
prefix = ["#define INFINITY (__int_as_float(0x7f800000))","#define NAN (__int_as_float(0x7fffffff))"]
|
||||
|
||||
used_dtypes = uops_to_dtypes(uops)
|
||||
if any(dt.scalar() in dtypes.fp8s for dt in used_dtypes): prefix.append("#include <cuda_fp8.h>")
|
||||
if any(dt.scalar() == dtypes.half for dt in used_dtypes): prefix.append("#include <cuda_fp16.h>")
|
||||
if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("#include <cuda_bf16.h>")
|
||||
prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count in (4,8) and dt.scalar() in {dtypes.half, dtypes.bfloat16}]
|
||||
|
||||
dt_map_in = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16" }
|
||||
prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if (dt.count in (4,8) and dt.scalar() in {dtypes.half, dtypes.bfloat16})
|
||||
or (dt.count in (8,16) and dt.scalar() in dtypes.fp8s)]
|
||||
dt_map_in = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16", dtypes.fp8e4m3: "e4m3", dtypes.fp8e5m2: "e5m2" }
|
||||
dt_map_out = { dtypes.float: "f32", dtypes.half: "f16" }
|
||||
for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in wmma_args(uops):
|
||||
upcast_sizes = [prod(size for _, size in upcast) for upcast in upcast_axes]
|
||||
|
||||
@@ -177,6 +177,11 @@ class PythonProgram:
|
||||
def b_elem(x, col, k, goff): return x[k%2 + (k//8)*2][goff + (k//2)%4 + col*4]
|
||||
ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
|
||||
|
||||
elif dims == (8,16,32):
|
||||
def a_elem(x, k, row, goff): return x[k%4 + (row//8)*4 + (k//16)*8][goff + (k//4)%4 + (row%8)*4]
|
||||
def b_elem(x, col, k, goff): return x[k%4 + (k//16)*4][goff + (k//4)%4 + col*4]
|
||||
ul[i] = wmma_helper(32, 32, 16, 8, 4, a_elem, b_elem, c_map)
|
||||
|
||||
elif dims == (8,16,8) and dtype_in == dtypes.half:
|
||||
def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2][goff + k//2 + (row%8)*4]
|
||||
def b_elem(x, col, k, goff): return x[k%2][goff + k//2 + col*4]
|
||||
@@ -220,6 +225,7 @@ class PythonRenderer(Renderer):
|
||||
case "AMD_RDNA4": self.device, self.tensor_cores = "AMD", tc.amd_rdna4
|
||||
case "CUDA": self.device, self.tensor_cores = "CUDA", tc.cuda_sm80
|
||||
case "CUDA_SM75": self.device, self.tensor_cores = "CUDA", tc.cuda_sm75
|
||||
case "CUDA_SM89": self.device, self.tensor_cores = "CUDA", tc.cuda_sm89
|
||||
case "INTEL": self.device, self.suffix, self.tensor_cores = "INTEL", "INTEL", tc.intel
|
||||
case "AMX": self.device, self.tensor_cores = "CPU", tc.amx
|
||||
case "": pass
|
||||
|
||||
Reference in New Issue
Block a user