mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
prep mi350x gemm for python dsl (#13918)
* start by pruning existing asm * better branch names * split to template and real instructions
This commit is contained in:
@@ -1,10 +1,3 @@
|
||||
.text
|
||||
.section .text.
|
||||
.global gemm
|
||||
.p2align 8
|
||||
.type gemm,@function
|
||||
|
||||
gemm:
|
||||
// ** global buffers
|
||||
s_load_dwordx2 s[28:29], s[0:1], 0x0 // C
|
||||
s_load_dwordx4 s[32:35], s[0:1], 0x8 // A, B
|
||||
@@ -221,11 +214,11 @@ gemm:
|
||||
s_mul_hi_u32 s87, s86, s40 // 000000003288: 96572856
|
||||
s_mul_i32 s86, s86, s40 // 00000000328C: 92562856
|
||||
s_and_b32 s84, s50, 0x8000 // 000000003290: 8654FF32 00008000
|
||||
s_cbranch_scc1 label_GSUC_A // 000000003298: BF850003
|
||||
s_cbranch_scc1 skip_offset_A // 000000003298: BF850003
|
||||
s_mul_hi_u32 s85, 64, s6 // 00000000329C: 965506C0
|
||||
s_mul_i32 s84, 64, s6 // 0000000032A0: 925406C0
|
||||
|
||||
label_GSUC_A:
|
||||
skip_offset_A:
|
||||
s_add_u32 s86, s86, s84 // 000000003330: 80565456
|
||||
s_addc_u32 s87, s87, s85 // 000000003334: 82575557
|
||||
s_mov_b64 s[60:61], 1 // 000000003338: BEBC0181
|
||||
@@ -259,11 +252,11 @@ label_GSUC_A:
|
||||
s_mul_hi_u32 s87, s86, s42 // 0000000033B4: 96572A56
|
||||
s_mul_i32 s86, s86, s42 // 0000000033B8: 92562A56
|
||||
s_and_b32 s84, s50, 0x8000 // 0000000033BC: 8654FF32 00008000
|
||||
s_cbranch_scc1 label_GSUC_B // 0000000033C4: BF850003
|
||||
s_cbranch_scc1 skip_offset_B // 0000000033C4: BF850003
|
||||
s_mul_hi_u32 s85, 64, s6 // 0000000033C8: 965506C0
|
||||
s_mul_i32 s84, 64, s6 // 0000000033CC: 925406C0
|
||||
|
||||
label_GSUC_B:
|
||||
skip_offset_B:
|
||||
s_add_u32 s86, s86, s84 // 00000000345C: 80565456
|
||||
s_addc_u32 s87, s87, s85 // 000000003460: 82575557
|
||||
s_mov_b64 s[62:63], 1 // 000000003464: BEBE0181
|
||||
@@ -308,8 +301,6 @@ label_GSUC_B:
|
||||
s_and_b32 s87, s10, 0xe000 // 0000000035A4: 8657FF0A 0000E000
|
||||
s_and_b32 s10, s10, 0xff // 0000000035AC: 860AFF0A 000000FF
|
||||
s_mov_b32 s84, s10 // 0000000035B4: BED4000A
|
||||
|
||||
label_beginStaggerUIter:
|
||||
s_lshl_b32 s85, s84, s86 // 0000000035B8: 8E555654
|
||||
s_cmp_ge_u32 s13, s85 // 0000000035BC: BF09550D
|
||||
s_sub_u32 s85, s84, 1 // 0000000035CC: 80D58154
|
||||
@@ -344,7 +335,7 @@ label_beginStaggerUIter:
|
||||
s_cselect_b32 s58, s62, -1 // 0000000036A4: 853AC13E
|
||||
s_add_u32 s51, s51, 2 // 0000000036A8: 80338233
|
||||
s_cmp_eq_u32 s12, 0 // 0000000036AC: BF06800C
|
||||
s_cbranch_scc1 label_ShadowInitStart // 0000000036B0: BF850092
|
||||
s_cbranch_scc1 init_output_buffers // 0000000036B0: BF850092
|
||||
s_mov_b32 m0, s46 // 0000000036B4: BEFC002E
|
||||
buffer_load_dwordx4 v0, s[52:55], 0 offen lds // 0000000036B8: E05D1000 800D0000
|
||||
s_add_u32 m0, m0, 0x1040 // 0000000036C0: 807CFF7C 00001040
|
||||
@@ -431,7 +422,7 @@ label_beginStaggerUIter:
|
||||
s_cmp_eq_u32 s63, 0 // 0000000038F4: BF06803F
|
||||
s_cselect_b32 s58, s62, -1 // 0000000038F8: 853AC13E
|
||||
|
||||
label_ShadowInitStart:
|
||||
init_output_buffers:
|
||||
s_mov_b64 s[16:17], s[28:29] // 0000000038FC: BE90011C
|
||||
s_mov_b32 s18, 0x80000000 // 000000003900: BE9200FF 80000000
|
||||
s_mov_b32 s19, 0x20000 // 000000003908: BE9300FF 00020000
|
||||
@@ -476,12 +467,10 @@ label_ShadowInitStart:
|
||||
s_lshl_b64 s[84:85], s[84:85], 2 // 0000000039C4: 8ED48254
|
||||
s_add_u32 s16, s16, s84 // 0000000039C8: 80105410
|
||||
s_addc_u32 s17, s17, s85 // 0000000039CC: 82115511
|
||||
|
||||
label_NoBranch_T8JHFHKM7BO5OHXW:
|
||||
s_xor_b32 s46, s48, s46 // 0000000039F0: 882E2E30
|
||||
s_xor_b32 s47, s49, s47 // 0000000039F4: 882F2F31
|
||||
s_cmp_eq_u32 s12, 1 // 0000000039F8: BF06810C
|
||||
s_cbranch_scc1 label_skipPGR2 // 0000000039FC: BF850040
|
||||
s_cbranch_scc1 after_prefetch // 0000000039FC: BF850040
|
||||
s_mov_b32 m0, s46 // 000000003A00: BEFC002E
|
||||
buffer_load_dwordx4 v0, s[52:55], 0 offen lds // 000000003A04: E05D1000 800D0000
|
||||
s_add_u32 m0, m0, 0x1040 // 000000003A0C: 807CFF7C 00001040
|
||||
@@ -517,7 +506,7 @@ label_NoBranch_T8JHFHKM7BO5OHXW:
|
||||
s_xor_b32 s46, s48, s46 // 000000003AF8: 882E2E30
|
||||
s_xor_b32 s47, s49, s47 // 000000003AFC: 882F2F31
|
||||
|
||||
label_skipPGR2:
|
||||
after_prefetch:
|
||||
s_waitcnt vmcnt(24) // 000000003B00: BF8C4F78
|
||||
s_barrier // 000000003B04: BF8A0000
|
||||
ds_read_b128 v[4:7], v2 // 000000003B08: D9FE0000 04000002
|
||||
@@ -539,14 +528,12 @@ label_skipPGR2:
|
||||
ds_read_b128 v[92:95], v3 offset:768 // 000000003B80: D9FE0300 5C000003
|
||||
ds_read_b128 v[96:99], v3 offset:896 // 000000003B88: D9FE0380 60000003
|
||||
s_waitcnt lgkmcnt(0) // 000000003B90: BF8CC07F
|
||||
|
||||
label_openLoopL:
|
||||
s_cmp_eq_u32 s12, 1 // 000000003B94: BF06810C
|
||||
s_cbranch_scc1 label_toPGR1 // 000000003B98: BF8502E5
|
||||
s_cbranch_scc1 final_compute // 000000003B98: BF8502E5
|
||||
s_cmp_le_u32 s12, 2 // 000000003B9C: BF0B820C
|
||||
s_cbranch_scc1 label_LoopEndL // 000000003BA0: BF85019E
|
||||
s_cbranch_scc1 loop_epilogue // 000000003BA0: BF85019E
|
||||
|
||||
label_LoopBeginL:
|
||||
main_loop:
|
||||
v_mfma_f32_16x16x32_bf16 a[0:3], v[68:71], v[4:7], a[0:3] // 000000003BA4: D3B58000 04020944
|
||||
ds_read_b128 v[36:39], v2 offset:64 // 000000003BAC: D9FE0040 24000002
|
||||
v_mfma_f32_16x16x32_bf16 a[4:7], v[68:71], v[8:11], a[4:7] // 000000003BB4: D3B58004 04121144
|
||||
@@ -770,9 +757,9 @@ label_LoopBeginL:
|
||||
s_cmp_eq_i32 s12, 2 // 000000004208: BF00820C
|
||||
s_waitcnt lgkmcnt(0) // 00000000420C: BF8CC07F
|
||||
v_mfma_f32_16x16x32_bf16 a[252:255], v[128:131], v[64:67], a[252:255]// 000000004210: D3B580FC 07F28180
|
||||
s_cbranch_scc0 label_LoopBeginL // 000000004218: BF84FE62
|
||||
s_cbranch_scc0 main_loop // 000000004218: BF84FE62
|
||||
|
||||
label_LoopEndL:
|
||||
loop_epilogue:
|
||||
v_mfma_f32_16x16x32_bf16 a[0:3], v[68:71], v[4:7], a[0:3] // 00000000421C: D3B58000 04020944
|
||||
ds_read_b128 v[36:39], v2 offset:64 // 000000004224: D9FE0040 24000002
|
||||
v_mfma_f32_16x16x32_bf16 a[4:7], v[68:71], v[8:11], a[4:7] // 00000000422C: D3B58004 04121144
|
||||
@@ -939,7 +926,7 @@ label_LoopEndL:
|
||||
v_mfma_f32_16x16x32_bf16 a[248:251], v[128:131], v[60:63], a[248:251]// 000000004720: D3B580F8 07E27980
|
||||
v_mfma_f32_16x16x32_bf16 a[252:255], v[128:131], v[64:67], a[252:255]// 000000004728: D3B580FC 07F28180
|
||||
|
||||
label_toPGR1:
|
||||
final_compute:
|
||||
s_and_b32 s8, s50, 0x3fff // 000000004730: 8608FF32 00003FFF
|
||||
s_and_b32 s84, 0xff, s24 // 000000004750: 865418FF 000000FF
|
||||
s_add_u32 s85, -1, s14 // 000000004758: 80550EC1
|
||||
@@ -1095,7 +1082,6 @@ label_toPGR1:
|
||||
v_mfma_f32_16x16x32_bf16 a[248:251], v[128:131], v[60:63], a[248:251]// 000000004BFC: D3B580F8 07E27980
|
||||
v_mfma_f32_16x16x32_bf16 a[252:255], v[128:131], v[64:67], a[252:255]// 000000004C04: D3B580FC 07F28180
|
||||
|
||||
label_toPGR1end_OptNLL:
|
||||
v_lshrrev_b32_e32 v4, 6, v134 // 000000004C0C: 20090C86
|
||||
v_lshrrev_b32_e32 v5, 1, v4 // 000000004C10: 200A0881
|
||||
v_mul_lo_u32 v5, 16, v5 // 000000004C14: D2850005 00020A90
|
||||
@@ -1114,7 +1100,6 @@ label_toPGR1end_OptNLL:
|
||||
s_mul_i32 s8, 0x100, s3 // 000000004C64: 920803FF 00000100
|
||||
v_add_u32_e32 v1, s8, v1 // 000000004C6C: 68020208
|
||||
|
||||
label_GW_B0_E0:
|
||||
v_add_lshl_u32 v11, v3, v0, 1 // 000000004C70: D1FE000B 02060103
|
||||
v_accvgpr_read_b32 v16, a0 // 000000004C78: D3D84010 18000100
|
||||
v_accvgpr_read_b32 v17, a4 // 000000004C80: D3D84011 18000104
|
||||
@@ -1633,86 +1618,4 @@ label_GW_B0_E0:
|
||||
s_addc_u32 s17, s17, 0 // 000000005B14: 82118011
|
||||
buffer_store_dwordx4 v[40:43], v11, s[16:19], 0 offen nt // 000000005B18: E07E1000 8004280B
|
||||
s_nop 0 // 000000005B20: BF800000
|
||||
|
||||
end:
|
||||
s_endpgm // 00000001F5D0: BF810000
|
||||
|
||||
.section .rodata,"a",@progbits
|
||||
.p2align 6, 0x0
|
||||
.amdhsa_kernel gemm
|
||||
# ---- basic memory requirements ----
|
||||
.amdhsa_group_segment_fixed_size 133120
|
||||
.amdhsa_private_segment_fixed_size 0
|
||||
.amdhsa_kernarg_size 32
|
||||
|
||||
# ---- register usage (RSRC1) ----
|
||||
.amdhsa_next_free_vgpr 504
|
||||
.amdhsa_next_free_sgpr 96
|
||||
|
||||
# ---- workgroup / workitem IDs (RSRC2) ----
|
||||
.amdhsa_system_sgpr_workgroup_id_x 1
|
||||
.amdhsa_system_sgpr_workgroup_id_y 1
|
||||
.amdhsa_system_sgpr_workgroup_id_z 1
|
||||
|
||||
# ---- user SGPR enables (descriptor bits >448) ----
|
||||
.amdhsa_user_sgpr_kernarg_segment_ptr 1
|
||||
.amdhsa_user_sgpr_count 2
|
||||
.amdhsa_user_sgpr_kernarg_preload_length 0
|
||||
.amdhsa_user_sgpr_kernarg_preload_offset 0
|
||||
|
||||
# ---- gfx90a / gfx940 specific (RSRC3) ----
|
||||
.amdhsa_accum_offset 248
|
||||
.amdhsa_uses_dynamic_stack 0
|
||||
.amdhsa_tg_split 0
|
||||
|
||||
.end_amdhsa_kernel
|
||||
|
||||
.amdgpu_metadata
|
||||
---
|
||||
amdhsa.kernels:
|
||||
- .args:
|
||||
- .address_space: global
|
||||
.name: C
|
||||
.offset: 0
|
||||
.size: 8
|
||||
.value_kind: global_buffer
|
||||
.value_type: bf16
|
||||
- .address_space: global
|
||||
.name: B
|
||||
.offset: 8
|
||||
.size: 8
|
||||
.value_kind: global_buffer
|
||||
.value_type: bf16
|
||||
- .address_space: global
|
||||
.name: A
|
||||
.offset: 16
|
||||
.size: 8
|
||||
.value_kind: global_buffer
|
||||
.value_type: bf16
|
||||
- .name: sz
|
||||
.offset: 24
|
||||
.size: 4
|
||||
.value_kind: by_value
|
||||
.value_type: u32
|
||||
- .name: num_wg
|
||||
.offset: 28
|
||||
.size: 4
|
||||
.value_kind: by_value
|
||||
.value_type: u32
|
||||
.group_segment_fixed_size: 133120
|
||||
.kernarg_segment_align: 8
|
||||
.kernarg_segment_size: 32
|
||||
.max_flat_workgroup_size: 256
|
||||
.name: gemm
|
||||
.private_segment_fixed_size: 0
|
||||
.sgpr_count: 88
|
||||
.sgpr_spill_count: 0
|
||||
.symbol: gemm.kd
|
||||
.vgpr_count: 248
|
||||
.vgpr_spill_count: 0
|
||||
.wavefront_size: 64
|
||||
amdhsa.version:
|
||||
- 1
|
||||
- 0
|
||||
...
|
||||
.end_amdgpu_metadata
|
||||
|
||||
83
extra/gemm/asm/template.s
Normal file
83
extra/gemm/asm/template.s
Normal file
@@ -0,0 +1,83 @@
|
||||
.text
|
||||
.section .text.
|
||||
.global gemm
|
||||
.p2align 8
|
||||
.type gemm,@function
|
||||
|
||||
gemm:
|
||||
INSTRUCTIONS
|
||||
|
||||
.section .rodata,"a",@progbits
|
||||
.p2align 6, 0x0
|
||||
.amdhsa_kernel gemm
|
||||
# basic memory requirements
|
||||
.amdhsa_group_segment_fixed_size 133120
|
||||
.amdhsa_private_segment_fixed_size 0
|
||||
.amdhsa_kernarg_size 32
|
||||
# register usage (RSRC1)
|
||||
.amdhsa_next_free_vgpr 504
|
||||
.amdhsa_next_free_sgpr 96
|
||||
# workgroup / workitem IDs (RSRC2)
|
||||
.amdhsa_system_sgpr_workgroup_id_x 1
|
||||
.amdhsa_system_sgpr_workgroup_id_y 1
|
||||
.amdhsa_system_sgpr_workgroup_id_z 1
|
||||
# user SGPRs, we only specify the kernel args ptr in s[0:1]
|
||||
.amdhsa_user_sgpr_kernarg_segment_ptr 1
|
||||
.amdhsa_user_sgpr_count 2
|
||||
.amdhsa_user_sgpr_kernarg_preload_length 0
|
||||
.amdhsa_user_sgpr_kernarg_preload_offset 0
|
||||
# gfx90a / gfx940 specifics (RSRC3)
|
||||
.amdhsa_accum_offset 248
|
||||
.amdhsa_uses_dynamic_stack 0
|
||||
.amdhsa_tg_split 0
|
||||
.end_amdhsa_kernel
|
||||
|
||||
.amdgpu_metadata
|
||||
---
|
||||
amdhsa.kernels:
|
||||
- .name: gemm
|
||||
.symbol: gemm.kd
|
||||
.args:
|
||||
- .name: C
|
||||
.address_space: global
|
||||
.offset: 0
|
||||
.size: 8
|
||||
.value_kind: global_buffer
|
||||
.value_type: bf16
|
||||
- .name: B
|
||||
.address_space: global
|
||||
.offset: 8
|
||||
.size: 8
|
||||
.value_kind: global_buffer
|
||||
.value_type: bf16
|
||||
- .name: A
|
||||
.address_space: global
|
||||
.offset: 16
|
||||
.size: 8
|
||||
.value_kind: global_buffer
|
||||
.value_type: bf16
|
||||
- .name: sz
|
||||
.offset: 24
|
||||
.size: 4
|
||||
.value_kind: by_value
|
||||
.value_type: u32
|
||||
- .name: num_wg
|
||||
.offset: 28
|
||||
.size: 4
|
||||
.value_kind: by_value
|
||||
.value_type: u32
|
||||
.group_segment_fixed_size: 133120
|
||||
.private_segment_fixed_size: 0
|
||||
.kernarg_segment_align: 8
|
||||
.kernarg_segment_size: 32
|
||||
.max_flat_workgroup_size: 256
|
||||
.sgpr_count: 88
|
||||
.sgpr_spill_count: 0
|
||||
.vgpr_count: 248
|
||||
.vgpr_spill_count: 0
|
||||
.wavefront_size: 64
|
||||
amdhsa.version:
|
||||
- 1
|
||||
- 0
|
||||
...
|
||||
.end_amdgpu_metadata
|
||||
@@ -48,7 +48,7 @@ ast = sched[-1].ast
|
||||
# assembly gemm
|
||||
@track_rewrites(name=lambda ret: TracingKey(ret.name, (ret.function_name,), ret))
|
||||
def get_asm_prg() -> ProgramSpec:
|
||||
src = fp.read_text()
|
||||
src = (pathlib.Path(__file__).parent/"template.s").read_text().replace("INSTRUCTIONS", fp.read_text())
|
||||
lib = Device[Device.DEFAULT].compiler.compile(src)
|
||||
return ProgramSpec("gemm", src, Device.DEFAULT, ast, lib=lib, global_size=[NUM_WG, 1, 1], local_size=[THREADS_PER_WG, 1, 1],
|
||||
globals=[0, 1, 2], vars=[UOp.variable("SZ", 256, 8192), UOp.variable("NUM_WG", 1, 1024)])
|
||||
|
||||
Reference in New Issue
Block a user