mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
remove ASM_GEMM context var (#15645)
This commit is contained in:
@@ -35,16 +35,15 @@ def quantize_fp8(x:Tensor, amax_state:Tensor|None=None):
|
||||
return x_clamped.cast(FP8_DTYPE), scale.float().reciprocal()
|
||||
|
||||
def matmul(x:Tensor, w:Tensor, fp8=FP8, amax_x:Tensor|None=None, amax_w:Tensor|None=None) -> Tensor:
|
||||
from tinygrad.helpers import ASM_GEMM
|
||||
if not fp8:
|
||||
if ASM_GEMM:
|
||||
if getenv("ASM_GEMM"):
|
||||
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
|
||||
if can_use_asm_gemm(x, w.T): return asm_gemm(x, w.T)
|
||||
return x @ w.T
|
||||
x_fp8, x_scale = quantize_fp8(x, amax_state=amax_x)
|
||||
w_fp8, w_scale = quantize_fp8(w, amax_state=amax_w)
|
||||
combined_scale = x_scale * w_scale
|
||||
if ASM_GEMM:
|
||||
if getenv("ASM_GEMM"):
|
||||
from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm
|
||||
if can_use_asm_gemm(x_fp8, w_fp8.T): return asm_gemm(x_fp8, w_fp8.T, combined_scale=combined_scale)
|
||||
return x_fp8.dot(w_fp8.T, dtype=dtypes.float) * combined_scale
|
||||
|
||||
@@ -21,9 +21,8 @@ def run_asm_gemm(a_shape, b_shape, dtype=dtypes.float16, a_shard=None, b_shard=N
|
||||
|
||||
a, b = a_rand.clone().requires_grad_(), b_rand.clone().requires_grad_()
|
||||
if multi: a, b = a.shard(devs, axis=a_shard), b.shard(devs, axis=b_shard)
|
||||
with Context(ASM_GEMM=1):
|
||||
tst = asm_gemm(a, b)
|
||||
tst.sum().backward()
|
||||
tst = asm_gemm(a, b)
|
||||
tst.sum().backward()
|
||||
Tensor.realize(tst, a.grad, b.grad)
|
||||
|
||||
a_ref, b_ref = a_rand.clone().requires_grad_(), b_rand.clone().requires_grad_()
|
||||
@@ -32,9 +31,8 @@ def run_asm_gemm(a_shape, b_shape, dtype=dtypes.float16, a_shard=None, b_shard=N
|
||||
a_ref = a_ref.cast(dtypes.bfloat16)
|
||||
b_ref = b_ref.cast(dtypes.bfloat16)
|
||||
if multi: a_ref, b_ref = a_ref.shard(devs, axis=a_shard), b_ref.shard(devs, axis=b_shard)
|
||||
with Context(ASM_GEMM=0):
|
||||
ref = asm_gemm(a_ref, b_ref)
|
||||
ref.sum().backward()
|
||||
ref = a_ref @ b_ref
|
||||
ref.sum().backward()
|
||||
Tensor.realize(ref, a_ref.grad, b_ref.grad)
|
||||
|
||||
# no validation on the NULL device
|
||||
@@ -136,10 +134,8 @@ class TestGemmLlama(unittest.TestCase):
|
||||
if not is_cdna4() or getenv("MOCKGPU"):
|
||||
self.skipTest("very slow on non mi350x")
|
||||
|
||||
@Context(ASM_GEMM=1)
|
||||
def test_empty(self): asm_gemm(Tensor.empty(N:=getenv("N", 4096), N, dtype=self.dtype), Tensor.empty(N, N, dtype=self.dtype)).realize()
|
||||
|
||||
@Context(ASM_GEMM=1)
|
||||
def test_empty_bw(self):
|
||||
x = Tensor.empty(1, N:=getenv("N", 4096), N, dtype=self.dtype, requires_grad=True)
|
||||
y = Tensor.empty((N, N), dtype=self.dtype, requires_grad=True)
|
||||
|
||||
@@ -251,8 +251,6 @@ ALLOW_TF32 = ContextVar("ALLOW_TF32", 0)
|
||||
SCACHE = ContextVar("SCACHE", 1)
|
||||
# allow use of atomics for embedding backward
|
||||
USE_ATOMICS = ContextVar("USE_ATOMICS", 0)
|
||||
# allow use of assembly for gemm
|
||||
ASM_GEMM = ContextVar("ASM_GEMM", 0)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Metadata:
|
||||
|
||||
Reference in New Issue
Block a user