amd fp8 llvm (#13186)

* amd fp8 llvm support

* fix max

* clean

* add test_mi350.sh

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
b1tg
2025-11-21 01:35:57 +08:00
committed by GitHub
parent 1058748440
commit 91e289cb14
3 changed files with 47 additions and 23 deletions

12
extra/test_mi350.sh Executable file
View File

@@ -0,0 +1,12 @@
#!/bin/bash
AMD=1 AMD_LLVM=1 python -m pytest -n=1 test/test_ops.py test/test_dtype.py test/test_dtype_alu.py test/test_linearizer.py test/test_randomness.py test/test_jit.py test/test_graph.py test/test_multitensor.py --durations=20
AMD=1 AMD_LLVM=0 python -m pytest -n=1 test/test_ops.py test/test_dtype.py test/test_dtype_alu.py test/test_linearizer.py test/test_randomness.py test/test_jit.py test/test_graph.py test/test_multitensor.py --durations=20
CNT=1 AMD_LLVM=0 DEBUG=2 FP8E4M3=1 HALF=0 BFLOAT16=0 SHOULD_USE_TC=1 python extra/gemm/simple_matmul.py
CNT=1 AMD_LLVM=0 DEBUG=2 FP8E4M3=0 HALF=1 BFLOAT16=0 SHOULD_USE_TC=1 python extra/gemm/simple_matmul.py
CNT=1 AMD_LLVM=0 DEBUG=2 FP8E4M3=0 HALF=0 BFLOAT16=1 SHOULD_USE_TC=1 python extra/gemm/simple_matmul.py
CNT=1 AMD_LLVM=1 DEBUG=2 FP8E4M3=0 HALF=1 BFLOAT16=0 SHOULD_USE_TC=1 python extra/gemm/simple_matmul.py
CNT=1 AMD_LLVM=1 DEBUG=2 FP8E4M3=0 HALF=0 BFLOAT16=1 SHOULD_USE_TC=1 python extra/gemm/simple_matmul.py
CNT=1 AMD_LLVM=1 DEBUG=2 FP8E4M3=1 HALF=0 BFLOAT16=0 SHOULD_USE_TC=1 python extra/gemm/simple_matmul.py

View File

@@ -5,7 +5,7 @@ from typing import Any, Generic, TypeVar, Iterator, Sequence, cast, Generator
import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re, atexit, pickle, decimal
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored, CPU_LLVM
from tinygrad.helpers import Context, CCACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, dedup
from tinygrad.helpers import unwrap_class_type, suppress_finalizing, AMD_LLVM, select_first_inited, VIZ
from tinygrad.helpers import unwrap_class_type, suppress_finalizing, select_first_inited, VIZ
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype
from tinygrad.renderer import Renderer
@@ -329,7 +329,7 @@ def is_dtype_supported(dtype:DType, device:str|None=None) -> bool:
return device in {"AMD", "PYTHON", "NULL"}
if dtype in dtypes.fp8s:
if device in {"CUDA", "NV"}: return not CI and not getenv(f"{device}_PTX") and not getenv("NV_NAK")
if device == "AMD": return not CI and not AMD_LLVM and getattr(Device["AMD"], "target") in {(9,4,2), (9,5,0)}
if device == "AMD": return not CI and getattr(Device["AMD"], "target") in {(9,4,2), (9,5,0)}
return device in {"PYTHON", "NULL"}
if device == "WEBGPU": return dtype in [dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short,
dtypes.ushort, dtypes.float, dtypes.int32, dtypes.uint32, dtypes.half]

View File

@@ -2,21 +2,22 @@ from typing import cast
import math, struct, sys
from tinygrad.codegen.opt import tc
from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import AMDRenderer
from tinygrad.renderer.cstyle import AMDRenderer, create_non_native_float_pats
from tinygrad.uop.decompositions import xexp2, xlog2
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, GroupOp, range_str
from tinygrad.dtype import dtypes, DType, PtrDType, truncate
from tinygrad.dtype import dtypes, float_to_fp8, DType, PtrDType, truncate
from tinygrad.helpers import prod, AMX
def ldt(dt:DType):
if dt.vcount > 1: return f"<{dt.vcount} x {ldt(dt.scalar())}>"
if isinstance(dt, PtrDType): return ldt(dt.base) + "*"
return {dtypes.void: "void", dtypes.bool: "i1", dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64",
dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64",
dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64", dtypes.fp8e4m3: "i8", dtypes.fp8e5m2: "i8",
dtypes.float16: "half", dtypes.bfloat16: "bfloat", dtypes.float32: "float", dtypes.float64: "double"}[dt]
def lconst(x, dtype:DType):
if dtype in dtypes.floats:
if dtype in dtypes.fp8s: return float_to_fp8(x, dtype)
if math.isinf(x) or math.isnan(x): return "0x%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1])
return truncate[dtype](x)
return int(x)
@@ -47,13 +48,14 @@ def render_wmma_amx(ctx, wmma: UOp) -> str:
f' {ctx[wmma]} = load {ldt(wmma.dtype)}, ptr {ctx[wmma]}_amx2, align {wmma.dtype.itemsize}'])
def render_wmma_amd(ctx, wmma: UOp, cdna=False) -> str:
dt_map = {dtypes.half: "f16", dtypes.float: "f32", dtypes.ushort: "bf16.1k" if cdna else "bf16", dtypes.bfloat16: "bf16.1k" if cdna else "bf16"}
dt_map = {dtypes.half: "f16", dtypes.float: "f32", dtypes.ushort: "bf16.1k" if cdna else "bf16", dtypes.bfloat16: "bf16.1k" if cdna else "bf16",
dtypes.fp8e4m3: ".fp8.fp8", dtypes.fp8e5m2: ".bf8.bf8"}
# https://github.com/llvm/llvm-project/blob/main/clang/test/CodeGenOpenCL/builtins-amdgcn-mfma.cl
N,M,K = wmma.arg[1]
if cdna:
if K == 32: dt_map.update({dtypes.half: ".f16", dtypes.bfloat16: ".bf16"})
return f" {ctx[wmma]} = call {ldt(wmma.dtype)} @llvm.amdgcn.mfma.{dt_map[wmma.src[-1].dtype.scalar()]}" + \
f".{N}x{M}x{K}{dt_map[wmma.src[0].dtype.scalar()]}(" + ", ".join([f"{ldt(w.dtype)} {ctx[w]}" for w in wmma.src]) + ", i32 0, i32 0, i32 0)"
f".{N}x{M}x{K}{dt_map[wmma.arg[2]]}(" + ", ".join([f"{ldt(w.dtype)} {ctx[w]}" for w in wmma.src]) + ", i32 0, i32 0, i32 0)"
# https://github.com/llvm/llvm-project/blob/main/llvm/test/CodeGen/AMDGPU/GlobalISel/llvm.amdgcn.wmma_32.ll
# example: %wmma0 = call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16(<16 x half> %v99,<16 x half> %v100,<8 x float> %v101)
return f" {ctx[wmma]} = call {ldt(wmma.dtype)} @llvm.amdgcn.wmma.{dt_map[wmma.src[-1].dtype.scalar()]}.16x16x16." + \
@@ -136,27 +138,16 @@ class LLVMRenderer(Renderer):
has_local = False
global_max: tuple[int, ...] | None = None
string_rewrite = base_rewrite + PatternMatcher([(UPat(Ops.WMMA, name="wmma"), render_wmma_amx)])
code_for_op = {Ops.FDIV: lambda: None}
code_for_op = {Ops.FDIV: lambda: None, Ops.CMPLT: lambda: None}
if AMX: tensor_cores = tc.amx
extra_matcher = PatternMatcher([
# rewrite MAX to CMPLT + WHERE
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
# copied from cstyle.py, upcast to float32 all the ops that don't support bfloat16
(UPat((Ops.SQRT, Ops.EXP2, Ops.LOG2, Ops.SIN), dtype=dtypes.bfloat16, name="x"),
lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16))),
# copied from cstyle.py, add float intermediate casting
(UPat(Ops.CAST, name="x", src=UPat.var("y", dtypes.bfloat16)),lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None),
(UPat(Ops.CAST, dtypes.bfloat16, UPat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None),
])
extra_matcher = create_non_native_float_pats((dtypes.bfloat16,))
def render(self, uops: list[UOp]) -> str: return "\n".join((k:=self._render_kernel(uops))[0] + (k[1], self._render_footer(uops)))
def _render_footer(self, uops: list[UOp]) -> str: return 'attributes #0 = { alwaysinline nounwind "no-builtins" "no-trapping-math"="true" }'
def _render_fn(self, name:str, args:list[tuple[str,DType]], kernel:list[str], prefix:list[str]|None=None) -> str:
# NOTE: CPUAllocator promises 0x20 alignment
sargs = ", ".join([f"{ldt(dt)}{' noalias align 32' if isinstance(dt, PtrDType) else ''} {name}" for name,dt in args])
sprefix = "".join([f" {x}" for x in (prefix or []) + [self.abi] if x is not None])
return "\n".join([f"define{sprefix} void @{name}({sargs}) #0", "{"] + kernel + [" ret void\n}"])
return "\n".join((prefix or []) + [f"define{' ' + self.abi if self.abi else ''} void @{name}({sargs}) #0", "{"] + kernel + [" ret void\n}"])
def _render_kernel(self, uops: list[UOp], prefix:list[str]|None=None) -> tuple[tuple[str, ...], str]:
r: dict[UOp, str] = {}
args: list[tuple[str, DType]] = []
@@ -226,8 +217,13 @@ class AMDLLVMRenderer(LLVMRenderer):
(UPat(tuple(llvm_intrinsics), name="x"),
lambda ctx, x: f" {ctx[x]} = call {ldt(x.dtype)} @llvm.{llvm_intrinsics[x.op]}.{ldt(x.dtype.scalar())}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"),
(UPat(Ops.BARRIER), lambda ctx: barrier),
(UPat(Ops.CAST, dtypes.fp8s, (UPat.var("y", dtypes.float),), name="x",), lambda ctx,x,y:
f" {ctx[x]} = call i8 @f32_to_fp8({ldt(x.src[0].dtype)} {ctx[x.src[0]]}, i1 {'1' if x.dtype == dtypes.fp8e5m2 else '0'})"),
(UPat(Ops.CAST, dtypes.float, (UPat.var("y", dtypes.fp8s),), name="x",), lambda ctx,x,y:
f" {ctx[x.src[0]]}_i32 = zext i8 {ctx[x.src[0]]} to i32\n"
f" {ctx[x]} = call float @llvm.amdgcn.cvt.f32.{'bf8' if y.dtype == dtypes.fp8e5m2 else 'fp8'}(i32 {ctx[x.src[0]]}_i32, i32 0)"),
]) + base_rewrite
extra_matcher = LLVMRenderer.extra_matcher + PatternMatcher([
extra_matcher = LLVMRenderer.extra_matcher + create_non_native_float_pats(dtypes.fp8s) + PatternMatcher([
(UPat(Ops.CAST, name="x", dtype=dtypes.half.vec(16), src=UPat.var("y", dtypes.half.vec(8))),
lambda x, y: UOp(Ops.VECTORIZE, dtypes.half.vec(16), tuple(y.gep(i // 2) if i % 2 == 0 else UOp.const(dtypes.half, 0.0) for i in range(16)))),
(UPat(Ops.CAST, name="x", dtype=dtypes.half.vec(8), src=UPat.var("y", dtypes.half.vec(16))),
@@ -236,6 +232,19 @@ class AMDLLVMRenderer(LLVMRenderer):
(UPat(Ops.LOG2, dtype=dtypes.double, src=(UPat.var("d"),)), xlog2),
(UPat(Ops.EXP2, dtype=dtypes.double, src=(UPat.var("d"),)), xexp2),
])
def render(self, uops: list[UOp]) -> str:
prefix = ["""define i8 @f32_to_fp8(float %val, i1 %is_bf8) {
entry: %ival = bitcast float %val to i32\n %exp = and i32 %ival, 2139095040\n %is_special = icmp eq i32 %exp, 2139095040
br i1 %is_special, label %select_clip, label %clip
clip: br i1 %is_bf8, label %bf8_clip, label %fp8_clip
bf8_clip: %clamped_bf8 = call float @llvm.amdgcn.fmed3.f32(float %val, float 57344.0, float -57344.0)\n br label %select_clip
fp8_clip: %clamped_fp8 = call float @llvm.amdgcn.fmed3.f32(float %val, float 448.0, float -448.0) \n br label %select_clip
select_clip: %phi_val = phi float [%val, %entry], [%clamped_bf8, %bf8_clip], [%clamped_fp8, %fp8_clip]\n br i1 %is_bf8, label %do_bf8, label %do_fp8
do_bf8: %packed_bf8 = call i32 @llvm.amdgcn.cvt.pk.bf8.f32(float %phi_val, float %phi_val, i32 0, i1 false)\n br label %exit
do_fp8: %packed_fp8 = call i32 @llvm.amdgcn.cvt.pk.fp8.f32(float %phi_val, float %phi_val, i32 0, i1 false)\n br label %exit
exit: %packed = phi i32 [%packed_bf8, %do_bf8], [%packed_fp8, %do_fp8]\n %trunc = trunc i32 %packed to i8\n ret i8 %trunc
}""".replace(": ", ":\n ")] if any(u.dtype in dtypes.fp8s for u in uops) else []
return "\n".join((k:=self._render_kernel(uops, prefix))[0] + (k[1], self._render_footer(uops)))
def _render_footer(self, uops: list[UOp]) -> str:
# TODO: this is copied from cstyle
local_dims = [u.src[0] for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"]
@@ -252,7 +261,10 @@ class AMDLLVMRenderer(LLVMRenderer):
self.extra_matcher += PatternMatcher([
(UPat(Ops.WMMA, name="x", dtype=dtypes.float.vec(4)),
lambda x: UOp(Ops.WMMA, dtypes.float.vec(4), (x.src[0].bitcast(dtypes.uint16.vec(4)), x.src[1].bitcast(dtypes.uint16.vec(4)),
x.src[2]), (*x.arg,)) if x.src[0].dtype == dtypes.bfloat16.vec(4) else None)
x.src[2]), (*x.arg,)) if x.src[0].dtype == dtypes.bfloat16.vec(4) else None),
(UPat(Ops.WMMA, name="x", dtype=dtypes.float.vec(4)),
lambda x: UOp(Ops.WMMA, dtypes.float.vec(4), (x.src[0].bitcast(dtypes.uint64), x.src[1].bitcast(dtypes.uint64),
x.src[2]), (*x.arg,)) if x.src[0].dtype in (dtypes.fp8e4m3.vec(8), dtypes.fp8e5m2.vec(8)) else None),
])
if self.arch.split(":")[0] == "gfx1100":
self.extra_matcher += PatternMatcher([