mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
assembly/amd: amd_asm_matmul (#13989)
* amd_asm_matmul * dsl transform * asm roundtrip * fixed * less * better * more * simpler * simplify * lil * simpler * compact * work * cleanups * simplify * simpler * cleanup * name the regs * simp * big simp * big simp * simp * acc grid * fast * stuff * fast * simpler * owrks * save vgprs * save vgprs * Compact * less VGPRs * after * SQTT support * fastest * faster * lil faster * tile regs * faster * readable * one more * simpler * lil simpler * NO_GLOBAL skips early globals * stock kernel * cleanups * cleanups * one b reg * safe reg changes * acc is compact now * remove confusing stuff * sregs * lds cleanups * vopd
This commit is contained in:
655
extra/gemm/amd_asm_matmul.py
Normal file
655
extra/gemm/amd_asm_matmul.py
Normal file
@@ -0,0 +1,655 @@
|
||||
# RDNA3 128x128 tiled GEMM kernel - DSL version
|
||||
# Computes C = A @ B for 4096x4096 float32 matrices using 128x128 tiles
|
||||
#
|
||||
# Architecture: RDNA3 (gfx1100)
|
||||
# Tile size: 128x128 (each workgroup computes one tile of C)
|
||||
# Workgroup: 128 threads (arranged as 32x4 for coalesced memory access)
|
||||
# Inner loop: 8 iterations per K-block, processing 8 columns of A and 8 rows of B
|
||||
#
|
||||
# Accumulators: 128 vgprs (v[2-117], v[120-124], v[126-129], v[131-133])
|
||||
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from tinygrad import Tensor, Device, Context, GlobalCounters
|
||||
from tinygrad.helpers import getenv, colored
|
||||
from tinygrad.engine.realize import Runner, Estimates, ExecItem
|
||||
from extra.assembly.amd.dsl import s, v, VCC_LO, RawImm, EXEC_LO
|
||||
from extra.assembly.amd.autogen.rdna3.ins import *
|
||||
|
||||
# =============================================================================
|
||||
# Kernel constants
|
||||
# =============================================================================
|
||||
LDS_SIZE = 8320 # Local data share size in bytes
|
||||
MATRIX_DIM = 4096 # Matrix dimension N (assumes square NxN matrices)
|
||||
LDS_A_STRIDE = 0x210 # LDS stride for A tile (528 bytes)
|
||||
LDS_B_STRIDE = 0x200 # LDS stride for B tile (512 bytes)
|
||||
LDS_BASE_OFFSET = 0x1080 # Base LDS offset for tiles
|
||||
ADDR_MASK = 0x3fffff80 # Address alignment mask
|
||||
|
||||
# s_waitcnt encodings: wait for memory operations to complete
|
||||
WAIT_LGKM = 64519 # wait for LDS/GDS/KMEM (lgkm_cnt=0)
|
||||
WAIT_ALL = 0 # wait for everything
|
||||
WAIT_VMEM = 1015 # wait for VMEM only (vm_cnt=0, lgkm_cnt=63)
|
||||
|
||||
# =============================================================================
|
||||
# Named register assignments (VGPRs) - COMPACT LAYOUT
|
||||
# =============================================================================
|
||||
V_LANE_ID_MOD8 = 214 # lane_id & 7 (column within 8-wide tile chunk)
|
||||
V_OUTPUT_ROW = 131 # output row coordinate
|
||||
V_LANE_MOD8_X4 = 134 # V_LANE_ID_MOD8 << 2 (byte offset)
|
||||
V_LANE_DIV8_X4 = 135 # (lane_id >> 3) << 2
|
||||
V_ADDR_HI_ZERO = 136 # always 0 (for 64-bit address high bits)
|
||||
V_LDS_A_BASE = 133 # LDS A-tile base address for inner loop (in ACC_RESERVED gap)
|
||||
V_LDS_B_BASE = 130 # LDS B-tile base address for inner loop (in ACC_RESERVED gap)
|
||||
V_GLOBAL_A_ADDR = 131 # global memory A prefetch address (reuses V_OUTPUT_ROW slot during main loop)
|
||||
V_GLOBAL_B_ADDR = 154 # global memory B prefetch address
|
||||
|
||||
# LDS tile register destinations - SEPARATE from DATA to avoid overlap
|
||||
# DATA regs (v155-170) receive global prefetch
|
||||
# A on banks 2-3, B on banks 0-1 to avoid bank conflicts in VOPD
|
||||
# This layout matches kernel8's optimization for VGPR cache utilization
|
||||
V_A_TILE_REGS = [186, 190, 194, 198] # A tile: banks 2,2,2,2 (186%4=2, 190%4=2, etc.)
|
||||
V_B_TILE_REGS = [184, 188, 192, 196, 200, 204, 208, 212] # B tile: banks 0,0,0,0,0,0,0,0
|
||||
|
||||
# =============================================================================
|
||||
# Named register assignments (SGPRs)
|
||||
# =============================================================================
|
||||
S_OUT_PTR = (0, 1) # output C matrix base pointer
|
||||
S_TILE_X = 2 # workgroup_x << 7
|
||||
S_TILE_Y = 3 # workgroup_y << 7
|
||||
S_DIM_N = 4 # matrix dimension N
|
||||
S_LOOP_BOUND = 7 # K-8 (loop termination bound)
|
||||
S_A_PTR = (8, 9) # A matrix base pointer
|
||||
S_B_PTR = (10, 11) # B matrix base pointer
|
||||
S_LOOP_CTR = 12 # loop counter (increments by 8)
|
||||
S_PREFETCH_FLAG = 13 # prefetch condition flag / row stride in epilogue
|
||||
S_WORKGROUP_X = 14 # workgroup_id_x
|
||||
S_WORKGROUP_Y = 15 # workgroup_id_y
|
||||
# Kernarg load destinations (before copy to working regs)
|
||||
S_KERNARG_OUT = (16, 17) # output pointer from kernarg
|
||||
S_KERNARG_A = (20, 21) # A pointer from kernarg
|
||||
S_KERNARG_B = (22, 23) # B pointer from kernarg
|
||||
# Prefetch base pointers (8 pairs each, 16KB/256KB apart)
|
||||
S_PREFETCH_B = 24 # s[24:39] - 8 B tile pointers
|
||||
S_PREFETCH_A = 40 # s[40:55] - 8 A tile pointers
|
||||
|
||||
# =============================================================================
|
||||
# Data tables
|
||||
# =============================================================================
|
||||
|
||||
# Accumulator grid: ACC_GRID[a_idx][b_idx] = vgpr for C[a,b]
|
||||
# a_idx: which A value (0-7), b_idx: which B value (0-15)
|
||||
# Scattered due to VOPD bank constraints (vdst_x % 4 != vdst_y % 4)
|
||||
# Range is from v2 - v129
|
||||
ACC_GRID = [
|
||||
[ 5, 3, 9, 8, 37, 35, 41, 40, 69, 67, 73, 72, 101, 99,105,104], # a0
|
||||
[ 4, 2, 7, 6, 36, 34, 39, 38, 68, 66, 71, 70, 100, 98,103,102], # a1
|
||||
[ 17, 16, 13, 11, 49, 48, 45, 43, 81, 80, 77, 75, 113,112,109,107], # a2
|
||||
[ 15, 14, 12, 10, 47, 46, 44, 42, 79, 78, 76, 74, 111,110,108,106], # a3
|
||||
[ 21, 19, 25, 24, 53, 51, 57, 56, 85, 83, 89, 88, 117,115,121,120], # a4
|
||||
[ 20, 18, 23, 22, 52, 50, 55, 54, 84, 82, 87, 86, 116,114,123,122], # a5
|
||||
[125,128, 29, 27, 33, 32, 61, 59, 65, 64, 93, 91, 97, 96,129,127], # a6
|
||||
[119,118, 28, 26, 31, 30, 60, 58, 63, 62, 92, 90, 95, 94,124,126], # a7
|
||||
]
|
||||
|
||||
# Optimized (a_pair, b_pair) iteration order for better GPU scheduling
|
||||
# Interleaves A and B pairs to maximize instruction-level parallelism
|
||||
FMAC_PAIR_ORDER = [
|
||||
(0,0),(0,1),(1,1),(1,0), (2,0),(2,1),(3,1),(3,2), (0,2),(0,3),(1,3),(1,2), (2,2),(2,3),(3,3),(3,4),
|
||||
(0,4),(0,5),(1,5),(1,4), (2,4),(2,5),(3,5),(3,6), (0,6),(0,7),(1,7),(1,6), (2,6),(2,7),(3,7),(3,0),
|
||||
]
|
||||
|
||||
def derive_fmac_pattern(acc_grid, a_tile_regs=None, b_tile_regs=None):
|
||||
"""Generate 64 dual FMAC ops from accumulator grid with optimized iteration order."""
|
||||
if a_tile_regs is None: a_tile_regs = V_A_TILE_REGS
|
||||
if b_tile_regs is None: b_tile_regs = V_B_TILE_REGS
|
||||
pattern = []
|
||||
for idx, (a_pair, b_pair) in enumerate(FMAC_PAIR_ORDER):
|
||||
a_even, a_odd = a_pair * 2, a_pair * 2 + 1
|
||||
b_even, b_odd = b_pair * 2, b_pair * 2 + 1
|
||||
a_base, b_base = a_tile_regs[a_pair], b_tile_regs[b_pair]
|
||||
# Op 1: normal order -> C[a_even, b_even] + C[a_odd, b_odd]
|
||||
pattern.append((acc_grid[a_even][b_even], acc_grid[a_odd][b_odd],
|
||||
a_base, b_base, a_base+1, b_base+1))
|
||||
# Op 2: alternate swapping A vs B to vary register banks
|
||||
if idx % 2 == 0: # swap B
|
||||
pattern.append((acc_grid[a_even][b_odd], acc_grid[a_odd][b_even],
|
||||
a_base, b_base+1, a_base+1, b_base))
|
||||
else: # swap A
|
||||
pattern.append((acc_grid[a_odd][b_even], acc_grid[a_even][b_odd],
|
||||
a_base+1, b_base, a_base, b_base+1))
|
||||
return pattern
|
||||
|
||||
# Derived: 64 dual FMAC operations
|
||||
FMAC_PATTERN = derive_fmac_pattern(ACC_GRID)
|
||||
|
||||
def derive_permute_swaps(acc_grid, out_regs):
|
||||
"""Derive swap sequence to permute accumulators from FMAC layout to output order.
|
||||
|
||||
After FMAC loop: acc_grid[a][b] holds C[a,b]
|
||||
Output order: for row_half in 0,1; col_group in 0-3; row_in_group in 0-3; b_off in 0-3
|
||||
-> need C[row_half*4 + row_in_group, col_group*4 + b_off] in descending reg order
|
||||
"""
|
||||
def target_ab(i):
|
||||
row_half, col_group = i // 64, (i // 16) % 4
|
||||
row_in_group, b_off = (i // 4) % 4, i % 4
|
||||
return (row_half * 4 + row_in_group, col_group * 4 + b_off)
|
||||
|
||||
reg_contents = {acc_grid[a][b]: (a, b) for a in range(8) for b in range(16)}
|
||||
ab_location = {ab: r for r, ab in reg_contents.items()}
|
||||
|
||||
swaps = []
|
||||
for i in range(128):
|
||||
target_reg, needed_ab = out_regs[i], target_ab(i)
|
||||
current_reg = ab_location[needed_ab]
|
||||
if current_reg != target_reg:
|
||||
swaps.append((current_reg, target_reg))
|
||||
ab_at_target = reg_contents.get(target_reg)
|
||||
reg_contents[target_reg], ab_location[needed_ab] = needed_ab, target_reg
|
||||
if ab_at_target is not None:
|
||||
reg_contents[current_reg], ab_location[ab_at_target] = ab_at_target, current_reg
|
||||
return swaps
|
||||
|
||||
# Derived: swap sequence to arrange accumulators for output
|
||||
OUT_REGS = list(range(129, 1, -1))
|
||||
PERMUTE_SWAPS = derive_permute_swaps(ACC_GRID, OUT_REGS)
|
||||
|
||||
# =============================================================================
|
||||
# LDS tile staging registers - COMPACT LAYOUT
|
||||
# =============================================================================
|
||||
# DATA regs receive contiguous global prefetch, then write to LDS
|
||||
# TILE regs receive scattered LDS loads (ds_load_b64 pairs), then feed FMACs
|
||||
# These are SEPARATE - DATA lives during prefetch/store, TILE lives during inner loop
|
||||
V_LDS_A_ADDR = 153 # single base register for A stores (use +512 offsets)
|
||||
V_LDS_A_DATA = list(range(155, 163)) # 8 data registers for A prefetch (v155-162)
|
||||
V_LDS_B_ADDR = 145 # single base register for B stores (use 16-bit offsets)
|
||||
V_LDS_B_DATA = list(range(163, 171)) # 8 data registers for B prefetch (v163-170)
|
||||
|
||||
# Global memory prefetch schedule: (vdst1, vdst2, addr_vreg, saddr_lo1, saddr_lo2)
|
||||
# First 2 pairs from B prefetch pointers (s[32:39]), next 4 pairs from A prefetch pointers (s[40:55])
|
||||
PREFETCH_LOADS = [(V_LDS_A_DATA[4+2*i], V_LDS_A_DATA[4+2*i+1], V_GLOBAL_B_ADDR, S_PREFETCH_B+8+4*i, S_PREFETCH_B+10+4*i) for i in range(2)] + \
|
||||
[(V_LDS_B_DATA[2*(i-2)], V_LDS_B_DATA[2*(i-2)+1], V_GLOBAL_A_ADDR, S_PREFETCH_A+4*(i-2), S_PREFETCH_A+2+4*(i-2)) for i in range(2, 6)]
|
||||
|
||||
# Initial tile prefetch: (vdst, saddr_lo) - load into A data regs using B prefetch pointers (s[24:31])
|
||||
INIT_PREFETCH = [(V_LDS_A_DATA[i], S_PREFETCH_B+2*i) for i in range(4)]
|
||||
|
||||
# Initial tile loads: (vdst, addr_lo) pairs - use temp regs in accumulator gaps
|
||||
INIT_TILE_LOADS = [(23,5),(24,9),(25,7),(26,2),(27,11),(28,13),(29,6),(30,8),(31,10),(12,12),(13,14),(3,2),(4,4),(5,8),(6,6),(7,10)]
|
||||
|
||||
# A matrix row offset registers (scattered to avoid accumulator conflicts)
|
||||
ROW_REGS = list(range(137, 145)) # v137-v144 (8 regs)
|
||||
|
||||
# =============================================================================
|
||||
# Kernel class
|
||||
# =============================================================================
|
||||
|
||||
class Kernel:
|
||||
def __init__(self, arch='gfx1100'):
|
||||
self.instructions, self.labels, self.branch_targets, self.arch = [], {}, {}, arch
|
||||
|
||||
def emit(self, inst): self.instructions.append(inst); return inst
|
||||
def label(self, name): self.labels[name] = len(self.instructions)
|
||||
def branch_to(self, label): self.branch_targets[len(self.instructions) - 1] = label
|
||||
|
||||
def add64(self, dst_lo, dst_hi, src_lo, src_hi, off):
|
||||
"""s[dst_lo:dst_hi] = s[src_lo:src_hi] + off"""
|
||||
if off: self.emit(s_add_u32(s[dst_lo], s[src_lo], off)); self.emit(s_addc_u32(s[dst_hi], s[src_hi], 0))
|
||||
elif dst_lo != src_lo: self.emit(s_mov_b64(s[dst_lo:dst_hi], s[src_lo:src_hi]))
|
||||
|
||||
def global_load(self, vdst, addr, saddr=None):
|
||||
"""Global load b32"""
|
||||
self.emit(global_load_b32(vdst=v[vdst], addr=v[addr:addr+1],
|
||||
saddr=s[saddr:saddr+2] if saddr else RawImm(124)))
|
||||
|
||||
def waitcnt(self, lgkm=None, vm=None):
|
||||
"""Wait for memory operations. lgkm=N waits until N lgkm ops remain, vm=N waits until N vmem ops remain."""
|
||||
from extra.assembly.amd.asm import waitcnt as encode_waitcnt
|
||||
if lgkm == 0 and vm is None: self.emit(s_waitcnt(simm16=WAIT_LGKM))
|
||||
elif vm == 0 and lgkm is None: self.emit(s_waitcnt(simm16=WAIT_VMEM))
|
||||
elif lgkm == 0 and vm == 0: self.emit(s_waitcnt(simm16=WAIT_ALL))
|
||||
elif vm is not None and lgkm is None:
|
||||
self.emit(s_waitcnt(simm16=encode_waitcnt(vmcnt=vm, expcnt=7, lgkmcnt=63)))
|
||||
elif lgkm is not None and vm is None:
|
||||
self.emit(s_waitcnt(simm16=encode_waitcnt(vmcnt=63, expcnt=7, lgkmcnt=lgkm)))
|
||||
else: raise ValueError(f"unsupported waitcnt: lgkm={lgkm}, vm={vm}")
|
||||
|
||||
def barrier(self): self.emit(s_barrier())
|
||||
|
||||
def to_asm(self):
|
||||
import re
|
||||
# Instruction stream with labels
|
||||
label_at = {pos: name for name, pos in self.labels.items()}
|
||||
body = []
|
||||
for i, inst in enumerate(self.instructions):
|
||||
if i in label_at: body.append(f'.{label_at[i]}:')
|
||||
asm = inst.disasm()
|
||||
if i in self.branch_targets:
|
||||
asm = re.sub(r'(s_cbranch_\w+|s_branch)\s+\S+', rf'\1 .{self.branch_targets[i]}', asm)
|
||||
body.append('\t' + asm)
|
||||
|
||||
# HSA kernel descriptor attributes (zeros included for compatibility)
|
||||
hsa = [
|
||||
('group_segment_fixed_size', LDS_SIZE), ('private_segment_fixed_size', 0), ('kernarg_size', 36),
|
||||
('user_sgpr_count', 14), ('user_sgpr_dispatch_ptr', 0), ('user_sgpr_queue_ptr', 0),
|
||||
('user_sgpr_kernarg_segment_ptr', 1), ('user_sgpr_dispatch_id', 0), ('user_sgpr_private_segment_size', 0),
|
||||
('wavefront_size32', 1), ('uses_dynamic_stack', 0), ('enable_private_segment', 0),
|
||||
('system_sgpr_workgroup_id_x', 1), ('system_sgpr_workgroup_id_y', 1), ('system_sgpr_workgroup_id_z', 0),
|
||||
('system_sgpr_workgroup_info', 0), ('system_vgpr_workitem_id', 0), ('next_free_vgpr', 214),
|
||||
('next_free_sgpr', 16), ('float_round_mode_32', 0), ('float_round_mode_16_64', 0),
|
||||
('float_denorm_mode_32', 3), ('float_denorm_mode_16_64', 3), ('dx10_clamp', 1), ('ieee_mode', 1),
|
||||
('fp16_overflow', 0), ('workgroup_processor_mode', 0), ('memory_ordered', 1), ('forward_progress', 0),
|
||||
('shared_vgpr_count', 0)]
|
||||
|
||||
return '\n'.join([
|
||||
'\t.text', f'\t.amdgcn_target "amdgcn-amd-amdhsa--{self.arch}"',
|
||||
'\t.protected\tkernel', '\t.globl\tkernel', '\t.p2align\t8', '\t.type\tkernel,@function', 'kernel:',
|
||||
*body,
|
||||
'\t.section\t.rodata,"a",@progbits', '\t.p2align\t6, 0x0', '\t.amdhsa_kernel kernel',
|
||||
*[f'\t\t.amdhsa_{k} {v}' for k, v in hsa],
|
||||
'\t.end_amdhsa_kernel', '\t.text', '.Lfunc_end0:', '\t.size\tkernel, .Lfunc_end0-kernel',
|
||||
'\t.amdgpu_metadata', '---', 'amdhsa.kernels:', ' - .args:',
|
||||
*[f' - .address_space: global\n .offset: {i*8}\n .size: 8\n .value_kind: global_buffer' for i in range(3)],
|
||||
f' .group_segment_fixed_size: {LDS_SIZE}', ' .kernarg_segment_align: 8',
|
||||
' .kernarg_segment_size: 24', ' .max_flat_workgroup_size: 128', ' .name: kernel',
|
||||
' .private_segment_fixed_size: 0', ' .sgpr_count: 60', ' .symbol: kernel.kd',
|
||||
' .vgpr_count: 214', ' .wavefront_size: 32', f'amdhsa.target: amdgcn-amd-amdhsa--{self.arch}',
|
||||
'amdhsa.version:', ' - 1', ' - 2', '...', '\t.end_amdgpu_metadata'])
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Kernel builder
|
||||
# =============================================================================
|
||||
|
||||
def build_kernel(arch='gfx1100'):
|
||||
k = Kernel(arch)
|
||||
|
||||
# ===========================================================================
|
||||
# PROLOGUE: Load kernel arguments, compute tile coordinates and addresses
|
||||
# ===========================================================================
|
||||
k.emit(s_load_b128(sdata=s[S_KERNARG_A[0]:S_KERNARG_B[1]], sbase=s[0:1], offset=0x0, soffset=RawImm(124)))
|
||||
k.emit(s_load_b64(sdata=s[S_KERNARG_OUT[0]:S_KERNARG_OUT[1]], sbase=s[0:1], offset=0x10, soffset=RawImm(124)))
|
||||
k.emit(s_mov_b32(s[S_DIM_N], MATRIX_DIM))
|
||||
k.emit(s_mov_b32(s[S_LOOP_CTR], 0)) # used by LDS swizzle, always 0 for valid workgroups
|
||||
k.emit(s_lshl_b32(s[S_TILE_X], s[S_WORKGROUP_X], 7))
|
||||
k.emit(s_lshl_b32(s[S_TILE_Y], s[S_WORKGROUP_Y], 7))
|
||||
|
||||
# Lane-derived values
|
||||
k.emit(v_and_b32_e32(v[V_LANE_ID_MOD8], 7, v[0]))
|
||||
k.emit(v_lshrrev_b32_e32(v[4], 3, v[0]))
|
||||
k.emit(v_or_b32_e32(v[1], s[S_TILE_X], v[0]))
|
||||
k.emit(v_or_b32_e32(v[22], s[S_TILE_Y], v[4]))
|
||||
k.emit(v_lshlrev_b32_e32(v[V_LANE_MOD8_X4], 2, v[V_LANE_ID_MOD8]))
|
||||
k.emit(v_mov_b32_e32(v[2], 0)) # v[1] always positive, sign extension is 0
|
||||
k.emit(v_lshlrev_b64(v[5:6], 2, v[1:2]))
|
||||
k.waitcnt(lgkm=0)
|
||||
|
||||
# Copy pointers to working registers
|
||||
k.emit(s_mov_b64(s[S_OUT_PTR[0]:S_OUT_PTR[1]], s[S_KERNARG_OUT[0]:S_KERNARG_OUT[1]]))
|
||||
k.emit(s_mov_b64(s[S_A_PTR[0]:S_A_PTR[1]], s[S_KERNARG_A[0]:S_KERNARG_A[1]]))
|
||||
k.emit(s_mov_b64(s[S_B_PTR[0]:S_B_PTR[1]], s[S_KERNARG_B[0]:S_KERNARG_B[1]]))
|
||||
|
||||
# Compute 8 A and B matrix tile base pointers for prefetch
|
||||
for i in range(8): k.add64(S_PREFETCH_B + i*2, S_PREFETCH_B + i*2 + 1, S_KERNARG_B[0], S_KERNARG_B[1], i * 0x4000) # B: 16KB apart
|
||||
for i in range(8): k.add64(S_PREFETCH_A + i*2, S_PREFETCH_A + i*2 + 1, S_KERNARG_A[0], S_KERNARG_A[1], i * 0x40000) # A: 256KB apart
|
||||
|
||||
# Global prefetch addresses: B = (tile_x + lane_id) * 4, A = ((tile_y << 12) + (lane_id/8)*4K + lane_id%8) * 4
|
||||
k.emit(v_add_nc_u32_e32(v[V_GLOBAL_B_ADDR], s[S_TILE_X], v[0]))
|
||||
k.emit(v_lshlrev_b32_e32(v[V_GLOBAL_B_ADDR], 2, v[V_GLOBAL_B_ADDR]))
|
||||
k.emit(s_lshl_b32(s[19], s[S_TILE_Y], 12))
|
||||
k.emit(v_lshl_add_u32(v[V_GLOBAL_A_ADDR], v[4], 12, v[V_LANE_ID_MOD8])) # (lane_id/8)*4K + lane_id%8
|
||||
k.emit(v_add_nc_u32_e32(v[V_GLOBAL_A_ADDR], s[19], v[V_GLOBAL_A_ADDR]))
|
||||
k.emit(v_lshlrev_b32_e32(v[V_GLOBAL_A_ADDR], 2, v[V_GLOBAL_A_ADDR]))
|
||||
|
||||
# ===========================================================================
|
||||
# Tile address computation for initial A/B matrix loads
|
||||
# ===========================================================================
|
||||
k.emit(s_lshl_b32(s[S_LOOP_BOUND], s[S_DIM_N], 4)) # row stride = 16*N
|
||||
k.emit(v_mul_lo_u32(v[ROW_REGS[0]], v[22], s[S_DIM_N])) # A matrix row offsets
|
||||
for i in range(1, 8): k.emit(v_add_nc_u32_e32(v[ROW_REGS[i]], s[S_LOOP_BOUND], v[ROW_REGS[i-1]]))
|
||||
|
||||
def addr64(dst, base_s): # 64-bit address: v[dst:dst+1] = s[base_s:base_s+1] + v[dst]*4
|
||||
k.emit(v_mov_b32_e32(v[dst+1], 0)) # offset always positive, sign ext = 0
|
||||
k.emit(v_lshlrev_b64(v[dst:dst+1], 2, v[dst:dst+1]))
|
||||
k.emit(v_add_co_u32(v[dst], VCC_LO, s[base_s], v[dst]))
|
||||
k.emit(v_add_co_ci_u32_e32(v[dst+1], s[base_s+1], v[dst+1]))
|
||||
|
||||
def b_addr(dst, mult, tmp=None): # B address for col + mult*N
|
||||
tmp = tmp if tmp is not None else dst
|
||||
k.emit(v_mad_u32_u24(v[tmp], s[S_DIM_N], mult, v[1]))
|
||||
if tmp != dst:
|
||||
k.emit(v_mov_b32_e32(v[tmp+1], 0)) # offset always positive
|
||||
k.emit(v_lshlrev_b64(v[dst:dst+1], 2, v[tmp:tmp+1]))
|
||||
k.emit(v_add_co_u32(v[dst], VCC_LO, s[S_B_PTR[0]], v[dst]))
|
||||
k.emit(v_add_co_ci_u32_e32(v[dst+1], s[S_B_PTR[1]], v[dst+1]))
|
||||
else: addr64(dst, S_B_PTR[0])
|
||||
|
||||
def a_addr(dst, row_reg, tmp): # A address for row_reg + lane_id_mod8
|
||||
k.emit(v_add_nc_u32_e32(v[tmp], v[row_reg], v[V_LANE_ID_MOD8]))
|
||||
k.emit(v_mov_b32_e32(v[tmp+1], 0)) # offset always positive
|
||||
k.emit(v_lshlrev_b64(v[dst:dst+1], 2, v[tmp:tmp+1]))
|
||||
k.emit(v_add_co_u32(v[dst], VCC_LO, s[S_A_PTR[0]], v[dst]))
|
||||
k.emit(v_add_co_ci_u32_e32(v[dst+1], s[S_A_PTR[1]], v[dst+1]))
|
||||
|
||||
# Batch 1: B addresses (cols 0-5) and loads
|
||||
k.emit(v_add_co_u32(v[5], VCC_LO, s[S_B_PTR[0]], v[5]))
|
||||
k.emit(v_add_co_ci_u32_e32(v[6], s[S_B_PTR[1]], v[6]))
|
||||
for dst, mult in [(9,1), (7,2), (2,3), (11,4), (13,5)]: b_addr(dst, mult)
|
||||
k.emit(s_clause(simm16=5)) # 6 consecutive global loads
|
||||
for vdst, addr in INIT_TILE_LOADS[:6]: k.global_load(vdst, addr)
|
||||
|
||||
# Batch 2: A addresses (rows 0-4) and loads
|
||||
for dst, ri in [(6,0), (8,1), (10,2), (12,3), (14,4)]:
|
||||
k.emit(v_add_nc_u32_e32(v[dst], v[ROW_REGS[ri]], v[V_LANE_ID_MOD8]))
|
||||
addr64(dst, S_A_PTR[0])
|
||||
k.emit(s_clause(simm16=4)) # 5 consecutive global loads
|
||||
for vdst, addr in INIT_TILE_LOADS[6:11]: k.global_load(vdst, addr)
|
||||
|
||||
# Batch 3: B cols 6-7, A rows 5-7, and loads
|
||||
for dst, mult, tmp in [(2,6,15), (4,7,4)]: b_addr(dst, mult, tmp)
|
||||
for dst, ri, tmp in [(8,5,16), (6,6,18), (10,7,20)]: a_addr(dst, ROW_REGS[ri], tmp)
|
||||
k.emit(s_clause(simm16=4)) # 5 consecutive global loads
|
||||
for vdst, addr in INIT_TILE_LOADS[11:]: k.global_load(vdst, addr)
|
||||
|
||||
# ===========================================================================
|
||||
# LDS store address computation (bank-conflict-avoiding swizzle)
|
||||
# ===========================================================================
|
||||
# This section computes LDS store addresses with a swizzle pattern to avoid bank conflicts.
|
||||
# Key outputs:
|
||||
# v[8]: A-tile initial store base (used only for initial stores with stride64)
|
||||
# V_LDS_B_ADDR (v145): B-tile store base (used for both initial and main loop)
|
||||
# V_LANE_DIV8_X4 (v135): (lane_id >> 3) << 2 for epilogue
|
||||
#
|
||||
# The swizzle ensures that threads in the same wavefront write to different LDS banks.
|
||||
# Formula: swizzled_addr = base + (lane_id & 7) * LDS_A_STRIDE + swizzle_offset
|
||||
# where swizzle_offset depends on (lane_id >> 3) to distribute across banks.
|
||||
|
||||
# v[22] = tile_y | (lane_id >> 3) from prologue, used as base for row offsets
|
||||
# Compute 7 row offsets for B-tile rows 1-7 (row 0 computed separately in v[9])
|
||||
k.emit(v_add_nc_u32_e32(v[9], s[S_LOOP_CTR], v[22])) # row 0 base (S_LOOP_CTR=0)
|
||||
for i in range(7): k.emit(v_or_b32_e32(v[10 + i if i < 2 else 12 + i], 16 * (i + 1), v[22])) # rows 1-7
|
||||
|
||||
# Extract sign bit of workgroup_x (always 0 for valid workgroups, used for masking)
|
||||
k.emit(s_bfe_i32(s[S_LOOP_BOUND], s[S_WORKGROUP_X], 0x10018))
|
||||
k.emit(v_and_b32_e32(v[9], ADDR_MASK, v[9]))
|
||||
k.emit(s_lshr_b32(s[S_LOOP_BOUND], s[S_LOOP_BOUND], 25))
|
||||
|
||||
# Compute masked row offsets for bank conflict avoidance pattern
|
||||
# Pattern: v[row] = row_val - (row_val & ADDR_MASK) extracts lower bits
|
||||
k.emit(v_add_nc_u32_e32(v[19], s[S_LOOP_CTR], v[10]))
|
||||
k.emit(v_add_nc_u32_e32(v[8], s[S_LOOP_BOUND], v[1])) # A-tile base computation
|
||||
for d, r in zip([20, 21, 32, 33, 34, 35], [11, 14, 15, 16, 17, 18]):
|
||||
k.emit(v_add_nc_u32_e32(v[d], s[S_LOOP_CTR], v[r]))
|
||||
k.emit(v_and_b32_e32(v[8], ADDR_MASK, v[8]))
|
||||
k.emit(v_sub_nc_u32_e32(v[9], v[22], v[9])) # row 0 swizzle offset
|
||||
for d, s_ in zip([19, 20, 21, 22, 32, 33, 34], [20, 21, 22, 32, 33, 34, 35]):
|
||||
k.emit(v_and_b32_e32(v[d], ADDR_MASK, v[s_]))
|
||||
k.emit(v_sub_nc_u32_e32(v[8], v[1], v[8])) # A-tile swizzle
|
||||
|
||||
# Apply swizzle offsets and scale to byte offsets
|
||||
k.emit(v_lshlrev_b32_e32(v[9], 2, v[9])) # row 0 offset * 4
|
||||
for r, t in zip([10, 11, 14, 15, 16, 17, 18], [19, 20, 21, 22, 32, 33, 34]):
|
||||
k.emit(v_sub_nc_u32_e32(v[r], v[r], v[t])) # rows 1-7 swizzle
|
||||
k.emit(v_bfe_u32(v[2], v[0], 3, 2)) # v[2] = (lane_id >> 3) & 3
|
||||
k.emit(v_lshlrev_b32_e32(v[8], 2, v[8])) # A-tile base * 4
|
||||
|
||||
# Compute B-tile base address: LDS_A_STRIDE * (lane_id % 8) + row0_offset
|
||||
k.emit(v_mad_u32_u24(v[V_LDS_B_ADDR], LDS_A_STRIDE, v[V_LANE_ID_MOD8], v[9]))
|
||||
# Scale row offsets 1-7 to byte offsets (row 0 already in v[9])
|
||||
for d, r in zip([9, 10, 11, 14, 15, 16, 17], [10, 11, 14, 15, 16, 17, 18]):
|
||||
k.emit(v_lshlrev_b32_e32(v[d], 2, v[r]))
|
||||
k.emit(v_lshlrev_b32_e32(v[V_LANE_DIV8_X4], 2, v[2]))
|
||||
k.emit(v_add_nc_u32_e32(v[8], 0x80, v[8])) # A-tile initial store base + 128
|
||||
|
||||
# Store initial tile data to LDS
|
||||
k.waitcnt(vm=0)
|
||||
for i, (d0, d1) in enumerate([(0,1), (2,3), (4,5), (11,12)]):
|
||||
k.emit(ds_store_2addr_stride64_b32(addr=v[8], data0=v[INIT_TILE_LOADS[d0][0]], data1=v[INIT_TILE_LOADS[d1][0]], offset0=16+i*4, offset1=18+i*4))
|
||||
# B stores: single base with offsets 0,64,128,192,256,320,384,448
|
||||
for i, idx in enumerate([6,7,8,9,10,13,14,15]):
|
||||
offset = i * 64
|
||||
k.emit(ds_store_b32(addr=v[V_LDS_B_ADDR], data0=v[INIT_TILE_LOADS[idx][0]], offset0=offset & 0xFF, offset1=offset >> 8))
|
||||
k.waitcnt(lgkm=0)
|
||||
k.barrier()
|
||||
|
||||
# ===========================================================================
|
||||
# INIT: Compute LDS base addresses, then zero accumulators
|
||||
# ===========================================================================
|
||||
# v[3] = v[1] & 0x7F (lower 7 bits) since S_LOOP_BOUND=0 for valid workgroups
|
||||
k.emit(v_lshlrev_b32_e32(v[2], 4, v[2]))
|
||||
k.emit(v_add_nc_u32_e32(v[3], s[S_LOOP_BOUND], v[1]))
|
||||
k.emit(v_and_b32_e32(v[3], ADDR_MASK, v[3]))
|
||||
k.emit(v_sub_nc_u32_e32(v[3], v[1], v[3]))
|
||||
k.emit(v_lshl_or_b32(v[V_LDS_B_BASE], v[V_LANE_ID_MOD8], 4, LDS_BASE_OFFSET))
|
||||
k.emit(v_lshl_add_u32(v[V_LDS_A_ADDR], v[3], 2, LDS_BASE_OFFSET))
|
||||
k.emit(v_lshlrev_b32_e32(v[3], 2, v[0]))
|
||||
k.emit(v_and_or_b32(v[V_LDS_A_BASE], 0x180, v[3], v[2]))
|
||||
|
||||
# Zero all 128 accumulators using VOPD dual moves (64 instructions instead of 128)
|
||||
for i in range(0, len(OUT_REGS), 2):
|
||||
k.emit(VOPD(VOPDOp.V_DUAL_MOV_B32, VOPDOp.V_DUAL_MOV_B32, vdstx=v[OUT_REGS[i]], vdsty=v[OUT_REGS[i+1]], srcx0=0, srcy0=0))
|
||||
|
||||
k.emit(s_add_i32(s[S_LOOP_BOUND], s[S_DIM_N], -8))
|
||||
k.emit(s_add_u32(s[S_A_PTR[0]], s[S_A_PTR[0]], 32))
|
||||
k.emit(s_addc_u32(s[S_A_PTR[1]], s[S_A_PTR[1]], 0))
|
||||
# S_LOOP_CTR is already 0 from prologue initialization
|
||||
k.emit(s_branch(simm16=0)); k.branch_to('LOOP_ENTRY')
|
||||
|
||||
# ===========================================================================
|
||||
# MAIN GEMM LOOP
|
||||
# ===========================================================================
|
||||
|
||||
NO_DS, NO_GLOBAL = getenv("NO_DS", 0), getenv("NO_GLOBAL", 0)
|
||||
|
||||
k.label('LOOP_INC')
|
||||
k.emit(s_add_i32(s[S_LOOP_CTR], s[S_LOOP_CTR], 8))
|
||||
k.emit(s_cmp_ge_i32(s[S_LOOP_CTR], s[S_DIM_N]))
|
||||
k.emit(s_cbranch_scc1(simm16=0)); k.branch_to('EPILOGUE')
|
||||
|
||||
k.label('LOOP_ENTRY')
|
||||
k.emit(s_cmp_lt_i32(s[S_LOOP_CTR], s[S_LOOP_BOUND]))
|
||||
k.emit(s_cselect_b32(s[S_PREFETCH_FLAG], -1, 0)) # s_cselect doesn't modify SCC
|
||||
k.emit(s_cbranch_scc0(simm16=0)); k.branch_to('SKIP_PREFETCH') # branch if loop_ctr >= loop_bound
|
||||
|
||||
# Advance prefetch pointers
|
||||
k.emit(v_add_nc_u32_e32(v[V_GLOBAL_B_ADDR], 0x20000, v[V_GLOBAL_B_ADDR]))
|
||||
k.emit(v_add_nc_u32_e32(v[V_GLOBAL_A_ADDR], 0x20, v[V_GLOBAL_A_ADDR]))
|
||||
|
||||
if not NO_GLOBAL:
|
||||
for vdst, saddr_lo in INIT_PREFETCH:
|
||||
k.global_load(vdst, V_GLOBAL_B_ADDR, saddr_lo)
|
||||
|
||||
k.label('SKIP_PREFETCH')
|
||||
|
||||
# 8 inner loop iterations
|
||||
for iter in range(8):
|
||||
# Load A tile (4 pairs) and B tile (8 pairs) from LDS
|
||||
if not NO_DS:
|
||||
k.emit(s_clause(simm16=11)) # 12 loads total: 4 A + 8 B
|
||||
# A tile: 4 ds_load_b64
|
||||
for i, vdst in enumerate(V_A_TILE_REGS):
|
||||
a_off = (i & 1) * 8 + (i >> 1) * 64 + iter * LDS_A_STRIDE
|
||||
k.emit(ds_load_b64(vdst=v[vdst:vdst+1], addr=v[V_LDS_A_BASE], offset0=a_off & 0xFF, offset1=a_off >> 8))
|
||||
# B tile: 8 ds_load_b64
|
||||
for i, vdst in enumerate(V_B_TILE_REGS):
|
||||
b_off = (i & 1) * 8 + (i & 2) * 64 + (i >> 2) * 256 + iter * LDS_B_STRIDE
|
||||
k.emit(ds_load_b64(vdst=v[vdst:vdst+1], addr=v[V_LDS_B_BASE], offset0=b_off & 0xFF, offset1=b_off >> 8))
|
||||
k.waitcnt(lgkm=0)
|
||||
|
||||
# 64 dual FMACs
|
||||
k.emit(s_clause(simm16=63))
|
||||
for i, (vdst_x, vdst_y, ax, bx, ay, by) in enumerate(FMAC_PATTERN):
|
||||
k.emit(VOPD(VOPDOp.V_DUAL_FMAC_F32, VOPDOp.V_DUAL_FMAC_F32,
|
||||
vdstx=v[vdst_x], vdsty=v[vdst_y], srcx0=v[ax], vsrcx1=v[bx], srcy0=v[ay], vsrcy1=v[by]))
|
||||
|
||||
# Issue global prefetch AFTER FMACs (first 6 iterations only)
|
||||
if iter < 6 and not NO_GLOBAL:
|
||||
vdst1, vdst2, addr, slo1, slo2 = PREFETCH_LOADS[iter]
|
||||
k.global_load(vdst1, addr, slo1)
|
||||
k.global_load(vdst2, addr, slo2)
|
||||
|
||||
k.emit(s_and_not1_b32(VCC_LO, EXEC_LO, s[S_PREFETCH_FLAG]))
|
||||
k.waitcnt(vm=0)
|
||||
k.barrier()
|
||||
k.emit(s_cbranch_vccnz(simm16=0)); k.branch_to('LOOP_INC')
|
||||
|
||||
# Store prefetched data to LDS
|
||||
# NOTE: Register naming reflects LDS tile organization, not source matrix:
|
||||
# V_LDS_A_DATA (v155-162) holds data that goes to LDS A-tile region
|
||||
# V_LDS_B_DATA (v163-170) holds data that goes to LDS B-tile region
|
||||
# The data sources are swapped: A-tile receives B matrix rows, B-tile receives A matrix columns
|
||||
for i in range(4): # A tile: 8 values via 4 stride64 stores
|
||||
k.emit(ds_store_2addr_stride64_b32(addr=v[V_LDS_A_ADDR], data0=v[V_LDS_A_DATA[i*2]], data1=v[V_LDS_A_DATA[i*2+1]], offset0=i*4, offset1=i*4+2))
|
||||
for i in range(8): # B tile: 8 values via 8 scalar stores with 64-byte spacing
|
||||
offset = i * 64
|
||||
k.emit(ds_store_b32(addr=v[V_LDS_B_ADDR], data0=v[V_LDS_B_DATA[i]], offset0=offset & 0xFF, offset1=offset >> 8))
|
||||
|
||||
k.waitcnt(lgkm=0)
|
||||
k.barrier()
|
||||
k.emit(s_branch(simm16=0)); k.branch_to('LOOP_INC')
|
||||
|
||||
# ===========================================================================
|
||||
# EPILOGUE: Permute and store results
|
||||
# ===========================================================================
|
||||
k.label('EPILOGUE')
|
||||
|
||||
# Rearrange accumulators from FMAC layout to contiguous output order
|
||||
for a, b in PERMUTE_SWAPS:
|
||||
k.emit(v_swap_b32_e32(v[a], v[b]))
|
||||
|
||||
# Compute output coordinates: v[V_LANE_ID_MOD8] = col, v[V_OUTPUT_ROW] = row
|
||||
k.emit(VOPD(VOPDOp.V_DUAL_MOV_B32, VOPDOp.V_DUAL_MOV_B32,
|
||||
vdstx=v[149], vdsty=v[150], srcx0=v[V_LANE_MOD8_X4], vsrcx1=v[0], srcy0=v[V_LANE_DIV8_X4], vsrcy1=v[0]))
|
||||
k.emit(v_and_b32_e32(v[0], 0x60, v[0]))
|
||||
k.emit(v_or_b32_e32(v[V_LANE_ID_MOD8], s[S_TILE_X], v[149]))
|
||||
k.emit(v_add_nc_u32_e32(v[0], s[S_TILE_Y], v[0]))
|
||||
k.emit(v_or_b32_e32(v[V_OUTPUT_ROW], v[0], v[150]))
|
||||
|
||||
# Precompute row offsets: v[144-147] for rows 0-3, v[148-151] for rows 16-19
|
||||
for base, row_off in [(144, 0), (148, 16)]:
|
||||
if row_off: k.emit(v_or_b32_e32(v[1], row_off, v[V_OUTPUT_ROW]))
|
||||
k.emit(v_mul_lo_u32(v[base], v[1] if row_off else v[V_OUTPUT_ROW], s[S_DIM_N]))
|
||||
for i in range(3): k.emit(v_add_nc_u32_e32(v[base + 1 + i], s[S_DIM_N], v[base + i]))
|
||||
|
||||
k.emit(v_mov_b32_e32(v[V_ADDR_HI_ZERO], 0))
|
||||
k.emit(s_lshl_b32(s[S_PREFETCH_FLAG], s[S_DIM_N], 2)) # row stride in bytes
|
||||
|
||||
# Store 128 output values as 32 groups of 4 (128-bit stores)
|
||||
# Layout: 2 row halves (0-3, 16-19) x 4 col groups x 4 rows = 32 stores of 4 floats
|
||||
epilogue_reserved = {V_LANE_ID_MOD8, V_OUTPUT_ROW, V_LANE_MOD8_X4, V_LANE_DIV8_X4, V_ADDR_HI_ZERO}
|
||||
|
||||
for i, (row_half, col_off, row_in_group) in enumerate([(rh, co, ri)
|
||||
for rh in range(2) for co in [0, 32, 64, 96] for ri in range(4)]):
|
||||
row = row_half * 16 + row_in_group
|
||||
srcs = OUT_REGS[i*4:(i+1)*4]
|
||||
|
||||
# Find temp register for scaled values (must not conflict with reserved regs)
|
||||
tmp = max(srcs) + 5
|
||||
while any(r in epilogue_reserved for r in range(tmp, tmp + 4)): tmp += 1
|
||||
|
||||
# Copy values to temp regs for output (alpha=1.0 hardcoded, so just move)
|
||||
for j, src in enumerate(srcs):
|
||||
k.emit(v_mov_b32_e32(v[tmp + j], v[src]))
|
||||
|
||||
# Compute output address
|
||||
if row_in_group == 0: # first row: compute base address for this column group
|
||||
if col_off == 0: k.emit(v_mov_b32_e32(v[0], v[V_LANE_ID_MOD8]))
|
||||
else: k.emit(v_add_nc_u32_e32(v[0], col_off, v[V_LANE_ID_MOD8]))
|
||||
row_base = 144 + row if row < 4 else 148 + row - 16
|
||||
k.emit(v_add_nc_u32_e32(v[0], v[row_base], v[0]))
|
||||
k.emit(v_lshlrev_b32_e32(v[0], 2, v[0]))
|
||||
k.emit(v_add_co_u32(v[0], VCC_LO, s[S_OUT_PTR[0]], v[0]))
|
||||
k.emit(v_add_co_ci_u32_e32(v[1], s[S_OUT_PTR[1]], v[V_ADDR_HI_ZERO]))
|
||||
else: # subsequent rows: just add stride
|
||||
k.emit(v_add_co_u32(v[0], VCC_LO, s[S_PREFETCH_FLAG], v[0]))
|
||||
k.emit(v_add_co_ci_u32_e32(v[1], v[1], v[V_ADDR_HI_ZERO]))
|
||||
|
||||
k.emit(global_store_b128(addr=v[0:1], data=v[tmp:tmp+3], saddr=RawImm(124)))
|
||||
|
||||
k.emit(s_sendmsg(simm16=3))
|
||||
k.emit(s_endpgm())
|
||||
|
||||
return k.to_asm()
|
||||
|
||||
# =============================================================================
|
||||
# Test harness
|
||||
# =============================================================================
|
||||
|
||||
N = getenv("N", 4096)
|
||||
BLOCK_M, BLOCK_N = 128, 128
|
||||
THREADS = 128
|
||||
|
||||
def test_matmul():
|
||||
dev = Device[Device.DEFAULT]
|
||||
print(f"Device arch: {dev.arch}")
|
||||
|
||||
if getenv("STOCK", 0):
|
||||
# Load the stock kernel from amd_seb/kernel8_batched_gmem.s
|
||||
stock_path = Path(__file__).parent / "amd_seb" / "kernel8_batched_gmem.s"
|
||||
asm = stock_path.read_text()
|
||||
print(f"Loaded stock kernel from {stock_path}")
|
||||
else:
|
||||
asm = build_kernel(dev.arch)
|
||||
if getenv("PRINT_ASM", 0): print(asm)
|
||||
|
||||
binary = dev.compiler.compile(asm)
|
||||
print(f"Compiled! Binary size: {len(binary)} bytes")
|
||||
|
||||
rng = np.random.default_rng(42)
|
||||
a = Tensor(rng.random((N, N), dtype=np.float32) - 0.5)
|
||||
b = Tensor(rng.random((N, N), dtype=np.float32) - 0.5)
|
||||
c = Tensor.empty(N, N)
|
||||
Tensor.realize(a, b, c)
|
||||
|
||||
grid, local = (N // BLOCK_N, N // BLOCK_M, 1), (THREADS, 1, 1)
|
||||
print(f"Grid: {grid}, Local: {local}")
|
||||
|
||||
_prg = dev.runtime("kernel", binary)
|
||||
class AsmRunner(Runner):
|
||||
def __init__(self):
|
||||
super().__init__(colored("kernel", "cyan"), Device.DEFAULT, Estimates(ops=N*N*N*2, mem=N*N*4*3))
|
||||
def __call__(self, rawbufs, var_vals, wait=False):
|
||||
c_buf, a_buf, b_buf = [x.ensure_allocated()._buf for x in rawbufs]
|
||||
return _prg(a_buf, b_buf, c_buf, global_size=grid, local_size=local, wait=wait)
|
||||
|
||||
ei = ExecItem(None, [c.uop.buffer, a.uop.buffer, b.uop.buffer], prg=AsmRunner())
|
||||
|
||||
ets = []
|
||||
with Context(DEBUG=2):
|
||||
for _ in range(getenv("CNT", 5)): ets.append(ei.run(wait=True))
|
||||
print(f"REAL TFLOPS {N * N * N * 2 / min(ets) * 1e-12:.2f}")
|
||||
|
||||
if getenv("VERIFY", 1):
|
||||
GlobalCounters.reset()
|
||||
with Context(DEBUG=2): tc = (a @ b).realize()
|
||||
with Context(DEBUG=0): err = (c - tc).square().mean().item()
|
||||
print(f"mean squared error {err}")
|
||||
if err > 1e-06: raise RuntimeError("matmul is wrong!")
|
||||
|
||||
def run_sqtt():
|
||||
"""Run with SQTT profiling and write trace files."""
|
||||
import subprocess, os
|
||||
|
||||
# Run test_matmul in a subprocess with SQTT enabled from the start (no verify)
|
||||
env = {**os.environ, "AMD": "1", "SQTT": "1", "CNT": "1", "PROFILE": "1", "PYTHONPATH": ".", "VERIFY": "0"}
|
||||
result = subprocess.run(
|
||||
["python", "-c", "from extra.gemm.amd_asm_matmul import test_matmul; test_matmul()"],
|
||||
capture_output=True, text=True, env=env, timeout=120
|
||||
)
|
||||
print(result.stdout)
|
||||
|
||||
# Run roc.py to extract trace data
|
||||
result = subprocess.run(
|
||||
["python", "extra/sqtt/roc.py", "--profile", "/tmp/profile.pkl.tiny", "--kernel", "kernel"],
|
||||
capture_output=True, text=True, env={**os.environ, "DEBUG": "5"}, timeout=60
|
||||
)
|
||||
output = result.stdout + result.stderr
|
||||
|
||||
# Write full output to trace file
|
||||
with open("/tmp/sqtt_trace.txt", "w") as f:
|
||||
f.write(output)
|
||||
print(f"Wrote {len(output)} bytes to /tmp/sqtt_trace.txt")
|
||||
|
||||
if __name__ == "__main__":
|
||||
if getenv("ASM", 0): print(build_kernel(Device[Device.DEFAULT].arch))
|
||||
elif getenv("SQTT", 0): run_sqtt()
|
||||
else: test_matmul()
|
||||
Reference in New Issue
Block a user