mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
correct args order in mi350x gemm (#13949)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
// ** global buffers
|
||||
s_load_dwordx2 s[28:29], s[0:1], 0x0 // C
|
||||
s_load_dwordx4 s[32:35], s[0:1], 0x8 // A, B
|
||||
s_load_dwordx2 s[34:35], s[0:1], 0x08 // A
|
||||
s_load_dwordx2 s[32:33], s[0:1], 0x10 // B
|
||||
// ** others kernel args
|
||||
s_load_dword s24, s[0:1], 0x18 // N
|
||||
s_load_dword s54, s[0:1], 0x1C // num work groups
|
||||
|
||||
@@ -52,7 +52,7 @@ def get_asm_prg() -> ProgramSpec:
|
||||
lib = Device[Device.DEFAULT].compiler.compile(src)
|
||||
return ProgramSpec("gemm", src, Device.DEFAULT, ast, lib=lib, global_size=[NUM_WG, 1, 1], local_size=[THREADS_PER_WG, 1, 1],
|
||||
globals=[0, 1, 2], vars=[UOp.variable("SZ", 256, 8192), UOp.variable("NUM_WG", 1, 1024)])
|
||||
eis.append(ExecItem(ast, [C_asm.uop.buffer, from_torch(B).uop.buffer, from_torch(A).uop.buffer], fixedvars={"SZ":N, "NUM_WG":NUM_WG},
|
||||
eis.append(ExecItem(ast, [C_asm.uop.buffer, from_torch(A).uop.buffer, from_torch(B).uop.buffer], fixedvars={"SZ":N, "NUM_WG":NUM_WG},
|
||||
prg=CompiledRunner(get_asm_prg())))
|
||||
|
||||
with Context(DEBUG=2):
|
||||
|
||||
Reference in New Issue
Block a user