diff --git a/extra/gemm/asm/cdna/gemm.py b/extra/gemm/asm/cdna/gemm.py index 1201966715..49b51f3fca 100644 --- a/extra/gemm/asm/cdna/gemm.py +++ b/extra/gemm/asm/cdna/gemm.py @@ -86,7 +86,7 @@ def asm_gemm(a:Tensor, b:Tensor) -> Tensor: out = Tensor.empty(batch, M, N, dtype=a.dtype, device=a.device) dname = a.device[0] if is_multi else a.device - arch = getattr(Device[dname].renderer, "arch", None) + arch = getattr(Device[dname].renderer, "arch", "") if arch.startswith("gfx950") and getenv("USE_ASM", 1): numWG = GEMM_ARGS[(M, N, K)][0] out = Tensor.custom_kernel(out, a, b, fxn=functools.partial(custom_asm_gemm, dname=dname, wg=numWG, arch=arch), grad_fxn=custom_gemm_bw)[0] diff --git a/extra/gemm/asm/cdna/test_asm_gemm.py b/extra/gemm/asm/cdna/test_asm_gemm.py index 905492cc60..8281446353 100644 --- a/extra/gemm/asm/cdna/test_asm_gemm.py +++ b/extra/gemm/asm/cdna/test_asm_gemm.py @@ -3,14 +3,14 @@ from tinygrad import Tensor, Device, dtypes, Context from tinygrad.helpers import getenv from extra.gemm.asm.cdna.gemm import asm_gemm -def verify_asm_gemm(batch:int, M:int, N:int, K:int, dtype=dtypes.bfloat16, multi=False) -> None: +def verify_asm_gemm(batch:int, M:int, N:int, K:int, dtype=dtypes.bfloat16, gpus:int=1) -> None: Tensor.manual_seed(0) a_rand = Tensor.randn((batch, M, K), dtype=dtypes.float).sub(0.5).cast(dtype) b_rand = Tensor.randn((K, N), dtype=dtypes.float).sub(0.5).cast(dtype) with Context(DEBUG=0): Tensor.realize(a_rand, b_rand) - devs = tuple(f"{Device.DEFAULT}:{i}" for i in range(8)) if multi else None + devs = tuple(f"{Device.DEFAULT}:{i}" for i in range(gpus)) if (multi:=gpus>1) else None a, b = Tensor(a_rand.numpy(), requires_grad=True).cast(dtype), Tensor(b_rand.numpy(), requires_grad=True).cast(dtype) if multi: a, b = a.shard(devs, axis=0), b.shard(devs, axis=None) @@ -31,16 +31,23 @@ def verify_asm_gemm(batch:int, M:int, N:int, K:int, dtype=dtypes.bfloat16, multi class TestGemm(unittest.TestCase): def test_simple(self): verify_asm_gemm(1, N:=getenv("N", 4096), N, N, dtype=dtypes.half) - - def test_gemm1(self): verify_asm_gemm(8, 8192, 4096, 14336, multi=True) - def test_gemm2(self): verify_asm_gemm(8, 8192, 128256, 4096, multi=True) - def test_gemm3(self): verify_asm_gemm(8, 8192, 14336, 4096, multi=True) - def test_gemm4(self): verify_asm_gemm(8, 4096, 14336, 4096, multi=True) - def test_gemm5(self): verify_asm_gemm(8, 4096, 4096, 14336, multi=True) - def test_gemm6(self): verify_asm_gemm(16, 4096, 4096, 14336, multi=True) + def test_gemm(self): verify_asm_gemm(1, 8192, 4096, 14336) + def test_gemm_multi(self): verify_asm_gemm(2, 8192, 4096, 14336, gpus=2) def test_gemm_unsupported(self): with self.assertRaisesRegex(AssertionError, "shape not supported"): - verify_asm_gemm(8, 8192, 1024, 4096, multi=True) + verify_asm_gemm(8, 8192, 1024, 4096, gpus=8) + +class TestGemmLarge(unittest.TestCase): + def setUp(self): + if getattr(Device[Device.DEFAULT].renderer, "arch", "") != "gfx950": + self.skipTest("very slow on non mi350x") + + def test_gemm1(self): verify_asm_gemm(8, 8192, 4096, 14336, gpus=8) + def test_gemm2(self): verify_asm_gemm(8, 8192, 128256, 4096, gpus=8) + def test_gemm3(self): verify_asm_gemm(8, 8192, 14336, 4096, gpus=8) + def test_gemm4(self): verify_asm_gemm(8, 4096, 14336, 4096, gpus=8) + def test_gemm5(self): verify_asm_gemm(8, 4096, 4096, 14336, gpus=8) + def test_gemm6(self): verify_asm_gemm(16, 4096, 4096, 14336, gpus=8) if __name__ == "__main__": unittest.main()