mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
asm_gemm: support more sharding (#15002)
This commit is contained in:
@@ -3,7 +3,7 @@ from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.dtype import AddrSpace
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
|
||||
from tinygrad.renderer import Estimates
|
||||
from tinygrad.helpers import getenv, all_same, dedup
|
||||
from tinygrad.helpers import getenv, all_same, DEBUG
|
||||
from extra.gemm.asm.cdna.asm import build_kernel, TILE_M, TILE_N, TILE_K, NUM_WG
|
||||
|
||||
# ** CDNA4 assembly gemm
|
||||
@@ -26,17 +26,25 @@ def custom_asm_gemm(C:UOp, A:UOp, B:UOp, dname:str) -> UOp:
|
||||
|
||||
counters = {"used":0, "todos":[]}
|
||||
def todo(msg:str) -> bool: counters["todos"].append(msg); return False
|
||||
atexit.register(lambda: print(f'asm_gemm: {counters["used"]} used, {len(counters["todos"])} not used'))
|
||||
def _asm_gemm_report():
|
||||
print(f'asm_gemm: {counters["used"]} used, {len(counters["todos"])} not used')
|
||||
if DEBUG >= 2 and counters["todos"]:
|
||||
from collections import Counter
|
||||
for msg, cnt in Counter(counters["todos"]).most_common(): print(f' {cnt:3d}x {msg}')
|
||||
atexit.register(_asm_gemm_report)
|
||||
|
||||
def can_use_asm_gemm(a:Tensor, b:Tensor) -> bool:
|
||||
if a.dtype != b.dtype: return todo(f"dtypes must match {a.dtype} != {b.dtype}")
|
||||
if a.dtype not in {dtypes.bfloat16, dtypes.float16}: return todo(f"only bfloat16/float16, got {a.dtype}")
|
||||
batch, M, K = (1, *a.shape) if a.ndim == 2 else a.shape
|
||||
N = b.shape[1]
|
||||
# only sharding on the batch or K is tested, others might work too
|
||||
if isinstance(a.device, tuple):
|
||||
if a.ndim == 2 and a.uop.axis == 1 and b.uop.axis == 0: K //= len(a.device)
|
||||
if a.ndim == 2 and a.uop.axis == 0 and b.uop.axis is None: M //= len(a.device)
|
||||
elif a.ndim == 2 and a.uop.axis == 1 and b.uop.axis == 0: K //= len(a.device)
|
||||
elif a.ndim == 2 and a.uop.axis is None and b.uop.axis == 1: N //= len(a.device)
|
||||
elif a.ndim == 3 and a.uop.axis == 0 and b.uop.axis is None: batch //= len(a.device)
|
||||
elif a.ndim == 3 and a.uop.axis is None and b.uop.axis == 1: N //= len(a.device)
|
||||
elif a.ndim == 3 and a.uop.axis == 2 and b.uop.axis == 0: K //= len(a.device)
|
||||
else: return todo(f"sharding mismatch a.ndim={a.ndim} a.uop.axis={a.uop.axis} b.uop.axis={b.uop.axis}")
|
||||
dname = a.device[0]
|
||||
else: dname = a.device
|
||||
@@ -78,6 +86,10 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp):
|
||||
def asm_gemm(a:Tensor, b:Tensor) -> Tensor:
|
||||
assert can_use_asm_gemm(a, b), f"{counters['todos'][-1]}"
|
||||
counters["used"] += 1
|
||||
unfold_batch = a.ndim == 3 and isinstance(a.device, tuple) and a.uop.axis == 2 and b.uop.axis == 0
|
||||
if unfold_batch:
|
||||
orig_batch = a.shape[0]
|
||||
a = a.reshape(a.shape[0]*a.shape[1], a.shape[2])
|
||||
squeeze = a.ndim == 2
|
||||
if squeeze: a = a.unsqueeze(0)
|
||||
|
||||
@@ -85,9 +97,16 @@ def asm_gemm(a:Tensor, b:Tensor) -> Tensor:
|
||||
N = b.shape[1]
|
||||
is_multi = isinstance(a.device, tuple)
|
||||
if (k_sharded:=is_multi and a.uop.axis == 2): K //= len(a.device)
|
||||
if (m_sharded:=is_multi and a.uop.axis == 1): M //= len(a.device)
|
||||
n_sharded = is_multi and b.uop.axis == 1
|
||||
|
||||
if is_multi:
|
||||
out = Tensor(Tensor.empty(batch//len(a.device) if a.uop.axis==0 else batch, M, N, dtype=a.dtype, device=a.device).uop.multi(0), device=a.device)
|
||||
if n_sharded:
|
||||
out = Tensor(Tensor.empty(batch, M, N//len(a.device), dtype=a.dtype, device=a.device).uop.multi(2), device=a.device)
|
||||
elif m_sharded:
|
||||
out = Tensor(Tensor.empty(batch, M, N, dtype=a.dtype, device=a.device).uop.multi(1), device=a.device)
|
||||
else:
|
||||
out = Tensor(Tensor.empty(batch//len(a.device) if a.uop.axis==0 else batch, M, N, dtype=a.dtype, device=a.device).uop.multi(0), device=a.device)
|
||||
else:
|
||||
out = Tensor.empty(batch, M, N, dtype=a.dtype, device=a.device)
|
||||
|
||||
@@ -98,4 +117,6 @@ def asm_gemm(a:Tensor, b:Tensor) -> Tensor:
|
||||
else:
|
||||
out = Tensor.custom_kernel(out, a, b, fxn=custom_uop_gemm, grad_fxn=custom_gemm_bw)[0]
|
||||
if k_sharded: out = out.sum(0)
|
||||
return out.squeeze(0) if squeeze else out
|
||||
out = out.squeeze(0) if squeeze else out
|
||||
if unfold_batch: out = out.reshape(orig_batch, -1, out.shape[-1])
|
||||
return out
|
||||
|
||||
@@ -47,6 +47,18 @@ def verify_asm_gemm(batch:int, M:int, N:int, K:int, dtype=dtypes.float16, gpus:i
|
||||
def verify_asm_gemm_k_sharded(M:int, N:int, K:int, dtype=dtypes.float16, gpus:int=8) -> None:
|
||||
run_asm_gemm((M, K), (K, N), dtype=dtype, a_shard=1, b_shard=0, gpus=gpus)
|
||||
|
||||
def verify_asm_gemm_n_sharded(batch:int, M:int, N:int, K:int, dtype=dtypes.float16, gpus:int=2) -> None:
|
||||
run_asm_gemm((batch, M, K), (K, N), dtype=dtype, a_shard=None, b_shard=1, gpus=gpus)
|
||||
|
||||
def verify_asm_gemm_m_sharded(M:int, N:int, K:int, dtype=dtypes.float16, gpus:int=2) -> None:
|
||||
run_asm_gemm((M, K), (K, N), dtype=dtype, a_shard=0, b_shard=None, gpus=gpus)
|
||||
|
||||
def verify_asm_gemm_n_sharded_2d(M:int, N:int, K:int, dtype=dtypes.float16, gpus:int=2) -> None:
|
||||
run_asm_gemm((M, K), (K, N), dtype=dtype, a_shard=None, b_shard=1, gpus=gpus)
|
||||
|
||||
def verify_asm_gemm_k_sharded_3d(batch:int, M:int, N:int, K:int, dtype=dtypes.float16, gpus:int=2) -> None:
|
||||
run_asm_gemm((batch, M, K), (K, N), dtype=dtype, a_shard=2, b_shard=0, gpus=gpus)
|
||||
|
||||
# 128x smaller than usual
|
||||
# uses the UOp GEMM, runs on non CDNA4 and CI
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
@@ -60,6 +72,14 @@ class TestGemm(unittest.TestCase):
|
||||
def test_gemm_multi(self): verify_asm_gemm(2, 64, 32, 32, gpus=2)
|
||||
@needs_second_gpu
|
||||
def test_gemm_k_sharded(self): verify_asm_gemm_k_sharded(64, 64, 2*64, gpus=2)
|
||||
@needs_second_gpu
|
||||
def test_gemm_m_sharded(self): verify_asm_gemm_m_sharded(2*64, 64, 32, gpus=2)
|
||||
@needs_second_gpu
|
||||
def test_gemm_n_sharded(self): verify_asm_gemm_n_sharded(1, 64, 64, 32, gpus=2)
|
||||
@needs_second_gpu
|
||||
def test_gemm_n_sharded_2d(self): verify_asm_gemm_n_sharded_2d(64, 2*64, 32, gpus=2)
|
||||
@needs_second_gpu
|
||||
def test_gemm_k_sharded_3d(self): verify_asm_gemm_k_sharded_3d(1, 64, 32, 2*64, gpus=2)
|
||||
|
||||
# uses the Asm GEMM on CDNA4 only for speed reasons
|
||||
class TestGemmLarge(unittest.TestCase):
|
||||
@@ -101,6 +121,20 @@ class TestGemmLarge(unittest.TestCase):
|
||||
verify_asm_gemm(3, 256, 256, 256)
|
||||
def test_gemm_previously_unsupported(self): verify_asm_gemm(8, 1024, 1024, 4096, gpus=8)
|
||||
|
||||
# M-sharded 2D
|
||||
def test_m_sharded_1(self): verify_asm_gemm_m_sharded(8*8192, 4096, 4096, dtype=dtypes.bfloat16, gpus=8)
|
||||
def test_m_sharded_2(self): verify_asm_gemm_m_sharded(8*4096, 14336, 4096, dtype=dtypes.bfloat16, gpus=8)
|
||||
|
||||
# N-sharded 2D
|
||||
def test_n_sharded_2d_1(self): verify_asm_gemm_n_sharded_2d(8192, 8*4096, 4096, dtype=dtypes.bfloat16, gpus=8)
|
||||
def test_n_sharded_2d_2(self): verify_asm_gemm_n_sharded_2d(4096, 8*14336, 4096, dtype=dtypes.bfloat16, gpus=8)
|
||||
|
||||
# tensor parallel shapes (Llama 8B, MP=8)
|
||||
def test_tp_n_sharded_wq(self): verify_asm_gemm_n_sharded(1, 8192, 4096, 4096, dtype=dtypes.bfloat16, gpus=8)
|
||||
def test_tp_n_sharded_w1(self): verify_asm_gemm_n_sharded(1, 8192, 14336, 4096, dtype=dtypes.bfloat16, gpus=8)
|
||||
def test_tp_k_sharded_wo(self): verify_asm_gemm_k_sharded_3d(1, 8192, 4096, 4096, dtype=dtypes.bfloat16, gpus=8)
|
||||
def test_tp_k_sharded_w2(self): verify_asm_gemm_k_sharded_3d(1, 8192, 4096, 14336, dtype=dtypes.bfloat16, gpus=8)
|
||||
|
||||
# more shapes: vary M, N, K independently
|
||||
def test_shape_small_square(self): verify_asm_gemm(1, 256, 256, 256)
|
||||
def test_shape_small_rect_m(self): verify_asm_gemm(1, 512, 256, 256)
|
||||
|
||||
Reference in New Issue
Block a user