From 087dab4c3b602df91e3f169ae79eca97123776eb Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sun, 8 Feb 2026 07:33:42 -0500 Subject: [PATCH] gemm/asm: split out cdna tests from CI (#14619) * gemm/asm: split out cdna tests from CI * reorder * work --- test/testextra/test_asm_gemm.py | 36 +++++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/test/testextra/test_asm_gemm.py b/test/testextra/test_asm_gemm.py index 65f698ece8..e1dac741c7 100644 --- a/test/testextra/test_asm_gemm.py +++ b/test/testextra/test_asm_gemm.py @@ -5,6 +5,10 @@ from tinygrad.helpers import getenv from extra.gemm.asm.cdna.gemm import asm_gemm from test.helpers import needs_second_gpu +# On non CDNA4 it will only validate the Tensor.custom_kernel integration +# Use NULL=1 EMULATE=AMD_CDNA4 to also test the assembly +def is_cdna4(): return getattr(Device[Device.DEFAULT].renderer, "arch", "").startswith("gfx950") + def verify_asm_gemm(batch:int, M:int, N:int, K:int, dtype=dtypes.float16, gpus:int=1) -> None: Tensor.manual_seed(0) a_rand = Tensor.randn((batch, M, K), dtype=dtypes.float).sub(0.5).cast(dtype) @@ -16,37 +20,47 @@ def verify_asm_gemm(batch:int, M:int, N:int, K:int, dtype=dtypes.float16, gpus:i a, b = Tensor(a_rand.numpy(), requires_grad=True).cast(dtype), Tensor(b_rand.numpy(), requires_grad=True).cast(dtype) if multi: a, b = a.shard(devs, axis=0), b.shard(devs, axis=None) - tst = asm_gemm(a, b) - tst.sum().backward() + with Context(ASM_GEMM=1): + tst = asm_gemm(a, b) + tst.sum().backward() Tensor.realize(tst, a.grad, b.grad) a_ref, b_ref = Tensor(a_rand.numpy(), requires_grad=True).cast(dtype), Tensor(b_rand.numpy(), requires_grad=True).cast(dtype) if multi: a_ref, b_ref = a_ref.shard(devs, axis=0), b_ref.shard(devs, axis=None) - with Context(ASM_GEMM=0): ref = a_ref @ b_ref - ref.sum().backward() + with Context(ASM_GEMM=0): + ref = asm_gemm(a_ref, b_ref) + ref.sum().backward() Tensor.realize(ref, a_ref.grad, b_ref.grad) + # no validation on the NULL device + if a_rand.device.startswith("NULL"): return None with Context(DEBUG=0): assert (tst - ref).square().max().float().item() < 1e-6, "forward mismatch" assert (a.grad - a_ref.grad).square().max().float().item() < 1e-3, "grad_a mismatch" assert (b.grad - b_ref.grad).square().max().float().item() < 1e-3, "grad_b mismatch" # 128x smaller than usual -SCALE = 128 - +# uses the UOp GEMM, runs on non CDNA4 and CI @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") class TestGemm(unittest.TestCase): - def test_simple(self): verify_asm_gemm(1, N:=(getenv("N", 4096)//SCALE), N, N, dtype=dtypes.half) - def test_gemm(self): verify_asm_gemm(1, 8192//SCALE, 4096//SCALE, 14336//SCALE) - def test_gemm_batched(self): verify_asm_gemm(2, 8192//SCALE, 4096//SCALE, 4096//SCALE) + def setUp(self): + if is_cdna4(): self.skipTest("shapes are too small for the assembly GEMM") + def test_simple(self): verify_asm_gemm(1, N:=getenv("N", 32), N, N, dtype=dtypes.half) + def test_gemm(self): verify_asm_gemm(1, 64, 32, 112) + def test_gemm_batched(self): verify_asm_gemm(2, 64, 32, 32) @needs_second_gpu - def test_gemm_multi(self): verify_asm_gemm(2, 8192//SCALE, 4096//SCALE, 4096//SCALE, gpus=2) + def test_gemm_multi(self): verify_asm_gemm(2, 64, 32, 32, gpus=2) +# uses the Asm GEMM on CDNA4 only for speed reasons class TestGemmLarge(unittest.TestCase): def setUp(self): - if getattr(Device[Device.DEFAULT].renderer, "arch", "") != "gfx950": + if not is_cdna4(): self.skipTest("very slow on non mi350x") + def test_simple(self): verify_asm_gemm(1, N:=getenv("N", 4096), N, N, dtype=dtypes.half) + def test_gemm(self): verify_asm_gemm(1, 8192, 4096, 14336) + def test_gemm_batched(self): verify_asm_gemm(2, 8192, 4096, 4096) + def test_gemm1(self): verify_asm_gemm(8, 8192, 4096, 14336, dtype=dtypes.bfloat16, gpus=8) @unittest.skip("disabled, asm in this shape is slower than tinygrad") def test_gemm2(self): verify_asm_gemm(8, 8192, 128256, 4096, dtype=dtypes.bfloat16, gpus=8)