From f081f154aec2a3422bd0ed5c39ecd31e287cc9ff Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 17 Feb 2026 11:35:18 +0800 Subject: [PATCH] parameterize the CDNA asm gemm (#14813) * parameterize the CDNA asm gemm * fix llama test * fix * add more gemmt ests * confirm all match * test these asm gemms --- extra/gemm/asm/cdna/asm.py | 47 ++++++++++-------- extra/gemm/asm/cdna/gemm.py | 8 +-- test/external/external_test_llama3_layer.py | 9 ++-- test/testextra/test_asm_gemm.py | 54 +++++++++++++++++++-- 4 files changed, 87 insertions(+), 31 deletions(-) diff --git a/extra/gemm/asm/cdna/asm.py b/extra/gemm/asm/cdna/asm.py index 118e753fcb..c23084bbcf 100644 --- a/extra/gemm/asm/cdna/asm.py +++ b/extra/gemm/asm/cdna/asm.py @@ -4,24 +4,32 @@ from tinygrad.dtype import dtypes # M0 is encoded with 124 (NULL in RDNA) in CDNA M0 = NULL -# (M, N, K) -> (numWG, iters, total) -GEMM_ARGS = { - (8192, 4096, 4096): (256, 64, 32768), - (8192, 14336, 4096): (256, 64, 114688), - (8192, 4096, 14336): (256, 224, 114688), - # TODO: get a fast gemm for this shape - #(8192, 128256, 4096): (16032, 64, 1026048), - (8192, 8192, 8192): (256, 128, 131072), - (4096, 4096, 4096): (256, 64, 16384), - (4096, 14336, 4096): (256, 64, 57344), - (4096, 14336, 8192): (256, 128, 114688), - (4096, 4096, 14336): (256, 224, 57344), - (14336, 4096, 8192): (256, 128, 114688), - (4096, 8192, 14336): (256, 224, 114688), - (4096, 4096, 8192): (256, 128, 32768), - (4096, 8192, 4096): (256, 64, 32768), -} -ITERS_ARGS = {64: (67108864, 0), 128: (33554432, 0), 224: (613566757, 2147483656)} +TILE_M, TILE_N, TILE_K, NUM_WG = 256, 256, 64, 256 + +def _magicgu_mulhi(d:int, vmax:int) -> tuple[int,int]: + """Compute magic number and shift for mul_hi-based unsigned division by d, valid for all 32-bit n. + Adapted from magicgu in tinygrad.uop.decompositions (Hacker's Delight, Chapter 10) but targeting the mul_hi encoding: + - If shift bit 31 is clear: result = mul_hi(n, magic) >> shift + - If shift bit 31 is set: result = (mul_hi(n, magic) + n) >> (shift & 0x7FFFFFFF) (wrapping 32-bit add) + """ + if d == 1: return 0, (1 << 31) # (mul_hi(n, 0) + n) >> 0 = n + nc = (1 << 32) // d * d - 1 + for s in range(32, 65): + if 2**s > nc * (d - 1 - (2**s - 1) % d): + m = (2**s + d - 1 - (2**s - 1) % d) // d + shift = s - 32 + if m < (1 << 32): return m, shift + if m < (1 << 33): + m_enc = m - (1 << 32) + if ((((vmax * m_enc) >> 32) + vmax) & 0xFFFFFFFF) >> shift == vmax // d: return m_enc, shift | (1 << 31) + raise AssertionError(f"cannot compute magic for d={d}, vmax={vmax}") + +def compute_gemm_args(M:int, N:int, K:int, batch:int) -> tuple[int, int, int, int, int]: + assert M % TILE_M == 0 and N % TILE_N == 0 and K % TILE_K == 0, f"shape ({M},{N},{K}) not a multiple of ({TILE_M},{TILE_N},{TILE_K})" + iters = K // TILE_K + total = (M // TILE_M) * (N // TILE_N) * iters + magic, shift = _magicgu_mulhi(iters, total * batch) + return NUM_WG, iters, total, magic, shift class Kernel: def __init__(self, name="gemm"): self.name, self.instructions, self.labels, self.label_at_pos, self.pos = name, [], {}, {}, 0 @@ -79,9 +87,8 @@ class Kernel: return "\n".join(lines) def build_kernel(batch, M, N, K, dtype): - numWG, iters, total = GEMM_ARGS[(M, N, K)] + numWG, iters, total, magic, shift = compute_gemm_args(M, N, K, batch) total *= batch - magic, shift = ITERS_ARGS[iters] v_mfma_16x16x32 = {dtypes.half:v_mfma_f32_16x16x32_f16, dtypes.bfloat16:v_mfma_f32_16x16x32_bf16}[dtype] v_cvt_pk = {dtypes.half:v_cvt_pk_f16_f32, dtypes.bfloat16:v_cvt_pk_bf16_f32}[dtype] v_cvt = {dtypes.half:v_cvt_f32_f16_e32, dtypes.bfloat16:v_cvt_f32_bf16_e32}[dtype] diff --git a/extra/gemm/asm/cdna/gemm.py b/extra/gemm/asm/cdna/gemm.py index 787349a085..2d249e0322 100644 --- a/extra/gemm/asm/cdna/gemm.py +++ b/extra/gemm/asm/cdna/gemm.py @@ -4,7 +4,7 @@ from tinygrad import Tensor, Device, dtypes from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType from tinygrad.renderer import Estimates from tinygrad.helpers import getenv, all_same, dedup -from extra.gemm.asm.cdna.asm import build_kernel, GEMM_ARGS +from extra.gemm.asm.cdna.asm import build_kernel, TILE_M, TILE_N, TILE_K, NUM_WG # ** CDNA4 assembly gemm @@ -43,7 +43,8 @@ def can_use_asm_gemm(a:Tensor, b:Tensor) -> bool: else: dname = a.device arch = getattr(Device[dname].renderer, "arch", "") if batch not in {1, 2}: return todo(f"GEMM batch size {batch}") - if (key:=(M, N, K)) not in GEMM_ARGS and arch == "gfx950": return todo(f"GEMM shape not supported {key} on {arch}") + if M % TILE_M != 0 or N % TILE_N != 0 or K % TILE_K != 0: + return todo(f"GEMM shape ({M},{N},{K}) not a multiple of ({TILE_M},{TILE_N},{TILE_K})") return True # ** UOp gemm to test Tensor.custom_kernel multi and backward correctness on non cdna4 @@ -94,8 +95,7 @@ def asm_gemm(a:Tensor, b:Tensor) -> Tensor: renderer = Device[a.device[0] if is_multi else a.device].renderer dname, arch = renderer.device, getattr(renderer, "arch", "") if arch.startswith("gfx950") and getenv("USE_ASM", 1): - numWG = GEMM_ARGS[(M, N, K)][0] - out = Tensor.custom_kernel(out, a, b, fxn=functools.partial(custom_asm_gemm, dname=dname, wg=numWG, arch=arch), grad_fxn=custom_gemm_bw)[0] + out = Tensor.custom_kernel(out, a, b, fxn=functools.partial(custom_asm_gemm, dname=dname, wg=NUM_WG, arch=arch), grad_fxn=custom_gemm_bw)[0] else: out = Tensor.custom_kernel(out, a, b, fxn=custom_uop_gemm, grad_fxn=custom_gemm_bw)[0] if k_sharded: out = out.sum(0) diff --git a/test/external/external_test_llama3_layer.py b/test/external/external_test_llama3_layer.py index 0ec028620e..eb8508ccd0 100644 --- a/test/external/external_test_llama3_layer.py +++ b/test/external/external_test_llama3_layer.py @@ -1,17 +1,18 @@ #!/usr/bin/env python3 -from tinygrad import Tensor, TinyJit, nn +from tinygrad import Tensor, TinyJit, nn, dtypes from tinygrad.helpers import getenv from extra.models.llama import TransformerBlock, precompute_freqs_cis BS = getenv("BS", 1) SEQLEN = getenv("SEQLEN", 128) -# SEQLEN=8192 ASM_GEMM=1 HK_FLASH_ATTENTION=1 EMULATE=AMD_CDNA4 NULL=1 DEBUG=2 VIZ=1 PYTHONPATH="." python test/external/external_test_llama3_layer.py +# DEFAULT_FLOAT=bfloat16 SEQLEN=8192 ASM_GEMM=1 HK_FLASH_ATTENTION=1 EMULATE=AMD_CDNA4 NULL=1 DEBUG=2 VIZ=1 PYTHONPATH="." +# python test/external/external_test_llama3_layer.py if __name__ == "__main__": dim, hidden_dim, n_heads, n_kv_heads, norm_eps = 4096, 14336, 32, 8, 1e-5 layer = TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context=0) - for x in nn.state.get_parameters(layer): x.replace(x.half()).realize() + for x in nn.state.get_parameters(layer): x.replace(x.cast(dtypes.default_float)).realize() freqs_cis = precompute_freqs_cis(dim // n_heads, SEQLEN, theta=500000.0).contiguous().requires_grad_(False).realize() @@ -20,4 +21,4 @@ if __name__ == "__main__": for i in range(5): print(f"*** run {i}") - run(Tensor.rand(BS, SEQLEN, dim).half().realize()) + run(Tensor.rand(BS, SEQLEN, dim, dtype=dtypes.default_float).realize()) diff --git a/test/testextra/test_asm_gemm.py b/test/testextra/test_asm_gemm.py index 6a76f1c84d..bb53bf6382 100644 --- a/test/testextra/test_asm_gemm.py +++ b/test/testextra/test_asm_gemm.py @@ -86,9 +86,57 @@ class TestGemmLarge(unittest.TestCase): def test_k_sharded_1(self): verify_asm_gemm_k_sharded(14336, 4096, 8*8192, gpus=8) def test_k_sharded_2(self): verify_asm_gemm_k_sharded(4096, 14336, 8*8192, gpus=8) def test_k_sharded_3(self): verify_asm_gemm_k_sharded(4096, 4096, 8*8192, gpus=8) - def test_gemm_unsupported(self): - with self.assertRaisesRegex(AssertionError, "shape not supported"): - verify_asm_gemm(8, 1024, 1024, 4096, gpus=8) + def test_unsupported_k(self): + with self.assertRaisesRegex(AssertionError, "not a multiple"): + verify_asm_gemm(1, 1024, 1024, 100) + def test_unsupported_m(self): + with self.assertRaisesRegex(AssertionError, "not a multiple"): + verify_asm_gemm(1, 1000, 256, 256) + def test_unsupported_n(self): + with self.assertRaisesRegex(AssertionError, "not a multiple"): + verify_asm_gemm(1, 256, 1000, 256) + def test_unsupported_batch(self): + with self.assertRaisesRegex(AssertionError, "batch size"): + verify_asm_gemm(3, 256, 256, 256) + def test_gemm_previously_unsupported(self): verify_asm_gemm(8, 1024, 1024, 4096, gpus=8) + + # more shapes: vary M, N, K independently + def test_shape_small_square(self): verify_asm_gemm(1, 256, 256, 256) + def test_shape_small_rect_m(self): verify_asm_gemm(1, 512, 256, 256) + def test_shape_small_rect_n(self): verify_asm_gemm(1, 256, 512, 256) + def test_shape_small_rect_k(self): verify_asm_gemm(1, 256, 256, 512) + def test_shape_tall(self): verify_asm_gemm(1, 2048, 256, 256) + def test_shape_wide(self): verify_asm_gemm(1, 256, 2048, 256) + def test_shape_deep(self): verify_asm_gemm(1, 256, 256, 4096) + def test_shape_non_square(self): verify_asm_gemm(1, 1024, 2048, 512) + def test_shape_batched_small(self): verify_asm_gemm(2, 256, 256, 256) + def test_shape_batched_rect(self): verify_asm_gemm(2, 512, 1024, 256) + # K edge cases: iters=1,2,3 exercise different loop paths + def test_shape_k64(self): verify_asm_gemm(1, 256, 256, 64) + def test_shape_k128(self): verify_asm_gemm(1, 256, 256, 128) + def test_shape_k192(self): verify_asm_gemm(1, 256, 256, 192) + + def test_llama3_out1(self): verify_asm_gemm(1, 8192, 128256, 4096) + def test_llama3_out2(self): verify_asm_gemm(1, 8192, 4096, 128256) + def test_llama3_out3(self): verify_asm_gemm(1, 4096, 128256, 8192) + +class TestMagicGu(unittest.TestCase): + def test_magicgu_matches_old(self): + from extra.gemm.asm.cdna.asm import _magicgu_mulhi, TILE_M, TILE_N, TILE_K + old_iters_args = {64: (67108864, 0), 128: (33554432, 0), 224: (613566757, 2147483656)} + old_gemm_shapes = [ + (8192, 4096, 4096), (8192, 14336, 4096), (8192, 4096, 14336), + (8192, 8192, 8192), (4096, 4096, 4096), (4096, 14336, 4096), + (4096, 14336, 8192), (4096, 4096, 14336), (14336, 4096, 8192), + (4096, 8192, 14336), (4096, 4096, 8192), (4096, 8192, 4096), + ] + for M, N, K in old_gemm_shapes: + iters = K // TILE_K + total = (M // TILE_M) * (N // TILE_N) * iters + for batch in [1, 2]: + magic, shift = _magicgu_mulhi(iters, total * batch) + old_magic, old_shift = old_iters_args[iters] + self.assertEqual((magic, shift), (old_magic, old_shift), f"mismatch for ({M},{N},{K}) batch={batch} iters={iters}") if __name__ == "__main__": unittest.main()