CDNA assembly gemm in tensor.py with flag (#14310)

* work

* work

* the assembly

* remove the old one

* remove ws bufs, assert splitk

* notes cleanup

* work

* gemm args

* gemm in mixins would be nice

* add gemm gradient

* print counters

* the realize is for DEBUG=2 aesthetics

* dedup

* rewrite to python dsl, no list copies

* leave that

* add B, M, N, K to gemm name

* it's M0 not NULL

* fp16 support

* test cleanup + more gemms

* work from viz

* more work

* gemm batch_size

* xccg path work

* tiny comments on the label naming

* s_waitcnt
This commit is contained in:
qazal
2026-01-31 08:34:14 -05:00
committed by GitHub
parent 55f806b713
commit 616e9c1483
11 changed files with 11667 additions and 1775 deletions

View File

@@ -10,6 +10,7 @@ export DEBUG=${DEBUG:-0}
export FLASH_ATTENTION=${FLASH_ATTENTION:-1}
export ALL2ALL=${ALL2ALL:-1}
export USE_ATOMICS=${USE_ATOMICS:-1}
export ASM_GEMM=${ASM_GEMM:-1}
export DEFAULT_FLOAT="bfloat16" OPTIM_DTYPE="bfloat16"
export DP=${DP:-8} BS=${BS:-8} EVAL_BS=${EVAL_BS:-8} GRADIENT_ACC_STEPS=${GRADIENT_ACC_STEPS:-2}

11517
extra/gemm/asm/cdna/asm.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,95 @@
import atexit, functools
from tinygrad.runtime.support.compiler_amd import HIPCompiler
from tinygrad import Tensor, Device, dtypes
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
from tinygrad.renderer import Estimates
from tinygrad.helpers import getenv, all_same, dedup
from extra.gemm.asm.cdna.asm import build_kernel, GEMM_ARGS
# ** CDNA4 assembly gemm
WORKGROUP_SIZE = 256
def custom_asm_gemm(C:UOp, A:UOp, B:UOp, dname:str, arch:str, wg:int) -> UOp:
batch, M, K = A.shape
K2, N = B.shape[(1 if B.ndim == 3 else 0):]
assert K == K2
lidx = UOp.special(WORKGROUP_SIZE, "lidx0")
gidx = UOp.special(wg, "gidx0")
k = build_kernel(batch, M, N, K, A.dtype.base)
sink = UOp.sink(C.base, A.base, B.base, lidx, gidx,
arg=KernelInfo(name=k.name, estimates=Estimates(ops=2*batch*M*N*K, mem=(batch*M*K + K*N + batch*M*N)*2)))
binary = HIPCompiler(arch).compile(k.to_asm())
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
UOp(Ops.SOURCE, arg=k.to_text()), UOp(Ops.BINARY, arg=binary)))
counters = {"used":0, "todos":[]}
def todo(msg:str) -> bool: counters["todos"].append(msg); return False
atexit.register(lambda: print(f'asm_gemm: {counters["used"]} used, {len(counters["todos"])} not used'))
def can_use_asm_gemm(a:Tensor, b:Tensor) -> bool:
if a.dtype != b.dtype: return todo(f"dtypes must match {a.dtype} != {b.dtype}")
if a.dtype not in {dtypes.bfloat16, dtypes.float16}: return todo(f"only bfloat16/float16, got {a.dtype}")
# only sharding on the batch is tested, others might work too
if isinstance(a.device, tuple) and not (a.ndim == 3 and a.uop.axis == 0 and b.uop.axis is None):
return todo(f"sharding mismatch a.ndim={a.ndim} a.uop.axis={a.uop.axis} b.uop.axis={b.uop.axis}")
batch, M, K = (1, *a.shape) if a.ndim == 2 else a.shape
N = b.shape[1]
if isinstance(a.device, tuple): batch //= len(a.device)
if batch not in {1, 2}: return todo(f"GEMM batch size {batch}")
if (key:=(M, N, K)) not in GEMM_ARGS: return todo(f"GEMM shape not supported {key}")
return True
# ** UOp gemm to test Tensor.custom_kernel multi and backward correctness on non cdna4
# note: this can be removed after we have GEMM on mixins
def custom_uop_gemm(C:UOp, A:UOp, B:UOp) -> UOp:
M, K = A.shape[0]*A.shape[1], A.shape[2]
K2, N = B.shape[(1 if B.ndim == 3 else 0):]
assert K == K2
m = UOp.range(M, 1, AxisType.LOOP)
n = UOp.range(N, 2, AxisType.LOOP)
k = UOp.range(K, 0, AxisType.REDUCE)
mul = (A.index((m*UOp.const(dtypes.index, K)+k))*B.index((k*UOp.const(dtypes.index, N)+n))).cast(dtypes.float32)
red = mul.reduce(k, arg=Ops.ADD, dtype=dtypes.float32).cast(C.dtype.base)
store = C.index((m*UOp.const(dtypes.index, N)+n), ptr=True).store(red).end(m, n)
return store.sink(arg=KernelInfo(name=f'uop_gemm_{M}_{N}_{K}'))
# ** backward gemm, might use the asm gemm
def custom_gemm_bw(gradient:UOp, kernel:UOp):
out, a, b = kernel.src
assert all_same([gradient.device, a.device, b.device, out.device])
a_t, b_t, g_t = Tensor(a, device=a.device), Tensor(b, device=a.device), Tensor(gradient, device=a.device)
grad_a = (g_t @ b_t.T).uop
a_T = a_t.transpose(-2, -1)
a_T = a_T.reshape(*a_T.shape[:-1], 1, a_T.shape[-1])
g_r = g_t.reshape(*g_t.shape[:-2], 1, *g_t.shape[-2:]).transpose(-1, -2)
grad_b = (a_T * g_r).sum((-1, 0)).uop
return (None, grad_a, grad_b)
# ** main gemm function
def asm_gemm(a:Tensor, b:Tensor) -> Tensor:
assert can_use_asm_gemm(a, b), f"{counters['todos'][-1]}"
counters["used"] += 1
squeeze = a.ndim == 2
if squeeze: a = a.unsqueeze(0)
batch, M, K = a.shape
N = b.shape[1]
is_multi = isinstance(a.device, tuple)
if is_multi:
out = Tensor(Tensor.empty(batch//len(a.device), M, N, dtype=a.dtype, device=a.device).uop.multi(0), device=a.device)
else:
out = Tensor.empty(batch, M, N, dtype=a.dtype, device=a.device)
dname = a.device[0] if is_multi else a.device
arch = getattr(Device[dname].renderer, "arch", None)
if arch.startswith("gfx950") and getenv("USE_ASM", 1):
numWG = GEMM_ARGS[(M, N, K)][0]
out = Tensor.custom_kernel(out, a, b, fxn=functools.partial(custom_asm_gemm, dname=dname, wg=numWG, arch=arch), grad_fxn=custom_gemm_bw)[0]
else:
out = Tensor.custom_kernel(out, a, b, fxn=custom_uop_gemm, grad_fxn=custom_gemm_bw)[0]
return out.squeeze(0) if squeeze else out

File diff suppressed because it is too large Load Diff

View File

@@ -1,78 +0,0 @@
.text
.section .text.
.global gemm
.p2align 8
.type gemm,@function
gemm:
INSTRUCTIONS
.section .rodata,"a",@progbits
.p2align 6, 0x0
.amdhsa_kernel gemm
# basic memory requirements
.amdhsa_group_segment_fixed_size 133120
.amdhsa_private_segment_fixed_size 0
.amdhsa_kernarg_size 28
# register usage (RSRC1)
.amdhsa_next_free_vgpr 504
.amdhsa_next_free_sgpr 96
# workgroup / workitem IDs (RSRC2)
.amdhsa_system_sgpr_workgroup_id_x 1
.amdhsa_system_sgpr_workgroup_id_y 1
.amdhsa_system_sgpr_workgroup_id_z 1
# user SGPRs, we only specify the kernel args ptr in s[0:1]
.amdhsa_user_sgpr_kernarg_segment_ptr 1
.amdhsa_user_sgpr_count 2
.amdhsa_user_sgpr_kernarg_preload_length 0
.amdhsa_user_sgpr_kernarg_preload_offset 0
# gfx90a / gfx940 specifics (RSRC3)
.amdhsa_accum_offset 248
.amdhsa_uses_dynamic_stack 0
.amdhsa_tg_split 0
.end_amdhsa_kernel
.amdgpu_metadata
---
amdhsa.kernels:
- .name: gemm
.symbol: gemm.kd
.args:
- .name: C
.address_space: global
.offset: 0
.size: 8
.value_kind: global_buffer
.value_type: bf16
- .name: B
.address_space: global
.offset: 8
.size: 8
.value_kind: global_buffer
.value_type: bf16
- .name: A
.address_space: global
.offset: 16
.size: 8
.value_kind: global_buffer
.value_type: bf16
- .name: sz
.offset: 24
.size: 4
.value_kind: by_value
.value_type: u32
.group_segment_fixed_size: 133120
.private_segment_fixed_size: 0
.kernarg_segment_align: 8
.kernarg_segment_size: 28
.max_flat_workgroup_size: 256
.sgpr_count: 88
.sgpr_spill_count: 0
.vgpr_count: 248
.vgpr_spill_count: 0
.wavefront_size: 64
amdhsa.version:
- 1
- 0
...
.end_amdgpu_metadata

View File

@@ -1,73 +0,0 @@
# Run assembly on the AMD runtime and check correctness
# VIZ=2 to profile
import pathlib
from tinygrad import Tensor, Device, dtypes, Context
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.engine.realize import Estimates
from tinygrad.helpers import getenv
fp = pathlib.Path(__file__).parent/"gemm.s"
N = getenv("N", 8192)
THREADS_PER_WG = 256
NUM_WG = N//THREADS_PER_WG * N//THREADS_PER_WG
assert N % THREADS_PER_WG == 0, "N must be divisible by THREADS_PER_WG"
# ** generate inputs on CPU
scale = 10.0
import torch
torch.manual_seed(0)
A = (torch.randn(N, N, dtype=torch.float32, device="cpu") / scale).to(torch.bfloat16).contiguous()
B = (torch.randn(N, N, dtype=torch.float32, device="cpu") / scale).to(torch.bfloat16).contiguous()
Bt = B.t().contiguous() # transpose B for the asm gemm
C_torch = A@B
# ** copy buffers to AMD
# input creation and validation run on the copy engine for simpler tracing
def from_torch(t:torch.Tensor) -> Tensor:
return Tensor.from_blob(t.data_ptr(), t.shape, dtype=dtypes.bfloat16, device="cpu").to(Device.DEFAULT).realize()
C_tiny = from_torch(A) @ from_torch(B)
C_asm = Tensor.empty_like(C_tiny)
# ** assembly custom kernel
def custom_asm_gemm(C:UOp, A:UOp, B:UOp) -> UOp:
lidx = UOp.special(THREADS_PER_WG, "lidx0")
gidx = UOp.special(NUM_WG, "gidx0")
src = (pathlib.Path(__file__).parent/"template.s").read_text().replace("INSTRUCTIONS", fp.read_text())
sz = UOp.variable("SZ", 256, 8192)
sink = UOp.sink(C.base, A.base, B.base, sz, lidx, gidx, arg=KernelInfo(name="gemm", estimates=Estimates(ops=N*N*N*2, mem=N*N*4*3)))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=Device.DEFAULT), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)))
C_asm = Tensor.custom_kernel(C_asm, from_torch(A), from_torch(Bt), fxn=custom_asm_gemm)[0]
# ** run gemms
sched = Tensor.schedule(C_tiny, C_asm)
eis = [si.lower() for si in sched]
with Context(DEBUG=2):
for ei in eis:
et = ei.run({"SZ":N}, wait=True)
print(f"{(N*N*N*2 / et)*1e-12:.2f} REAL TFLOPS")
# ** correctness
import ctypes
def torch_bf16(t:Tensor) -> torch.tensor:
asm_out = t.to("cpu").realize().uop.buffer._buf
buf = (ctypes.c_uint16*C_asm.uop.size).from_address(asm_out.va_addr)
return torch.frombuffer(buf, dtype=torch.bfloat16, count=C_asm.uop.size).reshape(C_asm.shape)
assert torch.allclose(torch_bf16(C_asm), C_torch, rtol=1e-2, atol=1e-3)
assert torch.allclose(torch_bf16(C_tiny), C_torch, rtol=1e-2, atol=1e-3)

View File

@@ -0,0 +1,46 @@
import unittest
from tinygrad import Tensor, Device, dtypes, Context
from tinygrad.helpers import getenv
from extra.gemm.asm.cdna.gemm import asm_gemm
def verify_asm_gemm(batch:int, M:int, N:int, K:int, dtype=dtypes.bfloat16, multi=False) -> None:
Tensor.manual_seed(0)
a_rand = Tensor.randn((batch, M, K), dtype=dtypes.float).sub(0.5).cast(dtype)
b_rand = Tensor.randn((K, N), dtype=dtypes.float).sub(0.5).cast(dtype)
with Context(DEBUG=0):
Tensor.realize(a_rand, b_rand)
devs = tuple(f"{Device.DEFAULT}:{i}" for i in range(8)) if multi else None
a, b = Tensor(a_rand.numpy(), requires_grad=True).cast(dtype), Tensor(b_rand.numpy(), requires_grad=True).cast(dtype)
if multi: a, b = a.shard(devs, axis=0), b.shard(devs, axis=None)
tst = asm_gemm(a, b)
tst.sum().backward()
Tensor.realize(tst, a.grad, b.grad)
a_ref, b_ref = Tensor(a_rand.numpy(), requires_grad=True).cast(dtype), Tensor(b_rand.numpy(), requires_grad=True).cast(dtype)
if multi: a_ref, b_ref = a_ref.shard(devs, axis=0), b_ref.shard(devs, axis=None)
with Context(ASM_GEMM=0): ref = a_ref @ b_ref
ref.sum().backward()
Tensor.realize(ref, a_ref.grad, b_ref.grad)
with Context(DEBUG=0):
assert (tst - ref).square().max().float().item() < 1e-6, "forward mismatch"
assert (a.grad - a_ref.grad).square().max().float().item() < 1e-3, "grad_a mismatch"
assert (b.grad - b_ref.grad).square().max().float().item() < 1e-3, "grad_b mismatch"
class TestGemm(unittest.TestCase):
def test_simple(self): verify_asm_gemm(1, N:=getenv("N", 4096), N, N, dtype=dtypes.half)
def test_gemm1(self): verify_asm_gemm(8, 8192, 4096, 14336, multi=True)
def test_gemm2(self): verify_asm_gemm(8, 8192, 128256, 4096, multi=True)
def test_gemm3(self): verify_asm_gemm(8, 8192, 14336, 4096, multi=True)
def test_gemm4(self): verify_asm_gemm(8, 4096, 14336, 4096, multi=True)
def test_gemm5(self): verify_asm_gemm(8, 4096, 4096, 14336, multi=True)
def test_gemm6(self): verify_asm_gemm(16, 4096, 4096, 14336, multi=True)
def test_gemm_unsupported(self):
with self.assertRaisesRegex(AssertionError, "shape not supported"):
verify_asm_gemm(8, 8192, 1024, 4096, multi=True)
if __name__ == "__main__":
unittest.main()

View File

@@ -206,6 +206,8 @@ ALLOW_TF32 = ContextVar("ALLOW_TF32", 0)
SCACHE = ContextVar("SCACHE", 1)
# allow use of atomics for embedding backward
USE_ATOMICS = ContextVar("USE_ATOMICS", 0)
# allow use of assembly for gemm
ASM_GEMM = ContextVar("ASM_GEMM", 0)
@dataclass(frozen=True)
class Metadata:

View File

@@ -7,7 +7,7 @@ if TYPE_CHECKING: import numpy
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, is_numpy_ndarray, TracingKey, cpu_profile
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ASM_GEMM, ceildiv, fetch, polyN, is_numpy_ndarray, TracingKey, cpu_profile
from tinygrad.helpers import suppress_finalizing, disable_gc
from tinygrad.gradient import compute_gradient
from tinygrad.mixin import OpMixin
@@ -2431,6 +2431,9 @@ class Tensor(OpMixin):
```
"""
if IMAGE: return self.image_dot(w, dtype)
if ASM_GEMM:
from extra.gemm.asm.cdna.gemm import can_use_asm_gemm, asm_gemm
if can_use_asm_gemm(self, w): return asm_gemm(self, w)
x, dx, dw = self, self.ndim, w.ndim
if not (dx > 0 and dw > 0): raise RuntimeError(f"both tensors need to be at least 1D, got {dx}D and {dw}D")
if x.shape[-1] != w.shape[axis_w:=-min(w.ndim,2)]: raise RuntimeError(f"cannot dot {x.shape} and {w.shape}")

View File

@@ -206,7 +206,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
match self.op:
# late ops don't have shape
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.CUSTOM_KERNEL | \
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.CUSTOM_KERNEL | Ops.SINK | \
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY:
return None

View File

@@ -558,7 +558,7 @@ def get_render(query:str) -> dict:
rows:dict[int, dict] = {}
for pc, (inst,_) in pc_to_inst.items():
if start_pc is None: start_pc = pc
rows[pc] = {"pc":pc-start_pc, "inst":inst, "hit_count":0, "dur":0, "stall":0, "hits":{"cols":inst_columns, "rows":[]}, "type":""}
rows[pc] = {"pc":pc-start_pc, "inst":inst, "hit_count":0, "dur":0, "stall":0, "type":"", "hits":{"cols":inst_columns, "rows":[]}}
for e in w.unpack_insts():
if not (inst:=rows[e.pc]).get("type"): inst["type"] = str(e.typ).split("_")[-1]
inst["hit_count"] += 1