mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
mi350x assembly gemm cleanups (#13867)
This commit is contained in:
10720
extra/gemm/asm/gemm.s
10720
extra/gemm/asm/gemm.s
File diff suppressed because it is too large
Load Diff
@@ -9,9 +9,12 @@ from tinygrad.helpers import TracingKey
|
||||
|
||||
fp = pathlib.Path(__file__).parent/"gemm.s"
|
||||
|
||||
N = 8192
|
||||
THREADS_PER_WG = 256
|
||||
NUM_WG = N//THREADS_PER_WG * N//THREADS_PER_WG
|
||||
|
||||
# ** generate inputs on CPU
|
||||
|
||||
N = 8192
|
||||
scale = 10.0
|
||||
|
||||
import torch
|
||||
@@ -34,19 +37,19 @@ C_asm.uop.buffer.allocate()
|
||||
|
||||
# ** run gemms
|
||||
|
||||
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name,), ret=ret))
|
||||
def get_asm_gemm(ast:UOp, fp:pathlib.Path) -> ProgramSpec:
|
||||
src = fp.read_text()
|
||||
lib = Device[Device.DEFAULT].compiler.compile(src)
|
||||
return ProgramSpec("gemm", src, Device.DEFAULT, ast, lib=lib, global_size=[1024, 1, 1], local_size=[256, 1, 1], globals=[0, 1, 2])
|
||||
|
||||
# baseline tinygrad
|
||||
sched = C_tiny.schedule()
|
||||
assert len(sched) == 1
|
||||
eis:list[ExecItem] = [sched[-1].lower()]
|
||||
ast = eis[0].ast
|
||||
prg = get_asm_gemm(ast, fp)
|
||||
eis.append(ExecItem(ast, [C_asm.uop.buffer, from_torch(B).uop.buffer, from_torch(A).uop.buffer], prg=CompiledRunner(prg)))
|
||||
ast = sched[-1].ast
|
||||
|
||||
# assembly gemm
|
||||
@track_rewrites(name=lambda ret: TracingKey(ret.name, (ret.function_name,), ret))
|
||||
def get_asm_prg() -> ProgramSpec:
|
||||
src = fp.read_text()
|
||||
lib = Device[Device.DEFAULT].compiler.compile(src)
|
||||
return ProgramSpec("gemm", src, Device.DEFAULT, ast, lib=lib, global_size=[NUM_WG, 1, 1], local_size=[THREADS_PER_WG, 1, 1], globals=[0, 1, 2])
|
||||
eis.append(ExecItem(ast, [C_asm.uop.buffer, from_torch(B).uop.buffer, from_torch(A).uop.buffer], prg=CompiledRunner(get_asm_prg())))
|
||||
|
||||
for ei in eis:
|
||||
et = ei.run(wait=True)
|
||||
|
||||
Reference in New Issue
Block a user