correct args order in mi350x gemm (#13949)

This commit is contained in:
qazal
2026-01-01 23:01:46 +09:00
committed by GitHub
parent baff10d32c
commit 6a5430ab00
2 changed files with 3 additions and 2 deletions

View File

@@ -1,6 +1,7 @@
// ** global buffers
s_load_dwordx2 s[28:29], s[0:1], 0x0 // C
s_load_dwordx4 s[32:35], s[0:1], 0x8 // A, B
s_load_dwordx2 s[34:35], s[0:1], 0x08 // A
s_load_dwordx2 s[32:33], s[0:1], 0x10 // B
// ** others kernel args
s_load_dword s24, s[0:1], 0x18 // N
s_load_dword s54, s[0:1], 0x1C // num work groups

View File

@@ -52,7 +52,7 @@ def get_asm_prg() -> ProgramSpec:
lib = Device[Device.DEFAULT].compiler.compile(src)
return ProgramSpec("gemm", src, Device.DEFAULT, ast, lib=lib, global_size=[NUM_WG, 1, 1], local_size=[THREADS_PER_WG, 1, 1],
globals=[0, 1, 2], vars=[UOp.variable("SZ", 256, 8192), UOp.variable("NUM_WG", 1, 1024)])
eis.append(ExecItem(ast, [C_asm.uop.buffer, from_torch(B).uop.buffer, from_torch(A).uop.buffer], fixedvars={"SZ":N, "NUM_WG":NUM_WG},
eis.append(ExecItem(ast, [C_asm.uop.buffer, from_torch(A).uop.buffer, from_torch(B).uop.buffer], fixedvars={"SZ":N, "NUM_WG":NUM_WG},
prg=CompiledRunner(get_asm_prg())))
with Context(DEBUG=2):