From 79c1559f69972e576e1db6fffec40fe403348d31 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 17 Jan 2026 18:40:10 +0900 Subject: [PATCH] amd asm can still be simpler (#14199) * amd asm can still be simpler * simpler * V_LANE_ID * simpler * simpler * compact vgpr --- extra/gemm/amd_asm_matmul.py | 290 +++++++++++------------------------ 1 file changed, 89 insertions(+), 201 deletions(-) diff --git a/extra/gemm/amd_asm_matmul.py b/extra/gemm/amd_asm_matmul.py index e0b6d5dfe3..0338aba80f 100644 --- a/extra/gemm/amd_asm_matmul.py +++ b/extra/gemm/amd_asm_matmul.py @@ -6,7 +6,7 @@ # Workgroup: 128 threads (arranged as 32x4 for coalesced memory access) # Inner loop: 8 iterations per K-block, processing 8 columns of A and 8 rows of B # -# Accumulators: 128 vgprs (v[2-117], v[120-124], v[126-129], v[131-133]) +# Accumulators: 128 vgprs (v[2-129]) import numpy as np from pathlib import Path @@ -27,23 +27,20 @@ LDS_B_STRIDE = 0x200 # LDS stride for B tile (512 bytes) LDS_BASE_OFFSET = 0x1080 # Base LDS offset for tiles ADDR_MASK = 0x3fffff80 # Address alignment mask -# s_waitcnt encodings: wait for memory operations to complete -WAIT_LGKM = 64519 # wait for LDS/GDS/KMEM (lgkm_cnt=0) -WAIT_ALL = 0 # wait for everything -WAIT_VMEM = 1015 # wait for VMEM only (vm_cnt=0, lgkm_cnt=63) - # ============================================================================= -# Named register assignments (VGPRs) - COMPACT LAYOUT +# Named register assignments (VGPRs) # ============================================================================= -V_LANE_ID_MOD8 = 182 # lane_id & 7 (column within 8-wide tile chunk) -V_OUTPUT_ROW = 171 # output row coordinate -V_LANE_MOD8_X4 = 174 # V_LANE_ID_MOD8 << 2 (byte offset) -V_LANE_DIV8_X4 = 175 # (lane_id >> 3) << 2 -V_ADDR_HI_ZERO = 188 # always 0 (for 64-bit address high bits) -V_LDS_A_BASE = 186 # LDS A-tile base address for inner loop -V_LDS_B_BASE = 170 # LDS B-tile base address for inner loop -V_GLOBAL_A_ADDR = 171 # global memory A prefetch address (reuses V_OUTPUT_ROW slot during main loop) -V_GLOBAL_B_ADDR = 178 # global memory B prefetch address +V_LANE_ID = 0 # lane_id set on startup +# Use tile gaps (v146-159) for named regs to minimize max VGPR +V_LANE_ID_MOD8 = 146 # lane_id & 7 +V_LANE_MOD8_X4 = 147 # (lane_id & 7) << 2 +V_LANE_DIV8_X4 = 150 # ((lane_id >> 3) & 3) << 2 +V_LDS_B_BASE = 151 # LDS B-tile base address for inner loop +V_LDS_A_BASE = 154 # LDS A-tile base address for inner loop +V_GLOBAL_A_ADDR = 155 # global memory A prefetch address +V_GLOBAL_B_ADDR = 158 # global memory B prefetch address +V_LDS_A_ADDR = 159 # single base register for A stores +V_LDS_B_ADDR = 162 # single base register for B stores # LDS tile register destinations - SEPARATE from DATA to avoid overlap # A on banks 2-3, B on banks 0-1 to avoid bank conflicts in VOPD @@ -58,14 +55,11 @@ S_TILE_X = 2 # workgroup_x << 7 S_TILE_Y = 3 # workgroup_y << 7 S_DIM_N = 4 # matrix dimension N S_LOOP_BOUND = 7 # K-8 (loop termination bound) -S_A_PTR = (8, 9) # A matrix base pointer -S_B_PTR = (10, 11) # B matrix base pointer S_LOOP_CTR = 12 # loop counter (increments by 8) S_PREFETCH_FLAG = 13 # prefetch condition flag / row stride in epilogue S_WORKGROUP_X = 14 # workgroup_id_x S_WORKGROUP_Y = 15 # workgroup_id_y -# Kernarg load destinations (before copy to working regs) -S_KERNARG_OUT = (16, 17) # output pointer from kernarg +# Kernarg load destinations S_KERNARG_A = (20, 21) # A pointer from kernarg S_KERNARG_B = (22, 23) # B pointer from kernarg # Prefetch base pointers (8 pairs each, 16KB/256KB apart) @@ -100,8 +94,6 @@ FMAC_PAIR_ORDER = [ def derive_fmac_pattern(acc_grid, a_tile_regs=None, b_tile_regs=None): """Generate 64 dual FMAC ops from accumulator grid with optimized iteration order.""" - if a_tile_regs is None: a_tile_regs = V_A_TILE_REGS - if b_tile_regs is None: b_tile_regs = V_B_TILE_REGS pattern = [] for idx, (a_pair, b_pair) in enumerate(FMAC_PAIR_ORDER): a_even, a_odd = a_pair * 2, a_pair * 2 + 1 @@ -120,14 +112,14 @@ def derive_fmac_pattern(acc_grid, a_tile_regs=None, b_tile_regs=None): return pattern # Derived: 64 dual FMAC operations -FMAC_PATTERN = derive_fmac_pattern(ACC_GRID) +FMAC_PATTERN = derive_fmac_pattern(ACC_GRID, V_A_TILE_REGS, V_B_TILE_REGS) def derive_permute_swaps(acc_grid, out_regs): """Derive swap sequence to permute accumulators from FMAC layout to output order. After FMAC loop: acc_grid[a][b] holds C[a,b] Output order: for row_half in 0,1; col_group in 0-3; row_in_group in 0-3; b_off in 0-3 - -> need C[row_half*4 + row_in_group, col_group*4 + b_off] in descending reg order + -> need C[row_half*4 + row_in_group, col_group*4 + b_off] in specified reg order """ def target_ab(i): row_half, col_group = i // 64, (i // 16) % 4 @@ -150,34 +142,27 @@ def derive_permute_swaps(acc_grid, out_regs): return swaps # Derived: swap sequence to arrange accumulators for output -OUT_REGS = list(range(129, 1, -1)) +# Each group of 4 registers is ascending for direct global_store_b128 +OUT_REGS = [r for i in range(32) for r in range(126 - i*4, 130 - i*4)] PERMUTE_SWAPS = derive_permute_swaps(ACC_GRID, OUT_REGS) # ============================================================================= -# LDS tile staging registers - COMPACT LAYOUT +# LDS tile staging registers # ============================================================================= # DATA regs receive contiguous global prefetch, then write to LDS # TILE regs receive scattered LDS loads (ds_load_b64 pairs), then feed FMACs -# These are SEPARATE - DATA lives during prefetch/store, TILE lives during inner loop -V_LDS_A_ADDR = 189 # single base register for A stores (use +512 offsets) -V_LDS_A_DATA = [155, 172, 173, 154, 159, 176, 177, 158] # 8 data registers for A prefetch (mod 4: 3,0,1,2,3,0,1,2) -V_LDS_B_ADDR = 190 # single base register for B stores (use 16-bit offsets) -V_LDS_B_DATA = [163, 180, 181, 162, 167, 184, 185, 166] # 8 data registers for B prefetch (mod 4: 3,0,1,2,3,0,1,2) +# Contiguous layout with mod4=[3,0,1,2,3,0,1,2] for bank conflict avoidance +V_LDS_A_DATA = [163, 164, 165, 166, 167, 168, 169, 170] +V_LDS_B_DATA = [171, 172, 173, 174, 175, 176, 177, 178] + +# Initial tile prefetch: (vdst, saddr_lo) - load into A data regs using B prefetch pointers (s[24:31]) +INIT_PREFETCH = [(V_LDS_A_DATA[i], S_PREFETCH_B+2*i) for i in range(4)] # Global memory prefetch schedule: (vdst1, vdst2, addr_vreg, saddr_lo1, saddr_lo2) # First 2 pairs from B prefetch pointers (s[32:39]), next 4 pairs from A prefetch pointers (s[40:55]) PREFETCH_LOADS = [(V_LDS_A_DATA[4+2*i], V_LDS_A_DATA[4+2*i+1], V_GLOBAL_B_ADDR, S_PREFETCH_B+8+4*i, S_PREFETCH_B+10+4*i) for i in range(2)] + \ [(V_LDS_B_DATA[2*(i-2)], V_LDS_B_DATA[2*(i-2)+1], V_GLOBAL_A_ADDR, S_PREFETCH_A+4*(i-2), S_PREFETCH_A+2+4*(i-2)) for i in range(2, 6)] -# Initial tile prefetch: (vdst, saddr_lo) - load into A data regs using B prefetch pointers (s[24:31]) -INIT_PREFETCH = [(V_LDS_A_DATA[i], S_PREFETCH_B+2*i) for i in range(4)] - -# Initial tile loads: (vdst, addr_lo) pairs - use temp regs in accumulator gaps -INIT_TILE_LOADS = [(23,5),(24,9),(25,7),(26,2),(27,11),(28,13),(29,6),(30,8),(31,10),(12,12),(13,14),(3,2),(4,4),(5,8),(6,6),(7,10)] - -# A matrix row offset registers (scattered to avoid accumulator conflicts) -ROW_REGS = [165, 146, 147, 164, 169, 150, 151, 168] # mod 4: 1,2,3,0,1,2,3,0 - # ============================================================================= # Kernel class # ============================================================================= @@ -218,7 +203,7 @@ class Kernel: ('user_sgpr_kernarg_segment_ptr', 1), ('user_sgpr_dispatch_id', 0), ('user_sgpr_private_segment_size', 0), ('wavefront_size32', 1), ('uses_dynamic_stack', 0), ('enable_private_segment', 0), ('system_sgpr_workgroup_id_x', 1), ('system_sgpr_workgroup_id_y', 1), ('system_sgpr_workgroup_id_z', 0), - ('system_sgpr_workgroup_info', 0), ('system_vgpr_workitem_id', 0), ('next_free_vgpr', 192), + ('system_sgpr_workgroup_info', 0), ('system_vgpr_workitem_id', 0), ('next_free_vgpr', 179), ('next_free_sgpr', 16), ('float_round_mode_32', 0), ('float_round_mode_16_64', 0), ('float_denorm_mode_32', 3), ('float_denorm_mode_16_64', 3), ('dx10_clamp', 1), ('ieee_mode', 1), ('fp16_overflow', 0), ('workgroup_processor_mode', 0), ('memory_ordered', 1), ('forward_progress', 0), @@ -236,7 +221,7 @@ class Kernel: f' .group_segment_fixed_size: {lds_size}', ' .kernarg_segment_align: 8', ' .kernarg_segment_size: 24', ' .max_flat_workgroup_size: 128', ' .name: kernel', ' .private_segment_fixed_size: 0', ' .sgpr_count: 60', ' .symbol: kernel.kd', - ' .vgpr_count: 192', ' .wavefront_size: 32', f'amdhsa.target: amdgcn-amd-amdhsa--{self.arch}', + ' .vgpr_count: 179', ' .wavefront_size: 32', f'amdhsa.target: amdgcn-amd-amdhsa--{self.arch}', 'amdhsa.version:', ' - 1', ' - 2', '...', '\t.end_amdgpu_metadata']) @@ -251,27 +236,20 @@ def build_kernel(arch='gfx1100'): # PROLOGUE: Load kernel arguments, compute tile coordinates and addresses # =========================================================================== k.emit(s_load_b128(sdata=s[S_KERNARG_A[0]:S_KERNARG_B[1]], sbase=s[0:1], offset=0x0, soffset=NULL)) - k.emit(s_load_b64(sdata=s[S_KERNARG_OUT[0]:S_KERNARG_OUT[1]], sbase=s[0:1], offset=0x10, soffset=NULL)) + k.emit(s_load_b64(sdata=s[S_OUT_PTR[0]:S_OUT_PTR[1]], sbase=s[0:1], offset=0x10, soffset=NULL)) k.emit(s_mov_b32(s[S_DIM_N], MATRIX_DIM)) k.emit(s_mov_b32(s[S_LOOP_CTR], 0)) # used by LDS swizzle, always 0 for valid workgroups k.emit(s_lshl_b32(s[S_TILE_X], s[S_WORKGROUP_X], 7)) k.emit(s_lshl_b32(s[S_TILE_Y], s[S_WORKGROUP_Y], 7)) # Lane-derived values - k.emit(v_and_b32_e32(v[V_LANE_ID_MOD8], 7, v[0])) - k.emit(v_lshrrev_b32_e32(v[4], 3, v[0])) - k.emit(v_or_b32_e32(v[1], s[S_TILE_X], v[0])) + k.emit(v_and_b32_e32(v[V_LANE_ID_MOD8], 7, v[V_LANE_ID])) + k.emit(v_lshrrev_b32_e32(v[4], 3, v[V_LANE_ID])) + k.emit(v_or_b32_e32(v[1], s[S_TILE_X], v[V_LANE_ID])) k.emit(v_or_b32_e32(v[22], s[S_TILE_Y], v[4])) k.emit(v_lshlrev_b32_e32(v[V_LANE_MOD8_X4], 2, v[V_LANE_ID_MOD8])) - k.emit(v_mov_b32_e32(v[2], 0)) # v[1] always positive, sign extension is 0 - k.emit(v_lshlrev_b64(v[5:6], 2, v[1:2])) k.waitcnt(lgkm=0) - # Copy pointers to working registers - k.emit(s_mov_b64(s[S_OUT_PTR[0]:S_OUT_PTR[1]], s[S_KERNARG_OUT[0]:S_KERNARG_OUT[1]])) - k.emit(s_mov_b64(s[S_A_PTR[0]:S_A_PTR[1]], s[S_KERNARG_A[0]:S_KERNARG_A[1]])) - k.emit(s_mov_b64(s[S_B_PTR[0]:S_B_PTR[1]], s[S_KERNARG_B[0]:S_KERNARG_B[1]])) - # Compute 8 A and B matrix tile base pointers for prefetch k.emit(s_mov_b64(s[S_PREFETCH_B:S_PREFETCH_B+1], s[S_KERNARG_B[0]:S_KERNARG_B[1]])) # B[0]: no offset for i in range(1, 8): # B: 16KB apart @@ -283,142 +261,59 @@ def build_kernel(arch='gfx1100'): k.emit(s_addc_u32(s[S_PREFETCH_A+i*2+1], s[S_KERNARG_A[1]], 0)) # Global prefetch addresses: B = (tile_x + lane_id) * 4, A = ((tile_y << 12) + (lane_id/8)*4K + lane_id%8) * 4 - k.emit(v_add_nc_u32_e32(v[V_GLOBAL_B_ADDR], s[S_TILE_X], v[0])) + k.emit(v_add_nc_u32_e32(v[V_GLOBAL_B_ADDR], s[S_TILE_X], v[V_LANE_ID])) k.emit(v_lshlrev_b32_e32(v[V_GLOBAL_B_ADDR], 2, v[V_GLOBAL_B_ADDR])) k.emit(s_lshl_b32(s[19], s[S_TILE_Y], 12)) k.emit(v_lshl_add_u32(v[V_GLOBAL_A_ADDR], v[4], 12, v[V_LANE_ID_MOD8])) # (lane_id/8)*4K + lane_id%8 k.emit(v_add_nc_u32_e32(v[V_GLOBAL_A_ADDR], s[19], v[V_GLOBAL_A_ADDR])) k.emit(v_lshlrev_b32_e32(v[V_GLOBAL_A_ADDR], 2, v[V_GLOBAL_A_ADDR])) - # =========================================================================== - # Tile address computation for initial A/B matrix loads - # =========================================================================== - k.emit(s_lshl_b32(s[S_LOOP_BOUND], s[S_DIM_N], 4)) # row stride = 16*N - k.emit(v_mul_lo_u32(v[ROW_REGS[0]], v[22], s[S_DIM_N])) # A matrix row offsets - for i in range(1, 8): k.emit(v_add_nc_u32_e32(v[ROW_REGS[i]], s[S_LOOP_BOUND], v[ROW_REGS[i-1]])) - - def addr64(dst, base_s): # 64-bit address: v[dst:dst+1] = s[base_s:base_s+1] + v[dst]*4 - k.emit(v_mov_b32_e32(v[dst+1], 0)) # offset always positive, sign ext = 0 - k.emit(v_lshlrev_b64(v[dst:dst+1], 2, v[dst:dst+1])) - k.emit(v_add_co_u32(v[dst], VCC_LO, s[base_s], v[dst])) - k.emit(v_add_co_ci_u32_e32(v[dst+1], s[base_s+1], v[dst+1])) - - def b_addr(dst, mult, tmp=None): # B address for col + mult*N - tmp = tmp if tmp is not None else dst - k.emit(v_mad_u32_u24(v[tmp], s[S_DIM_N], mult, v[1])) - if tmp != dst: - k.emit(v_mov_b32_e32(v[tmp+1], 0)) # offset always positive - k.emit(v_lshlrev_b64(v[dst:dst+1], 2, v[tmp:tmp+1])) - k.emit(v_add_co_u32(v[dst], VCC_LO, s[S_B_PTR[0]], v[dst])) - k.emit(v_add_co_ci_u32_e32(v[dst+1], s[S_B_PTR[1]], v[dst+1])) - else: addr64(dst, S_B_PTR[0]) - - def a_addr(dst, row_reg, tmp): # A address for row_reg + lane_id_mod8 - k.emit(v_add_nc_u32_e32(v[tmp], v[row_reg], v[V_LANE_ID_MOD8])) - k.emit(v_mov_b32_e32(v[tmp+1], 0)) # offset always positive - k.emit(v_lshlrev_b64(v[dst:dst+1], 2, v[tmp:tmp+1])) - k.emit(v_add_co_u32(v[dst], VCC_LO, s[S_A_PTR[0]], v[dst])) - k.emit(v_add_co_ci_u32_e32(v[dst+1], s[S_A_PTR[1]], v[dst+1])) - - # Batch 1: B addresses (cols 0-5) and loads - k.emit(v_add_co_u32(v[5], VCC_LO, s[S_B_PTR[0]], v[5])) - k.emit(v_add_co_ci_u32_e32(v[6], s[S_B_PTR[1]], v[6])) - for dst, mult in [(9,1), (7,2), (2,3), (11,4), (13,5)]: b_addr(dst, mult) - k.emit(s_clause(simm16=5)) # 6 consecutive global loads - for vdst, addr in INIT_TILE_LOADS[:6]: k.emit(global_load_b32(vdst=v[vdst], addr=v[addr:addr+1], saddr=NULL)) - - # Batch 2: A addresses (rows 0-4) and loads - for dst, ri in [(6,0), (8,1), (10,2), (12,3), (14,4)]: - k.emit(v_add_nc_u32_e32(v[dst], v[ROW_REGS[ri]], v[V_LANE_ID_MOD8])) - addr64(dst, S_A_PTR[0]) - k.emit(s_clause(simm16=4)) # 5 consecutive global loads - for vdst, addr in INIT_TILE_LOADS[6:11]: k.emit(global_load_b32(vdst=v[vdst], addr=v[addr:addr+1], saddr=NULL)) - - # Batch 3: B cols 6-7, A rows 5-7, and loads - for dst, mult, tmp in [(2,6,15), (4,7,4)]: b_addr(dst, mult, tmp) - for dst, ri, tmp in [(8,5,16), (6,6,18), (10,7,20)]: a_addr(dst, ROW_REGS[ri], tmp) - k.emit(s_clause(simm16=4)) # 5 consecutive global loads - for vdst, addr in INIT_TILE_LOADS[11:]: k.emit(global_load_b32(vdst=v[vdst], addr=v[addr:addr+1], saddr=NULL)) + # Do initial loads + for vdst, saddr_lo in INIT_PREFETCH: + k.emit(global_load_b32(vdst=v[vdst], addr=v[V_GLOBAL_B_ADDR], saddr=s[saddr_lo:saddr_lo+1])) + for iter in range(6): + vdst1, vdst2, addr, slo1, slo2 = PREFETCH_LOADS[iter] + k.emit(global_load_b32(vdst=v[vdst1], addr=v[addr], saddr=s[slo1:slo1+1])) + k.emit(global_load_b32(vdst=v[vdst2], addr=v[addr], saddr=s[slo2:slo2+1])) # =========================================================================== # LDS store address computation (bank-conflict-avoiding swizzle) # =========================================================================== # This section computes LDS store addresses with a swizzle pattern to avoid bank conflicts. - # Key outputs: - # v[8]: A-tile initial store base (used only for initial stores with stride64) - # V_LDS_B_ADDR (v145): B-tile store base (used for both initial and main loop) - # V_LANE_DIV8_X4 (v135): (lane_id >> 3) << 2 for epilogue - # # The swizzle ensures that threads in the same wavefront write to different LDS banks. # Formula: swizzled_addr = base + (lane_id & 7) * LDS_A_STRIDE + swizzle_offset # where swizzle_offset depends on (lane_id >> 3) to distribute across banks. - - # v[22] = tile_y | (lane_id >> 3) from prologue, used as base for row offsets - # Compute 7 row offsets for B-tile rows 1-7 (row 0 computed separately in v[9]) - k.emit(v_add_nc_u32_e32(v[9], s[S_LOOP_CTR], v[22])) # row 0 base (S_LOOP_CTR=0) - for i in range(7): k.emit(v_or_b32_e32(v[10 + i if i < 2 else 12 + i], 16 * (i + 1), v[22])) # rows 1-7 - - # Extract sign bit of workgroup_x (always 0 for valid workgroups, used for masking) - k.emit(s_bfe_i32(s[S_LOOP_BOUND], s[S_WORKGROUP_X], 0x10018)) + k.emit(v_add_nc_u32_e32(v[9], s[S_LOOP_CTR], v[22])) # row 0 base k.emit(v_and_b32_e32(v[9], ADDR_MASK, v[9])) - k.emit(s_lshr_b32(s[S_LOOP_BOUND], s[S_LOOP_BOUND], 25)) - - # Compute masked row offsets for bank conflict avoidance pattern - # Pattern: v[row] = row_val - (row_val & ADDR_MASK) extracts lower bits - k.emit(v_add_nc_u32_e32(v[19], s[S_LOOP_CTR], v[10])) - k.emit(v_add_nc_u32_e32(v[8], s[S_LOOP_BOUND], v[1])) # A-tile base computation - for d, r in zip([20, 21, 32, 33, 34, 35], [11, 14, 15, 16, 17, 18]): - k.emit(v_add_nc_u32_e32(v[d], s[S_LOOP_CTR], v[r])) - k.emit(v_and_b32_e32(v[8], ADDR_MASK, v[8])) k.emit(v_sub_nc_u32_e32(v[9], v[22], v[9])) # row 0 swizzle offset - for d, s_ in zip([19, 20, 21, 22, 32, 33, 34], [20, 21, 22, 32, 33, 34, 35]): - k.emit(v_and_b32_e32(v[d], ADDR_MASK, v[s_])) - k.emit(v_sub_nc_u32_e32(v[8], v[1], v[8])) # A-tile swizzle - - # Apply swizzle offsets and scale to byte offsets - k.emit(v_lshlrev_b32_e32(v[9], 2, v[9])) # row 0 offset * 4 - for r, t in zip([10, 11, 14, 15, 16, 17, 18], [19, 20, 21, 22, 32, 33, 34]): - k.emit(v_sub_nc_u32_e32(v[r], v[r], v[t])) # rows 1-7 swizzle - k.emit(v_bfe_u32(v[2], v[0], 3, 2)) # v[2] = (lane_id >> 3) & 3 - k.emit(v_lshlrev_b32_e32(v[8], 2, v[8])) # A-tile base * 4 - - # Compute B-tile base address: LDS_A_STRIDE * (lane_id % 8) + row0_offset + k.emit(v_lshlrev_b32_e32(v[9], 2, v[9])) # * 4 k.emit(v_mad_u32_u24(v[V_LDS_B_ADDR], LDS_A_STRIDE, v[V_LANE_ID_MOD8], v[9])) - # Scale row offsets 1-7 to byte offsets (row 0 already in v[9]) - for d, r in zip([9, 10, 11, 14, 15, 16, 17], [10, 11, 14, 15, 16, 17, 18]): - k.emit(v_lshlrev_b32_e32(v[d], 2, v[r])) + + # For V_LDS_A_BASE and epilogue + k.emit(v_bfe_u32(v[2], v[V_LANE_ID], 3, 2)) # v[2] = (lane_id >> 3) & 3 k.emit(v_lshlrev_b32_e32(v[V_LANE_DIV8_X4], 2, v[2])) - k.emit(v_add_nc_u32_e32(v[8], 0x80, v[8])) # A-tile initial store base + 128 - # Store initial tile data to LDS - k.waitcnt(vm=0) - for i, (d0, d1) in enumerate([(0,1), (2,3), (4,5), (11,12)]): - k.emit(ds_store_2addr_stride64_b32(addr=v[8], data0=v[INIT_TILE_LOADS[d0][0]], data1=v[INIT_TILE_LOADS[d1][0]], offset0=16+i*4, offset1=18+i*4)) - # B stores: single base with offsets 0,64,128,192,256,320,384,448 - for i, idx in enumerate([6,7,8,9,10,13,14,15]): - offset = i * 64 - k.emit(ds_store_b32(addr=v[V_LDS_B_ADDR], data0=v[INIT_TILE_LOADS[idx][0]], offset0=offset & 0xFF, offset1=offset >> 8)) - - # =========================================================================== - # INIT: Compute LDS base addresses, then zero accumulators - # =========================================================================== - # v[3] = v[1] & 0x7F (lower 7 bits) since S_LOOP_BOUND=0 for valid workgroups + # Compute LDS load/store base addresses for inner loop k.emit(v_lshlrev_b32_e32(v[2], 4, v[2])) - k.emit(v_add_nc_u32_e32(v[3], s[S_LOOP_BOUND], v[1])) - k.emit(v_and_b32_e32(v[3], ADDR_MASK, v[3])) - k.emit(v_sub_nc_u32_e32(v[3], v[1], v[3])) + k.emit(v_and_b32_e32(v[3], 0x7F, v[1])) # simplified from 3 lines k.emit(v_lshl_or_b32(v[V_LDS_B_BASE], v[V_LANE_ID_MOD8], 4, LDS_BASE_OFFSET)) k.emit(v_lshl_add_u32(v[V_LDS_A_ADDR], v[3], 2, LDS_BASE_OFFSET)) - k.emit(v_lshlrev_b32_e32(v[3], 2, v[0])) + k.emit(v_lshlrev_b32_e32(v[3], 2, v[V_LANE_ID])) k.emit(v_and_or_b32(v[V_LDS_A_BASE], 0x180, v[3], v[2])) + # Do initial stores + k.waitcnt(vm=0) + for i in range(4): # A tile: 8 values via 4 stride64 stores + k.emit(ds_store_2addr_stride64_b32(addr=v[V_LDS_A_ADDR], data0=v[V_LDS_A_DATA[i*2]], data1=v[V_LDS_A_DATA[i*2+1]], offset0=i*4, offset1=i*4+2)) + for i in range(8): # B tile: 8 values via 8 scalar stores with 64-byte spacing + offset = i * 64 + k.emit(ds_store_b32(addr=v[V_LDS_B_ADDR], data0=v[V_LDS_B_DATA[i]], offset0=offset & 0xFF, offset1=offset >> 8)) + # Zero all 128 accumulators using VOPD dual moves (64 instructions instead of 128) for i in range(0, len(OUT_REGS), 2): k.emit(VOPD(VOPDOp.V_DUAL_MOV_B32, VOPDOp.V_DUAL_MOV_B32, vdstx=v[OUT_REGS[i]], vdsty=v[OUT_REGS[i+1]], srcx0=0, srcy0=0)) - k.emit(s_add_i32(s[S_LOOP_BOUND], s[S_DIM_N], -8)) - k.emit(s_add_u32(s[S_A_PTR[0]], s[S_A_PTR[0]], 32)) - k.emit(s_addc_u32(s[S_A_PTR[1]], s[S_A_PTR[1]], 0)) + # S_LOOP_CTR is already 0 from prologue initialization k.emit(s_branch(simm16=0)); k.branch_to('LOOP_ENTRY') @@ -443,7 +338,7 @@ def build_kernel(arch='gfx1100'): #k.emit(v_add_nc_u32_e32(v[V_GLOBAL_B_ADDR], 0x20000, v[V_GLOBAL_B_ADDR])) #k.emit(v_add_nc_u32_e32(v[V_GLOBAL_A_ADDR], 0x20, v[V_GLOBAL_A_ADDR])) - # Advance prefetch pointers (SGPRs, 64-bit adds) + # Advance prefetch pointers (64-bit adds) k.emit(s_clause(simm16=31)) for i in range(8): k.emit(s_add_u32(s[S_PREFETCH_B+i*2], s[S_PREFETCH_B+i*2], 0x20000)) @@ -490,7 +385,7 @@ def build_kernel(arch='gfx1100'): k.emit(VOPD(VOPDOp.V_DUAL_FMAC_F32, VOPDOp.V_DUAL_FMAC_F32, vdstx=v[vdst_x], vdsty=v[vdst_y], srcx0=v[ax], vsrcx1=v[bx], srcy0=v[ay], vsrcy1=v[by])) - # wait for all global stores to finish + # wait for all global loads to finish # then sync the warp so it's safe to store local k.waitcnt(vm=0) k.emit(s_barrier()) @@ -518,54 +413,47 @@ def build_kernel(arch='gfx1100'): for a, b in PERMUTE_SWAPS: k.emit(v_swap_b32_e32(v[a], v[b])) - # Compute output coordinates: v[V_LANE_ID_MOD8] = col, v[V_OUTPUT_ROW] = row - k.emit(VOPD(VOPDOp.V_DUAL_MOV_B32, VOPDOp.V_DUAL_MOV_B32, - vdstx=v[149], vdsty=v[150], srcx0=v[V_LANE_MOD8_X4], vsrcx1=v[0], srcy0=v[V_LANE_DIV8_X4], vsrcy1=v[0])) - k.emit(v_and_b32_e32(v[0], 0x60, v[0])) - k.emit(v_or_b32_e32(v[V_LANE_ID_MOD8], s[S_TILE_X], v[149])) - k.emit(v_add_nc_u32_e32(v[0], s[S_TILE_Y], v[0])) - k.emit(v_or_b32_e32(v[V_OUTPUT_ROW], v[0], v[150])) + # Compute output base coordinates + # v[130] = col_base = tile_x + (lane_id & 7) * 4 + # v[131] = row_base = tile_y + (lane_id & 0x60) + ((lane_id >> 3) & 3) * 4 + # v[132] = 0 (for 64-bit address high part) + k.emit(v_add_nc_u32_e32(v[130], s[S_TILE_X], v[V_LANE_MOD8_X4])) + k.emit(v_and_b32_e32(v[131], 0x60, v[V_LANE_ID])) + k.emit(v_add_nc_u32_e32(v[131], s[S_TILE_Y], v[131])) + k.emit(v_add_nc_u32_e32(v[131], v[V_LANE_DIV8_X4], v[131])) + k.emit(v_mov_b32_e32(v[132], 0)) - # Precompute row offsets: v[144-147] for rows 0-3, v[148-151] for rows 16-19 - for base, row_off in [(144, 0), (148, 16)]: - if row_off: k.emit(v_or_b32_e32(v[1], row_off, v[V_OUTPUT_ROW])) - k.emit(v_mul_lo_u32(v[base], v[1] if row_off else v[V_OUTPUT_ROW], s[S_DIM_N])) - for i in range(3): k.emit(v_add_nc_u32_e32(v[base + 1 + i], s[S_DIM_N], v[base + i])) + # Precompute row offsets: v[133-136] for rows 0-3, v[137-140] for rows 16-19 + for base, row_off in [(133, 0), (137, 16)]: + if row_off: k.emit(v_add_nc_u32_e32(v[141], row_off, v[131])) + k.emit(v_mul_lo_u32(v[base], v[141] if row_off else v[131], s[S_DIM_N])) + for j in range(3): k.emit(v_add_nc_u32_e32(v[base + 1 + j], s[S_DIM_N], v[base + j])) - k.emit(v_mov_b32_e32(v[V_ADDR_HI_ZERO], 0)) - k.emit(s_lshl_b32(s[S_PREFETCH_FLAG], s[S_DIM_N], 2)) # row stride in bytes + # s[S_PREFETCH_FLAG] = row stride in bytes (N * 4) + k.emit(s_lshl_b32(s[S_PREFETCH_FLAG], s[S_DIM_N], 2)) # Store 128 output values as 32 groups of 4 (128-bit stores) # Layout: 2 row halves (0-3, 16-19) x 4 col groups x 4 rows = 32 stores of 4 floats - epilogue_reserved = {V_LANE_ID_MOD8, V_OUTPUT_ROW, V_LANE_MOD8_X4, V_LANE_DIV8_X4, V_ADDR_HI_ZERO} - for i, (row_half, col_off, row_in_group) in enumerate([(rh, co, ri) for rh in range(2) for co in [0, 32, 64, 96] for ri in range(4)]): row = row_half * 16 + row_in_group - srcs = OUT_REGS[i*4:(i+1)*4] + src = OUT_REGS[i*4] # first reg of ascending group of 4 - # Find temp register for scaled values (must not conflict with reserved regs) - tmp = max(srcs) + 5 - while any(r in epilogue_reserved for r in range(tmp, tmp + 4)): tmp += 1 + if row_in_group == 0: + # First row of group: compute full address + if col_off == 0: k.emit(v_mov_b32_e32(v[141], v[130])) + else: k.emit(v_add_nc_u32_e32(v[141], col_off, v[130])) + row_base = 133 + row if row < 4 else 137 + row - 16 + k.emit(v_add_nc_u32_e32(v[141], v[row_base], v[141])) + k.emit(v_lshlrev_b32_e32(v[141], 2, v[141])) + k.emit(v_add_co_u32(v[141], VCC_LO, s[S_OUT_PTR[0]], v[141])) + k.emit(v_add_co_ci_u32_e32(v[142], s[S_OUT_PTR[1]], v[132])) + else: + # Subsequent rows: add stride + k.emit(v_add_co_u32(v[141], VCC_LO, s[S_PREFETCH_FLAG], v[141])) + k.emit(v_add_co_ci_u32_e32(v[142], v[142], v[132])) - # Copy values to temp regs for output (alpha=1.0 hardcoded, so just move) - for j, src in enumerate(srcs): - k.emit(v_mov_b32_e32(v[tmp + j], v[src])) - - # Compute output address - if row_in_group == 0: # first row: compute base address for this column group - if col_off == 0: k.emit(v_mov_b32_e32(v[0], v[V_LANE_ID_MOD8])) - else: k.emit(v_add_nc_u32_e32(v[0], col_off, v[V_LANE_ID_MOD8])) - row_base = 144 + row if row < 4 else 148 + row - 16 - k.emit(v_add_nc_u32_e32(v[0], v[row_base], v[0])) - k.emit(v_lshlrev_b32_e32(v[0], 2, v[0])) - k.emit(v_add_co_u32(v[0], VCC_LO, s[S_OUT_PTR[0]], v[0])) - k.emit(v_add_co_ci_u32_e32(v[1], s[S_OUT_PTR[1]], v[V_ADDR_HI_ZERO])) - else: # subsequent rows: just add stride - k.emit(v_add_co_u32(v[0], VCC_LO, s[S_PREFETCH_FLAG], v[0])) - k.emit(v_add_co_ci_u32_e32(v[1], v[1], v[V_ADDR_HI_ZERO])) - - k.emit(global_store_b128(addr=v[0:1], data=v[tmp:tmp+3], saddr=NULL)) + k.emit(global_store_b128(addr=v[141:142], data=v[src:src+3], saddr=NULL)) k.emit(s_sendmsg(simm16=3)) # DEALLOC_VGPRS k.emit(s_endpgm())