From 91e289cb1457dccac3e5fbc0583d0d720d213470 Mon Sep 17 00:00:00 2001 From: b1tg <33436708+b1tg@users.noreply.github.com> Date: Fri, 21 Nov 2025 01:35:57 +0800 Subject: [PATCH] amd fp8 llvm (#13186) * amd fp8 llvm support * fix max * clean * add test_mi350.sh --------- Co-authored-by: chenyu --- extra/test_mi350.sh | 12 +++++++++ tinygrad/device.py | 4 +-- tinygrad/renderer/llvmir.py | 54 ++++++++++++++++++++++--------------- 3 files changed, 47 insertions(+), 23 deletions(-) create mode 100755 extra/test_mi350.sh diff --git a/extra/test_mi350.sh b/extra/test_mi350.sh new file mode 100755 index 0000000000..de2068419d --- /dev/null +++ b/extra/test_mi350.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +AMD=1 AMD_LLVM=1 python -m pytest -n=1 test/test_ops.py test/test_dtype.py test/test_dtype_alu.py test/test_linearizer.py test/test_randomness.py test/test_jit.py test/test_graph.py test/test_multitensor.py --durations=20 +AMD=1 AMD_LLVM=0 python -m pytest -n=1 test/test_ops.py test/test_dtype.py test/test_dtype_alu.py test/test_linearizer.py test/test_randomness.py test/test_jit.py test/test_graph.py test/test_multitensor.py --durations=20 + +CNT=1 AMD_LLVM=0 DEBUG=2 FP8E4M3=1 HALF=0 BFLOAT16=0 SHOULD_USE_TC=1 python extra/gemm/simple_matmul.py +CNT=1 AMD_LLVM=0 DEBUG=2 FP8E4M3=0 HALF=1 BFLOAT16=0 SHOULD_USE_TC=1 python extra/gemm/simple_matmul.py +CNT=1 AMD_LLVM=0 DEBUG=2 FP8E4M3=0 HALF=0 BFLOAT16=1 SHOULD_USE_TC=1 python extra/gemm/simple_matmul.py + +CNT=1 AMD_LLVM=1 DEBUG=2 FP8E4M3=0 HALF=1 BFLOAT16=0 SHOULD_USE_TC=1 python extra/gemm/simple_matmul.py +CNT=1 AMD_LLVM=1 DEBUG=2 FP8E4M3=0 HALF=0 BFLOAT16=1 SHOULD_USE_TC=1 python extra/gemm/simple_matmul.py +CNT=1 AMD_LLVM=1 DEBUG=2 FP8E4M3=1 HALF=0 BFLOAT16=0 SHOULD_USE_TC=1 python extra/gemm/simple_matmul.py \ No newline at end of file diff --git a/tinygrad/device.py b/tinygrad/device.py index b2b8e1fbc4..00a61d8db6 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -5,7 +5,7 @@ from typing import Any, Generic, TypeVar, Iterator, Sequence, cast, Generator import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re, atexit, pickle, decimal from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored, CPU_LLVM from tinygrad.helpers import Context, CCACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, dedup -from tinygrad.helpers import unwrap_class_type, suppress_finalizing, AMD_LLVM, select_first_inited, VIZ +from tinygrad.helpers import unwrap_class_type, suppress_finalizing, select_first_inited, VIZ from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype from tinygrad.renderer import Renderer @@ -329,7 +329,7 @@ def is_dtype_supported(dtype:DType, device:str|None=None) -> bool: return device in {"AMD", "PYTHON", "NULL"} if dtype in dtypes.fp8s: if device in {"CUDA", "NV"}: return not CI and not getenv(f"{device}_PTX") and not getenv("NV_NAK") - if device == "AMD": return not CI and not AMD_LLVM and getattr(Device["AMD"], "target") in {(9,4,2), (9,5,0)} + if device == "AMD": return not CI and getattr(Device["AMD"], "target") in {(9,4,2), (9,5,0)} return device in {"PYTHON", "NULL"} if device == "WEBGPU": return dtype in [dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short, dtypes.ushort, dtypes.float, dtypes.int32, dtypes.uint32, dtypes.half] diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 7fd3dd207e..d468a3055f 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -2,21 +2,22 @@ from typing import cast import math, struct, sys from tinygrad.codegen.opt import tc from tinygrad.renderer import Renderer -from tinygrad.renderer.cstyle import AMDRenderer +from tinygrad.renderer.cstyle import AMDRenderer, create_non_native_float_pats from tinygrad.uop.decompositions import xexp2, xlog2 from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, GroupOp, range_str -from tinygrad.dtype import dtypes, DType, PtrDType, truncate +from tinygrad.dtype import dtypes, float_to_fp8, DType, PtrDType, truncate from tinygrad.helpers import prod, AMX def ldt(dt:DType): if dt.vcount > 1: return f"<{dt.vcount} x {ldt(dt.scalar())}>" if isinstance(dt, PtrDType): return ldt(dt.base) + "*" return {dtypes.void: "void", dtypes.bool: "i1", dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64", - dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64", + dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64", dtypes.fp8e4m3: "i8", dtypes.fp8e5m2: "i8", dtypes.float16: "half", dtypes.bfloat16: "bfloat", dtypes.float32: "float", dtypes.float64: "double"}[dt] def lconst(x, dtype:DType): if dtype in dtypes.floats: + if dtype in dtypes.fp8s: return float_to_fp8(x, dtype) if math.isinf(x) or math.isnan(x): return "0x%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1]) return truncate[dtype](x) return int(x) @@ -47,13 +48,14 @@ def render_wmma_amx(ctx, wmma: UOp) -> str: f' {ctx[wmma]} = load {ldt(wmma.dtype)}, ptr {ctx[wmma]}_amx2, align {wmma.dtype.itemsize}']) def render_wmma_amd(ctx, wmma: UOp, cdna=False) -> str: - dt_map = {dtypes.half: "f16", dtypes.float: "f32", dtypes.ushort: "bf16.1k" if cdna else "bf16", dtypes.bfloat16: "bf16.1k" if cdna else "bf16"} + dt_map = {dtypes.half: "f16", dtypes.float: "f32", dtypes.ushort: "bf16.1k" if cdna else "bf16", dtypes.bfloat16: "bf16.1k" if cdna else "bf16", + dtypes.fp8e4m3: ".fp8.fp8", dtypes.fp8e5m2: ".bf8.bf8"} # https://github.com/llvm/llvm-project/blob/main/clang/test/CodeGenOpenCL/builtins-amdgcn-mfma.cl N,M,K = wmma.arg[1] if cdna: if K == 32: dt_map.update({dtypes.half: ".f16", dtypes.bfloat16: ".bf16"}) return f" {ctx[wmma]} = call {ldt(wmma.dtype)} @llvm.amdgcn.mfma.{dt_map[wmma.src[-1].dtype.scalar()]}" + \ - f".{N}x{M}x{K}{dt_map[wmma.src[0].dtype.scalar()]}(" + ", ".join([f"{ldt(w.dtype)} {ctx[w]}" for w in wmma.src]) + ", i32 0, i32 0, i32 0)" + f".{N}x{M}x{K}{dt_map[wmma.arg[2]]}(" + ", ".join([f"{ldt(w.dtype)} {ctx[w]}" for w in wmma.src]) + ", i32 0, i32 0, i32 0)" # https://github.com/llvm/llvm-project/blob/main/llvm/test/CodeGen/AMDGPU/GlobalISel/llvm.amdgcn.wmma_32.ll # example: %wmma0 = call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16(<16 x half> %v99,<16 x half> %v100,<8 x float> %v101) return f" {ctx[wmma]} = call {ldt(wmma.dtype)} @llvm.amdgcn.wmma.{dt_map[wmma.src[-1].dtype.scalar()]}.16x16x16." + \ @@ -136,27 +138,16 @@ class LLVMRenderer(Renderer): has_local = False global_max: tuple[int, ...] | None = None string_rewrite = base_rewrite + PatternMatcher([(UPat(Ops.WMMA, name="wmma"), render_wmma_amx)]) - code_for_op = {Ops.FDIV: lambda: None} + code_for_op = {Ops.FDIV: lambda: None, Ops.CMPLT: lambda: None} if AMX: tensor_cores = tc.amx - extra_matcher = PatternMatcher([ - # rewrite MAX to CMPLT + WHERE - (UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])), - # copied from cstyle.py, upcast to float32 all the ops that don't support bfloat16 - (UPat((Ops.SQRT, Ops.EXP2, Ops.LOG2, Ops.SIN), dtype=dtypes.bfloat16, name="x"), - lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16))), - # copied from cstyle.py, add float intermediate casting - (UPat(Ops.CAST, name="x", src=UPat.var("y", dtypes.bfloat16)),lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None), - (UPat(Ops.CAST, dtypes.bfloat16, UPat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None), - ]) - + extra_matcher = create_non_native_float_pats((dtypes.bfloat16,)) def render(self, uops: list[UOp]) -> str: return "\n".join((k:=self._render_kernel(uops))[0] + (k[1], self._render_footer(uops))) def _render_footer(self, uops: list[UOp]) -> str: return 'attributes #0 = { alwaysinline nounwind "no-builtins" "no-trapping-math"="true" }' def _render_fn(self, name:str, args:list[tuple[str,DType]], kernel:list[str], prefix:list[str]|None=None) -> str: # NOTE: CPUAllocator promises 0x20 alignment sargs = ", ".join([f"{ldt(dt)}{' noalias align 32' if isinstance(dt, PtrDType) else ''} {name}" for name,dt in args]) - sprefix = "".join([f" {x}" for x in (prefix or []) + [self.abi] if x is not None]) - return "\n".join([f"define{sprefix} void @{name}({sargs}) #0", "{"] + kernel + [" ret void\n}"]) + return "\n".join((prefix or []) + [f"define{' ' + self.abi if self.abi else ''} void @{name}({sargs}) #0", "{"] + kernel + [" ret void\n}"]) def _render_kernel(self, uops: list[UOp], prefix:list[str]|None=None) -> tuple[tuple[str, ...], str]: r: dict[UOp, str] = {} args: list[tuple[str, DType]] = [] @@ -226,8 +217,13 @@ class AMDLLVMRenderer(LLVMRenderer): (UPat(tuple(llvm_intrinsics), name="x"), lambda ctx, x: f" {ctx[x]} = call {ldt(x.dtype)} @llvm.{llvm_intrinsics[x.op]}.{ldt(x.dtype.scalar())}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"), (UPat(Ops.BARRIER), lambda ctx: barrier), + (UPat(Ops.CAST, dtypes.fp8s, (UPat.var("y", dtypes.float),), name="x",), lambda ctx,x,y: + f" {ctx[x]} = call i8 @f32_to_fp8({ldt(x.src[0].dtype)} {ctx[x.src[0]]}, i1 {'1' if x.dtype == dtypes.fp8e5m2 else '0'})"), + (UPat(Ops.CAST, dtypes.float, (UPat.var("y", dtypes.fp8s),), name="x",), lambda ctx,x,y: + f" {ctx[x.src[0]]}_i32 = zext i8 {ctx[x.src[0]]} to i32\n" + f" {ctx[x]} = call float @llvm.amdgcn.cvt.f32.{'bf8' if y.dtype == dtypes.fp8e5m2 else 'fp8'}(i32 {ctx[x.src[0]]}_i32, i32 0)"), ]) + base_rewrite - extra_matcher = LLVMRenderer.extra_matcher + PatternMatcher([ + extra_matcher = LLVMRenderer.extra_matcher + create_non_native_float_pats(dtypes.fp8s) + PatternMatcher([ (UPat(Ops.CAST, name="x", dtype=dtypes.half.vec(16), src=UPat.var("y", dtypes.half.vec(8))), lambda x, y: UOp(Ops.VECTORIZE, dtypes.half.vec(16), tuple(y.gep(i // 2) if i % 2 == 0 else UOp.const(dtypes.half, 0.0) for i in range(16)))), (UPat(Ops.CAST, name="x", dtype=dtypes.half.vec(8), src=UPat.var("y", dtypes.half.vec(16))), @@ -236,6 +232,19 @@ class AMDLLVMRenderer(LLVMRenderer): (UPat(Ops.LOG2, dtype=dtypes.double, src=(UPat.var("d"),)), xlog2), (UPat(Ops.EXP2, dtype=dtypes.double, src=(UPat.var("d"),)), xexp2), ]) + def render(self, uops: list[UOp]) -> str: + prefix = ["""define i8 @f32_to_fp8(float %val, i1 %is_bf8) { +entry: %ival = bitcast float %val to i32\n %exp = and i32 %ival, 2139095040\n %is_special = icmp eq i32 %exp, 2139095040 +br i1 %is_special, label %select_clip, label %clip +clip: br i1 %is_bf8, label %bf8_clip, label %fp8_clip +bf8_clip: %clamped_bf8 = call float @llvm.amdgcn.fmed3.f32(float %val, float 57344.0, float -57344.0)\n br label %select_clip +fp8_clip: %clamped_fp8 = call float @llvm.amdgcn.fmed3.f32(float %val, float 448.0, float -448.0) \n br label %select_clip +select_clip: %phi_val = phi float [%val, %entry], [%clamped_bf8, %bf8_clip], [%clamped_fp8, %fp8_clip]\n br i1 %is_bf8, label %do_bf8, label %do_fp8 +do_bf8: %packed_bf8 = call i32 @llvm.amdgcn.cvt.pk.bf8.f32(float %phi_val, float %phi_val, i32 0, i1 false)\n br label %exit +do_fp8: %packed_fp8 = call i32 @llvm.amdgcn.cvt.pk.fp8.f32(float %phi_val, float %phi_val, i32 0, i1 false)\n br label %exit +exit: %packed = phi i32 [%packed_bf8, %do_bf8], [%packed_fp8, %do_fp8]\n %trunc = trunc i32 %packed to i8\n ret i8 %trunc +}""".replace(": ", ":\n ")] if any(u.dtype in dtypes.fp8s for u in uops) else [] + return "\n".join((k:=self._render_kernel(uops, prefix))[0] + (k[1], self._render_footer(uops))) def _render_footer(self, uops: list[UOp]) -> str: # TODO: this is copied from cstyle local_dims = [u.src[0] for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"] @@ -252,7 +261,10 @@ class AMDLLVMRenderer(LLVMRenderer): self.extra_matcher += PatternMatcher([ (UPat(Ops.WMMA, name="x", dtype=dtypes.float.vec(4)), lambda x: UOp(Ops.WMMA, dtypes.float.vec(4), (x.src[0].bitcast(dtypes.uint16.vec(4)), x.src[1].bitcast(dtypes.uint16.vec(4)), - x.src[2]), (*x.arg,)) if x.src[0].dtype == dtypes.bfloat16.vec(4) else None) + x.src[2]), (*x.arg,)) if x.src[0].dtype == dtypes.bfloat16.vec(4) else None), + (UPat(Ops.WMMA, name="x", dtype=dtypes.float.vec(4)), + lambda x: UOp(Ops.WMMA, dtypes.float.vec(4), (x.src[0].bitcast(dtypes.uint64), x.src[1].bitcast(dtypes.uint64), + x.src[2]), (*x.arg,)) if x.src[0].dtype in (dtypes.fp8e4m3.vec(8), dtypes.fp8e5m2.vec(8)) else None), ]) if self.arch.split(":")[0] == "gfx1100": self.extra_matcher += PatternMatcher([