diff --git a/examples/mlperf/models/flat_llama.py b/examples/mlperf/models/flat_llama.py index fd6aa1e812..6373888088 100644 --- a/examples/mlperf/models/flat_llama.py +++ b/examples/mlperf/models/flat_llama.py @@ -35,16 +35,15 @@ def quantize_fp8(x:Tensor, amax_state:Tensor|None=None): return x_clamped.cast(FP8_DTYPE), scale.float().reciprocal() def matmul(x:Tensor, w:Tensor, fp8=FP8, amax_x:Tensor|None=None, amax_w:Tensor|None=None) -> Tensor: - from tinygrad.helpers import ASM_GEMM if not fp8: - if ASM_GEMM: + if getenv("ASM_GEMM"): from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm if can_use_asm_gemm(x, w.T): return asm_gemm(x, w.T) return x @ w.T x_fp8, x_scale = quantize_fp8(x, amax_state=amax_x) w_fp8, w_scale = quantize_fp8(w, amax_state=amax_w) combined_scale = x_scale * w_scale - if ASM_GEMM: + if getenv("ASM_GEMM"): from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm if can_use_asm_gemm(x_fp8, w_fp8.T): return asm_gemm(x_fp8, w_fp8.T, combined_scale=combined_scale) return x_fp8.dot(w_fp8.T, dtype=dtypes.float) * combined_scale diff --git a/test/backend/test_asm_gemm.py b/test/backend/test_asm_gemm.py index 3567360109..ab91dc37e0 100644 --- a/test/backend/test_asm_gemm.py +++ b/test/backend/test_asm_gemm.py @@ -21,9 +21,8 @@ def run_asm_gemm(a_shape, b_shape, dtype=dtypes.float16, a_shard=None, b_shard=N a, b = a_rand.clone().requires_grad_(), b_rand.clone().requires_grad_() if multi: a, b = a.shard(devs, axis=a_shard), b.shard(devs, axis=b_shard) - with Context(ASM_GEMM=1): - tst = asm_gemm(a, b) - tst.sum().backward() + tst = asm_gemm(a, b) + tst.sum().backward() Tensor.realize(tst, a.grad, b.grad) a_ref, b_ref = a_rand.clone().requires_grad_(), b_rand.clone().requires_grad_() @@ -32,9 +31,8 @@ def run_asm_gemm(a_shape, b_shape, dtype=dtypes.float16, a_shard=None, b_shard=N a_ref = a_ref.cast(dtypes.bfloat16) b_ref = b_ref.cast(dtypes.bfloat16) if multi: a_ref, b_ref = a_ref.shard(devs, axis=a_shard), b_ref.shard(devs, axis=b_shard) - with Context(ASM_GEMM=0): - ref = asm_gemm(a_ref, b_ref) - ref.sum().backward() + ref = a_ref @ b_ref + ref.sum().backward() Tensor.realize(ref, a_ref.grad, b_ref.grad) # no validation on the NULL device @@ -136,10 +134,8 @@ class TestGemmLlama(unittest.TestCase): if not is_cdna4() or getenv("MOCKGPU"): self.skipTest("very slow on non mi350x") - @Context(ASM_GEMM=1) def test_empty(self): asm_gemm(Tensor.empty(N:=getenv("N", 4096), N, dtype=self.dtype), Tensor.empty(N, N, dtype=self.dtype)).realize() - @Context(ASM_GEMM=1) def test_empty_bw(self): x = Tensor.empty(1, N:=getenv("N", 4096), N, dtype=self.dtype, requires_grad=True) y = Tensor.empty((N, N), dtype=self.dtype, requires_grad=True) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 22f2213835..ab09e1c6ad 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -251,8 +251,6 @@ ALLOW_TF32 = ContextVar("ALLOW_TF32", 0) SCACHE = ContextVar("SCACHE", 1) # allow use of atomics for embedding backward USE_ATOMICS = ContextVar("USE_ATOMICS", 0) -# allow use of assembly for gemm -ASM_GEMM = ContextVar("ASM_GEMM", 0) @dataclass(frozen=True) class Metadata: