mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
amd fp8 llvm (#13186)
* amd fp8 llvm support * fix max * clean * add test_mi350.sh --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
12
extra/test_mi350.sh
Executable file
12
extra/test_mi350.sh
Executable file
@@ -0,0 +1,12 @@
|
||||
#!/bin/bash
|
||||
|
||||
AMD=1 AMD_LLVM=1 python -m pytest -n=1 test/test_ops.py test/test_dtype.py test/test_dtype_alu.py test/test_linearizer.py test/test_randomness.py test/test_jit.py test/test_graph.py test/test_multitensor.py --durations=20
|
||||
AMD=1 AMD_LLVM=0 python -m pytest -n=1 test/test_ops.py test/test_dtype.py test/test_dtype_alu.py test/test_linearizer.py test/test_randomness.py test/test_jit.py test/test_graph.py test/test_multitensor.py --durations=20
|
||||
|
||||
CNT=1 AMD_LLVM=0 DEBUG=2 FP8E4M3=1 HALF=0 BFLOAT16=0 SHOULD_USE_TC=1 python extra/gemm/simple_matmul.py
|
||||
CNT=1 AMD_LLVM=0 DEBUG=2 FP8E4M3=0 HALF=1 BFLOAT16=0 SHOULD_USE_TC=1 python extra/gemm/simple_matmul.py
|
||||
CNT=1 AMD_LLVM=0 DEBUG=2 FP8E4M3=0 HALF=0 BFLOAT16=1 SHOULD_USE_TC=1 python extra/gemm/simple_matmul.py
|
||||
|
||||
CNT=1 AMD_LLVM=1 DEBUG=2 FP8E4M3=0 HALF=1 BFLOAT16=0 SHOULD_USE_TC=1 python extra/gemm/simple_matmul.py
|
||||
CNT=1 AMD_LLVM=1 DEBUG=2 FP8E4M3=0 HALF=0 BFLOAT16=1 SHOULD_USE_TC=1 python extra/gemm/simple_matmul.py
|
||||
CNT=1 AMD_LLVM=1 DEBUG=2 FP8E4M3=1 HALF=0 BFLOAT16=0 SHOULD_USE_TC=1 python extra/gemm/simple_matmul.py
|
||||
@@ -5,7 +5,7 @@ from typing import Any, Generic, TypeVar, Iterator, Sequence, cast, Generator
|
||||
import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re, atexit, pickle, decimal
|
||||
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored, CPU_LLVM
|
||||
from tinygrad.helpers import Context, CCACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, dedup
|
||||
from tinygrad.helpers import unwrap_class_type, suppress_finalizing, AMD_LLVM, select_first_inited, VIZ
|
||||
from tinygrad.helpers import unwrap_class_type, suppress_finalizing, select_first_inited, VIZ
|
||||
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
@@ -329,7 +329,7 @@ def is_dtype_supported(dtype:DType, device:str|None=None) -> bool:
|
||||
return device in {"AMD", "PYTHON", "NULL"}
|
||||
if dtype in dtypes.fp8s:
|
||||
if device in {"CUDA", "NV"}: return not CI and not getenv(f"{device}_PTX") and not getenv("NV_NAK")
|
||||
if device == "AMD": return not CI and not AMD_LLVM and getattr(Device["AMD"], "target") in {(9,4,2), (9,5,0)}
|
||||
if device == "AMD": return not CI and getattr(Device["AMD"], "target") in {(9,4,2), (9,5,0)}
|
||||
return device in {"PYTHON", "NULL"}
|
||||
if device == "WEBGPU": return dtype in [dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short,
|
||||
dtypes.ushort, dtypes.float, dtypes.int32, dtypes.uint32, dtypes.half]
|
||||
|
||||
@@ -2,21 +2,22 @@ from typing import cast
|
||||
import math, struct, sys
|
||||
from tinygrad.codegen.opt import tc
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer.cstyle import AMDRenderer
|
||||
from tinygrad.renderer.cstyle import AMDRenderer, create_non_native_float_pats
|
||||
from tinygrad.uop.decompositions import xexp2, xlog2
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, GroupOp, range_str
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType, truncate
|
||||
from tinygrad.dtype import dtypes, float_to_fp8, DType, PtrDType, truncate
|
||||
from tinygrad.helpers import prod, AMX
|
||||
|
||||
def ldt(dt:DType):
|
||||
if dt.vcount > 1: return f"<{dt.vcount} x {ldt(dt.scalar())}>"
|
||||
if isinstance(dt, PtrDType): return ldt(dt.base) + "*"
|
||||
return {dtypes.void: "void", dtypes.bool: "i1", dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64",
|
||||
dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64",
|
||||
dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64", dtypes.fp8e4m3: "i8", dtypes.fp8e5m2: "i8",
|
||||
dtypes.float16: "half", dtypes.bfloat16: "bfloat", dtypes.float32: "float", dtypes.float64: "double"}[dt]
|
||||
|
||||
def lconst(x, dtype:DType):
|
||||
if dtype in dtypes.floats:
|
||||
if dtype in dtypes.fp8s: return float_to_fp8(x, dtype)
|
||||
if math.isinf(x) or math.isnan(x): return "0x%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1])
|
||||
return truncate[dtype](x)
|
||||
return int(x)
|
||||
@@ -47,13 +48,14 @@ def render_wmma_amx(ctx, wmma: UOp) -> str:
|
||||
f' {ctx[wmma]} = load {ldt(wmma.dtype)}, ptr {ctx[wmma]}_amx2, align {wmma.dtype.itemsize}'])
|
||||
|
||||
def render_wmma_amd(ctx, wmma: UOp, cdna=False) -> str:
|
||||
dt_map = {dtypes.half: "f16", dtypes.float: "f32", dtypes.ushort: "bf16.1k" if cdna else "bf16", dtypes.bfloat16: "bf16.1k" if cdna else "bf16"}
|
||||
dt_map = {dtypes.half: "f16", dtypes.float: "f32", dtypes.ushort: "bf16.1k" if cdna else "bf16", dtypes.bfloat16: "bf16.1k" if cdna else "bf16",
|
||||
dtypes.fp8e4m3: ".fp8.fp8", dtypes.fp8e5m2: ".bf8.bf8"}
|
||||
# https://github.com/llvm/llvm-project/blob/main/clang/test/CodeGenOpenCL/builtins-amdgcn-mfma.cl
|
||||
N,M,K = wmma.arg[1]
|
||||
if cdna:
|
||||
if K == 32: dt_map.update({dtypes.half: ".f16", dtypes.bfloat16: ".bf16"})
|
||||
return f" {ctx[wmma]} = call {ldt(wmma.dtype)} @llvm.amdgcn.mfma.{dt_map[wmma.src[-1].dtype.scalar()]}" + \
|
||||
f".{N}x{M}x{K}{dt_map[wmma.src[0].dtype.scalar()]}(" + ", ".join([f"{ldt(w.dtype)} {ctx[w]}" for w in wmma.src]) + ", i32 0, i32 0, i32 0)"
|
||||
f".{N}x{M}x{K}{dt_map[wmma.arg[2]]}(" + ", ".join([f"{ldt(w.dtype)} {ctx[w]}" for w in wmma.src]) + ", i32 0, i32 0, i32 0)"
|
||||
# https://github.com/llvm/llvm-project/blob/main/llvm/test/CodeGen/AMDGPU/GlobalISel/llvm.amdgcn.wmma_32.ll
|
||||
# example: %wmma0 = call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16(<16 x half> %v99,<16 x half> %v100,<8 x float> %v101)
|
||||
return f" {ctx[wmma]} = call {ldt(wmma.dtype)} @llvm.amdgcn.wmma.{dt_map[wmma.src[-1].dtype.scalar()]}.16x16x16." + \
|
||||
@@ -136,27 +138,16 @@ class LLVMRenderer(Renderer):
|
||||
has_local = False
|
||||
global_max: tuple[int, ...] | None = None
|
||||
string_rewrite = base_rewrite + PatternMatcher([(UPat(Ops.WMMA, name="wmma"), render_wmma_amx)])
|
||||
code_for_op = {Ops.FDIV: lambda: None}
|
||||
code_for_op = {Ops.FDIV: lambda: None, Ops.CMPLT: lambda: None}
|
||||
if AMX: tensor_cores = tc.amx
|
||||
|
||||
extra_matcher = PatternMatcher([
|
||||
# rewrite MAX to CMPLT + WHERE
|
||||
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
|
||||
# copied from cstyle.py, upcast to float32 all the ops that don't support bfloat16
|
||||
(UPat((Ops.SQRT, Ops.EXP2, Ops.LOG2, Ops.SIN), dtype=dtypes.bfloat16, name="x"),
|
||||
lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16))),
|
||||
# copied from cstyle.py, add float intermediate casting
|
||||
(UPat(Ops.CAST, name="x", src=UPat.var("y", dtypes.bfloat16)),lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None),
|
||||
(UPat(Ops.CAST, dtypes.bfloat16, UPat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None),
|
||||
])
|
||||
|
||||
extra_matcher = create_non_native_float_pats((dtypes.bfloat16,))
|
||||
def render(self, uops: list[UOp]) -> str: return "\n".join((k:=self._render_kernel(uops))[0] + (k[1], self._render_footer(uops)))
|
||||
def _render_footer(self, uops: list[UOp]) -> str: return 'attributes #0 = { alwaysinline nounwind "no-builtins" "no-trapping-math"="true" }'
|
||||
def _render_fn(self, name:str, args:list[tuple[str,DType]], kernel:list[str], prefix:list[str]|None=None) -> str:
|
||||
# NOTE: CPUAllocator promises 0x20 alignment
|
||||
sargs = ", ".join([f"{ldt(dt)}{' noalias align 32' if isinstance(dt, PtrDType) else ''} {name}" for name,dt in args])
|
||||
sprefix = "".join([f" {x}" for x in (prefix or []) + [self.abi] if x is not None])
|
||||
return "\n".join([f"define{sprefix} void @{name}({sargs}) #0", "{"] + kernel + [" ret void\n}"])
|
||||
return "\n".join((prefix or []) + [f"define{' ' + self.abi if self.abi else ''} void @{name}({sargs}) #0", "{"] + kernel + [" ret void\n}"])
|
||||
def _render_kernel(self, uops: list[UOp], prefix:list[str]|None=None) -> tuple[tuple[str, ...], str]:
|
||||
r: dict[UOp, str] = {}
|
||||
args: list[tuple[str, DType]] = []
|
||||
@@ -226,8 +217,13 @@ class AMDLLVMRenderer(LLVMRenderer):
|
||||
(UPat(tuple(llvm_intrinsics), name="x"),
|
||||
lambda ctx, x: f" {ctx[x]} = call {ldt(x.dtype)} @llvm.{llvm_intrinsics[x.op]}.{ldt(x.dtype.scalar())}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"),
|
||||
(UPat(Ops.BARRIER), lambda ctx: barrier),
|
||||
(UPat(Ops.CAST, dtypes.fp8s, (UPat.var("y", dtypes.float),), name="x",), lambda ctx,x,y:
|
||||
f" {ctx[x]} = call i8 @f32_to_fp8({ldt(x.src[0].dtype)} {ctx[x.src[0]]}, i1 {'1' if x.dtype == dtypes.fp8e5m2 else '0'})"),
|
||||
(UPat(Ops.CAST, dtypes.float, (UPat.var("y", dtypes.fp8s),), name="x",), lambda ctx,x,y:
|
||||
f" {ctx[x.src[0]]}_i32 = zext i8 {ctx[x.src[0]]} to i32\n"
|
||||
f" {ctx[x]} = call float @llvm.amdgcn.cvt.f32.{'bf8' if y.dtype == dtypes.fp8e5m2 else 'fp8'}(i32 {ctx[x.src[0]]}_i32, i32 0)"),
|
||||
]) + base_rewrite
|
||||
extra_matcher = LLVMRenderer.extra_matcher + PatternMatcher([
|
||||
extra_matcher = LLVMRenderer.extra_matcher + create_non_native_float_pats(dtypes.fp8s) + PatternMatcher([
|
||||
(UPat(Ops.CAST, name="x", dtype=dtypes.half.vec(16), src=UPat.var("y", dtypes.half.vec(8))),
|
||||
lambda x, y: UOp(Ops.VECTORIZE, dtypes.half.vec(16), tuple(y.gep(i // 2) if i % 2 == 0 else UOp.const(dtypes.half, 0.0) for i in range(16)))),
|
||||
(UPat(Ops.CAST, name="x", dtype=dtypes.half.vec(8), src=UPat.var("y", dtypes.half.vec(16))),
|
||||
@@ -236,6 +232,19 @@ class AMDLLVMRenderer(LLVMRenderer):
|
||||
(UPat(Ops.LOG2, dtype=dtypes.double, src=(UPat.var("d"),)), xlog2),
|
||||
(UPat(Ops.EXP2, dtype=dtypes.double, src=(UPat.var("d"),)), xexp2),
|
||||
])
|
||||
def render(self, uops: list[UOp]) -> str:
|
||||
prefix = ["""define i8 @f32_to_fp8(float %val, i1 %is_bf8) {
|
||||
entry: %ival = bitcast float %val to i32\n %exp = and i32 %ival, 2139095040\n %is_special = icmp eq i32 %exp, 2139095040
|
||||
br i1 %is_special, label %select_clip, label %clip
|
||||
clip: br i1 %is_bf8, label %bf8_clip, label %fp8_clip
|
||||
bf8_clip: %clamped_bf8 = call float @llvm.amdgcn.fmed3.f32(float %val, float 57344.0, float -57344.0)\n br label %select_clip
|
||||
fp8_clip: %clamped_fp8 = call float @llvm.amdgcn.fmed3.f32(float %val, float 448.0, float -448.0) \n br label %select_clip
|
||||
select_clip: %phi_val = phi float [%val, %entry], [%clamped_bf8, %bf8_clip], [%clamped_fp8, %fp8_clip]\n br i1 %is_bf8, label %do_bf8, label %do_fp8
|
||||
do_bf8: %packed_bf8 = call i32 @llvm.amdgcn.cvt.pk.bf8.f32(float %phi_val, float %phi_val, i32 0, i1 false)\n br label %exit
|
||||
do_fp8: %packed_fp8 = call i32 @llvm.amdgcn.cvt.pk.fp8.f32(float %phi_val, float %phi_val, i32 0, i1 false)\n br label %exit
|
||||
exit: %packed = phi i32 [%packed_bf8, %do_bf8], [%packed_fp8, %do_fp8]\n %trunc = trunc i32 %packed to i8\n ret i8 %trunc
|
||||
}""".replace(": ", ":\n ")] if any(u.dtype in dtypes.fp8s for u in uops) else []
|
||||
return "\n".join((k:=self._render_kernel(uops, prefix))[0] + (k[1], self._render_footer(uops)))
|
||||
def _render_footer(self, uops: list[UOp]) -> str:
|
||||
# TODO: this is copied from cstyle
|
||||
local_dims = [u.src[0] for u in uops if u.op is Ops.SPECIAL and u.arg[0] == "l"]
|
||||
@@ -252,7 +261,10 @@ class AMDLLVMRenderer(LLVMRenderer):
|
||||
self.extra_matcher += PatternMatcher([
|
||||
(UPat(Ops.WMMA, name="x", dtype=dtypes.float.vec(4)),
|
||||
lambda x: UOp(Ops.WMMA, dtypes.float.vec(4), (x.src[0].bitcast(dtypes.uint16.vec(4)), x.src[1].bitcast(dtypes.uint16.vec(4)),
|
||||
x.src[2]), (*x.arg,)) if x.src[0].dtype == dtypes.bfloat16.vec(4) else None)
|
||||
x.src[2]), (*x.arg,)) if x.src[0].dtype == dtypes.bfloat16.vec(4) else None),
|
||||
(UPat(Ops.WMMA, name="x", dtype=dtypes.float.vec(4)),
|
||||
lambda x: UOp(Ops.WMMA, dtypes.float.vec(4), (x.src[0].bitcast(dtypes.uint64), x.src[1].bitcast(dtypes.uint64),
|
||||
x.src[2]), (*x.arg,)) if x.src[0].dtype in (dtypes.fp8e4m3.vec(8), dtypes.fp8e5m2.vec(8)) else None),
|
||||
])
|
||||
if self.arch.split(":")[0] == "gfx1100":
|
||||
self.extra_matcher += PatternMatcher([
|
||||
|
||||
Reference in New Issue
Block a user