remove ASM_GEMM context var (#15645)

This commit is contained in:
qazal
2026-04-08 12:02:40 +03:00
committed by GitHub
parent dc6a51e44d
commit 39a029ec55
3 changed files with 6 additions and 13 deletions

View File

@@ -35,16 +35,15 @@ def quantize_fp8(x:Tensor, amax_state:Tensor|None=None):
return x_clamped.cast(FP8_DTYPE), scale.float().reciprocal()
def matmul(x:Tensor, w:Tensor, fp8=FP8, amax_x:Tensor|None=None, amax_w:Tensor|None=None) -> Tensor:
from tinygrad.helpers import ASM_GEMM
if not fp8:
if ASM_GEMM:
if getenv("ASM_GEMM"):
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
if can_use_asm_gemm(x, w.T): return asm_gemm(x, w.T)
return x @ w.T
x_fp8, x_scale = quantize_fp8(x, amax_state=amax_x)
w_fp8, w_scale = quantize_fp8(w, amax_state=amax_w)
combined_scale = x_scale * w_scale
if ASM_GEMM:
if getenv("ASM_GEMM"):
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
if can_use_asm_gemm(x_fp8, w_fp8.T): return asm_gemm(x_fp8, w_fp8.T, combined_scale=combined_scale)
return x_fp8.dot(w_fp8.T, dtype=dtypes.float) * combined_scale

View File

@@ -21,9 +21,8 @@ def run_asm_gemm(a_shape, b_shape, dtype=dtypes.float16, a_shard=None, b_shard=N
a, b = a_rand.clone().requires_grad_(), b_rand.clone().requires_grad_()
if multi: a, b = a.shard(devs, axis=a_shard), b.shard(devs, axis=b_shard)
with Context(ASM_GEMM=1):
tst = asm_gemm(a, b)
tst.sum().backward()
tst = asm_gemm(a, b)
tst.sum().backward()
Tensor.realize(tst, a.grad, b.grad)
a_ref, b_ref = a_rand.clone().requires_grad_(), b_rand.clone().requires_grad_()
@@ -32,9 +31,8 @@ def run_asm_gemm(a_shape, b_shape, dtype=dtypes.float16, a_shard=None, b_shard=N
a_ref = a_ref.cast(dtypes.bfloat16)
b_ref = b_ref.cast(dtypes.bfloat16)
if multi: a_ref, b_ref = a_ref.shard(devs, axis=a_shard), b_ref.shard(devs, axis=b_shard)
with Context(ASM_GEMM=0):
ref = asm_gemm(a_ref, b_ref)
ref.sum().backward()
ref = a_ref @ b_ref
ref.sum().backward()
Tensor.realize(ref, a_ref.grad, b_ref.grad)
# no validation on the NULL device
@@ -136,10 +134,8 @@ class TestGemmLlama(unittest.TestCase):
if not is_cdna4() or getenv("MOCKGPU"):
self.skipTest("very slow on non mi350x")
@Context(ASM_GEMM=1)
def test_empty(self): asm_gemm(Tensor.empty(N:=getenv("N", 4096), N, dtype=self.dtype), Tensor.empty(N, N, dtype=self.dtype)).realize()
@Context(ASM_GEMM=1)
def test_empty_bw(self):
x = Tensor.empty(1, N:=getenv("N", 4096), N, dtype=self.dtype, requires_grad=True)
y = Tensor.empty((N, N), dtype=self.dtype, requires_grad=True)

View File

@@ -251,8 +251,6 @@ ALLOW_TF32 = ContextVar("ALLOW_TF32", 0)
SCACHE = ContextVar("SCACHE", 1)
# allow use of atomics for embedding backward
USE_ATOMICS = ContextVar("USE_ATOMICS", 0)
# allow use of assembly for gemm
ASM_GEMM = ContextVar("ASM_GEMM", 0)
@dataclass(frozen=True)
class Metadata: