diff --git a/extra/gemm/asm/cdna/gemm.py b/extra/gemm/asm/cdna/gemm.py index 49b51f3fca..463f38238c 100644 --- a/extra/gemm/asm/cdna/gemm.py +++ b/extra/gemm/asm/cdna/gemm.py @@ -35,9 +35,13 @@ def can_use_asm_gemm(a:Tensor, b:Tensor) -> bool: return todo(f"sharding mismatch a.ndim={a.ndim} a.uop.axis={a.uop.axis} b.uop.axis={b.uop.axis}") batch, M, K = (1, *a.shape) if a.ndim == 2 else a.shape N = b.shape[1] - if isinstance(a.device, tuple): batch //= len(a.device) + if isinstance(a.device, tuple): + batch //= len(a.device) + dname = a.device[0] + else: dname = a.device + arch = getattr(Device[dname].renderer, "arch", "") if batch not in {1, 2}: return todo(f"GEMM batch size {batch}") - if (key:=(M, N, K)) not in GEMM_ARGS: return todo(f"GEMM shape not supported {key}") + if (key:=(M, N, K)) not in GEMM_ARGS and arch == "gfx950": return todo(f"GEMM shape not supported {key} on {arch}") return True # ** UOp gemm to test Tensor.custom_kernel multi and backward correctness on non cdna4 diff --git a/extra/gemm/asm/cdna/test_asm_gemm.py b/test/testextra/test_asm_gemm.py similarity index 64% rename from extra/gemm/asm/cdna/test_asm_gemm.py rename to test/testextra/test_asm_gemm.py index 53e7b6d5b5..a0607e9cb0 100644 --- a/extra/gemm/asm/cdna/test_asm_gemm.py +++ b/test/testextra/test_asm_gemm.py @@ -1,9 +1,10 @@ import unittest from tinygrad import Tensor, Device, dtypes, Context -from tinygrad.helpers import getenv +from tinygrad.device import is_dtype_supported +from tinygrad.helpers import getenv, CI from extra.gemm.asm.cdna.gemm import asm_gemm -def verify_asm_gemm(batch:int, M:int, N:int, K:int, dtype=dtypes.bfloat16, gpus:int=1) -> None: +def verify_asm_gemm(batch:int, M:int, N:int, K:int, dtype=dtypes.float16, gpus:int=1) -> None: Tensor.manual_seed(0) a_rand = Tensor.randn((batch, M, K), dtype=dtypes.float).sub(0.5).cast(dtype) b_rand = Tensor.randn((K, N), dtype=dtypes.float).sub(0.5).cast(dtype) @@ -29,26 +30,29 @@ def verify_asm_gemm(batch:int, M:int, N:int, K:int, dtype=dtypes.bfloat16, gpus: assert (a.grad - a_ref.grad).square().max().float().item() < 1e-3, "grad_a mismatch" assert (b.grad - b_ref.grad).square().max().float().item() < 1e-3, "grad_b mismatch" +SCALE = 128 if CI else 1 + +@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") class TestGemm(unittest.TestCase): - def test_simple(self): verify_asm_gemm(1, N:=getenv("N", 4096), N, N, dtype=dtypes.half) - def test_gemm(self): verify_asm_gemm(1, 8192, 4096, 14336) - def test_gemm_multi(self): verify_asm_gemm(2, 8192, 4096, 14336, gpus=2) - def test_gemm_unsupported(self): - with self.assertRaisesRegex(AssertionError, "shape not supported"): - verify_asm_gemm(8, 8192, 1024, 4096, gpus=8) + def test_simple(self): verify_asm_gemm(1, N:=getenv("N", 4096)//SCALE, N//SCALE, N//SCALE, dtype=dtypes.half) + def test_gemm(self): verify_asm_gemm(1, 8192//SCALE, 4096//SCALE, 14336//SCALE) + def test_gemm_multi(self): verify_asm_gemm(2, 8192//SCALE, 4096//SCALE, 4096//SCALE, gpus=2) class TestGemmLarge(unittest.TestCase): def setUp(self): if getattr(Device[Device.DEFAULT].renderer, "arch", "") != "gfx950": self.skipTest("very slow on non mi350x") - def test_gemm1(self): verify_asm_gemm(8, 8192, 4096, 14336, gpus=8) - def test_gemm2(self): verify_asm_gemm(8, 8192, 128256, 4096, gpus=8) - def test_gemm3(self): verify_asm_gemm(8, 8192, 14336, 4096, gpus=8) - def test_gemm4(self): verify_asm_gemm(8, 4096, 14336, 4096, gpus=8) - def test_gemm5(self): verify_asm_gemm(8, 4096, 4096, 14336, gpus=8) - def test_gemm6(self): verify_asm_gemm(16, 4096, 4096, 14336, gpus=8) + def test_gemm1(self): verify_asm_gemm(8, 8192, 4096, 14336, dtype=dtypes.bfloat16, gpus=8) + def test_gemm2(self): verify_asm_gemm(8, 8192, 128256, 4096, dtype=dtypes.bfloat16, gpus=8) + def test_gemm3(self): verify_asm_gemm(8, 8192, 14336, 4096, dtype=dtypes.bfloat16, gpus=8) + def test_gemm4(self): verify_asm_gemm(8, 4096, 14336, 4096, dtype=dtypes.bfloat16, gpus=8) + def test_gemm5(self): verify_asm_gemm(8, 4096, 4096, 14336, dtype=dtypes.bfloat16, gpus=8) + def test_gemm6(self): verify_asm_gemm(16, 4096, 4096, 14336, dtype=dtypes.bfloat16, gpus=8) def test_gemm7(self): verify_asm_gemm(1, 8192, 128256, 4096) + def test_gemm_unsupported(self): + with self.assertRaisesRegex(AssertionError, "shape not supported"): + verify_asm_gemm(8, 1024, 1024, 4096, gpus=8) if __name__ == "__main__": unittest.main()