mi350x assembly gemm cleanups (#13867)

This commit is contained in:
qazal
2025-12-29 18:47:23 +09:00
committed by GitHub
parent f07c39cfa4
commit fc5278746f
2 changed files with 19 additions and 10724 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -9,9 +9,12 @@ from tinygrad.helpers import TracingKey
fp = pathlib.Path(__file__).parent/"gemm.s"
N = 8192
THREADS_PER_WG = 256
NUM_WG = N//THREADS_PER_WG * N//THREADS_PER_WG
# ** generate inputs on CPU
N = 8192
scale = 10.0
import torch
@@ -34,19 +37,19 @@ C_asm.uop.buffer.allocate()
# ** run gemms
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name,), ret=ret))
def get_asm_gemm(ast:UOp, fp:pathlib.Path) -> ProgramSpec:
src = fp.read_text()
lib = Device[Device.DEFAULT].compiler.compile(src)
return ProgramSpec("gemm", src, Device.DEFAULT, ast, lib=lib, global_size=[1024, 1, 1], local_size=[256, 1, 1], globals=[0, 1, 2])
# baseline tinygrad
sched = C_tiny.schedule()
assert len(sched) == 1
eis:list[ExecItem] = [sched[-1].lower()]
ast = eis[0].ast
prg = get_asm_gemm(ast, fp)
eis.append(ExecItem(ast, [C_asm.uop.buffer, from_torch(B).uop.buffer, from_torch(A).uop.buffer], prg=CompiledRunner(prg)))
ast = sched[-1].ast
# assembly gemm
@track_rewrites(name=lambda ret: TracingKey(ret.name, (ret.function_name,), ret))
def get_asm_prg() -> ProgramSpec:
src = fp.read_text()
lib = Device[Device.DEFAULT].compiler.compile(src)
return ProgramSpec("gemm", src, Device.DEFAULT, ast, lib=lib, global_size=[NUM_WG, 1, 1], local_size=[THREADS_PER_WG, 1, 1], globals=[0, 1, 2])
eis.append(ExecItem(ast, [C_asm.uop.buffer, from_torch(B).uop.buffer, from_torch(A).uop.buffer], prg=CompiledRunner(get_asm_prg())))
for ei in eis:
et = ei.run(wait=True)