test asm_gemm in CI (#14551)

* test asm_gemm in CI

* default float16

* use a smaller shape for multi

* smaller size

* smaller for CI

* smaller for ci

* need half
This commit is contained in:
qazal
2026-02-04 23:32:22 -05:00
committed by GitHub
parent c0ca7f9c51
commit f9cfb64cd9
2 changed files with 24 additions and 16 deletions

View File

@@ -35,9 +35,13 @@ def can_use_asm_gemm(a:Tensor, b:Tensor) -> bool:
return todo(f"sharding mismatch a.ndim={a.ndim} a.uop.axis={a.uop.axis} b.uop.axis={b.uop.axis}")
batch, M, K = (1, *a.shape) if a.ndim == 2 else a.shape
N = b.shape[1]
if isinstance(a.device, tuple): batch //= len(a.device)
if isinstance(a.device, tuple):
batch //= len(a.device)
dname = a.device[0]
else: dname = a.device
arch = getattr(Device[dname].renderer, "arch", "")
if batch not in {1, 2}: return todo(f"GEMM batch size {batch}")
if (key:=(M, N, K)) not in GEMM_ARGS: return todo(f"GEMM shape not supported {key}")
if (key:=(M, N, K)) not in GEMM_ARGS and arch == "gfx950": return todo(f"GEMM shape not supported {key} on {arch}")
return True
# ** UOp gemm to test Tensor.custom_kernel multi and backward correctness on non cdna4

View File

@@ -1,9 +1,10 @@
import unittest
from tinygrad import Tensor, Device, dtypes, Context
from tinygrad.helpers import getenv
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import getenv, CI
from extra.gemm.asm.cdna.gemm import asm_gemm
def verify_asm_gemm(batch:int, M:int, N:int, K:int, dtype=dtypes.bfloat16, gpus:int=1) -> None:
def verify_asm_gemm(batch:int, M:int, N:int, K:int, dtype=dtypes.float16, gpus:int=1) -> None:
Tensor.manual_seed(0)
a_rand = Tensor.randn((batch, M, K), dtype=dtypes.float).sub(0.5).cast(dtype)
b_rand = Tensor.randn((K, N), dtype=dtypes.float).sub(0.5).cast(dtype)
@@ -29,26 +30,29 @@ def verify_asm_gemm(batch:int, M:int, N:int, K:int, dtype=dtypes.bfloat16, gpus:
assert (a.grad - a_ref.grad).square().max().float().item() < 1e-3, "grad_a mismatch"
assert (b.grad - b_ref.grad).square().max().float().item() < 1e-3, "grad_b mismatch"
SCALE = 128 if CI else 1
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
class TestGemm(unittest.TestCase):
def test_simple(self): verify_asm_gemm(1, N:=getenv("N", 4096), N, N, dtype=dtypes.half)
def test_gemm(self): verify_asm_gemm(1, 8192, 4096, 14336)
def test_gemm_multi(self): verify_asm_gemm(2, 8192, 4096, 14336, gpus=2)
def test_gemm_unsupported(self):
with self.assertRaisesRegex(AssertionError, "shape not supported"):
verify_asm_gemm(8, 8192, 1024, 4096, gpus=8)
def test_simple(self): verify_asm_gemm(1, N:=getenv("N", 4096)//SCALE, N//SCALE, N//SCALE, dtype=dtypes.half)
def test_gemm(self): verify_asm_gemm(1, 8192//SCALE, 4096//SCALE, 14336//SCALE)
def test_gemm_multi(self): verify_asm_gemm(2, 8192//SCALE, 4096//SCALE, 4096//SCALE, gpus=2)
class TestGemmLarge(unittest.TestCase):
def setUp(self):
if getattr(Device[Device.DEFAULT].renderer, "arch", "") != "gfx950":
self.skipTest("very slow on non mi350x")
def test_gemm1(self): verify_asm_gemm(8, 8192, 4096, 14336, gpus=8)
def test_gemm2(self): verify_asm_gemm(8, 8192, 128256, 4096, gpus=8)
def test_gemm3(self): verify_asm_gemm(8, 8192, 14336, 4096, gpus=8)
def test_gemm4(self): verify_asm_gemm(8, 4096, 14336, 4096, gpus=8)
def test_gemm5(self): verify_asm_gemm(8, 4096, 4096, 14336, gpus=8)
def test_gemm6(self): verify_asm_gemm(16, 4096, 4096, 14336, gpus=8)
def test_gemm1(self): verify_asm_gemm(8, 8192, 4096, 14336, dtype=dtypes.bfloat16, gpus=8)
def test_gemm2(self): verify_asm_gemm(8, 8192, 128256, 4096, dtype=dtypes.bfloat16, gpus=8)
def test_gemm3(self): verify_asm_gemm(8, 8192, 14336, 4096, dtype=dtypes.bfloat16, gpus=8)
def test_gemm4(self): verify_asm_gemm(8, 4096, 14336, 4096, dtype=dtypes.bfloat16, gpus=8)
def test_gemm5(self): verify_asm_gemm(8, 4096, 4096, 14336, dtype=dtypes.bfloat16, gpus=8)
def test_gemm6(self): verify_asm_gemm(16, 4096, 4096, 14336, dtype=dtypes.bfloat16, gpus=8)
def test_gemm7(self): verify_asm_gemm(1, 8192, 128256, 4096)
def test_gemm_unsupported(self):
with self.assertRaisesRegex(AssertionError, "shape not supported"):
verify_asm_gemm(8, 1024, 1024, 4096, gpus=8)
if __name__ == "__main__":
unittest.main()