diff --git a/extra/gemm/asm/gemm.s b/extra/gemm/asm/gemm.s index 52e44b9f01..1f1bb22611 100644 --- a/extra/gemm/asm/gemm.s +++ b/extra/gemm/asm/gemm.s @@ -1,10 +1,3 @@ -.text -.section .text. -.global gemm -.p2align 8 -.type gemm,@function - -gemm: // ** global buffers s_load_dwordx2 s[28:29], s[0:1], 0x0 // C s_load_dwordx4 s[32:35], s[0:1], 0x8 // A, B @@ -221,11 +214,11 @@ gemm: s_mul_hi_u32 s87, s86, s40 // 000000003288: 96572856 s_mul_i32 s86, s86, s40 // 00000000328C: 92562856 s_and_b32 s84, s50, 0x8000 // 000000003290: 8654FF32 00008000 - s_cbranch_scc1 label_GSUC_A // 000000003298: BF850003 + s_cbranch_scc1 skip_offset_A // 000000003298: BF850003 s_mul_hi_u32 s85, 64, s6 // 00000000329C: 965506C0 s_mul_i32 s84, 64, s6 // 0000000032A0: 925406C0 -label_GSUC_A: +skip_offset_A: s_add_u32 s86, s86, s84 // 000000003330: 80565456 s_addc_u32 s87, s87, s85 // 000000003334: 82575557 s_mov_b64 s[60:61], 1 // 000000003338: BEBC0181 @@ -259,11 +252,11 @@ label_GSUC_A: s_mul_hi_u32 s87, s86, s42 // 0000000033B4: 96572A56 s_mul_i32 s86, s86, s42 // 0000000033B8: 92562A56 s_and_b32 s84, s50, 0x8000 // 0000000033BC: 8654FF32 00008000 - s_cbranch_scc1 label_GSUC_B // 0000000033C4: BF850003 + s_cbranch_scc1 skip_offset_B // 0000000033C4: BF850003 s_mul_hi_u32 s85, 64, s6 // 0000000033C8: 965506C0 s_mul_i32 s84, 64, s6 // 0000000033CC: 925406C0 -label_GSUC_B: +skip_offset_B: s_add_u32 s86, s86, s84 // 00000000345C: 80565456 s_addc_u32 s87, s87, s85 // 000000003460: 82575557 s_mov_b64 s[62:63], 1 // 000000003464: BEBE0181 @@ -308,8 +301,6 @@ label_GSUC_B: s_and_b32 s87, s10, 0xe000 // 0000000035A4: 8657FF0A 0000E000 s_and_b32 s10, s10, 0xff // 0000000035AC: 860AFF0A 000000FF s_mov_b32 s84, s10 // 0000000035B4: BED4000A - -label_beginStaggerUIter: s_lshl_b32 s85, s84, s86 // 0000000035B8: 8E555654 s_cmp_ge_u32 s13, s85 // 0000000035BC: BF09550D s_sub_u32 s85, s84, 1 // 0000000035CC: 80D58154 @@ -344,7 +335,7 @@ label_beginStaggerUIter: s_cselect_b32 s58, s62, -1 // 0000000036A4: 853AC13E s_add_u32 s51, s51, 2 // 0000000036A8: 80338233 s_cmp_eq_u32 s12, 0 // 0000000036AC: BF06800C - s_cbranch_scc1 label_ShadowInitStart // 0000000036B0: BF850092 + s_cbranch_scc1 init_output_buffers // 0000000036B0: BF850092 s_mov_b32 m0, s46 // 0000000036B4: BEFC002E buffer_load_dwordx4 v0, s[52:55], 0 offen lds // 0000000036B8: E05D1000 800D0000 s_add_u32 m0, m0, 0x1040 // 0000000036C0: 807CFF7C 00001040 @@ -431,7 +422,7 @@ label_beginStaggerUIter: s_cmp_eq_u32 s63, 0 // 0000000038F4: BF06803F s_cselect_b32 s58, s62, -1 // 0000000038F8: 853AC13E -label_ShadowInitStart: +init_output_buffers: s_mov_b64 s[16:17], s[28:29] // 0000000038FC: BE90011C s_mov_b32 s18, 0x80000000 // 000000003900: BE9200FF 80000000 s_mov_b32 s19, 0x20000 // 000000003908: BE9300FF 00020000 @@ -476,12 +467,10 @@ label_ShadowInitStart: s_lshl_b64 s[84:85], s[84:85], 2 // 0000000039C4: 8ED48254 s_add_u32 s16, s16, s84 // 0000000039C8: 80105410 s_addc_u32 s17, s17, s85 // 0000000039CC: 82115511 - -label_NoBranch_T8JHFHKM7BO5OHXW: s_xor_b32 s46, s48, s46 // 0000000039F0: 882E2E30 s_xor_b32 s47, s49, s47 // 0000000039F4: 882F2F31 s_cmp_eq_u32 s12, 1 // 0000000039F8: BF06810C - s_cbranch_scc1 label_skipPGR2 // 0000000039FC: BF850040 + s_cbranch_scc1 after_prefetch // 0000000039FC: BF850040 s_mov_b32 m0, s46 // 000000003A00: BEFC002E buffer_load_dwordx4 v0, s[52:55], 0 offen lds // 000000003A04: E05D1000 800D0000 s_add_u32 m0, m0, 0x1040 // 000000003A0C: 807CFF7C 00001040 @@ -517,7 +506,7 @@ label_NoBranch_T8JHFHKM7BO5OHXW: s_xor_b32 s46, s48, s46 // 000000003AF8: 882E2E30 s_xor_b32 s47, s49, s47 // 000000003AFC: 882F2F31 -label_skipPGR2: +after_prefetch: s_waitcnt vmcnt(24) // 000000003B00: BF8C4F78 s_barrier // 000000003B04: BF8A0000 ds_read_b128 v[4:7], v2 // 000000003B08: D9FE0000 04000002 @@ -539,14 +528,12 @@ label_skipPGR2: ds_read_b128 v[92:95], v3 offset:768 // 000000003B80: D9FE0300 5C000003 ds_read_b128 v[96:99], v3 offset:896 // 000000003B88: D9FE0380 60000003 s_waitcnt lgkmcnt(0) // 000000003B90: BF8CC07F - -label_openLoopL: s_cmp_eq_u32 s12, 1 // 000000003B94: BF06810C - s_cbranch_scc1 label_toPGR1 // 000000003B98: BF8502E5 + s_cbranch_scc1 final_compute // 000000003B98: BF8502E5 s_cmp_le_u32 s12, 2 // 000000003B9C: BF0B820C - s_cbranch_scc1 label_LoopEndL // 000000003BA0: BF85019E + s_cbranch_scc1 loop_epilogue // 000000003BA0: BF85019E -label_LoopBeginL: +main_loop: v_mfma_f32_16x16x32_bf16 a[0:3], v[68:71], v[4:7], a[0:3] // 000000003BA4: D3B58000 04020944 ds_read_b128 v[36:39], v2 offset:64 // 000000003BAC: D9FE0040 24000002 v_mfma_f32_16x16x32_bf16 a[4:7], v[68:71], v[8:11], a[4:7] // 000000003BB4: D3B58004 04121144 @@ -770,9 +757,9 @@ label_LoopBeginL: s_cmp_eq_i32 s12, 2 // 000000004208: BF00820C s_waitcnt lgkmcnt(0) // 00000000420C: BF8CC07F v_mfma_f32_16x16x32_bf16 a[252:255], v[128:131], v[64:67], a[252:255]// 000000004210: D3B580FC 07F28180 - s_cbranch_scc0 label_LoopBeginL // 000000004218: BF84FE62 + s_cbranch_scc0 main_loop // 000000004218: BF84FE62 -label_LoopEndL: +loop_epilogue: v_mfma_f32_16x16x32_bf16 a[0:3], v[68:71], v[4:7], a[0:3] // 00000000421C: D3B58000 04020944 ds_read_b128 v[36:39], v2 offset:64 // 000000004224: D9FE0040 24000002 v_mfma_f32_16x16x32_bf16 a[4:7], v[68:71], v[8:11], a[4:7] // 00000000422C: D3B58004 04121144 @@ -939,7 +926,7 @@ label_LoopEndL: v_mfma_f32_16x16x32_bf16 a[248:251], v[128:131], v[60:63], a[248:251]// 000000004720: D3B580F8 07E27980 v_mfma_f32_16x16x32_bf16 a[252:255], v[128:131], v[64:67], a[252:255]// 000000004728: D3B580FC 07F28180 -label_toPGR1: +final_compute: s_and_b32 s8, s50, 0x3fff // 000000004730: 8608FF32 00003FFF s_and_b32 s84, 0xff, s24 // 000000004750: 865418FF 000000FF s_add_u32 s85, -1, s14 // 000000004758: 80550EC1 @@ -1095,7 +1082,6 @@ label_toPGR1: v_mfma_f32_16x16x32_bf16 a[248:251], v[128:131], v[60:63], a[248:251]// 000000004BFC: D3B580F8 07E27980 v_mfma_f32_16x16x32_bf16 a[252:255], v[128:131], v[64:67], a[252:255]// 000000004C04: D3B580FC 07F28180 -label_toPGR1end_OptNLL: v_lshrrev_b32_e32 v4, 6, v134 // 000000004C0C: 20090C86 v_lshrrev_b32_e32 v5, 1, v4 // 000000004C10: 200A0881 v_mul_lo_u32 v5, 16, v5 // 000000004C14: D2850005 00020A90 @@ -1114,7 +1100,6 @@ label_toPGR1end_OptNLL: s_mul_i32 s8, 0x100, s3 // 000000004C64: 920803FF 00000100 v_add_u32_e32 v1, s8, v1 // 000000004C6C: 68020208 -label_GW_B0_E0: v_add_lshl_u32 v11, v3, v0, 1 // 000000004C70: D1FE000B 02060103 v_accvgpr_read_b32 v16, a0 // 000000004C78: D3D84010 18000100 v_accvgpr_read_b32 v17, a4 // 000000004C80: D3D84011 18000104 @@ -1633,86 +1618,4 @@ label_GW_B0_E0: s_addc_u32 s17, s17, 0 // 000000005B14: 82118011 buffer_store_dwordx4 v[40:43], v11, s[16:19], 0 offen nt // 000000005B18: E07E1000 8004280B s_nop 0 // 000000005B20: BF800000 - -end: s_endpgm // 00000001F5D0: BF810000 - -.section .rodata,"a",@progbits -.p2align 6, 0x0 -.amdhsa_kernel gemm - # ---- basic memory requirements ---- - .amdhsa_group_segment_fixed_size 133120 - .amdhsa_private_segment_fixed_size 0 - .amdhsa_kernarg_size 32 - - # ---- register usage (RSRC1) ---- - .amdhsa_next_free_vgpr 504 - .amdhsa_next_free_sgpr 96 - - # ---- workgroup / workitem IDs (RSRC2) ---- - .amdhsa_system_sgpr_workgroup_id_x 1 - .amdhsa_system_sgpr_workgroup_id_y 1 - .amdhsa_system_sgpr_workgroup_id_z 1 - - # ---- user SGPR enables (descriptor bits >448) ---- - .amdhsa_user_sgpr_kernarg_segment_ptr 1 - .amdhsa_user_sgpr_count 2 - .amdhsa_user_sgpr_kernarg_preload_length 0 - .amdhsa_user_sgpr_kernarg_preload_offset 0 - - # ---- gfx90a / gfx940 specific (RSRC3) ---- - .amdhsa_accum_offset 248 - .amdhsa_uses_dynamic_stack 0 - .amdhsa_tg_split 0 - -.end_amdhsa_kernel - -.amdgpu_metadata ---- -amdhsa.kernels: - - .args: - - .address_space: global - .name: C - .offset: 0 - .size: 8 - .value_kind: global_buffer - .value_type: bf16 - - .address_space: global - .name: B - .offset: 8 - .size: 8 - .value_kind: global_buffer - .value_type: bf16 - - .address_space: global - .name: A - .offset: 16 - .size: 8 - .value_kind: global_buffer - .value_type: bf16 - - .name: sz - .offset: 24 - .size: 4 - .value_kind: by_value - .value_type: u32 - - .name: num_wg - .offset: 28 - .size: 4 - .value_kind: by_value - .value_type: u32 - .group_segment_fixed_size: 133120 - .kernarg_segment_align: 8 - .kernarg_segment_size: 32 - .max_flat_workgroup_size: 256 - .name: gemm - .private_segment_fixed_size: 0 - .sgpr_count: 88 - .sgpr_spill_count: 0 - .symbol: gemm.kd - .vgpr_count: 248 - .vgpr_spill_count: 0 - .wavefront_size: 64 -amdhsa.version: - - 1 - - 0 -... - .end_amdgpu_metadata diff --git a/extra/gemm/asm/template.s b/extra/gemm/asm/template.s new file mode 100644 index 0000000000..3bb1292d9e --- /dev/null +++ b/extra/gemm/asm/template.s @@ -0,0 +1,83 @@ +.text +.section .text. +.global gemm +.p2align 8 +.type gemm,@function + +gemm: +INSTRUCTIONS + +.section .rodata,"a",@progbits +.p2align 6, 0x0 +.amdhsa_kernel gemm + # basic memory requirements + .amdhsa_group_segment_fixed_size 133120 + .amdhsa_private_segment_fixed_size 0 + .amdhsa_kernarg_size 32 + # register usage (RSRC1) + .amdhsa_next_free_vgpr 504 + .amdhsa_next_free_sgpr 96 + # workgroup / workitem IDs (RSRC2) + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + # user SGPRs, we only specify the kernel args ptr in s[0:1] + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_user_sgpr_count 2 + .amdhsa_user_sgpr_kernarg_preload_length 0 + .amdhsa_user_sgpr_kernarg_preload_offset 0 + # gfx90a / gfx940 specifics (RSRC3) + .amdhsa_accum_offset 248 + .amdhsa_uses_dynamic_stack 0 + .amdhsa_tg_split 0 +.end_amdhsa_kernel + +.amdgpu_metadata +--- +amdhsa.kernels: + - .name: gemm + .symbol: gemm.kd + .args: + - .name: C + .address_space: global + .offset: 0 + .size: 8 + .value_kind: global_buffer + .value_type: bf16 + - .name: B + .address_space: global + .offset: 8 + .size: 8 + .value_kind: global_buffer + .value_type: bf16 + - .name: A + .address_space: global + .offset: 16 + .size: 8 + .value_kind: global_buffer + .value_type: bf16 + - .name: sz + .offset: 24 + .size: 4 + .value_kind: by_value + .value_type: u32 + - .name: num_wg + .offset: 28 + .size: 4 + .value_kind: by_value + .value_type: u32 + .group_segment_fixed_size: 133120 + .private_segment_fixed_size: 0 + .kernarg_segment_align: 8 + .kernarg_segment_size: 32 + .max_flat_workgroup_size: 256 + .sgpr_count: 88 + .sgpr_spill_count: 0 + .vgpr_count: 248 + .vgpr_spill_count: 0 + .wavefront_size: 64 +amdhsa.version: + - 1 + - 0 +... +.end_amdgpu_metadata diff --git a/extra/gemm/asm/test.py b/extra/gemm/asm/test.py index 817d379859..3b7dc3196e 100644 --- a/extra/gemm/asm/test.py +++ b/extra/gemm/asm/test.py @@ -48,7 +48,7 @@ ast = sched[-1].ast # assembly gemm @track_rewrites(name=lambda ret: TracingKey(ret.name, (ret.function_name,), ret)) def get_asm_prg() -> ProgramSpec: - src = fp.read_text() + src = (pathlib.Path(__file__).parent/"template.s").read_text().replace("INSTRUCTIONS", fp.read_text()) lib = Device[Device.DEFAULT].compiler.compile(src) return ProgramSpec("gemm", src, Device.DEFAULT, ast, lib=lib, global_size=[NUM_WG, 1, 1], local_size=[THREADS_PER_WG, 1, 1], globals=[0, 1, 2], vars=[UOp.variable("SZ", 256, 8192), UOp.variable("NUM_WG", 1, 1024)])