diff --git a/examples/mlperf/models/flat_llama.py b/examples/mlperf/models/flat_llama.py index 579ad3bebf..1c478e182d 100644 --- a/examples/mlperf/models/flat_llama.py +++ b/examples/mlperf/models/flat_llama.py @@ -42,25 +42,20 @@ def quantize_fp8(x:Tensor, amax_state:Tensor|None=None): x_clamped = x_scaled + (x_scaled.detach().clamp(-FP8_MAX, FP8_MAX) - x_scaled.detach()) # STE return x_clamped.cast(FP8_DTYPE), scale.float().reciprocal(), new_amax -def matmul(x:Tensor, w:Tensor, fp8=FP8, amax_x:Tensor|None=None, w_inv_scale:Tensor|None=None) -> tuple[Tensor,...]: +def matmul(x:Tensor, w:Tensor, fp8=FP8, amax_x:Tensor|None=None, w_inv_scale:Tensor|None=None, + x_fp8:Tensor|None=None, x_scale:Tensor|None=None, x_new_amax:Tensor|None=None) -> tuple[Tensor,...]: if not fp8: if getenv("ASM_GEMM"): from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm if can_use_asm_gemm(x, w.T): return (asm_gemm(x, w.T),) return (x @ w.T,) assert w_inv_scale is not None, "fp8 matmul requires w_inv_scale (weights must be stored in fp8 with per-tensor scale)" - x_fp8, x_scale, x_new_amax = quantize_fp8(x, amax_state=amax_x) + if x_fp8 is None: x_fp8, x_scale, x_new_amax = quantize_fp8(x, amax_state=amax_x) if getenv("ASM_GEMM"): from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm if can_use_asm_gemm(x_fp8, w.T): return asm_gemm(x_fp8, w.T, x_scale=x_scale, w_scale=w_inv_scale), x_new_amax, x_fp8, w return x_fp8.dot(w.T, dtype=dtypes.float) * x_scale * w_inv_scale, x_new_amax, x_fp8, w -def matmul_fp8_precomputed(x_fp8:Tensor, x_inv_scale:Tensor, x_new_amax:Tensor, w:Tensor, w_inv_scale:Tensor) -> tuple[Tensor,...]: - if getenv("ASM_GEMM"): - from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm - if can_use_asm_gemm(x_fp8, w.T): return asm_gemm(x_fp8, w.T, x_scale=x_inv_scale, w_scale=w_inv_scale), x_new_amax, x_fp8, w - return x_fp8.dot(w.T, dtype=dtypes.float) * x_inv_scale * w_inv_scale, x_new_amax, x_fp8, w - def _rmsnorm_fwd(x_in:Tensor, eps:float) -> tuple[Tensor, Tensor]: x = x_in.float() rrms = (x.square().mean(-1, keepdim=True) + eps).rsqrt() @@ -180,10 +175,14 @@ class FlatTransformer: new_amaxs.extend(ret[:1]) saves.extend(ret[1:] + [x_w13]) - x_w1 = x_w13[..., :self.hidden_dim] - x_w3 = x_w13[..., self.hidden_dim:] - - out, *ret = matmul(x_w1.silu() * x_w3, w2, amax_x=amax_x2, w_inv_scale=s_2) + if FP8 and getenv("FUSED_SILU_W13", 1): + from extra.amax.cast_amax import fused_quantize_fp8_w13 + amax_s = amax_x2 if amax_x2 is not None else Tensor.full((), 1.0, dtype=dtypes.bfloat16, device=x_w13.device) + x2_fp8, x2_inv_scale, new_amax_x2 = fused_quantize_fp8_w13(x_w13, amax_s, FP8_DTYPE) + out, *ret = matmul(None, w2, w_inv_scale=s_2, x_fp8=x2_fp8, x_scale=x2_inv_scale, x_new_amax=new_amax_x2) + else: + x_w1, x_w3 = x_w13[..., :self.hidden_dim], x_w13[..., self.hidden_dim:] + out, *ret = matmul(x_w1.silu() * x_w3, w2, amax_x=amax_x2, w_inv_scale=s_2) new_amaxs.extend(ret[:1]) saves.extend(ret[1:] + [out]) return (out, *new_amaxs, *saves) diff --git a/extra/amax/cast_amax.py b/extra/amax/cast_amax.py new file mode 100644 index 0000000000..d641255010 --- /dev/null +++ b/extra/amax/cast_amax.py @@ -0,0 +1,85 @@ +import functools, pathlib +from tinygrad import Tensor, dtypes +from tinygrad.uop.ops import UOp, Ops, KernelInfo +from tinygrad.renderer import Estimates +from tinygrad.runtime.support.compiler_amd import HIPCCCompiler + +FP8_MAX = 448.0 +NUM_WG, THREADS_PER_WG = 1024, 256 + +def _compile(cpp_name:str, n_elems:int, hidden:int): + src = (pathlib.Path(__file__).parent/cpp_name).read_text() + defines = [f"-DN_ELEMS={n_elems}", f"-DHIDDEN={hidden}", f"-DNUM_WG={NUM_WG}", f"-DTHREADS_PER_WG={THREADS_PER_WG}"] + return src, HIPCCCompiler("gfx950", ["-std=c++20", "-ffast-math", *defines]).compile_cached(src) + +def _shard_shape(shape:tuple, axis:int, ndev:int) -> list: + s = list(shape); s[axis] //= ndev; return s + +@functools.cache +def _custom_fused_bwd_w13(grad_xw13:UOp, xw13:UOp, grad_x2:UOp, amax_state:UOp, dname:str) -> UOp: + hidden = xw13.shape[2] // 2 + n_elems = xw13.shape[0] * xw13.shape[1] * hidden + threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(NUM_WG, "gidx0") + # read 2*N bf16 (xw13) + N bf16 (grad_x2) + 1 scalar; write 2*N bf16 (grad_xw13) + mem = n_elems * 2 * 5 + sink = UOp.sink(grad_xw13.base, xw13.base, grad_x2.base, amax_state.base, threads, workgroups, + arg=KernelInfo(f"fused_silu_mul_bwd_w13_{n_elems}", estimates=Estimates(ops=8*n_elems, mem=mem))) + src, lib = _compile("cast_amax_bwd_w13.cpp", n_elems, hidden) + return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)), + UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib))) + +@functools.cache +def _custom_fused_cast_amax_w13(fp8_out:UOp, amax_buf:UOp, xw13:UOp, amax_state:UOp, dname:str) -> UOp: + hidden = xw13.shape[2] // 2 + n_elems = xw13.shape[0] * xw13.shape[1] * hidden + threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(NUM_WG, "gidx0") + # read 2*N bf16 + 1 scalar, write N fp8 + NUM_WG bf16 + mem = n_elems * 2 * 2 + n_elems + NUM_WG * 2 + sink = UOp.sink(fp8_out.base, amax_buf.base, xw13.base, amax_state.base, threads, workgroups, + arg=KernelInfo(f"fused_silu_mul_cast_amax_w13_{n_elems}", estimates=Estimates(ops=5*n_elems, mem=mem))) + src, lib = _compile("cast_amax_fwd_w13.cpp", n_elems, hidden) + return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)), + UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib))) + +def _fused_quantize_bwd_w13(gradient:UOp, kernel:UOp): + # kernel.src[1:] is (fp8_out, amax_buf, xw13, amax_state); only xw13 needs a grad + _, _, xw13, amax_state = kernel.src[1:] + device = xw13.device + if isinstance(device, tuple): + axis, ndev = xw13.axis, len(device) + assert axis in (0, 1), f"unsupported sharding axis={axis}" + grad_xw13 = Tensor(Tensor.invalid(*_shard_shape(xw13.shape, axis, ndev), dtype=dtypes.bfloat16, device=device).uop.multi(axis), device=device) + dname = device[0].split(":")[0] + else: + grad_xw13 = Tensor.invalid(*xw13.shape, dtype=dtypes.bfloat16, device=device) + dname = device.split(":")[0] if isinstance(device, str) else device + grad_x2_t = Tensor(gradient, device=device).cast(dtypes.bfloat16) + fxn = functools.partial(_custom_fused_bwd_w13, dname=dname) + grad_xw13, *_ = Tensor.custom_kernel(grad_xw13, Tensor(xw13, device=device), grad_x2_t, Tensor(amax_state, device=device), fxn=fxn) + return (None, None, grad_xw13.uop, None) + +def fused_quantize_fp8_w13(xw13:Tensor, amax_state:Tensor, fp8_dtype) -> tuple[Tensor, Tensor, Tensor]: + # silu(xw1)*xw3 -> fp8 + amax over fused xw13 layout. Returns (fp8, inv_scale, new_amax). + assert xw13.dtype == dtypes.bfloat16, f"expected bf16, got {xw13.dtype}" + MBS, SEQ, H2 = xw13.shape + assert H2 % 2 == 0, f"w13 last-axis must be even, got {H2}" + HIDDEN = H2 // 2 + if isinstance(xw13.device, tuple): + axis, ndev = xw13.uop.axis, len(xw13.device) + assert axis in (0, 1), f"unsupported sharding axis={axis}" + fp8_out = Tensor(Tensor.invalid(*_shard_shape((MBS, SEQ, HIDDEN), axis, ndev), dtype=fp8_dtype, device=xw13.device).uop.multi(axis), device=xw13.device) + amax_buf = Tensor(Tensor.invalid(NUM_WG, dtype=dtypes.bfloat16, device=xw13.device).uop.multi(0), device=xw13.device) + dname = xw13.device[0].split(":")[0] + else: + fp8_out = Tensor.invalid(MBS, SEQ, HIDDEN, dtype=fp8_dtype, device=xw13.device) + amax_buf = Tensor.invalid(NUM_WG, dtype=dtypes.bfloat16, device=xw13.device) + dname = xw13.device.split(":")[0] if isinstance(xw13.device, str) else xw13.device + fxn = functools.partial(_custom_fused_cast_amax_w13, dname=dname) + fp8_out, amax_buf, *_ = Tensor.custom_kernel(fp8_out, amax_buf, xw13, amax_state, fxn=fxn, grad_fxn=_fused_quantize_bwd_w13) + # per-device scalar amax (no cross-device allreduce, matches _local_abs_max semantics) + if isinstance(amax_buf.device, tuple): + from examples.mlperf.models.flat_llama import _local_abs_max + new_amax = _local_abs_max(amax_buf).detach() + else: new_amax = amax_buf.max().detach() + inv_scale = (FP8_MAX / (amax_state + 1e-8)).float().reciprocal() + return fp8_out, inv_scale, new_amax diff --git a/extra/amax/cast_amax_bwd_w13.cpp b/extra/amax/cast_amax_bwd_w13.cpp new file mode 100644 index 0000000000..cffddf9f5b --- /dev/null +++ b/extra/amax/cast_amax_bwd_w13.cpp @@ -0,0 +1,68 @@ +#include +#include + +#ifndef N_ELEMS +#define N_ELEMS 234881024 +#endif +#ifndef HIDDEN +#define HIDDEN 14336 +#endif +#ifndef NUM_WG +#define NUM_WG 1024 +#endif +#ifndef THREADS_PER_WG +#define THREADS_PER_WG 256 +#endif + +constexpr int VEC = 8; +constexpr float FP8_MAX = 448.0f; + +static_assert(N_ELEMS % VEC == 0, "N_ELEMS must be divisible by VEC"); +static_assert(HIDDEN % VEC == 0, "HIDDEN must be divisible by VEC"); + +extern "C" __global__ __launch_bounds__(THREADS_PER_WG) void +fused_silu_mul_bwd_w13( + __hip_bfloat16* __restrict__ grad_xw13_out, // bf16, 2*N_ELEMS (interleaved layout) + const __hip_bfloat16* __restrict__ xw13, // bf16, 2*N_ELEMS (interleaved) + const __hip_bfloat16* __restrict__ grad_x2, // bf16, N_ELEMS + const __hip_bfloat16* __restrict__ amax_state) // bf16 scalar +{ + const int tid = threadIdx.x; + const int wg = blockIdx.x; + const int gid = wg * THREADS_PER_WG + tid; + const int stride_elems = NUM_WG * THREADS_PER_WG * VEC; + + const float scale = FP8_MAX / (static_cast(*amax_state) + 1e-8f); + + for (int base = gid * VEC; base < N_ELEMS; base += stride_elems) { + const int outer = base / HIDDEN; + const int inner = base % HIDDEN; + const int xw1_off = outer * 2 * HIDDEN + inner; + const int xw3_off = xw1_off + HIDDEN; + + float4 x1_raw = *reinterpret_cast(&xw13[xw1_off]); + float4 x3_raw = *reinterpret_cast(&xw13[xw3_off]); + float4 g_raw = *reinterpret_cast(&grad_x2[base]); + + const __hip_bfloat16 *x1 = reinterpret_cast(&x1_raw); + const __hip_bfloat16 *x3 = reinterpret_cast(&x3_raw); + const __hip_bfloat16 *gv = reinterpret_cast(&g_raw); + + __hip_bfloat16 out1[VEC], out3[VEC]; + #pragma unroll + for (int i = 0; i < VEC; i++) { + const float f1 = static_cast(x1[i]); + const float f3 = static_cast(x3[i]); + const float fg = static_cast(gv[i]); + const float sig = 1.0f / (1.0f + __expf(-f1)); + const float silu = f1 * sig; + const float silu_prime = sig + silu * (1.0f - sig); + const float gs = fg * scale; + out1[i] = static_cast<__hip_bfloat16>(gs * silu_prime * f3); + out3[i] = static_cast<__hip_bfloat16>(gs * silu); + } + + *reinterpret_cast(&grad_xw13_out[xw1_off]) = *reinterpret_cast(out1); + *reinterpret_cast(&grad_xw13_out[xw3_off]) = *reinterpret_cast(out3); + } +} diff --git a/extra/amax/cast_amax_fwd_w13.cpp b/extra/amax/cast_amax_fwd_w13.cpp new file mode 100644 index 0000000000..0d25157cfc --- /dev/null +++ b/extra/amax/cast_amax_fwd_w13.cpp @@ -0,0 +1,79 @@ +#include +#include +#include + +#ifndef N_ELEMS +#define N_ELEMS 234881024 +#endif +#ifndef HIDDEN +#define HIDDEN 14336 +#endif +#ifndef NUM_WG +#define NUM_WG 1024 +#endif +#ifndef THREADS_PER_WG +#define THREADS_PER_WG 256 +#endif + +constexpr int VEC = 8; +constexpr float FP8_MAX = 448.0f; + +static_assert(N_ELEMS % VEC == 0, "N_ELEMS must be divisible by VEC"); +static_assert(HIDDEN % VEC == 0, "HIDDEN must be divisible by VEC (so VEC loads don't straddle block boundary)"); + +extern "C" __global__ __launch_bounds__(THREADS_PER_WG) void +fused_silu_mul_cast_amax_w13( + __hip_fp8_storage_t* __restrict__ fp8_out, // fp8, N_ELEMS + __hip_bfloat16* __restrict__ amax_buf, // bf16, NUM_WG (per-WG amaxes) + const __hip_bfloat16* __restrict__ xw13, // bf16, 2*N_ELEMS + const __hip_bfloat16* __restrict__ amax_state) // bf16 scalar +{ + __shared__ float sdata[THREADS_PER_WG]; + + const int tid = threadIdx.x; + const int wg = blockIdx.x; + const int gid = wg * THREADS_PER_WG + tid; + const int stride_elems = NUM_WG * THREADS_PER_WG * VEC; + + const float scale = FP8_MAX / (static_cast(*amax_state) + 1e-8f); + float local_max = 0.0f; + + // grid-stride over 8-element groups + for (int base = gid * VEC; base < N_ELEMS; base += stride_elems) { + // interleaved xw13 layout: xw1 and xw3 are not contiguous halves + const int outer = base / HIDDEN; + const int inner = base % HIDDEN; + const int xw1_off = outer * 2 * HIDDEN + inner; + const int xw3_off = xw1_off + HIDDEN; + + float4 x1_raw = *reinterpret_cast(&xw13[xw1_off]); + float4 x3_raw = *reinterpret_cast(&xw13[xw3_off]); + + const __hip_bfloat16 *x1 = reinterpret_cast(&x1_raw); + const __hip_bfloat16 *x3 = reinterpret_cast(&x3_raw); + + __hip_fp8_storage_t out[VEC]; + #pragma unroll + for (int i = 0; i < VEC; i++) { + const float f1 = static_cast(x1[i]); + const float f3 = static_cast(x3[i]); + const float silu = f1 / (1.0f + __expf(-f1)); + const float x2 = silu * f3; + local_max = fmaxf(local_max, fabsf(x2)); + const float x_scaled = fmaxf(-FP8_MAX, fminf(FP8_MAX, x2 * scale)); + out[i] = __hip_cvt_float_to_fp8(x_scaled, __HIP_SATFINITE, __HIP_E4M3); + } + + *reinterpret_cast(&fp8_out[base]) = *reinterpret_cast(out); + } + + // LDS tree reduction: per-workgroup amax + sdata[tid] = local_max; + __syncthreads(); + for (int s = THREADS_PER_WG / 2; s > 0; s >>= 1) { + if (tid < s) sdata[tid] = fmaxf(sdata[tid], sdata[tid + s]); + __syncthreads(); + } + + if (tid == 0) amax_buf[wg] = static_cast<__hip_bfloat16>(sdata[0]); +}