From c35de9bd68c3062fa022252c2ddde5612e4b5f8a Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Tue, 3 Mar 2026 15:16:37 +0800 Subject: [PATCH] asm_gemm: support more sharding (#15002) --- extra/gemm/asm/cdna/gemm.py | 33 +++++++++++++++++++++++++++------ test/backend/test_asm_gemm.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 6 deletions(-) diff --git a/extra/gemm/asm/cdna/gemm.py b/extra/gemm/asm/cdna/gemm.py index 6b7bfc7e8d..b46e93e4a5 100644 --- a/extra/gemm/asm/cdna/gemm.py +++ b/extra/gemm/asm/cdna/gemm.py @@ -3,7 +3,7 @@ from tinygrad import Tensor, Device, dtypes from tinygrad.dtype import AddrSpace from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType from tinygrad.renderer import Estimates -from tinygrad.helpers import getenv, all_same, dedup +from tinygrad.helpers import getenv, all_same, DEBUG from extra.gemm.asm.cdna.asm import build_kernel, TILE_M, TILE_N, TILE_K, NUM_WG # ** CDNA4 assembly gemm @@ -26,17 +26,25 @@ def custom_asm_gemm(C:UOp, A:UOp, B:UOp, dname:str) -> UOp: counters = {"used":0, "todos":[]} def todo(msg:str) -> bool: counters["todos"].append(msg); return False -atexit.register(lambda: print(f'asm_gemm: {counters["used"]} used, {len(counters["todos"])} not used')) +def _asm_gemm_report(): + print(f'asm_gemm: {counters["used"]} used, {len(counters["todos"])} not used') + if DEBUG >= 2 and counters["todos"]: + from collections import Counter + for msg, cnt in Counter(counters["todos"]).most_common(): print(f' {cnt:3d}x {msg}') +atexit.register(_asm_gemm_report) def can_use_asm_gemm(a:Tensor, b:Tensor) -> bool: if a.dtype != b.dtype: return todo(f"dtypes must match {a.dtype} != {b.dtype}") if a.dtype not in {dtypes.bfloat16, dtypes.float16}: return todo(f"only bfloat16/float16, got {a.dtype}") batch, M, K = (1, *a.shape) if a.ndim == 2 else a.shape N = b.shape[1] - # only sharding on the batch or K is tested, others might work too if isinstance(a.device, tuple): - if a.ndim == 2 and a.uop.axis == 1 and b.uop.axis == 0: K //= len(a.device) + if a.ndim == 2 and a.uop.axis == 0 and b.uop.axis is None: M //= len(a.device) + elif a.ndim == 2 and a.uop.axis == 1 and b.uop.axis == 0: K //= len(a.device) + elif a.ndim == 2 and a.uop.axis is None and b.uop.axis == 1: N //= len(a.device) elif a.ndim == 3 and a.uop.axis == 0 and b.uop.axis is None: batch //= len(a.device) + elif a.ndim == 3 and a.uop.axis is None and b.uop.axis == 1: N //= len(a.device) + elif a.ndim == 3 and a.uop.axis == 2 and b.uop.axis == 0: K //= len(a.device) else: return todo(f"sharding mismatch a.ndim={a.ndim} a.uop.axis={a.uop.axis} b.uop.axis={b.uop.axis}") dname = a.device[0] else: dname = a.device @@ -78,6 +86,10 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp): def asm_gemm(a:Tensor, b:Tensor) -> Tensor: assert can_use_asm_gemm(a, b), f"{counters['todos'][-1]}" counters["used"] += 1 + unfold_batch = a.ndim == 3 and isinstance(a.device, tuple) and a.uop.axis == 2 and b.uop.axis == 0 + if unfold_batch: + orig_batch = a.shape[0] + a = a.reshape(a.shape[0]*a.shape[1], a.shape[2]) squeeze = a.ndim == 2 if squeeze: a = a.unsqueeze(0) @@ -85,9 +97,16 @@ def asm_gemm(a:Tensor, b:Tensor) -> Tensor: N = b.shape[1] is_multi = isinstance(a.device, tuple) if (k_sharded:=is_multi and a.uop.axis == 2): K //= len(a.device) + if (m_sharded:=is_multi and a.uop.axis == 1): M //= len(a.device) + n_sharded = is_multi and b.uop.axis == 1 if is_multi: - out = Tensor(Tensor.empty(batch//len(a.device) if a.uop.axis==0 else batch, M, N, dtype=a.dtype, device=a.device).uop.multi(0), device=a.device) + if n_sharded: + out = Tensor(Tensor.empty(batch, M, N//len(a.device), dtype=a.dtype, device=a.device).uop.multi(2), device=a.device) + elif m_sharded: + out = Tensor(Tensor.empty(batch, M, N, dtype=a.dtype, device=a.device).uop.multi(1), device=a.device) + else: + out = Tensor(Tensor.empty(batch//len(a.device) if a.uop.axis==0 else batch, M, N, dtype=a.dtype, device=a.device).uop.multi(0), device=a.device) else: out = Tensor.empty(batch, M, N, dtype=a.dtype, device=a.device) @@ -98,4 +117,6 @@ def asm_gemm(a:Tensor, b:Tensor) -> Tensor: else: out = Tensor.custom_kernel(out, a, b, fxn=custom_uop_gemm, grad_fxn=custom_gemm_bw)[0] if k_sharded: out = out.sum(0) - return out.squeeze(0) if squeeze else out + out = out.squeeze(0) if squeeze else out + if unfold_batch: out = out.reshape(orig_batch, -1, out.shape[-1]) + return out diff --git a/test/backend/test_asm_gemm.py b/test/backend/test_asm_gemm.py index bd97f6d0da..d3179d87f8 100644 --- a/test/backend/test_asm_gemm.py +++ b/test/backend/test_asm_gemm.py @@ -47,6 +47,18 @@ def verify_asm_gemm(batch:int, M:int, N:int, K:int, dtype=dtypes.float16, gpus:i def verify_asm_gemm_k_sharded(M:int, N:int, K:int, dtype=dtypes.float16, gpus:int=8) -> None: run_asm_gemm((M, K), (K, N), dtype=dtype, a_shard=1, b_shard=0, gpus=gpus) +def verify_asm_gemm_n_sharded(batch:int, M:int, N:int, K:int, dtype=dtypes.float16, gpus:int=2) -> None: + run_asm_gemm((batch, M, K), (K, N), dtype=dtype, a_shard=None, b_shard=1, gpus=gpus) + +def verify_asm_gemm_m_sharded(M:int, N:int, K:int, dtype=dtypes.float16, gpus:int=2) -> None: + run_asm_gemm((M, K), (K, N), dtype=dtype, a_shard=0, b_shard=None, gpus=gpus) + +def verify_asm_gemm_n_sharded_2d(M:int, N:int, K:int, dtype=dtypes.float16, gpus:int=2) -> None: + run_asm_gemm((M, K), (K, N), dtype=dtype, a_shard=None, b_shard=1, gpus=gpus) + +def verify_asm_gemm_k_sharded_3d(batch:int, M:int, N:int, K:int, dtype=dtypes.float16, gpus:int=2) -> None: + run_asm_gemm((batch, M, K), (K, N), dtype=dtype, a_shard=2, b_shard=0, gpus=gpus) + # 128x smaller than usual # uses the UOp GEMM, runs on non CDNA4 and CI @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") @@ -60,6 +72,14 @@ class TestGemm(unittest.TestCase): def test_gemm_multi(self): verify_asm_gemm(2, 64, 32, 32, gpus=2) @needs_second_gpu def test_gemm_k_sharded(self): verify_asm_gemm_k_sharded(64, 64, 2*64, gpus=2) + @needs_second_gpu + def test_gemm_m_sharded(self): verify_asm_gemm_m_sharded(2*64, 64, 32, gpus=2) + @needs_second_gpu + def test_gemm_n_sharded(self): verify_asm_gemm_n_sharded(1, 64, 64, 32, gpus=2) + @needs_second_gpu + def test_gemm_n_sharded_2d(self): verify_asm_gemm_n_sharded_2d(64, 2*64, 32, gpus=2) + @needs_second_gpu + def test_gemm_k_sharded_3d(self): verify_asm_gemm_k_sharded_3d(1, 64, 32, 2*64, gpus=2) # uses the Asm GEMM on CDNA4 only for speed reasons class TestGemmLarge(unittest.TestCase): @@ -101,6 +121,20 @@ class TestGemmLarge(unittest.TestCase): verify_asm_gemm(3, 256, 256, 256) def test_gemm_previously_unsupported(self): verify_asm_gemm(8, 1024, 1024, 4096, gpus=8) + # M-sharded 2D + def test_m_sharded_1(self): verify_asm_gemm_m_sharded(8*8192, 4096, 4096, dtype=dtypes.bfloat16, gpus=8) + def test_m_sharded_2(self): verify_asm_gemm_m_sharded(8*4096, 14336, 4096, dtype=dtypes.bfloat16, gpus=8) + + # N-sharded 2D + def test_n_sharded_2d_1(self): verify_asm_gemm_n_sharded_2d(8192, 8*4096, 4096, dtype=dtypes.bfloat16, gpus=8) + def test_n_sharded_2d_2(self): verify_asm_gemm_n_sharded_2d(4096, 8*14336, 4096, dtype=dtypes.bfloat16, gpus=8) + + # tensor parallel shapes (Llama 8B, MP=8) + def test_tp_n_sharded_wq(self): verify_asm_gemm_n_sharded(1, 8192, 4096, 4096, dtype=dtypes.bfloat16, gpus=8) + def test_tp_n_sharded_w1(self): verify_asm_gemm_n_sharded(1, 8192, 14336, 4096, dtype=dtypes.bfloat16, gpus=8) + def test_tp_k_sharded_wo(self): verify_asm_gemm_k_sharded_3d(1, 8192, 4096, 4096, dtype=dtypes.bfloat16, gpus=8) + def test_tp_k_sharded_w2(self): verify_asm_gemm_k_sharded_3d(1, 8192, 4096, 14336, dtype=dtypes.bfloat16, gpus=8) + # more shapes: vary M, N, K independently def test_shape_small_square(self): verify_asm_gemm(1, 256, 256, 256) def test_shape_small_rect_m(self): verify_asm_gemm(1, 512, 256, 256)