mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
variable N for asm gemm (#13869)
* variable N for asm gemm * cleanup spacing
This commit is contained in:
@@ -6,29 +6,29 @@
|
||||
|
||||
gemm:
|
||||
// ** global buffers
|
||||
s_load_dwordx2 s[28:29], s[0:1], 0x0 // C
|
||||
s_load_dwordx4 s[32:35], s[0:1], 0x8 // A, B
|
||||
s_waitcnt lgkmcnt(0)
|
||||
s_load_dwordx2 s[28:29], s[0:1], 0x0 // C
|
||||
s_load_dwordx4 s[32:35], s[0:1], 0x8 // A, B
|
||||
// ** others kernel args
|
||||
// info
|
||||
s_mov_b32 s51, 0x00000001 // gemm_info = 1
|
||||
s_mov_b32 s53, 0x00000001 // kernel_info0 = 1
|
||||
s_load_dword s24, s[0:1], 0x18 // N
|
||||
s_load_dword s54, s[0:1], 0x1C // num work groups
|
||||
s_waitcnt lgkmcnt(0)
|
||||
// "info"
|
||||
s_mov_b32 s51, 1 // gemm_info = 1
|
||||
s_mov_b32 s53, 1 // kernel_info0 = 1
|
||||
s_mov_b32 s11, 0x40010020 // kernel_info1 = 0x40010020
|
||||
s_mov_b32 s54, 0x00000400 // numWG = 1024
|
||||
// sizes / strides
|
||||
s_mov_b32 s24, 0x00002000 // sizesFree0 = M = 8192
|
||||
s_mov_b32 s25, 0x00002000 // sizesFree1 = N = 8192
|
||||
s_mov_b32 s26, 0x00000001 // sizesFree2 = BATCH = 1
|
||||
s_mov_b32 s27, 0x00002000 // sizesSum0 = K = 8192
|
||||
// Strides: major=8192, minor=0 (addr = base + idx0*8192 + idx1*0)
|
||||
s_mov_b32 s36, 0x00002000 // strideD0
|
||||
s_mov_b32 s37, 0x00000000 // strideD1
|
||||
s_mov_b32 s38, 0x00002000 // strideC0
|
||||
s_mov_b32 s39, 0x00000000 // strideC1
|
||||
s_mov_b32 s40, 0x00002000 // strideA0
|
||||
s_mov_b32 s41, 0x00000000 // strideA1
|
||||
s_mov_b32 s42, 0x00002000 // strideB0
|
||||
s_mov_b32 s43, 0x00000000 // strideB1
|
||||
s_mov_b32 s25, s24 // sizesFree1 = N
|
||||
s_mov_b32 s26, 1 // sizesFree2 = BATCH
|
||||
s_mov_b32 s27, s24 // sizesSum0 = K (== N)
|
||||
// Strides: major=N, minor=0 (addr = base + idx0*N + idx1*0)
|
||||
s_mov_b32 s36, s24 // strideD0
|
||||
s_mov_b32 s37, 0 // strideD1
|
||||
s_mov_b32 s38, s24 // strideC0
|
||||
s_mov_b32 s39, 0 // strideC1
|
||||
s_mov_b32 s40, s24 // strideA0
|
||||
s_mov_b32 s41, 0 // strideA1
|
||||
s_mov_b32 s42, s24 // strideB0
|
||||
s_mov_b32 s43, 0 // strideB1
|
||||
// ** workgroup mapping
|
||||
s_lshr_b32 s52, s51, 30 // 000000002924: 8F349E33
|
||||
s_and_b32 s51, 0x3fffffff, s51 // 000000002928: 863333FF 3FFFFFFF
|
||||
@@ -4565,7 +4565,7 @@ end:
|
||||
# ---- basic memory requirements ----
|
||||
.amdhsa_group_segment_fixed_size 133120
|
||||
.amdhsa_private_segment_fixed_size 0
|
||||
.amdhsa_kernarg_size 24
|
||||
.amdhsa_kernarg_size 32
|
||||
|
||||
# ---- register usage (RSRC1) ----
|
||||
.amdhsa_next_free_vgpr 504
|
||||
@@ -4611,9 +4611,19 @@ amdhsa.kernels:
|
||||
.size: 8
|
||||
.value_kind: global_buffer
|
||||
.value_type: bf16
|
||||
- .name: sz
|
||||
.offset: 24
|
||||
.size: 4
|
||||
.value_kind: by_value
|
||||
.value_type: u32
|
||||
- .name: num_wg
|
||||
.offset: 28
|
||||
.size: 4
|
||||
.value_kind: by_value
|
||||
.value_type: u32
|
||||
.group_segment_fixed_size: 133120
|
||||
.kernarg_segment_align: 8
|
||||
.kernarg_segment_size: 24
|
||||
.kernarg_segment_size: 32
|
||||
.max_flat_workgroup_size: 256
|
||||
.name: gemm
|
||||
.private_segment_fixed_size: 0
|
||||
|
||||
@@ -5,14 +5,16 @@ from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.uop.ops import track_rewrites, UOp
|
||||
from tinygrad.helpers import TracingKey
|
||||
from tinygrad.helpers import TracingKey, getenv
|
||||
|
||||
fp = pathlib.Path(__file__).parent/"gemm.s"
|
||||
|
||||
N = 8192
|
||||
N = getenv("N", 8192)
|
||||
THREADS_PER_WG = 256
|
||||
NUM_WG = N//THREADS_PER_WG * N//THREADS_PER_WG
|
||||
|
||||
assert N % THREADS_PER_WG == 0, "N must be divisible by THREADS_PER_WG"
|
||||
|
||||
# ** generate inputs on CPU
|
||||
|
||||
scale = 10.0
|
||||
@@ -48,8 +50,10 @@ ast = sched[-1].ast
|
||||
def get_asm_prg() -> ProgramSpec:
|
||||
src = fp.read_text()
|
||||
lib = Device[Device.DEFAULT].compiler.compile(src)
|
||||
return ProgramSpec("gemm", src, Device.DEFAULT, ast, lib=lib, global_size=[NUM_WG, 1, 1], local_size=[THREADS_PER_WG, 1, 1], globals=[0, 1, 2])
|
||||
eis.append(ExecItem(ast, [C_asm.uop.buffer, from_torch(B).uop.buffer, from_torch(A).uop.buffer], prg=CompiledRunner(get_asm_prg())))
|
||||
return ProgramSpec("gemm", src, Device.DEFAULT, ast, lib=lib, global_size=[NUM_WG, 1, 1], local_size=[THREADS_PER_WG, 1, 1],
|
||||
globals=[0, 1, 2], vars=[UOp.variable("SZ", 256, 8192), UOp.variable("NUM_WG", 1, 1024)])
|
||||
eis.append(ExecItem(ast, [C_asm.uop.buffer, from_torch(B).uop.buffer, from_torch(A).uop.buffer], fixedvars={"SZ":N, "NUM_WG":NUM_WG},
|
||||
prg=CompiledRunner(get_asm_prg())))
|
||||
|
||||
for ei in eis:
|
||||
et = ei.run(wait=True)
|
||||
|
||||
Reference in New Issue
Block a user