prep mi350x gemm for python dsl (#13918)

* start by pruning existing asm

* better branch names

* split to template and real instructions
This commit is contained in:
qazal
2025-12-31 20:00:57 +09:00
committed by GitHub
parent 3f3786ded9
commit b23f4517ab
3 changed files with 98 additions and 112 deletions

View File

@@ -1,10 +1,3 @@
.text
.section .text.
.global gemm
.p2align 8
.type gemm,@function
gemm:
// ** global buffers
s_load_dwordx2 s[28:29], s[0:1], 0x0 // C
s_load_dwordx4 s[32:35], s[0:1], 0x8 // A, B
@@ -221,11 +214,11 @@ gemm:
s_mul_hi_u32 s87, s86, s40 // 000000003288: 96572856
s_mul_i32 s86, s86, s40 // 00000000328C: 92562856
s_and_b32 s84, s50, 0x8000 // 000000003290: 8654FF32 00008000
s_cbranch_scc1 label_GSUC_A // 000000003298: BF850003
s_cbranch_scc1 skip_offset_A // 000000003298: BF850003
s_mul_hi_u32 s85, 64, s6 // 00000000329C: 965506C0
s_mul_i32 s84, 64, s6 // 0000000032A0: 925406C0
label_GSUC_A:
skip_offset_A:
s_add_u32 s86, s86, s84 // 000000003330: 80565456
s_addc_u32 s87, s87, s85 // 000000003334: 82575557
s_mov_b64 s[60:61], 1 // 000000003338: BEBC0181
@@ -259,11 +252,11 @@ label_GSUC_A:
s_mul_hi_u32 s87, s86, s42 // 0000000033B4: 96572A56
s_mul_i32 s86, s86, s42 // 0000000033B8: 92562A56
s_and_b32 s84, s50, 0x8000 // 0000000033BC: 8654FF32 00008000
s_cbranch_scc1 label_GSUC_B // 0000000033C4: BF850003
s_cbranch_scc1 skip_offset_B // 0000000033C4: BF850003
s_mul_hi_u32 s85, 64, s6 // 0000000033C8: 965506C0
s_mul_i32 s84, 64, s6 // 0000000033CC: 925406C0
label_GSUC_B:
skip_offset_B:
s_add_u32 s86, s86, s84 // 00000000345C: 80565456
s_addc_u32 s87, s87, s85 // 000000003460: 82575557
s_mov_b64 s[62:63], 1 // 000000003464: BEBE0181
@@ -308,8 +301,6 @@ label_GSUC_B:
s_and_b32 s87, s10, 0xe000 // 0000000035A4: 8657FF0A 0000E000
s_and_b32 s10, s10, 0xff // 0000000035AC: 860AFF0A 000000FF
s_mov_b32 s84, s10 // 0000000035B4: BED4000A
label_beginStaggerUIter:
s_lshl_b32 s85, s84, s86 // 0000000035B8: 8E555654
s_cmp_ge_u32 s13, s85 // 0000000035BC: BF09550D
s_sub_u32 s85, s84, 1 // 0000000035CC: 80D58154
@@ -344,7 +335,7 @@ label_beginStaggerUIter:
s_cselect_b32 s58, s62, -1 // 0000000036A4: 853AC13E
s_add_u32 s51, s51, 2 // 0000000036A8: 80338233
s_cmp_eq_u32 s12, 0 // 0000000036AC: BF06800C
s_cbranch_scc1 label_ShadowInitStart // 0000000036B0: BF850092
s_cbranch_scc1 init_output_buffers // 0000000036B0: BF850092
s_mov_b32 m0, s46 // 0000000036B4: BEFC002E
buffer_load_dwordx4 v0, s[52:55], 0 offen lds // 0000000036B8: E05D1000 800D0000
s_add_u32 m0, m0, 0x1040 // 0000000036C0: 807CFF7C 00001040
@@ -431,7 +422,7 @@ label_beginStaggerUIter:
s_cmp_eq_u32 s63, 0 // 0000000038F4: BF06803F
s_cselect_b32 s58, s62, -1 // 0000000038F8: 853AC13E
label_ShadowInitStart:
init_output_buffers:
s_mov_b64 s[16:17], s[28:29] // 0000000038FC: BE90011C
s_mov_b32 s18, 0x80000000 // 000000003900: BE9200FF 80000000
s_mov_b32 s19, 0x20000 // 000000003908: BE9300FF 00020000
@@ -476,12 +467,10 @@ label_ShadowInitStart:
s_lshl_b64 s[84:85], s[84:85], 2 // 0000000039C4: 8ED48254
s_add_u32 s16, s16, s84 // 0000000039C8: 80105410
s_addc_u32 s17, s17, s85 // 0000000039CC: 82115511
label_NoBranch_T8JHFHKM7BO5OHXW:
s_xor_b32 s46, s48, s46 // 0000000039F0: 882E2E30
s_xor_b32 s47, s49, s47 // 0000000039F4: 882F2F31
s_cmp_eq_u32 s12, 1 // 0000000039F8: BF06810C
s_cbranch_scc1 label_skipPGR2 // 0000000039FC: BF850040
s_cbranch_scc1 after_prefetch // 0000000039FC: BF850040
s_mov_b32 m0, s46 // 000000003A00: BEFC002E
buffer_load_dwordx4 v0, s[52:55], 0 offen lds // 000000003A04: E05D1000 800D0000
s_add_u32 m0, m0, 0x1040 // 000000003A0C: 807CFF7C 00001040
@@ -517,7 +506,7 @@ label_NoBranch_T8JHFHKM7BO5OHXW:
s_xor_b32 s46, s48, s46 // 000000003AF8: 882E2E30
s_xor_b32 s47, s49, s47 // 000000003AFC: 882F2F31
label_skipPGR2:
after_prefetch:
s_waitcnt vmcnt(24) // 000000003B00: BF8C4F78
s_barrier // 000000003B04: BF8A0000
ds_read_b128 v[4:7], v2 // 000000003B08: D9FE0000 04000002
@@ -539,14 +528,12 @@ label_skipPGR2:
ds_read_b128 v[92:95], v3 offset:768 // 000000003B80: D9FE0300 5C000003
ds_read_b128 v[96:99], v3 offset:896 // 000000003B88: D9FE0380 60000003
s_waitcnt lgkmcnt(0) // 000000003B90: BF8CC07F
label_openLoopL:
s_cmp_eq_u32 s12, 1 // 000000003B94: BF06810C
s_cbranch_scc1 label_toPGR1 // 000000003B98: BF8502E5
s_cbranch_scc1 final_compute // 000000003B98: BF8502E5
s_cmp_le_u32 s12, 2 // 000000003B9C: BF0B820C
s_cbranch_scc1 label_LoopEndL // 000000003BA0: BF85019E
s_cbranch_scc1 loop_epilogue // 000000003BA0: BF85019E
label_LoopBeginL:
main_loop:
v_mfma_f32_16x16x32_bf16 a[0:3], v[68:71], v[4:7], a[0:3] // 000000003BA4: D3B58000 04020944
ds_read_b128 v[36:39], v2 offset:64 // 000000003BAC: D9FE0040 24000002
v_mfma_f32_16x16x32_bf16 a[4:7], v[68:71], v[8:11], a[4:7] // 000000003BB4: D3B58004 04121144
@@ -770,9 +757,9 @@ label_LoopBeginL:
s_cmp_eq_i32 s12, 2 // 000000004208: BF00820C
s_waitcnt lgkmcnt(0) // 00000000420C: BF8CC07F
v_mfma_f32_16x16x32_bf16 a[252:255], v[128:131], v[64:67], a[252:255]// 000000004210: D3B580FC 07F28180
s_cbranch_scc0 label_LoopBeginL // 000000004218: BF84FE62
s_cbranch_scc0 main_loop // 000000004218: BF84FE62
label_LoopEndL:
loop_epilogue:
v_mfma_f32_16x16x32_bf16 a[0:3], v[68:71], v[4:7], a[0:3] // 00000000421C: D3B58000 04020944
ds_read_b128 v[36:39], v2 offset:64 // 000000004224: D9FE0040 24000002
v_mfma_f32_16x16x32_bf16 a[4:7], v[68:71], v[8:11], a[4:7] // 00000000422C: D3B58004 04121144
@@ -939,7 +926,7 @@ label_LoopEndL:
v_mfma_f32_16x16x32_bf16 a[248:251], v[128:131], v[60:63], a[248:251]// 000000004720: D3B580F8 07E27980
v_mfma_f32_16x16x32_bf16 a[252:255], v[128:131], v[64:67], a[252:255]// 000000004728: D3B580FC 07F28180
label_toPGR1:
final_compute:
s_and_b32 s8, s50, 0x3fff // 000000004730: 8608FF32 00003FFF
s_and_b32 s84, 0xff, s24 // 000000004750: 865418FF 000000FF
s_add_u32 s85, -1, s14 // 000000004758: 80550EC1
@@ -1095,7 +1082,6 @@ label_toPGR1:
v_mfma_f32_16x16x32_bf16 a[248:251], v[128:131], v[60:63], a[248:251]// 000000004BFC: D3B580F8 07E27980
v_mfma_f32_16x16x32_bf16 a[252:255], v[128:131], v[64:67], a[252:255]// 000000004C04: D3B580FC 07F28180
label_toPGR1end_OptNLL:
v_lshrrev_b32_e32 v4, 6, v134 // 000000004C0C: 20090C86
v_lshrrev_b32_e32 v5, 1, v4 // 000000004C10: 200A0881
v_mul_lo_u32 v5, 16, v5 // 000000004C14: D2850005 00020A90
@@ -1114,7 +1100,6 @@ label_toPGR1end_OptNLL:
s_mul_i32 s8, 0x100, s3 // 000000004C64: 920803FF 00000100
v_add_u32_e32 v1, s8, v1 // 000000004C6C: 68020208
label_GW_B0_E0:
v_add_lshl_u32 v11, v3, v0, 1 // 000000004C70: D1FE000B 02060103
v_accvgpr_read_b32 v16, a0 // 000000004C78: D3D84010 18000100
v_accvgpr_read_b32 v17, a4 // 000000004C80: D3D84011 18000104
@@ -1633,86 +1618,4 @@ label_GW_B0_E0:
s_addc_u32 s17, s17, 0 // 000000005B14: 82118011
buffer_store_dwordx4 v[40:43], v11, s[16:19], 0 offen nt // 000000005B18: E07E1000 8004280B
s_nop 0 // 000000005B20: BF800000
end:
s_endpgm // 00000001F5D0: BF810000
.section .rodata,"a",@progbits
.p2align 6, 0x0
.amdhsa_kernel gemm
# ---- basic memory requirements ----
.amdhsa_group_segment_fixed_size 133120
.amdhsa_private_segment_fixed_size 0
.amdhsa_kernarg_size 32
# ---- register usage (RSRC1) ----
.amdhsa_next_free_vgpr 504
.amdhsa_next_free_sgpr 96
# ---- workgroup / workitem IDs (RSRC2) ----
.amdhsa_system_sgpr_workgroup_id_x 1
.amdhsa_system_sgpr_workgroup_id_y 1
.amdhsa_system_sgpr_workgroup_id_z 1
# ---- user SGPR enables (descriptor bits >448) ----
.amdhsa_user_sgpr_kernarg_segment_ptr 1
.amdhsa_user_sgpr_count 2
.amdhsa_user_sgpr_kernarg_preload_length 0
.amdhsa_user_sgpr_kernarg_preload_offset 0
# ---- gfx90a / gfx940 specific (RSRC3) ----
.amdhsa_accum_offset 248
.amdhsa_uses_dynamic_stack 0
.amdhsa_tg_split 0
.end_amdhsa_kernel
.amdgpu_metadata
---
amdhsa.kernels:
- .args:
- .address_space: global
.name: C
.offset: 0
.size: 8
.value_kind: global_buffer
.value_type: bf16
- .address_space: global
.name: B
.offset: 8
.size: 8
.value_kind: global_buffer
.value_type: bf16
- .address_space: global
.name: A
.offset: 16
.size: 8
.value_kind: global_buffer
.value_type: bf16
- .name: sz
.offset: 24
.size: 4
.value_kind: by_value
.value_type: u32
- .name: num_wg
.offset: 28
.size: 4
.value_kind: by_value
.value_type: u32
.group_segment_fixed_size: 133120
.kernarg_segment_align: 8
.kernarg_segment_size: 32
.max_flat_workgroup_size: 256
.name: gemm
.private_segment_fixed_size: 0
.sgpr_count: 88
.sgpr_spill_count: 0
.symbol: gemm.kd
.vgpr_count: 248
.vgpr_spill_count: 0
.wavefront_size: 64
amdhsa.version:
- 1
- 0
...
.end_amdgpu_metadata

83
extra/gemm/asm/template.s Normal file
View File

@@ -0,0 +1,83 @@
.text
.section .text.
.global gemm
.p2align 8
.type gemm,@function
gemm:
INSTRUCTIONS
.section .rodata,"a",@progbits
.p2align 6, 0x0
.amdhsa_kernel gemm
# basic memory requirements
.amdhsa_group_segment_fixed_size 133120
.amdhsa_private_segment_fixed_size 0
.amdhsa_kernarg_size 32
# register usage (RSRC1)
.amdhsa_next_free_vgpr 504
.amdhsa_next_free_sgpr 96
# workgroup / workitem IDs (RSRC2)
.amdhsa_system_sgpr_workgroup_id_x 1
.amdhsa_system_sgpr_workgroup_id_y 1
.amdhsa_system_sgpr_workgroup_id_z 1
# user SGPRs, we only specify the kernel args ptr in s[0:1]
.amdhsa_user_sgpr_kernarg_segment_ptr 1
.amdhsa_user_sgpr_count 2
.amdhsa_user_sgpr_kernarg_preload_length 0
.amdhsa_user_sgpr_kernarg_preload_offset 0
# gfx90a / gfx940 specifics (RSRC3)
.amdhsa_accum_offset 248
.amdhsa_uses_dynamic_stack 0
.amdhsa_tg_split 0
.end_amdhsa_kernel
.amdgpu_metadata
---
amdhsa.kernels:
- .name: gemm
.symbol: gemm.kd
.args:
- .name: C
.address_space: global
.offset: 0
.size: 8
.value_kind: global_buffer
.value_type: bf16
- .name: B
.address_space: global
.offset: 8
.size: 8
.value_kind: global_buffer
.value_type: bf16
- .name: A
.address_space: global
.offset: 16
.size: 8
.value_kind: global_buffer
.value_type: bf16
- .name: sz
.offset: 24
.size: 4
.value_kind: by_value
.value_type: u32
- .name: num_wg
.offset: 28
.size: 4
.value_kind: by_value
.value_type: u32
.group_segment_fixed_size: 133120
.private_segment_fixed_size: 0
.kernarg_segment_align: 8
.kernarg_segment_size: 32
.max_flat_workgroup_size: 256
.sgpr_count: 88
.sgpr_spill_count: 0
.vgpr_count: 248
.vgpr_spill_count: 0
.wavefront_size: 64
amdhsa.version:
- 1
- 0
...
.end_amdgpu_metadata

View File

@@ -48,7 +48,7 @@ ast = sched[-1].ast
# assembly gemm
@track_rewrites(name=lambda ret: TracingKey(ret.name, (ret.function_name,), ret))
def get_asm_prg() -> ProgramSpec:
src = fp.read_text()
src = (pathlib.Path(__file__).parent/"template.s").read_text().replace("INSTRUCTIONS", fp.read_text())
lib = Device[Device.DEFAULT].compiler.compile(src)
return ProgramSpec("gemm", src, Device.DEFAULT, ast, lib=lib, global_size=[NUM_WG, 1, 1], local_size=[THREADS_PER_WG, 1, 1],
globals=[0, 1, 2], vars=[UOp.variable("SZ", 256, 8192), UOp.variable("NUM_WG", 1, 1024)])