llama: fused silu fp8 amax (#15798)

* llama: combined w13

* llama: fused swiglu+fp8

* llama: fix amax interleaving

* llama: don't need seperate matmul
This commit is contained in:
wozeparrot
2026-04-19 12:03:55 +08:00
committed by GitHub
parent 5bdfd4883f
commit f28ea84de2
4 changed files with 243 additions and 12 deletions

View File

@@ -42,25 +42,20 @@ def quantize_fp8(x:Tensor, amax_state:Tensor|None=None):
x_clamped = x_scaled + (x_scaled.detach().clamp(-FP8_MAX, FP8_MAX) - x_scaled.detach()) # STE
return x_clamped.cast(FP8_DTYPE), scale.float().reciprocal(), new_amax
def matmul(x:Tensor, w:Tensor, fp8=FP8, amax_x:Tensor|None=None, w_inv_scale:Tensor|None=None) -> tuple[Tensor,...]:
def matmul(x:Tensor, w:Tensor, fp8=FP8, amax_x:Tensor|None=None, w_inv_scale:Tensor|None=None,
x_fp8:Tensor|None=None, x_scale:Tensor|None=None, x_new_amax:Tensor|None=None) -> tuple[Tensor,...]:
if not fp8:
if getenv("ASM_GEMM"):
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
if can_use_asm_gemm(x, w.T): return (asm_gemm(x, w.T),)
return (x @ w.T,)
assert w_inv_scale is not None, "fp8 matmul requires w_inv_scale (weights must be stored in fp8 with per-tensor scale)"
x_fp8, x_scale, x_new_amax = quantize_fp8(x, amax_state=amax_x)
if x_fp8 is None: x_fp8, x_scale, x_new_amax = quantize_fp8(x, amax_state=amax_x)
if getenv("ASM_GEMM"):
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
if can_use_asm_gemm(x_fp8, w.T): return asm_gemm(x_fp8, w.T, x_scale=x_scale, w_scale=w_inv_scale), x_new_amax, x_fp8, w
return x_fp8.dot(w.T, dtype=dtypes.float) * x_scale * w_inv_scale, x_new_amax, x_fp8, w
def matmul_fp8_precomputed(x_fp8:Tensor, x_inv_scale:Tensor, x_new_amax:Tensor, w:Tensor, w_inv_scale:Tensor) -> tuple[Tensor,...]:
if getenv("ASM_GEMM"):
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
if can_use_asm_gemm(x_fp8, w.T): return asm_gemm(x_fp8, w.T, x_scale=x_inv_scale, w_scale=w_inv_scale), x_new_amax, x_fp8, w
return x_fp8.dot(w.T, dtype=dtypes.float) * x_inv_scale * w_inv_scale, x_new_amax, x_fp8, w
def _rmsnorm_fwd(x_in:Tensor, eps:float) -> tuple[Tensor, Tensor]:
x = x_in.float()
rrms = (x.square().mean(-1, keepdim=True) + eps).rsqrt()
@@ -180,10 +175,14 @@ class FlatTransformer:
new_amaxs.extend(ret[:1])
saves.extend(ret[1:] + [x_w13])
x_w1 = x_w13[..., :self.hidden_dim]
x_w3 = x_w13[..., self.hidden_dim:]
out, *ret = matmul(x_w1.silu() * x_w3, w2, amax_x=amax_x2, w_inv_scale=s_2)
if FP8 and getenv("FUSED_SILU_W13", 1):
from extra.amax.cast_amax import fused_quantize_fp8_w13
amax_s = amax_x2 if amax_x2 is not None else Tensor.full((), 1.0, dtype=dtypes.bfloat16, device=x_w13.device)
x2_fp8, x2_inv_scale, new_amax_x2 = fused_quantize_fp8_w13(x_w13, amax_s, FP8_DTYPE)
out, *ret = matmul(None, w2, w_inv_scale=s_2, x_fp8=x2_fp8, x_scale=x2_inv_scale, x_new_amax=new_amax_x2)
else:
x_w1, x_w3 = x_w13[..., :self.hidden_dim], x_w13[..., self.hidden_dim:]
out, *ret = matmul(x_w1.silu() * x_w3, w2, amax_x=amax_x2, w_inv_scale=s_2)
new_amaxs.extend(ret[:1])
saves.extend(ret[1:] + [out])
return (out, *new_amaxs, *saves)

85
extra/amax/cast_amax.py Normal file
View File

@@ -0,0 +1,85 @@
import functools, pathlib
from tinygrad import Tensor, dtypes
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.renderer import Estimates
from tinygrad.runtime.support.compiler_amd import HIPCCCompiler
FP8_MAX = 448.0
NUM_WG, THREADS_PER_WG = 1024, 256
def _compile(cpp_name:str, n_elems:int, hidden:int):
src = (pathlib.Path(__file__).parent/cpp_name).read_text()
defines = [f"-DN_ELEMS={n_elems}", f"-DHIDDEN={hidden}", f"-DNUM_WG={NUM_WG}", f"-DTHREADS_PER_WG={THREADS_PER_WG}"]
return src, HIPCCCompiler("gfx950", ["-std=c++20", "-ffast-math", *defines]).compile_cached(src)
def _shard_shape(shape:tuple, axis:int, ndev:int) -> list:
s = list(shape); s[axis] //= ndev; return s
@functools.cache
def _custom_fused_bwd_w13(grad_xw13:UOp, xw13:UOp, grad_x2:UOp, amax_state:UOp, dname:str) -> UOp:
hidden = xw13.shape[2] // 2
n_elems = xw13.shape[0] * xw13.shape[1] * hidden
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(NUM_WG, "gidx0")
# read 2*N bf16 (xw13) + N bf16 (grad_x2) + 1 scalar; write 2*N bf16 (grad_xw13)
mem = n_elems * 2 * 5
sink = UOp.sink(grad_xw13.base, xw13.base, grad_x2.base, amax_state.base, threads, workgroups,
arg=KernelInfo(f"fused_silu_mul_bwd_w13_{n_elems}", estimates=Estimates(ops=8*n_elems, mem=mem)))
src, lib = _compile("cast_amax_bwd_w13.cpp", n_elems, hidden)
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)))
@functools.cache
def _custom_fused_cast_amax_w13(fp8_out:UOp, amax_buf:UOp, xw13:UOp, amax_state:UOp, dname:str) -> UOp:
hidden = xw13.shape[2] // 2
n_elems = xw13.shape[0] * xw13.shape[1] * hidden
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(NUM_WG, "gidx0")
# read 2*N bf16 + 1 scalar, write N fp8 + NUM_WG bf16
mem = n_elems * 2 * 2 + n_elems + NUM_WG * 2
sink = UOp.sink(fp8_out.base, amax_buf.base, xw13.base, amax_state.base, threads, workgroups,
arg=KernelInfo(f"fused_silu_mul_cast_amax_w13_{n_elems}", estimates=Estimates(ops=5*n_elems, mem=mem)))
src, lib = _compile("cast_amax_fwd_w13.cpp", n_elems, hidden)
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)))
def _fused_quantize_bwd_w13(gradient:UOp, kernel:UOp):
# kernel.src[1:] is (fp8_out, amax_buf, xw13, amax_state); only xw13 needs a grad
_, _, xw13, amax_state = kernel.src[1:]
device = xw13.device
if isinstance(device, tuple):
axis, ndev = xw13.axis, len(device)
assert axis in (0, 1), f"unsupported sharding axis={axis}"
grad_xw13 = Tensor(Tensor.invalid(*_shard_shape(xw13.shape, axis, ndev), dtype=dtypes.bfloat16, device=device).uop.multi(axis), device=device)
dname = device[0].split(":")[0]
else:
grad_xw13 = Tensor.invalid(*xw13.shape, dtype=dtypes.bfloat16, device=device)
dname = device.split(":")[0] if isinstance(device, str) else device
grad_x2_t = Tensor(gradient, device=device).cast(dtypes.bfloat16)
fxn = functools.partial(_custom_fused_bwd_w13, dname=dname)
grad_xw13, *_ = Tensor.custom_kernel(grad_xw13, Tensor(xw13, device=device), grad_x2_t, Tensor(amax_state, device=device), fxn=fxn)
return (None, None, grad_xw13.uop, None)
def fused_quantize_fp8_w13(xw13:Tensor, amax_state:Tensor, fp8_dtype) -> tuple[Tensor, Tensor, Tensor]:
# silu(xw1)*xw3 -> fp8 + amax over fused xw13 layout. Returns (fp8, inv_scale, new_amax).
assert xw13.dtype == dtypes.bfloat16, f"expected bf16, got {xw13.dtype}"
MBS, SEQ, H2 = xw13.shape
assert H2 % 2 == 0, f"w13 last-axis must be even, got {H2}"
HIDDEN = H2 // 2
if isinstance(xw13.device, tuple):
axis, ndev = xw13.uop.axis, len(xw13.device)
assert axis in (0, 1), f"unsupported sharding axis={axis}"
fp8_out = Tensor(Tensor.invalid(*_shard_shape((MBS, SEQ, HIDDEN), axis, ndev), dtype=fp8_dtype, device=xw13.device).uop.multi(axis), device=xw13.device)
amax_buf = Tensor(Tensor.invalid(NUM_WG, dtype=dtypes.bfloat16, device=xw13.device).uop.multi(0), device=xw13.device)
dname = xw13.device[0].split(":")[0]
else:
fp8_out = Tensor.invalid(MBS, SEQ, HIDDEN, dtype=fp8_dtype, device=xw13.device)
amax_buf = Tensor.invalid(NUM_WG, dtype=dtypes.bfloat16, device=xw13.device)
dname = xw13.device.split(":")[0] if isinstance(xw13.device, str) else xw13.device
fxn = functools.partial(_custom_fused_cast_amax_w13, dname=dname)
fp8_out, amax_buf, *_ = Tensor.custom_kernel(fp8_out, amax_buf, xw13, amax_state, fxn=fxn, grad_fxn=_fused_quantize_bwd_w13)
# per-device scalar amax (no cross-device allreduce, matches _local_abs_max semantics)
if isinstance(amax_buf.device, tuple):
from examples.mlperf.models.flat_llama import _local_abs_max
new_amax = _local_abs_max(amax_buf).detach()
else: new_amax = amax_buf.max().detach()
inv_scale = (FP8_MAX / (amax_state + 1e-8)).float().reciprocal()
return fp8_out, inv_scale, new_amax

View File

@@ -0,0 +1,68 @@
#include <hip/hip_runtime.h>
#include <hip/hip_bf16.h>
#ifndef N_ELEMS
#define N_ELEMS 234881024
#endif
#ifndef HIDDEN
#define HIDDEN 14336
#endif
#ifndef NUM_WG
#define NUM_WG 1024
#endif
#ifndef THREADS_PER_WG
#define THREADS_PER_WG 256
#endif
constexpr int VEC = 8;
constexpr float FP8_MAX = 448.0f;
static_assert(N_ELEMS % VEC == 0, "N_ELEMS must be divisible by VEC");
static_assert(HIDDEN % VEC == 0, "HIDDEN must be divisible by VEC");
extern "C" __global__ __launch_bounds__(THREADS_PER_WG) void
fused_silu_mul_bwd_w13(
__hip_bfloat16* __restrict__ grad_xw13_out, // bf16, 2*N_ELEMS (interleaved layout)
const __hip_bfloat16* __restrict__ xw13, // bf16, 2*N_ELEMS (interleaved)
const __hip_bfloat16* __restrict__ grad_x2, // bf16, N_ELEMS
const __hip_bfloat16* __restrict__ amax_state) // bf16 scalar
{
const int tid = threadIdx.x;
const int wg = blockIdx.x;
const int gid = wg * THREADS_PER_WG + tid;
const int stride_elems = NUM_WG * THREADS_PER_WG * VEC;
const float scale = FP8_MAX / (static_cast<float>(*amax_state) + 1e-8f);
for (int base = gid * VEC; base < N_ELEMS; base += stride_elems) {
const int outer = base / HIDDEN;
const int inner = base % HIDDEN;
const int xw1_off = outer * 2 * HIDDEN + inner;
const int xw3_off = xw1_off + HIDDEN;
float4 x1_raw = *reinterpret_cast<const float4*>(&xw13[xw1_off]);
float4 x3_raw = *reinterpret_cast<const float4*>(&xw13[xw3_off]);
float4 g_raw = *reinterpret_cast<const float4*>(&grad_x2[base]);
const __hip_bfloat16 *x1 = reinterpret_cast<const __hip_bfloat16*>(&x1_raw);
const __hip_bfloat16 *x3 = reinterpret_cast<const __hip_bfloat16*>(&x3_raw);
const __hip_bfloat16 *gv = reinterpret_cast<const __hip_bfloat16*>(&g_raw);
__hip_bfloat16 out1[VEC], out3[VEC];
#pragma unroll
for (int i = 0; i < VEC; i++) {
const float f1 = static_cast<float>(x1[i]);
const float f3 = static_cast<float>(x3[i]);
const float fg = static_cast<float>(gv[i]);
const float sig = 1.0f / (1.0f + __expf(-f1));
const float silu = f1 * sig;
const float silu_prime = sig + silu * (1.0f - sig);
const float gs = fg * scale;
out1[i] = static_cast<__hip_bfloat16>(gs * silu_prime * f3);
out3[i] = static_cast<__hip_bfloat16>(gs * silu);
}
*reinterpret_cast<float4*>(&grad_xw13_out[xw1_off]) = *reinterpret_cast<float4*>(out1);
*reinterpret_cast<float4*>(&grad_xw13_out[xw3_off]) = *reinterpret_cast<float4*>(out3);
}
}

View File

@@ -0,0 +1,79 @@
#include <hip/hip_runtime.h>
#include <hip/hip_bf16.h>
#include <hip/hip_fp8.h>
#ifndef N_ELEMS
#define N_ELEMS 234881024
#endif
#ifndef HIDDEN
#define HIDDEN 14336
#endif
#ifndef NUM_WG
#define NUM_WG 1024
#endif
#ifndef THREADS_PER_WG
#define THREADS_PER_WG 256
#endif
constexpr int VEC = 8;
constexpr float FP8_MAX = 448.0f;
static_assert(N_ELEMS % VEC == 0, "N_ELEMS must be divisible by VEC");
static_assert(HIDDEN % VEC == 0, "HIDDEN must be divisible by VEC (so VEC loads don't straddle block boundary)");
extern "C" __global__ __launch_bounds__(THREADS_PER_WG) void
fused_silu_mul_cast_amax_w13(
__hip_fp8_storage_t* __restrict__ fp8_out, // fp8, N_ELEMS
__hip_bfloat16* __restrict__ amax_buf, // bf16, NUM_WG (per-WG amaxes)
const __hip_bfloat16* __restrict__ xw13, // bf16, 2*N_ELEMS
const __hip_bfloat16* __restrict__ amax_state) // bf16 scalar
{
__shared__ float sdata[THREADS_PER_WG];
const int tid = threadIdx.x;
const int wg = blockIdx.x;
const int gid = wg * THREADS_PER_WG + tid;
const int stride_elems = NUM_WG * THREADS_PER_WG * VEC;
const float scale = FP8_MAX / (static_cast<float>(*amax_state) + 1e-8f);
float local_max = 0.0f;
// grid-stride over 8-element groups
for (int base = gid * VEC; base < N_ELEMS; base += stride_elems) {
// interleaved xw13 layout: xw1 and xw3 are not contiguous halves
const int outer = base / HIDDEN;
const int inner = base % HIDDEN;
const int xw1_off = outer * 2 * HIDDEN + inner;
const int xw3_off = xw1_off + HIDDEN;
float4 x1_raw = *reinterpret_cast<const float4*>(&xw13[xw1_off]);
float4 x3_raw = *reinterpret_cast<const float4*>(&xw13[xw3_off]);
const __hip_bfloat16 *x1 = reinterpret_cast<const __hip_bfloat16*>(&x1_raw);
const __hip_bfloat16 *x3 = reinterpret_cast<const __hip_bfloat16*>(&x3_raw);
__hip_fp8_storage_t out[VEC];
#pragma unroll
for (int i = 0; i < VEC; i++) {
const float f1 = static_cast<float>(x1[i]);
const float f3 = static_cast<float>(x3[i]);
const float silu = f1 / (1.0f + __expf(-f1));
const float x2 = silu * f3;
local_max = fmaxf(local_max, fabsf(x2));
const float x_scaled = fmaxf(-FP8_MAX, fminf(FP8_MAX, x2 * scale));
out[i] = __hip_cvt_float_to_fp8(x_scaled, __HIP_SATFINITE, __HIP_E4M3);
}
*reinterpret_cast<uint64_t*>(&fp8_out[base]) = *reinterpret_cast<uint64_t*>(out);
}
// LDS tree reduction: per-workgroup amax
sdata[tid] = local_max;
__syncthreads();
for (int s = THREADS_PER_WG / 2; s > 0; s >>= 1) {
if (tid < s) sdata[tid] = fmaxf(sdata[tid], sdata[tid + s]);
__syncthreads();
}
if (tid == 0) amax_buf[wg] = static_cast<__hip_bfloat16>(sdata[0]);
}