mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
amd asm can still be simpler (#14199)
* amd asm can still be simpler * simpler * V_LANE_ID * simpler * simpler * compact vgpr
This commit is contained in:
@@ -6,7 +6,7 @@
|
||||
# Workgroup: 128 threads (arranged as 32x4 for coalesced memory access)
|
||||
# Inner loop: 8 iterations per K-block, processing 8 columns of A and 8 rows of B
|
||||
#
|
||||
# Accumulators: 128 vgprs (v[2-117], v[120-124], v[126-129], v[131-133])
|
||||
# Accumulators: 128 vgprs (v[2-129])
|
||||
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
@@ -27,23 +27,20 @@ LDS_B_STRIDE = 0x200 # LDS stride for B tile (512 bytes)
|
||||
LDS_BASE_OFFSET = 0x1080 # Base LDS offset for tiles
|
||||
ADDR_MASK = 0x3fffff80 # Address alignment mask
|
||||
|
||||
# s_waitcnt encodings: wait for memory operations to complete
|
||||
WAIT_LGKM = 64519 # wait for LDS/GDS/KMEM (lgkm_cnt=0)
|
||||
WAIT_ALL = 0 # wait for everything
|
||||
WAIT_VMEM = 1015 # wait for VMEM only (vm_cnt=0, lgkm_cnt=63)
|
||||
|
||||
# =============================================================================
|
||||
# Named register assignments (VGPRs) - COMPACT LAYOUT
|
||||
# Named register assignments (VGPRs)
|
||||
# =============================================================================
|
||||
V_LANE_ID_MOD8 = 182 # lane_id & 7 (column within 8-wide tile chunk)
|
||||
V_OUTPUT_ROW = 171 # output row coordinate
|
||||
V_LANE_MOD8_X4 = 174 # V_LANE_ID_MOD8 << 2 (byte offset)
|
||||
V_LANE_DIV8_X4 = 175 # (lane_id >> 3) << 2
|
||||
V_ADDR_HI_ZERO = 188 # always 0 (for 64-bit address high bits)
|
||||
V_LDS_A_BASE = 186 # LDS A-tile base address for inner loop
|
||||
V_LDS_B_BASE = 170 # LDS B-tile base address for inner loop
|
||||
V_GLOBAL_A_ADDR = 171 # global memory A prefetch address (reuses V_OUTPUT_ROW slot during main loop)
|
||||
V_GLOBAL_B_ADDR = 178 # global memory B prefetch address
|
||||
V_LANE_ID = 0 # lane_id set on startup
|
||||
# Use tile gaps (v146-159) for named regs to minimize max VGPR
|
||||
V_LANE_ID_MOD8 = 146 # lane_id & 7
|
||||
V_LANE_MOD8_X4 = 147 # (lane_id & 7) << 2
|
||||
V_LANE_DIV8_X4 = 150 # ((lane_id >> 3) & 3) << 2
|
||||
V_LDS_B_BASE = 151 # LDS B-tile base address for inner loop
|
||||
V_LDS_A_BASE = 154 # LDS A-tile base address for inner loop
|
||||
V_GLOBAL_A_ADDR = 155 # global memory A prefetch address
|
||||
V_GLOBAL_B_ADDR = 158 # global memory B prefetch address
|
||||
V_LDS_A_ADDR = 159 # single base register for A stores
|
||||
V_LDS_B_ADDR = 162 # single base register for B stores
|
||||
|
||||
# LDS tile register destinations - SEPARATE from DATA to avoid overlap
|
||||
# A on banks 2-3, B on banks 0-1 to avoid bank conflicts in VOPD
|
||||
@@ -58,14 +55,11 @@ S_TILE_X = 2 # workgroup_x << 7
|
||||
S_TILE_Y = 3 # workgroup_y << 7
|
||||
S_DIM_N = 4 # matrix dimension N
|
||||
S_LOOP_BOUND = 7 # K-8 (loop termination bound)
|
||||
S_A_PTR = (8, 9) # A matrix base pointer
|
||||
S_B_PTR = (10, 11) # B matrix base pointer
|
||||
S_LOOP_CTR = 12 # loop counter (increments by 8)
|
||||
S_PREFETCH_FLAG = 13 # prefetch condition flag / row stride in epilogue
|
||||
S_WORKGROUP_X = 14 # workgroup_id_x
|
||||
S_WORKGROUP_Y = 15 # workgroup_id_y
|
||||
# Kernarg load destinations (before copy to working regs)
|
||||
S_KERNARG_OUT = (16, 17) # output pointer from kernarg
|
||||
# Kernarg load destinations
|
||||
S_KERNARG_A = (20, 21) # A pointer from kernarg
|
||||
S_KERNARG_B = (22, 23) # B pointer from kernarg
|
||||
# Prefetch base pointers (8 pairs each, 16KB/256KB apart)
|
||||
@@ -100,8 +94,6 @@ FMAC_PAIR_ORDER = [
|
||||
|
||||
def derive_fmac_pattern(acc_grid, a_tile_regs=None, b_tile_regs=None):
|
||||
"""Generate 64 dual FMAC ops from accumulator grid with optimized iteration order."""
|
||||
if a_tile_regs is None: a_tile_regs = V_A_TILE_REGS
|
||||
if b_tile_regs is None: b_tile_regs = V_B_TILE_REGS
|
||||
pattern = []
|
||||
for idx, (a_pair, b_pair) in enumerate(FMAC_PAIR_ORDER):
|
||||
a_even, a_odd = a_pair * 2, a_pair * 2 + 1
|
||||
@@ -120,14 +112,14 @@ def derive_fmac_pattern(acc_grid, a_tile_regs=None, b_tile_regs=None):
|
||||
return pattern
|
||||
|
||||
# Derived: 64 dual FMAC operations
|
||||
FMAC_PATTERN = derive_fmac_pattern(ACC_GRID)
|
||||
FMAC_PATTERN = derive_fmac_pattern(ACC_GRID, V_A_TILE_REGS, V_B_TILE_REGS)
|
||||
|
||||
def derive_permute_swaps(acc_grid, out_regs):
|
||||
"""Derive swap sequence to permute accumulators from FMAC layout to output order.
|
||||
|
||||
After FMAC loop: acc_grid[a][b] holds C[a,b]
|
||||
Output order: for row_half in 0,1; col_group in 0-3; row_in_group in 0-3; b_off in 0-3
|
||||
-> need C[row_half*4 + row_in_group, col_group*4 + b_off] in descending reg order
|
||||
-> need C[row_half*4 + row_in_group, col_group*4 + b_off] in specified reg order
|
||||
"""
|
||||
def target_ab(i):
|
||||
row_half, col_group = i // 64, (i // 16) % 4
|
||||
@@ -150,34 +142,27 @@ def derive_permute_swaps(acc_grid, out_regs):
|
||||
return swaps
|
||||
|
||||
# Derived: swap sequence to arrange accumulators for output
|
||||
OUT_REGS = list(range(129, 1, -1))
|
||||
# Each group of 4 registers is ascending for direct global_store_b128
|
||||
OUT_REGS = [r for i in range(32) for r in range(126 - i*4, 130 - i*4)]
|
||||
PERMUTE_SWAPS = derive_permute_swaps(ACC_GRID, OUT_REGS)
|
||||
|
||||
# =============================================================================
|
||||
# LDS tile staging registers - COMPACT LAYOUT
|
||||
# LDS tile staging registers
|
||||
# =============================================================================
|
||||
# DATA regs receive contiguous global prefetch, then write to LDS
|
||||
# TILE regs receive scattered LDS loads (ds_load_b64 pairs), then feed FMACs
|
||||
# These are SEPARATE - DATA lives during prefetch/store, TILE lives during inner loop
|
||||
V_LDS_A_ADDR = 189 # single base register for A stores (use +512 offsets)
|
||||
V_LDS_A_DATA = [155, 172, 173, 154, 159, 176, 177, 158] # 8 data registers for A prefetch (mod 4: 3,0,1,2,3,0,1,2)
|
||||
V_LDS_B_ADDR = 190 # single base register for B stores (use 16-bit offsets)
|
||||
V_LDS_B_DATA = [163, 180, 181, 162, 167, 184, 185, 166] # 8 data registers for B prefetch (mod 4: 3,0,1,2,3,0,1,2)
|
||||
# Contiguous layout with mod4=[3,0,1,2,3,0,1,2] for bank conflict avoidance
|
||||
V_LDS_A_DATA = [163, 164, 165, 166, 167, 168, 169, 170]
|
||||
V_LDS_B_DATA = [171, 172, 173, 174, 175, 176, 177, 178]
|
||||
|
||||
# Initial tile prefetch: (vdst, saddr_lo) - load into A data regs using B prefetch pointers (s[24:31])
|
||||
INIT_PREFETCH = [(V_LDS_A_DATA[i], S_PREFETCH_B+2*i) for i in range(4)]
|
||||
|
||||
# Global memory prefetch schedule: (vdst1, vdst2, addr_vreg, saddr_lo1, saddr_lo2)
|
||||
# First 2 pairs from B prefetch pointers (s[32:39]), next 4 pairs from A prefetch pointers (s[40:55])
|
||||
PREFETCH_LOADS = [(V_LDS_A_DATA[4+2*i], V_LDS_A_DATA[4+2*i+1], V_GLOBAL_B_ADDR, S_PREFETCH_B+8+4*i, S_PREFETCH_B+10+4*i) for i in range(2)] + \
|
||||
[(V_LDS_B_DATA[2*(i-2)], V_LDS_B_DATA[2*(i-2)+1], V_GLOBAL_A_ADDR, S_PREFETCH_A+4*(i-2), S_PREFETCH_A+2+4*(i-2)) for i in range(2, 6)]
|
||||
|
||||
# Initial tile prefetch: (vdst, saddr_lo) - load into A data regs using B prefetch pointers (s[24:31])
|
||||
INIT_PREFETCH = [(V_LDS_A_DATA[i], S_PREFETCH_B+2*i) for i in range(4)]
|
||||
|
||||
# Initial tile loads: (vdst, addr_lo) pairs - use temp regs in accumulator gaps
|
||||
INIT_TILE_LOADS = [(23,5),(24,9),(25,7),(26,2),(27,11),(28,13),(29,6),(30,8),(31,10),(12,12),(13,14),(3,2),(4,4),(5,8),(6,6),(7,10)]
|
||||
|
||||
# A matrix row offset registers (scattered to avoid accumulator conflicts)
|
||||
ROW_REGS = [165, 146, 147, 164, 169, 150, 151, 168] # mod 4: 1,2,3,0,1,2,3,0
|
||||
|
||||
# =============================================================================
|
||||
# Kernel class
|
||||
# =============================================================================
|
||||
@@ -218,7 +203,7 @@ class Kernel:
|
||||
('user_sgpr_kernarg_segment_ptr', 1), ('user_sgpr_dispatch_id', 0), ('user_sgpr_private_segment_size', 0),
|
||||
('wavefront_size32', 1), ('uses_dynamic_stack', 0), ('enable_private_segment', 0),
|
||||
('system_sgpr_workgroup_id_x', 1), ('system_sgpr_workgroup_id_y', 1), ('system_sgpr_workgroup_id_z', 0),
|
||||
('system_sgpr_workgroup_info', 0), ('system_vgpr_workitem_id', 0), ('next_free_vgpr', 192),
|
||||
('system_sgpr_workgroup_info', 0), ('system_vgpr_workitem_id', 0), ('next_free_vgpr', 179),
|
||||
('next_free_sgpr', 16), ('float_round_mode_32', 0), ('float_round_mode_16_64', 0),
|
||||
('float_denorm_mode_32', 3), ('float_denorm_mode_16_64', 3), ('dx10_clamp', 1), ('ieee_mode', 1),
|
||||
('fp16_overflow', 0), ('workgroup_processor_mode', 0), ('memory_ordered', 1), ('forward_progress', 0),
|
||||
@@ -236,7 +221,7 @@ class Kernel:
|
||||
f' .group_segment_fixed_size: {lds_size}', ' .kernarg_segment_align: 8',
|
||||
' .kernarg_segment_size: 24', ' .max_flat_workgroup_size: 128', ' .name: kernel',
|
||||
' .private_segment_fixed_size: 0', ' .sgpr_count: 60', ' .symbol: kernel.kd',
|
||||
' .vgpr_count: 192', ' .wavefront_size: 32', f'amdhsa.target: amdgcn-amd-amdhsa--{self.arch}',
|
||||
' .vgpr_count: 179', ' .wavefront_size: 32', f'amdhsa.target: amdgcn-amd-amdhsa--{self.arch}',
|
||||
'amdhsa.version:', ' - 1', ' - 2', '...', '\t.end_amdgpu_metadata'])
|
||||
|
||||
|
||||
@@ -251,27 +236,20 @@ def build_kernel(arch='gfx1100'):
|
||||
# PROLOGUE: Load kernel arguments, compute tile coordinates and addresses
|
||||
# ===========================================================================
|
||||
k.emit(s_load_b128(sdata=s[S_KERNARG_A[0]:S_KERNARG_B[1]], sbase=s[0:1], offset=0x0, soffset=NULL))
|
||||
k.emit(s_load_b64(sdata=s[S_KERNARG_OUT[0]:S_KERNARG_OUT[1]], sbase=s[0:1], offset=0x10, soffset=NULL))
|
||||
k.emit(s_load_b64(sdata=s[S_OUT_PTR[0]:S_OUT_PTR[1]], sbase=s[0:1], offset=0x10, soffset=NULL))
|
||||
k.emit(s_mov_b32(s[S_DIM_N], MATRIX_DIM))
|
||||
k.emit(s_mov_b32(s[S_LOOP_CTR], 0)) # used by LDS swizzle, always 0 for valid workgroups
|
||||
k.emit(s_lshl_b32(s[S_TILE_X], s[S_WORKGROUP_X], 7))
|
||||
k.emit(s_lshl_b32(s[S_TILE_Y], s[S_WORKGROUP_Y], 7))
|
||||
|
||||
# Lane-derived values
|
||||
k.emit(v_and_b32_e32(v[V_LANE_ID_MOD8], 7, v[0]))
|
||||
k.emit(v_lshrrev_b32_e32(v[4], 3, v[0]))
|
||||
k.emit(v_or_b32_e32(v[1], s[S_TILE_X], v[0]))
|
||||
k.emit(v_and_b32_e32(v[V_LANE_ID_MOD8], 7, v[V_LANE_ID]))
|
||||
k.emit(v_lshrrev_b32_e32(v[4], 3, v[V_LANE_ID]))
|
||||
k.emit(v_or_b32_e32(v[1], s[S_TILE_X], v[V_LANE_ID]))
|
||||
k.emit(v_or_b32_e32(v[22], s[S_TILE_Y], v[4]))
|
||||
k.emit(v_lshlrev_b32_e32(v[V_LANE_MOD8_X4], 2, v[V_LANE_ID_MOD8]))
|
||||
k.emit(v_mov_b32_e32(v[2], 0)) # v[1] always positive, sign extension is 0
|
||||
k.emit(v_lshlrev_b64(v[5:6], 2, v[1:2]))
|
||||
k.waitcnt(lgkm=0)
|
||||
|
||||
# Copy pointers to working registers
|
||||
k.emit(s_mov_b64(s[S_OUT_PTR[0]:S_OUT_PTR[1]], s[S_KERNARG_OUT[0]:S_KERNARG_OUT[1]]))
|
||||
k.emit(s_mov_b64(s[S_A_PTR[0]:S_A_PTR[1]], s[S_KERNARG_A[0]:S_KERNARG_A[1]]))
|
||||
k.emit(s_mov_b64(s[S_B_PTR[0]:S_B_PTR[1]], s[S_KERNARG_B[0]:S_KERNARG_B[1]]))
|
||||
|
||||
# Compute 8 A and B matrix tile base pointers for prefetch
|
||||
k.emit(s_mov_b64(s[S_PREFETCH_B:S_PREFETCH_B+1], s[S_KERNARG_B[0]:S_KERNARG_B[1]])) # B[0]: no offset
|
||||
for i in range(1, 8): # B: 16KB apart
|
||||
@@ -283,142 +261,59 @@ def build_kernel(arch='gfx1100'):
|
||||
k.emit(s_addc_u32(s[S_PREFETCH_A+i*2+1], s[S_KERNARG_A[1]], 0))
|
||||
|
||||
# Global prefetch addresses: B = (tile_x + lane_id) * 4, A = ((tile_y << 12) + (lane_id/8)*4K + lane_id%8) * 4
|
||||
k.emit(v_add_nc_u32_e32(v[V_GLOBAL_B_ADDR], s[S_TILE_X], v[0]))
|
||||
k.emit(v_add_nc_u32_e32(v[V_GLOBAL_B_ADDR], s[S_TILE_X], v[V_LANE_ID]))
|
||||
k.emit(v_lshlrev_b32_e32(v[V_GLOBAL_B_ADDR], 2, v[V_GLOBAL_B_ADDR]))
|
||||
k.emit(s_lshl_b32(s[19], s[S_TILE_Y], 12))
|
||||
k.emit(v_lshl_add_u32(v[V_GLOBAL_A_ADDR], v[4], 12, v[V_LANE_ID_MOD8])) # (lane_id/8)*4K + lane_id%8
|
||||
k.emit(v_add_nc_u32_e32(v[V_GLOBAL_A_ADDR], s[19], v[V_GLOBAL_A_ADDR]))
|
||||
k.emit(v_lshlrev_b32_e32(v[V_GLOBAL_A_ADDR], 2, v[V_GLOBAL_A_ADDR]))
|
||||
|
||||
# ===========================================================================
|
||||
# Tile address computation for initial A/B matrix loads
|
||||
# ===========================================================================
|
||||
k.emit(s_lshl_b32(s[S_LOOP_BOUND], s[S_DIM_N], 4)) # row stride = 16*N
|
||||
k.emit(v_mul_lo_u32(v[ROW_REGS[0]], v[22], s[S_DIM_N])) # A matrix row offsets
|
||||
for i in range(1, 8): k.emit(v_add_nc_u32_e32(v[ROW_REGS[i]], s[S_LOOP_BOUND], v[ROW_REGS[i-1]]))
|
||||
|
||||
def addr64(dst, base_s): # 64-bit address: v[dst:dst+1] = s[base_s:base_s+1] + v[dst]*4
|
||||
k.emit(v_mov_b32_e32(v[dst+1], 0)) # offset always positive, sign ext = 0
|
||||
k.emit(v_lshlrev_b64(v[dst:dst+1], 2, v[dst:dst+1]))
|
||||
k.emit(v_add_co_u32(v[dst], VCC_LO, s[base_s], v[dst]))
|
||||
k.emit(v_add_co_ci_u32_e32(v[dst+1], s[base_s+1], v[dst+1]))
|
||||
|
||||
def b_addr(dst, mult, tmp=None): # B address for col + mult*N
|
||||
tmp = tmp if tmp is not None else dst
|
||||
k.emit(v_mad_u32_u24(v[tmp], s[S_DIM_N], mult, v[1]))
|
||||
if tmp != dst:
|
||||
k.emit(v_mov_b32_e32(v[tmp+1], 0)) # offset always positive
|
||||
k.emit(v_lshlrev_b64(v[dst:dst+1], 2, v[tmp:tmp+1]))
|
||||
k.emit(v_add_co_u32(v[dst], VCC_LO, s[S_B_PTR[0]], v[dst]))
|
||||
k.emit(v_add_co_ci_u32_e32(v[dst+1], s[S_B_PTR[1]], v[dst+1]))
|
||||
else: addr64(dst, S_B_PTR[0])
|
||||
|
||||
def a_addr(dst, row_reg, tmp): # A address for row_reg + lane_id_mod8
|
||||
k.emit(v_add_nc_u32_e32(v[tmp], v[row_reg], v[V_LANE_ID_MOD8]))
|
||||
k.emit(v_mov_b32_e32(v[tmp+1], 0)) # offset always positive
|
||||
k.emit(v_lshlrev_b64(v[dst:dst+1], 2, v[tmp:tmp+1]))
|
||||
k.emit(v_add_co_u32(v[dst], VCC_LO, s[S_A_PTR[0]], v[dst]))
|
||||
k.emit(v_add_co_ci_u32_e32(v[dst+1], s[S_A_PTR[1]], v[dst+1]))
|
||||
|
||||
# Batch 1: B addresses (cols 0-5) and loads
|
||||
k.emit(v_add_co_u32(v[5], VCC_LO, s[S_B_PTR[0]], v[5]))
|
||||
k.emit(v_add_co_ci_u32_e32(v[6], s[S_B_PTR[1]], v[6]))
|
||||
for dst, mult in [(9,1), (7,2), (2,3), (11,4), (13,5)]: b_addr(dst, mult)
|
||||
k.emit(s_clause(simm16=5)) # 6 consecutive global loads
|
||||
for vdst, addr in INIT_TILE_LOADS[:6]: k.emit(global_load_b32(vdst=v[vdst], addr=v[addr:addr+1], saddr=NULL))
|
||||
|
||||
# Batch 2: A addresses (rows 0-4) and loads
|
||||
for dst, ri in [(6,0), (8,1), (10,2), (12,3), (14,4)]:
|
||||
k.emit(v_add_nc_u32_e32(v[dst], v[ROW_REGS[ri]], v[V_LANE_ID_MOD8]))
|
||||
addr64(dst, S_A_PTR[0])
|
||||
k.emit(s_clause(simm16=4)) # 5 consecutive global loads
|
||||
for vdst, addr in INIT_TILE_LOADS[6:11]: k.emit(global_load_b32(vdst=v[vdst], addr=v[addr:addr+1], saddr=NULL))
|
||||
|
||||
# Batch 3: B cols 6-7, A rows 5-7, and loads
|
||||
for dst, mult, tmp in [(2,6,15), (4,7,4)]: b_addr(dst, mult, tmp)
|
||||
for dst, ri, tmp in [(8,5,16), (6,6,18), (10,7,20)]: a_addr(dst, ROW_REGS[ri], tmp)
|
||||
k.emit(s_clause(simm16=4)) # 5 consecutive global loads
|
||||
for vdst, addr in INIT_TILE_LOADS[11:]: k.emit(global_load_b32(vdst=v[vdst], addr=v[addr:addr+1], saddr=NULL))
|
||||
# Do initial loads
|
||||
for vdst, saddr_lo in INIT_PREFETCH:
|
||||
k.emit(global_load_b32(vdst=v[vdst], addr=v[V_GLOBAL_B_ADDR], saddr=s[saddr_lo:saddr_lo+1]))
|
||||
for iter in range(6):
|
||||
vdst1, vdst2, addr, slo1, slo2 = PREFETCH_LOADS[iter]
|
||||
k.emit(global_load_b32(vdst=v[vdst1], addr=v[addr], saddr=s[slo1:slo1+1]))
|
||||
k.emit(global_load_b32(vdst=v[vdst2], addr=v[addr], saddr=s[slo2:slo2+1]))
|
||||
|
||||
# ===========================================================================
|
||||
# LDS store address computation (bank-conflict-avoiding swizzle)
|
||||
# ===========================================================================
|
||||
# This section computes LDS store addresses with a swizzle pattern to avoid bank conflicts.
|
||||
# Key outputs:
|
||||
# v[8]: A-tile initial store base (used only for initial stores with stride64)
|
||||
# V_LDS_B_ADDR (v145): B-tile store base (used for both initial and main loop)
|
||||
# V_LANE_DIV8_X4 (v135): (lane_id >> 3) << 2 for epilogue
|
||||
#
|
||||
# The swizzle ensures that threads in the same wavefront write to different LDS banks.
|
||||
# Formula: swizzled_addr = base + (lane_id & 7) * LDS_A_STRIDE + swizzle_offset
|
||||
# where swizzle_offset depends on (lane_id >> 3) to distribute across banks.
|
||||
|
||||
# v[22] = tile_y | (lane_id >> 3) from prologue, used as base for row offsets
|
||||
# Compute 7 row offsets for B-tile rows 1-7 (row 0 computed separately in v[9])
|
||||
k.emit(v_add_nc_u32_e32(v[9], s[S_LOOP_CTR], v[22])) # row 0 base (S_LOOP_CTR=0)
|
||||
for i in range(7): k.emit(v_or_b32_e32(v[10 + i if i < 2 else 12 + i], 16 * (i + 1), v[22])) # rows 1-7
|
||||
|
||||
# Extract sign bit of workgroup_x (always 0 for valid workgroups, used for masking)
|
||||
k.emit(s_bfe_i32(s[S_LOOP_BOUND], s[S_WORKGROUP_X], 0x10018))
|
||||
k.emit(v_add_nc_u32_e32(v[9], s[S_LOOP_CTR], v[22])) # row 0 base
|
||||
k.emit(v_and_b32_e32(v[9], ADDR_MASK, v[9]))
|
||||
k.emit(s_lshr_b32(s[S_LOOP_BOUND], s[S_LOOP_BOUND], 25))
|
||||
|
||||
# Compute masked row offsets for bank conflict avoidance pattern
|
||||
# Pattern: v[row] = row_val - (row_val & ADDR_MASK) extracts lower bits
|
||||
k.emit(v_add_nc_u32_e32(v[19], s[S_LOOP_CTR], v[10]))
|
||||
k.emit(v_add_nc_u32_e32(v[8], s[S_LOOP_BOUND], v[1])) # A-tile base computation
|
||||
for d, r in zip([20, 21, 32, 33, 34, 35], [11, 14, 15, 16, 17, 18]):
|
||||
k.emit(v_add_nc_u32_e32(v[d], s[S_LOOP_CTR], v[r]))
|
||||
k.emit(v_and_b32_e32(v[8], ADDR_MASK, v[8]))
|
||||
k.emit(v_sub_nc_u32_e32(v[9], v[22], v[9])) # row 0 swizzle offset
|
||||
for d, s_ in zip([19, 20, 21, 22, 32, 33, 34], [20, 21, 22, 32, 33, 34, 35]):
|
||||
k.emit(v_and_b32_e32(v[d], ADDR_MASK, v[s_]))
|
||||
k.emit(v_sub_nc_u32_e32(v[8], v[1], v[8])) # A-tile swizzle
|
||||
|
||||
# Apply swizzle offsets and scale to byte offsets
|
||||
k.emit(v_lshlrev_b32_e32(v[9], 2, v[9])) # row 0 offset * 4
|
||||
for r, t in zip([10, 11, 14, 15, 16, 17, 18], [19, 20, 21, 22, 32, 33, 34]):
|
||||
k.emit(v_sub_nc_u32_e32(v[r], v[r], v[t])) # rows 1-7 swizzle
|
||||
k.emit(v_bfe_u32(v[2], v[0], 3, 2)) # v[2] = (lane_id >> 3) & 3
|
||||
k.emit(v_lshlrev_b32_e32(v[8], 2, v[8])) # A-tile base * 4
|
||||
|
||||
# Compute B-tile base address: LDS_A_STRIDE * (lane_id % 8) + row0_offset
|
||||
k.emit(v_lshlrev_b32_e32(v[9], 2, v[9])) # * 4
|
||||
k.emit(v_mad_u32_u24(v[V_LDS_B_ADDR], LDS_A_STRIDE, v[V_LANE_ID_MOD8], v[9]))
|
||||
# Scale row offsets 1-7 to byte offsets (row 0 already in v[9])
|
||||
for d, r in zip([9, 10, 11, 14, 15, 16, 17], [10, 11, 14, 15, 16, 17, 18]):
|
||||
k.emit(v_lshlrev_b32_e32(v[d], 2, v[r]))
|
||||
|
||||
# For V_LDS_A_BASE and epilogue
|
||||
k.emit(v_bfe_u32(v[2], v[V_LANE_ID], 3, 2)) # v[2] = (lane_id >> 3) & 3
|
||||
k.emit(v_lshlrev_b32_e32(v[V_LANE_DIV8_X4], 2, v[2]))
|
||||
k.emit(v_add_nc_u32_e32(v[8], 0x80, v[8])) # A-tile initial store base + 128
|
||||
|
||||
# Store initial tile data to LDS
|
||||
k.waitcnt(vm=0)
|
||||
for i, (d0, d1) in enumerate([(0,1), (2,3), (4,5), (11,12)]):
|
||||
k.emit(ds_store_2addr_stride64_b32(addr=v[8], data0=v[INIT_TILE_LOADS[d0][0]], data1=v[INIT_TILE_LOADS[d1][0]], offset0=16+i*4, offset1=18+i*4))
|
||||
# B stores: single base with offsets 0,64,128,192,256,320,384,448
|
||||
for i, idx in enumerate([6,7,8,9,10,13,14,15]):
|
||||
offset = i * 64
|
||||
k.emit(ds_store_b32(addr=v[V_LDS_B_ADDR], data0=v[INIT_TILE_LOADS[idx][0]], offset0=offset & 0xFF, offset1=offset >> 8))
|
||||
|
||||
# ===========================================================================
|
||||
# INIT: Compute LDS base addresses, then zero accumulators
|
||||
# ===========================================================================
|
||||
# v[3] = v[1] & 0x7F (lower 7 bits) since S_LOOP_BOUND=0 for valid workgroups
|
||||
# Compute LDS load/store base addresses for inner loop
|
||||
k.emit(v_lshlrev_b32_e32(v[2], 4, v[2]))
|
||||
k.emit(v_add_nc_u32_e32(v[3], s[S_LOOP_BOUND], v[1]))
|
||||
k.emit(v_and_b32_e32(v[3], ADDR_MASK, v[3]))
|
||||
k.emit(v_sub_nc_u32_e32(v[3], v[1], v[3]))
|
||||
k.emit(v_and_b32_e32(v[3], 0x7F, v[1])) # simplified from 3 lines
|
||||
k.emit(v_lshl_or_b32(v[V_LDS_B_BASE], v[V_LANE_ID_MOD8], 4, LDS_BASE_OFFSET))
|
||||
k.emit(v_lshl_add_u32(v[V_LDS_A_ADDR], v[3], 2, LDS_BASE_OFFSET))
|
||||
k.emit(v_lshlrev_b32_e32(v[3], 2, v[0]))
|
||||
k.emit(v_lshlrev_b32_e32(v[3], 2, v[V_LANE_ID]))
|
||||
k.emit(v_and_or_b32(v[V_LDS_A_BASE], 0x180, v[3], v[2]))
|
||||
|
||||
# Do initial stores
|
||||
k.waitcnt(vm=0)
|
||||
for i in range(4): # A tile: 8 values via 4 stride64 stores
|
||||
k.emit(ds_store_2addr_stride64_b32(addr=v[V_LDS_A_ADDR], data0=v[V_LDS_A_DATA[i*2]], data1=v[V_LDS_A_DATA[i*2+1]], offset0=i*4, offset1=i*4+2))
|
||||
for i in range(8): # B tile: 8 values via 8 scalar stores with 64-byte spacing
|
||||
offset = i * 64
|
||||
k.emit(ds_store_b32(addr=v[V_LDS_B_ADDR], data0=v[V_LDS_B_DATA[i]], offset0=offset & 0xFF, offset1=offset >> 8))
|
||||
|
||||
# Zero all 128 accumulators using VOPD dual moves (64 instructions instead of 128)
|
||||
for i in range(0, len(OUT_REGS), 2):
|
||||
k.emit(VOPD(VOPDOp.V_DUAL_MOV_B32, VOPDOp.V_DUAL_MOV_B32, vdstx=v[OUT_REGS[i]], vdsty=v[OUT_REGS[i+1]], srcx0=0, srcy0=0))
|
||||
|
||||
k.emit(s_add_i32(s[S_LOOP_BOUND], s[S_DIM_N], -8))
|
||||
k.emit(s_add_u32(s[S_A_PTR[0]], s[S_A_PTR[0]], 32))
|
||||
k.emit(s_addc_u32(s[S_A_PTR[1]], s[S_A_PTR[1]], 0))
|
||||
|
||||
# S_LOOP_CTR is already 0 from prologue initialization
|
||||
k.emit(s_branch(simm16=0)); k.branch_to('LOOP_ENTRY')
|
||||
|
||||
@@ -443,7 +338,7 @@ def build_kernel(arch='gfx1100'):
|
||||
#k.emit(v_add_nc_u32_e32(v[V_GLOBAL_B_ADDR], 0x20000, v[V_GLOBAL_B_ADDR]))
|
||||
#k.emit(v_add_nc_u32_e32(v[V_GLOBAL_A_ADDR], 0x20, v[V_GLOBAL_A_ADDR]))
|
||||
|
||||
# Advance prefetch pointers (SGPRs, 64-bit adds)
|
||||
# Advance prefetch pointers (64-bit adds)
|
||||
k.emit(s_clause(simm16=31))
|
||||
for i in range(8):
|
||||
k.emit(s_add_u32(s[S_PREFETCH_B+i*2], s[S_PREFETCH_B+i*2], 0x20000))
|
||||
@@ -490,7 +385,7 @@ def build_kernel(arch='gfx1100'):
|
||||
k.emit(VOPD(VOPDOp.V_DUAL_FMAC_F32, VOPDOp.V_DUAL_FMAC_F32,
|
||||
vdstx=v[vdst_x], vdsty=v[vdst_y], srcx0=v[ax], vsrcx1=v[bx], srcy0=v[ay], vsrcy1=v[by]))
|
||||
|
||||
# wait for all global stores to finish
|
||||
# wait for all global loads to finish
|
||||
# then sync the warp so it's safe to store local
|
||||
k.waitcnt(vm=0)
|
||||
k.emit(s_barrier())
|
||||
@@ -518,54 +413,47 @@ def build_kernel(arch='gfx1100'):
|
||||
for a, b in PERMUTE_SWAPS:
|
||||
k.emit(v_swap_b32_e32(v[a], v[b]))
|
||||
|
||||
# Compute output coordinates: v[V_LANE_ID_MOD8] = col, v[V_OUTPUT_ROW] = row
|
||||
k.emit(VOPD(VOPDOp.V_DUAL_MOV_B32, VOPDOp.V_DUAL_MOV_B32,
|
||||
vdstx=v[149], vdsty=v[150], srcx0=v[V_LANE_MOD8_X4], vsrcx1=v[0], srcy0=v[V_LANE_DIV8_X4], vsrcy1=v[0]))
|
||||
k.emit(v_and_b32_e32(v[0], 0x60, v[0]))
|
||||
k.emit(v_or_b32_e32(v[V_LANE_ID_MOD8], s[S_TILE_X], v[149]))
|
||||
k.emit(v_add_nc_u32_e32(v[0], s[S_TILE_Y], v[0]))
|
||||
k.emit(v_or_b32_e32(v[V_OUTPUT_ROW], v[0], v[150]))
|
||||
# Compute output base coordinates
|
||||
# v[130] = col_base = tile_x + (lane_id & 7) * 4
|
||||
# v[131] = row_base = tile_y + (lane_id & 0x60) + ((lane_id >> 3) & 3) * 4
|
||||
# v[132] = 0 (for 64-bit address high part)
|
||||
k.emit(v_add_nc_u32_e32(v[130], s[S_TILE_X], v[V_LANE_MOD8_X4]))
|
||||
k.emit(v_and_b32_e32(v[131], 0x60, v[V_LANE_ID]))
|
||||
k.emit(v_add_nc_u32_e32(v[131], s[S_TILE_Y], v[131]))
|
||||
k.emit(v_add_nc_u32_e32(v[131], v[V_LANE_DIV8_X4], v[131]))
|
||||
k.emit(v_mov_b32_e32(v[132], 0))
|
||||
|
||||
# Precompute row offsets: v[144-147] for rows 0-3, v[148-151] for rows 16-19
|
||||
for base, row_off in [(144, 0), (148, 16)]:
|
||||
if row_off: k.emit(v_or_b32_e32(v[1], row_off, v[V_OUTPUT_ROW]))
|
||||
k.emit(v_mul_lo_u32(v[base], v[1] if row_off else v[V_OUTPUT_ROW], s[S_DIM_N]))
|
||||
for i in range(3): k.emit(v_add_nc_u32_e32(v[base + 1 + i], s[S_DIM_N], v[base + i]))
|
||||
# Precompute row offsets: v[133-136] for rows 0-3, v[137-140] for rows 16-19
|
||||
for base, row_off in [(133, 0), (137, 16)]:
|
||||
if row_off: k.emit(v_add_nc_u32_e32(v[141], row_off, v[131]))
|
||||
k.emit(v_mul_lo_u32(v[base], v[141] if row_off else v[131], s[S_DIM_N]))
|
||||
for j in range(3): k.emit(v_add_nc_u32_e32(v[base + 1 + j], s[S_DIM_N], v[base + j]))
|
||||
|
||||
k.emit(v_mov_b32_e32(v[V_ADDR_HI_ZERO], 0))
|
||||
k.emit(s_lshl_b32(s[S_PREFETCH_FLAG], s[S_DIM_N], 2)) # row stride in bytes
|
||||
# s[S_PREFETCH_FLAG] = row stride in bytes (N * 4)
|
||||
k.emit(s_lshl_b32(s[S_PREFETCH_FLAG], s[S_DIM_N], 2))
|
||||
|
||||
# Store 128 output values as 32 groups of 4 (128-bit stores)
|
||||
# Layout: 2 row halves (0-3, 16-19) x 4 col groups x 4 rows = 32 stores of 4 floats
|
||||
epilogue_reserved = {V_LANE_ID_MOD8, V_OUTPUT_ROW, V_LANE_MOD8_X4, V_LANE_DIV8_X4, V_ADDR_HI_ZERO}
|
||||
|
||||
for i, (row_half, col_off, row_in_group) in enumerate([(rh, co, ri)
|
||||
for rh in range(2) for co in [0, 32, 64, 96] for ri in range(4)]):
|
||||
row = row_half * 16 + row_in_group
|
||||
srcs = OUT_REGS[i*4:(i+1)*4]
|
||||
src = OUT_REGS[i*4] # first reg of ascending group of 4
|
||||
|
||||
# Find temp register for scaled values (must not conflict with reserved regs)
|
||||
tmp = max(srcs) + 5
|
||||
while any(r in epilogue_reserved for r in range(tmp, tmp + 4)): tmp += 1
|
||||
if row_in_group == 0:
|
||||
# First row of group: compute full address
|
||||
if col_off == 0: k.emit(v_mov_b32_e32(v[141], v[130]))
|
||||
else: k.emit(v_add_nc_u32_e32(v[141], col_off, v[130]))
|
||||
row_base = 133 + row if row < 4 else 137 + row - 16
|
||||
k.emit(v_add_nc_u32_e32(v[141], v[row_base], v[141]))
|
||||
k.emit(v_lshlrev_b32_e32(v[141], 2, v[141]))
|
||||
k.emit(v_add_co_u32(v[141], VCC_LO, s[S_OUT_PTR[0]], v[141]))
|
||||
k.emit(v_add_co_ci_u32_e32(v[142], s[S_OUT_PTR[1]], v[132]))
|
||||
else:
|
||||
# Subsequent rows: add stride
|
||||
k.emit(v_add_co_u32(v[141], VCC_LO, s[S_PREFETCH_FLAG], v[141]))
|
||||
k.emit(v_add_co_ci_u32_e32(v[142], v[142], v[132]))
|
||||
|
||||
# Copy values to temp regs for output (alpha=1.0 hardcoded, so just move)
|
||||
for j, src in enumerate(srcs):
|
||||
k.emit(v_mov_b32_e32(v[tmp + j], v[src]))
|
||||
|
||||
# Compute output address
|
||||
if row_in_group == 0: # first row: compute base address for this column group
|
||||
if col_off == 0: k.emit(v_mov_b32_e32(v[0], v[V_LANE_ID_MOD8]))
|
||||
else: k.emit(v_add_nc_u32_e32(v[0], col_off, v[V_LANE_ID_MOD8]))
|
||||
row_base = 144 + row if row < 4 else 148 + row - 16
|
||||
k.emit(v_add_nc_u32_e32(v[0], v[row_base], v[0]))
|
||||
k.emit(v_lshlrev_b32_e32(v[0], 2, v[0]))
|
||||
k.emit(v_add_co_u32(v[0], VCC_LO, s[S_OUT_PTR[0]], v[0]))
|
||||
k.emit(v_add_co_ci_u32_e32(v[1], s[S_OUT_PTR[1]], v[V_ADDR_HI_ZERO]))
|
||||
else: # subsequent rows: just add stride
|
||||
k.emit(v_add_co_u32(v[0], VCC_LO, s[S_PREFETCH_FLAG], v[0]))
|
||||
k.emit(v_add_co_ci_u32_e32(v[1], v[1], v[V_ADDR_HI_ZERO]))
|
||||
|
||||
k.emit(global_store_b128(addr=v[0:1], data=v[tmp:tmp+3], saddr=NULL))
|
||||
k.emit(global_store_b128(addr=v[141:142], data=v[src:src+3], saddr=NULL))
|
||||
|
||||
k.emit(s_sendmsg(simm16=3)) # DEALLOC_VGPRS
|
||||
k.emit(s_endpgm())
|
||||
|
||||
Reference in New Issue
Block a user