variable N for asm gemm (#13869)

* variable N for asm gemm

* cleanup spacing
This commit is contained in:
qazal
2025-12-29 19:35:50 +09:00
committed by GitHub
parent c6769badc2
commit f541540129
2 changed files with 40 additions and 26 deletions

View File

@@ -6,29 +6,29 @@
gemm:
// ** global buffers
s_load_dwordx2 s[28:29], s[0:1], 0x0 // C
s_load_dwordx4 s[32:35], s[0:1], 0x8 // A, B
s_waitcnt lgkmcnt(0)
s_load_dwordx2 s[28:29], s[0:1], 0x0 // C
s_load_dwordx4 s[32:35], s[0:1], 0x8 // A, B
// ** others kernel args
// info
s_mov_b32 s51, 0x00000001 // gemm_info = 1
s_mov_b32 s53, 0x00000001 // kernel_info0 = 1
s_load_dword s24, s[0:1], 0x18 // N
s_load_dword s54, s[0:1], 0x1C // num work groups
s_waitcnt lgkmcnt(0)
// "info"
s_mov_b32 s51, 1 // gemm_info = 1
s_mov_b32 s53, 1 // kernel_info0 = 1
s_mov_b32 s11, 0x40010020 // kernel_info1 = 0x40010020
s_mov_b32 s54, 0x00000400 // numWG = 1024
// sizes / strides
s_mov_b32 s24, 0x00002000 // sizesFree0 = M = 8192
s_mov_b32 s25, 0x00002000 // sizesFree1 = N = 8192
s_mov_b32 s26, 0x00000001 // sizesFree2 = BATCH = 1
s_mov_b32 s27, 0x00002000 // sizesSum0 = K = 8192
// Strides: major=8192, minor=0 (addr = base + idx0*8192 + idx1*0)
s_mov_b32 s36, 0x00002000 // strideD0
s_mov_b32 s37, 0x00000000 // strideD1
s_mov_b32 s38, 0x00002000 // strideC0
s_mov_b32 s39, 0x00000000 // strideC1
s_mov_b32 s40, 0x00002000 // strideA0
s_mov_b32 s41, 0x00000000 // strideA1
s_mov_b32 s42, 0x00002000 // strideB0
s_mov_b32 s43, 0x00000000 // strideB1
s_mov_b32 s25, s24 // sizesFree1 = N
s_mov_b32 s26, 1 // sizesFree2 = BATCH
s_mov_b32 s27, s24 // sizesSum0 = K (== N)
// Strides: major=N, minor=0 (addr = base + idx0*N + idx1*0)
s_mov_b32 s36, s24 // strideD0
s_mov_b32 s37, 0 // strideD1
s_mov_b32 s38, s24 // strideC0
s_mov_b32 s39, 0 // strideC1
s_mov_b32 s40, s24 // strideA0
s_mov_b32 s41, 0 // strideA1
s_mov_b32 s42, s24 // strideB0
s_mov_b32 s43, 0 // strideB1
// ** workgroup mapping
s_lshr_b32 s52, s51, 30 // 000000002924: 8F349E33
s_and_b32 s51, 0x3fffffff, s51 // 000000002928: 863333FF 3FFFFFFF
@@ -4565,7 +4565,7 @@ end:
# ---- basic memory requirements ----
.amdhsa_group_segment_fixed_size 133120
.amdhsa_private_segment_fixed_size 0
.amdhsa_kernarg_size 24
.amdhsa_kernarg_size 32
# ---- register usage (RSRC1) ----
.amdhsa_next_free_vgpr 504
@@ -4611,9 +4611,19 @@ amdhsa.kernels:
.size: 8
.value_kind: global_buffer
.value_type: bf16
- .name: sz
.offset: 24
.size: 4
.value_kind: by_value
.value_type: u32
- .name: num_wg
.offset: 28
.size: 4
.value_kind: by_value
.value_type: u32
.group_segment_fixed_size: 133120
.kernarg_segment_align: 8
.kernarg_segment_size: 24
.kernarg_segment_size: 32
.max_flat_workgroup_size: 256
.name: gemm
.private_segment_fixed_size: 0

View File

@@ -5,14 +5,16 @@ from tinygrad import Tensor, Device, dtypes
from tinygrad.engine.realize import ExecItem, CompiledRunner
from tinygrad.renderer import ProgramSpec
from tinygrad.uop.ops import track_rewrites, UOp
from tinygrad.helpers import TracingKey
from tinygrad.helpers import TracingKey, getenv
fp = pathlib.Path(__file__).parent/"gemm.s"
N = 8192
N = getenv("N", 8192)
THREADS_PER_WG = 256
NUM_WG = N//THREADS_PER_WG * N//THREADS_PER_WG
assert N % THREADS_PER_WG == 0, "N must be divisible by THREADS_PER_WG"
# ** generate inputs on CPU
scale = 10.0
@@ -48,8 +50,10 @@ ast = sched[-1].ast
def get_asm_prg() -> ProgramSpec:
src = fp.read_text()
lib = Device[Device.DEFAULT].compiler.compile(src)
return ProgramSpec("gemm", src, Device.DEFAULT, ast, lib=lib, global_size=[NUM_WG, 1, 1], local_size=[THREADS_PER_WG, 1, 1], globals=[0, 1, 2])
eis.append(ExecItem(ast, [C_asm.uop.buffer, from_torch(B).uop.buffer, from_torch(A).uop.buffer], prg=CompiledRunner(get_asm_prg())))
return ProgramSpec("gemm", src, Device.DEFAULT, ast, lib=lib, global_size=[NUM_WG, 1, 1], local_size=[THREADS_PER_WG, 1, 1],
globals=[0, 1, 2], vars=[UOp.variable("SZ", 256, 8192), UOp.variable("NUM_WG", 1, 1024)])
eis.append(ExecItem(ast, [C_asm.uop.buffer, from_torch(B).uop.buffer, from_torch(A).uop.buffer], fixedvars={"SZ":N, "NUM_WG":NUM_WG},
prg=CompiledRunner(get_asm_prg())))
for ei in eis:
et = ei.run(wait=True)