Files
tinygrad/extra/assembly/amd/emu.py
2025-12-29 14:10:55 -05:00

719 lines
39 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# RDNA3 emulator - executes compiled pseudocode from AMD ISA PDF
# mypy: ignore-errors
from __future__ import annotations
import ctypes, os
from extra.assembly.amd.lib import Inst, RawImm
from extra.assembly.amd.pcode import _f32, _i32, _sext, _f16, _i16, _f64, _i64
from extra.assembly.amd.autogen.rdna3.gen_pcode import get_compiled_functions
from extra.assembly.amd.autogen.rdna3 import (
SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, DS, FLAT, VOPD, SrcEnum,
SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, DSOp, FLATOp, GLOBALOp, VOPDOp
)
Program = dict[int, Inst]
WAVE_SIZE, SGPR_COUNT, VGPR_COUNT = 32, 128, 256
VCC_LO, VCC_HI, NULL, EXEC_LO, EXEC_HI, SCC = SrcEnum.VCC_LO, SrcEnum.VCC_HI, SrcEnum.NULL, SrcEnum.EXEC_LO, SrcEnum.EXEC_HI, SrcEnum.SCC
# VOP3 ops that use 64-bit operands (and thus 64-bit literals when src is 255)
# Exception: V_LDEXP_F64 has 32-bit integer src1, so literal should NOT be 64-bit when src1=255
_VOP3_64BIT_OPS = {op.value for op in VOP3Op if op.name.endswith(('_F64', '_B64', '_I64', '_U64'))}
# Ops where src1 is 32-bit (exponent/shift amount) even though the op name suggests 64-bit
_VOP3_64BIT_OPS_32BIT_SRC1 = {VOP3Op.V_LDEXP_F64.value}
# Ops with 16-bit types in name (for source/dest handling)
_VOP3_16BIT_OPS = {op for op in VOP3Op if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16'))}
_VOP1_16BIT_OPS = {op for op in VOP1Op if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16'))}
_VOP2_16BIT_OPS = {op for op in VOP2Op if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16'))}
# CVT ops with 32/64-bit source (despite 16-bit in name)
_CVT_32_64_SRC_OPS = {op for op in VOP3Op if op.name.startswith('V_CVT_') and op.name.endswith(('_F32', '_I32', '_U32', '_F64', '_I64', '_U64'))} | \
{op for op in VOP1Op if op.name.startswith('V_CVT_') and op.name.endswith(('_F32', '_I32', '_U32', '_F64', '_I64', '_U64'))}
# 16-bit dst ops (PACK has 32-bit dst despite F16 in name)
_VOP3_16BIT_DST_OPS = {op for op in _VOP3_16BIT_OPS if 'PACK' not in op.name}
_VOP1_16BIT_DST_OPS = {op for op in _VOP1_16BIT_OPS if 'PACK' not in op.name}
# Inline constants for src operands 128-254. Build tables for f32, f16, and f64 formats.
import struct as _struct
_FLOAT_CONSTS = {SrcEnum.POS_HALF: 0.5, SrcEnum.NEG_HALF: -0.5, SrcEnum.POS_ONE: 1.0, SrcEnum.NEG_ONE: -1.0,
SrcEnum.POS_TWO: 2.0, SrcEnum.NEG_TWO: -2.0, SrcEnum.POS_FOUR: 4.0, SrcEnum.NEG_FOUR: -4.0, SrcEnum.INV_2PI: 0.15915494309189535}
def _build_inline_consts(neg_mask, float_to_bits):
tbl = list(range(65)) + [((-i) & neg_mask) for i in range(1, 17)] + [0] * (127 - 81)
for k, v in _FLOAT_CONSTS.items(): tbl[k - 128] = float_to_bits(v)
return tbl
_INLINE_CONSTS = _build_inline_consts(0xffffffff, lambda f: _struct.unpack('<I', _struct.pack('<f', f))[0])
_INLINE_CONSTS_F16 = _build_inline_consts(0xffff, lambda f: _struct.unpack('<H', _struct.pack('<e', f))[0])
_INLINE_CONSTS_F64 = _build_inline_consts(0xffffffffffffffff, lambda f: _struct.unpack('<Q', _struct.pack('<d', f))[0])
# Memory access
_valid_mem_ranges: list[tuple[int, int]] = []
def set_valid_mem_ranges(ranges: set[tuple[int, int]]) -> None: _valid_mem_ranges.clear(); _valid_mem_ranges.extend(ranges)
def _mem_valid(addr: int, size: int) -> bool:
for s, z in _valid_mem_ranges:
if s <= addr and addr + size <= s + z: return True
return not _valid_mem_ranges
def _ctypes_at(addr: int, size: int): return (ctypes.c_uint8 if size == 1 else ctypes.c_uint16 if size == 2 else ctypes.c_uint32).from_address(addr)
def mem_read(addr: int, size: int) -> int: return _ctypes_at(addr, size).value if _mem_valid(addr, size) else 0
def mem_write(addr: int, size: int, val: int) -> None:
if _mem_valid(addr, size): _ctypes_at(addr, size).value = val
# Memory op tables (not pseudocode - these are format descriptions)
def _mem_ops(ops, suffix_map):
return {getattr(e, f"{p}_{s}"): v for e in ops for s, v in suffix_map.items() for p in [e.__name__.replace("Op", "")]}
_LOAD_MAP = {'LOAD_B32': (1,4,0), 'LOAD_B64': (2,4,0), 'LOAD_B96': (3,4,0), 'LOAD_B128': (4,4,0), 'LOAD_U8': (1,1,0), 'LOAD_I8': (1,1,1), 'LOAD_U16': (1,2,0), 'LOAD_I16': (1,2,1)}
_STORE_MAP = {'STORE_B32': (1,4), 'STORE_B64': (2,4), 'STORE_B96': (3,4), 'STORE_B128': (4,4), 'STORE_B8': (1,1), 'STORE_B16': (1,2)}
FLAT_LOAD, FLAT_STORE = _mem_ops([GLOBALOp, FLATOp], _LOAD_MAP), _mem_ops([GLOBALOp, FLATOp], _STORE_MAP)
# D16 ops: load/store 16-bit to lower or upper half of VGPR. Format: (size, sign, hi) where hi=1 means upper 16 bits
_D16_LOAD_MAP = {'LOAD_D16_U8': (1,0,0), 'LOAD_D16_I8': (1,1,0), 'LOAD_D16_B16': (2,0,0),
'LOAD_D16_HI_U8': (1,0,1), 'LOAD_D16_HI_I8': (1,1,1), 'LOAD_D16_HI_B16': (2,0,1)}
_D16_STORE_MAP = {'STORE_D16_HI_B8': (1,1), 'STORE_D16_HI_B16': (2,1)} # (size, hi)
FLAT_D16_LOAD = _mem_ops([GLOBALOp, FLATOp], _D16_LOAD_MAP)
FLAT_D16_STORE = _mem_ops([GLOBALOp, FLATOp], _D16_STORE_MAP)
DS_LOAD = {DSOp.DS_LOAD_B32: (1,4,0), DSOp.DS_LOAD_B64: (2,4,0), DSOp.DS_LOAD_B128: (4,4,0), DSOp.DS_LOAD_U8: (1,1,0), DSOp.DS_LOAD_I8: (1,1,1), DSOp.DS_LOAD_U16: (1,2,0), DSOp.DS_LOAD_I16: (1,2,1)}
DS_STORE = {DSOp.DS_STORE_B32: (1,4), DSOp.DS_STORE_B64: (2,4), DSOp.DS_STORE_B128: (4,4), DSOp.DS_STORE_B8: (1,1), DSOp.DS_STORE_B16: (1,2)}
SMEM_LOAD = {SMEMOp.S_LOAD_B32: 1, SMEMOp.S_LOAD_B64: 2, SMEMOp.S_LOAD_B128: 4, SMEMOp.S_LOAD_B256: 8, SMEMOp.S_LOAD_B512: 16}
# VOPD op -> VOP3 op mapping (VOPD is dual-issue of VOP1/VOP2 ops, use VOP3 enums for pseudocode lookup)
_VOPD_TO_VOP = {
VOPDOp.V_DUAL_FMAC_F32: VOP3Op.V_FMAC_F32, VOPDOp.V_DUAL_FMAAK_F32: VOP2Op.V_FMAAK_F32, VOPDOp.V_DUAL_FMAMK_F32: VOP2Op.V_FMAMK_F32,
VOPDOp.V_DUAL_MUL_F32: VOP3Op.V_MUL_F32, VOPDOp.V_DUAL_ADD_F32: VOP3Op.V_ADD_F32, VOPDOp.V_DUAL_SUB_F32: VOP3Op.V_SUB_F32,
VOPDOp.V_DUAL_SUBREV_F32: VOP3Op.V_SUBREV_F32, VOPDOp.V_DUAL_MUL_DX9_ZERO_F32: VOP3Op.V_MUL_DX9_ZERO_F32,
VOPDOp.V_DUAL_MOV_B32: VOP3Op.V_MOV_B32, VOPDOp.V_DUAL_CNDMASK_B32: VOP3Op.V_CNDMASK_B32,
VOPDOp.V_DUAL_MAX_F32: VOP3Op.V_MAX_F32, VOPDOp.V_DUAL_MIN_F32: VOP3Op.V_MIN_F32,
VOPDOp.V_DUAL_ADD_NC_U32: VOP3Op.V_ADD_NC_U32, VOPDOp.V_DUAL_LSHLREV_B32: VOP3Op.V_LSHLREV_B32, VOPDOp.V_DUAL_AND_B32: VOP3Op.V_AND_B32,
}
# Compiled pseudocode functions (lazy loaded)
_COMPILED: dict | None = None
def _get_compiled() -> dict:
global _COMPILED
if _COMPILED is None: _COMPILED = get_compiled_functions()
return _COMPILED
class WaveState:
__slots__ = ('sgpr', 'vgpr', 'scc', 'pc', 'literal', '_pend_sgpr')
def __init__(self):
self.sgpr, self.vgpr = [0] * SGPR_COUNT, [[0] * VGPR_COUNT for _ in range(WAVE_SIZE)]
self.sgpr[EXEC_LO], self.scc, self.pc, self.literal, self._pend_sgpr = 0xffffffff, 0, 0, 0, {}
@property
def vcc(self) -> int: return self.sgpr[VCC_LO] | (self.sgpr[VCC_HI] << 32)
@vcc.setter
def vcc(self, v: int): self.sgpr[VCC_LO], self.sgpr[VCC_HI] = v & 0xffffffff, (v >> 32) & 0xffffffff
@property
def exec_mask(self) -> int: return self.sgpr[EXEC_LO] | (self.sgpr[EXEC_HI] << 32)
@exec_mask.setter
def exec_mask(self, v: int): self.sgpr[EXEC_LO], self.sgpr[EXEC_HI] = v & 0xffffffff, (v >> 32) & 0xffffffff
def rsgpr(self, i: int) -> int: return 0 if i == NULL else self.scc if i == SCC else self.sgpr[i] if i < SGPR_COUNT else 0
def wsgpr(self, i: int, v: int):
if i < SGPR_COUNT and i != NULL: self.sgpr[i] = v & 0xffffffff
def rsgpr64(self, i: int) -> int: return self.rsgpr(i) | (self.rsgpr(i+1) << 32)
def wsgpr64(self, i: int, v: int): self.wsgpr(i, v & 0xffffffff); self.wsgpr(i+1, (v >> 32) & 0xffffffff)
def rsrc(self, v: int, lane: int) -> int:
if v < SGPR_COUNT: return self.sgpr[v]
if v == SCC: return self.scc
if v < 255: return _INLINE_CONSTS[v - 128]
if v == 255: return self.literal
return self.vgpr[lane][v - 256] if v <= 511 else 0
def rsrc_f16(self, v: int, lane: int) -> int:
"""Read source operand for VOP3P packed f16 operations. Uses f16 inline constants."""
if v < SGPR_COUNT: return self.sgpr[v]
if v == SCC: return self.scc
if v < 255: return _INLINE_CONSTS_F16[v - 128]
if v == 255: return self.literal
return self.vgpr[lane][v - 256] if v <= 511 else 0
def rsrc64(self, v: int, lane: int) -> int:
"""Read 64-bit source operand. For inline constants, returns 64-bit representation."""
# Inline constants 128-254 need special handling for 64-bit ops
if 128 <= v < 255: return _INLINE_CONSTS_F64[v - 128]
if v == 255: return self.literal # 32-bit literal, caller handles extension
return self.rsrc(v, lane) | ((self.rsrc(v+1, lane) if v < VCC_LO or 256 <= v <= 511 else 0) << 32)
def pend_sgpr_lane(self, reg: int, lane: int, val: int):
if reg not in self._pend_sgpr: self._pend_sgpr[reg] = 0
if val: self._pend_sgpr[reg] |= (1 << lane)
def commit_pends(self):
for reg, val in self._pend_sgpr.items(): self.sgpr[reg] = val
self._pend_sgpr.clear()
# Instruction decode
def decode_format(word: int) -> tuple[type[Inst] | None, bool]:
hi2 = (word >> 30) & 0x3
if hi2 == 0b11:
enc = (word >> 26) & 0xf
if enc == 0b1101: return SMEM, True
if enc == 0b0101:
op = (word >> 16) & 0x3ff
return (VOP3SD, True) if op in (288, 289, 290, 764, 765, 766, 767, 768, 769, 770) else (VOP3, True)
return {0b0011: (VOP3P, True), 0b0110: (DS, True), 0b0111: (FLAT, True), 0b0010: (VOPD, True)}.get(enc, (None, True))
if hi2 == 0b10:
enc = (word >> 23) & 0x7f
return {0b1111101: (SOP1, False), 0b1111110: (SOPC, False), 0b1111111: (SOPP, False)}.get(enc, (SOPK, False) if ((word >> 28) & 0xf) == 0b1011 else (SOP2, False))
enc = (word >> 25) & 0x7f
return (VOPC, False) if enc == 0b0111110 else (VOP1, False) if enc == 0b0111111 else (VOP2, False)
def _unwrap(v) -> int: return v.val if isinstance(v, RawImm) else v.value if hasattr(v, 'value') else v
def decode_program(data: bytes) -> Program:
result: Program = {}
i = 0
while i < len(data):
word = int.from_bytes(data[i:i+4], 'little')
inst_class, is_64 = decode_format(word)
if inst_class is None: i += 4; continue
base_size = 8 if is_64 else 4
# Pass enough data for potential 64-bit literal (base + 8 bytes max)
inst = inst_class.from_bytes(data[i:i+base_size+8])
for name, val in inst._values.items(): setattr(inst, name, _unwrap(val))
# from_bytes already handles literal reading - only need fallback for cases it doesn't handle
if inst._literal is None:
has_literal = any(getattr(inst, fld, None) == 255 for fld in ('src0', 'src1', 'src2', 'ssrc0', 'ssrc1', 'srcx0', 'srcy0'))
if inst_class == VOP2 and inst.op in (44, 45, 55, 56): has_literal = True
if inst_class == VOPD and (inst.opx in (1, 2) or inst.opy in (1, 2)): has_literal = True
if inst_class == SOP2 and inst.op in (69, 70): has_literal = True
if has_literal:
# For 64-bit ops, the 32-bit literal is placed in HIGH 32 bits (low 32 bits = 0)
# Exception: some ops have mixed src sizes (e.g., V_LDEXP_F64 has 32-bit src1)
op_val = inst._values.get('op')
if hasattr(op_val, 'value'): op_val = op_val.value
is_64bit = inst_class is VOP3 and op_val in _VOP3_64BIT_OPS
# Don't treat literal as 64-bit if the op has 32-bit src1 and src1 is the literal
if is_64bit and op_val in _VOP3_64BIT_OPS_32BIT_SRC1 and getattr(inst, 'src1', None) == 255:
is_64bit = False
lit32 = int.from_bytes(data[i+base_size:i+base_size+4], 'little')
inst._literal = (lit32 << 32) if is_64bit else lit32
inst._words = inst.size() // 4
result[i // 4] = inst
i += inst._words * 4
return result
# ═══════════════════════════════════════════════════════════════════════════════
# EXECUTION - All ALU ops use pseudocode from PDF
# ═══════════════════════════════════════════════════════════════════════════════
def exec_scalar(st: WaveState, inst: Inst) -> int:
"""Execute scalar instruction. Returns PC delta or negative for special cases."""
compiled = _get_compiled()
inst_type = type(inst)
# SOPP: control flow (not ALU)
if inst_type is SOPP:
op = inst.op
if op == SOPPOp.S_ENDPGM: return -1
if op == SOPPOp.S_BARRIER: return -2
if op == SOPPOp.S_BRANCH: return _sext(inst.simm16, 16)
if op == SOPPOp.S_CBRANCH_SCC0: return _sext(inst.simm16, 16) if st.scc == 0 else 0
if op == SOPPOp.S_CBRANCH_SCC1: return _sext(inst.simm16, 16) if st.scc == 1 else 0
if op == SOPPOp.S_CBRANCH_VCCZ: return _sext(inst.simm16, 16) if (st.vcc & 0xffffffff) == 0 else 0
if op == SOPPOp.S_CBRANCH_VCCNZ: return _sext(inst.simm16, 16) if (st.vcc & 0xffffffff) != 0 else 0
if op == SOPPOp.S_CBRANCH_EXECZ: return _sext(inst.simm16, 16) if st.exec_mask == 0 else 0
if op == SOPPOp.S_CBRANCH_EXECNZ: return _sext(inst.simm16, 16) if st.exec_mask != 0 else 0
# Valid SOPP range is 0-61 (max defined opcode); anything above is invalid
if op > 61: raise NotImplementedError(f"Invalid SOPP opcode {op}")
return 0 # waits, hints, nops
# SMEM: memory loads (not ALU)
if inst_type is SMEM:
addr = st.rsgpr64(inst.sbase * 2) + _sext(inst.offset, 21)
if inst.soffset not in (NULL, 0x7f): addr += st.rsrc(inst.soffset, 0)
if (cnt := SMEM_LOAD.get(inst.op)) is None: raise NotImplementedError(f"SMEM op {inst.op}")
for i in range(cnt): st.wsgpr(inst.sdata + i, mem_read((addr + i * 4) & 0xffffffffffffffff, 4))
return 0
# SOP1: special handling for ops not in pseudocode
if inst_type is SOP1:
op = SOP1Op(inst.op)
# S_GETPC_B64: Get program counter (PC is stored as byte offset, convert from words)
if op == SOP1Op.S_GETPC_B64:
pc_bytes = st.pc * 4 # PC is in words, convert to bytes
st.wsgpr64(inst.sdst, pc_bytes)
return 0
# S_SETPC_B64: Set program counter to source value (indirect jump)
# Returns delta such that st.pc + inst_words + delta = target_words
if op == SOP1Op.S_SETPC_B64:
target_bytes = st.rsrc64(inst.ssrc0, 0)
target_words = target_bytes // 4
inst_words = 1 # SOP1 is always 1 word
return target_words - st.pc - inst_words
# Get op enum and lookup compiled function
if inst_type is SOP1: op_cls, ssrc0, sdst = SOP1Op, inst.ssrc0, inst.sdst
elif inst_type is SOP2: op_cls, ssrc0, sdst = SOP2Op, inst.ssrc0, inst.sdst
elif inst_type is SOPC: op_cls, ssrc0, sdst = SOPCOp, inst.ssrc0, None
elif inst_type is SOPK: op_cls, ssrc0, sdst = SOPKOp, inst.sdst, inst.sdst # sdst is both src and dst
else: raise NotImplementedError(f"Unknown scalar type {inst_type}")
op = op_cls(inst.op)
fn = compiled.get(op_cls, {}).get(op)
if fn is None: raise NotImplementedError(f"{op.name} not in pseudocode")
# Build context - handle 64-bit ops that need 64-bit source reads
# 64-bit source ops: name ends with _B64, _I64, _U64 or contains _U64, _I64 before last underscore
is_64bit_s0 = op.name.endswith(('_B64', '_I64', '_U64')) or '_U64_' in op.name or '_I64_' in op.name
is_64bit_s0s1 = op_cls is SOPCOp and op in (SOPCOp.S_CMP_EQ_U64, SOPCOp.S_CMP_LG_U64)
s0 = st.rsrc64(ssrc0, 0) if is_64bit_s0 or is_64bit_s0s1 else (st.rsrc(ssrc0, 0) if inst_type != SOPK else st.rsgpr(inst.sdst))
is_64bit_sop2 = is_64bit_s0 and inst_type is SOP2
s1 = st.rsrc64(inst.ssrc1, 0) if (is_64bit_sop2 or is_64bit_s0s1) else (st.rsrc(inst.ssrc1, 0) if inst_type in (SOP2, SOPC) else inst.simm16 if inst_type is SOPK else 0)
d0 = st.rsgpr64(sdst) if (is_64bit_s0 or is_64bit_s0s1) and sdst is not None else (st.rsgpr(sdst) if sdst is not None else 0)
exec_mask = st.exec_mask
literal = inst.simm16 if inst_type is SOPK else st.literal
# Execute compiled function
result = fn(s0, s1, 0, d0, st.scc, st.vcc, 0, exec_mask, literal, None, {})
# Apply results
if sdst is not None:
if result.get('d0_64'):
st.wsgpr64(sdst, result['d0'])
else:
st.wsgpr(sdst, result['d0'])
if 'scc' in result: st.scc = result['scc']
if 'exec' in result: st.exec_mask = result['exec']
if 'pc_delta' in result: return result['pc_delta']
return 0
def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = None) -> None:
"""Execute vector instruction for one lane."""
compiled = _get_compiled()
inst_type, V = type(inst), st.vgpr[lane]
# Memory ops (not ALU pseudocode)
if inst_type is FLAT:
op, addr_reg, data_reg, vdst, offset, saddr = inst.op, inst.addr, inst.data, inst.vdst, _sext(inst.offset, 13), inst.saddr
addr = V[addr_reg] | (V[addr_reg+1] << 32)
addr = (st.rsgpr64(saddr) + V[addr_reg] + offset) & 0xffffffffffffffff if saddr not in (NULL, 0x7f) else (addr + offset) & 0xffffffffffffffff
if op in FLAT_LOAD:
cnt, sz, sign = FLAT_LOAD[op]
for i in range(cnt): val = mem_read(addr + i * sz, sz); V[vdst + i] = _sext(val, sz * 8) & 0xffffffff if sign else val
elif op in FLAT_STORE:
cnt, sz = FLAT_STORE[op]
for i in range(cnt): mem_write(addr + i * sz, sz, V[data_reg + i] & ((1 << (sz * 8)) - 1))
elif op in FLAT_D16_LOAD:
sz, sign, hi = FLAT_D16_LOAD[op]
val = mem_read(addr, sz)
if sign: val = _sext(val, sz * 8) & 0xffff
if hi: V[vdst] = (V[vdst] & 0xffff) | (val << 16) # upper 16 bits
else: V[vdst] = (V[vdst] & 0xffff0000) | (val & 0xffff) # lower 16 bits
elif op in FLAT_D16_STORE:
sz, hi = FLAT_D16_STORE[op]
val = (V[data_reg] >> 16) & 0xffff if hi else V[data_reg] & 0xffff
mem_write(addr, sz, val & ((1 << (sz * 8)) - 1))
else: raise NotImplementedError(f"FLAT op {op}")
return
if inst_type is DS:
op, addr, vdst = inst.op, (V[inst.addr] + inst.offset0) & 0xffff, inst.vdst
if op in DS_LOAD:
cnt, sz, sign = DS_LOAD[op]
for i in range(cnt): val = int.from_bytes(lds[addr+i*sz:addr+i*sz+sz], 'little'); V[vdst + i] = _sext(val, sz * 8) & 0xffffffff if sign else val
elif op in DS_STORE:
cnt, sz = DS_STORE[op]
for i in range(cnt): lds[addr+i*sz:addr+i*sz+sz] = (V[inst.data0 + i] & ((1 << (sz * 8)) - 1)).to_bytes(sz, 'little')
else: raise NotImplementedError(f"DS op {op}")
return
# VOPD: dual-issue, execute two ops using VOP2/VOP3 compiled functions
# Both ops execute simultaneously using pre-instruction values, so read all inputs first
if inst_type is VOPD:
vdsty = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1)
# Read all source operands BEFORE any writes (dual-issue semantics)
sx0, sx1 = st.rsrc(inst.srcx0, lane), V[inst.vsrcx1]
sy0, sy1 = st.rsrc(inst.srcy0, lane), V[inst.vsrcy1]
dx0, dy0 = V[inst.vdstx], V[vdsty]
# Execute X op
res_x = None
if (op_x := _VOPD_TO_VOP.get(inst.opx)):
if (fn_x := compiled.get(type(op_x), {}).get(op_x)):
res_x = fn_x(sx0, sx1, 0, dx0, st.scc, st.vcc, lane, st.exec_mask, st.literal, None, {})
# Execute Y op
res_y = None
if (op_y := _VOPD_TO_VOP.get(inst.opy)):
if (fn_y := compiled.get(type(op_y), {}).get(op_y)):
res_y = fn_y(sy0, sy1, 0, dy0, st.scc, st.vcc, lane, st.exec_mask, st.literal, None, {})
# Write results after both ops complete
if res_x is not None: V[inst.vdstx] = res_x['d0']
if res_y is not None: V[vdsty] = res_y['d0']
return
# VOP3SD: has extra scalar dest for carry output
if inst_type is VOP3SD:
op = VOP3SDOp(inst.op)
fn = compiled.get(VOP3SDOp, {}).get(op)
if fn is None: raise NotImplementedError(f"{op.name} not in pseudocode")
s0, s1, s2 = st.rsrc(inst.src0, lane), st.rsrc(inst.src1, lane), st.rsrc(inst.src2, lane)
# For 64-bit src2 ops (V_MAD_U64_U32, V_MAD_I64_I32), read from consecutive registers
mad64_ops = (VOP3SDOp.V_MAD_U64_U32, VOP3SDOp.V_MAD_I64_I32)
if op in mad64_ops:
if inst.src2 >= 256: # VGPR
s2 = V[inst.src2 - 256] | (V[inst.src2 - 256 + 1] << 32)
else: # SGPR - read 64-bit from consecutive SGPRs
s2 = st.rsgpr64(inst.src2)
d0 = V[inst.vdst]
# For carry-in operations (V_*_CO_CI_*), src2 register contains the carry bitmask (not VCC).
# The pseudocode uses VCC but in VOP3SD encoding, the actual carry source is inst.src2.
# We pass the src2 register value as 'vcc' to the interpreter so it reads the correct carry.
carry_ops = (VOP3SDOp.V_ADD_CO_CI_U32, VOP3SDOp.V_SUB_CO_CI_U32, VOP3SDOp.V_SUBREV_CO_CI_U32)
vcc_for_exec = st.rsgpr64(inst.src2) if op in carry_ops else st.vcc
result = fn(s0, s1, s2, d0, st.scc, vcc_for_exec, lane, st.exec_mask, st.literal, None, {})
# Write result - handle 64-bit destinations
if result.get('d0_64'):
V[inst.vdst] = result['d0'] & 0xffffffff
V[inst.vdst + 1] = (result['d0'] >> 32) & 0xffffffff
else:
V[inst.vdst] = result['d0'] & 0xffffffff
if result.get('vcc_lane') is not None:
st.pend_sgpr_lane(inst.sdst, lane, result['vcc_lane'])
return
# Get op enum and sources (None means "no source" for that operand)
if inst_type is VOP1:
if inst.op == VOP1Op.V_NOP: return
op_cls, op, src0, src1, src2, vdst = VOP1Op, VOP1Op(inst.op), inst.src0, None, None, inst.vdst
elif inst_type is VOP2:
op_cls, op, src0, src1, src2, vdst = VOP2Op, VOP2Op(inst.op), inst.src0, inst.vsrc1 + 256, None, inst.vdst
elif inst_type is VOP3:
# VOP3 ops 0-255 are VOPC comparisons encoded as VOP3 (use VOPCOp pseudocode)
if inst.op < 256:
op_cls, op, src0, src1, src2, vdst = VOPCOp, VOPCOp(inst.op), inst.src0, inst.src1, None, inst.vdst
else:
op_cls, op, src0, src1, src2, vdst = VOP3Op, VOP3Op(inst.op), inst.src0, inst.src1, inst.src2, inst.vdst
# V_PERM_B32: byte permutation - not in pseudocode PDF, implement directly
# D0[byte_i] = selector[byte_i] < 8 ? {src1, src0}[selector[byte_i]] : (selector[byte_i] >= 0xD ? 0xFF : 0x00)
if op == VOP3Op.V_PERM_B32:
s0, s1, s2 = st.rsrc(inst.src0, lane), st.rsrc(inst.src1, lane), st.rsrc(inst.src2, lane)
# Combine src0 and src1 into 8-byte value: src0 is bytes 0-3, src1 is bytes 4-7
combined = (s0 & 0xffffffff) | ((s1 & 0xffffffff) << 32)
result = 0
for i in range(4): # 4 result bytes
sel = (s2 >> (i * 8)) & 0xff # byte selector for this position
if sel <= 7: result |= (((combined >> (sel * 8)) & 0xff) << (i * 8)) # select byte from combined
elif sel >= 0xd: result |= (0xff << (i * 8)) # 0xD-0xF: constant 0xFF
# else 0x8-0xC: constant 0x00 (already 0)
V[vdst] = result & 0xffffffff
return
elif inst_type is VOPC:
op_cls, op, src0, src1, src2, vdst = VOPCOp, VOPCOp(inst.op), inst.src0, inst.vsrc1 + 256, None, VCC_LO
elif inst_type is VOP3P:
# VOP3P: Packed 16-bit operations using compiled functions
op = VOP3POp(inst.op)
# WMMA: wave-level matrix multiply-accumulate (special handling - needs cross-lane access)
if op in (VOP3POp.V_WMMA_F32_16X16X16_F16, VOP3POp.V_WMMA_F32_16X16X16_BF16, VOP3POp.V_WMMA_F16_16X16X16_F16):
if lane == 0: # Only execute once per wave, write results for all lanes
exec_wmma(st, inst, op)
return
# V_FMA_MIX: Mixed precision FMA - inputs can be f16 or f32 controlled by opsel
if op in (VOP3POp.V_FMA_MIX_F32, VOP3POp.V_FMA_MIXLO_F16, VOP3POp.V_FMA_MIXHI_F16):
opsel = getattr(inst, 'opsel', 0)
opsel_hi = getattr(inst, 'opsel_hi', 0)
neg = getattr(inst, 'neg', 0)
neg_hi = getattr(inst, 'neg_hi', 0)
vdst = inst.vdst
# Read raw 32-bit values - for V_FMA_MIX, sources can be either f32 or f16
s0_raw = st.rsrc(inst.src0, lane)
s1_raw = st.rsrc(inst.src1, lane)
s2_raw = st.rsrc(inst.src2, lane) if inst.src2 is not None else 0
# opsel[i]=0: use as f32, opsel[i]=1: use hi f16 as f32
# For src0: opsel[0], for src1: opsel[1], for src2: opsel[2]
if opsel & 1: s0 = _f16((s0_raw >> 16) & 0xffff) # hi f16 -> f32
else: s0 = _f32(s0_raw) # use as f32
if opsel & 2: s1 = _f16((s1_raw >> 16) & 0xffff)
else: s1 = _f32(s1_raw)
if opsel & 4: s2 = _f16((s2_raw >> 16) & 0xffff)
else: s2 = _f32(s2_raw)
# Apply neg modifiers (for f32 values)
if neg & 1: s0 = -s0
if neg & 2: s1 = -s1
if neg & 4: s2 = -s2
# Compute FMA: d = s0 * s1 + s2
result = s0 * s1 + s2
V = st.vgpr[lane]
if op == VOP3POp.V_FMA_MIX_F32:
V[vdst] = _i32(result)
elif op == VOP3POp.V_FMA_MIXLO_F16:
lo = _i16(result) & 0xffff
V[vdst] = (V[vdst] & 0xffff0000) | lo
else: # V_FMA_MIXHI_F16
hi = _i16(result) & 0xffff
V[vdst] = (V[vdst] & 0x0000ffff) | (hi << 16)
return
# Use rsrc_f16 for VOP3P to get correct f16 inline constants
s0_raw = st.rsrc_f16(inst.src0, lane)
s1_raw = st.rsrc_f16(inst.src1, lane)
s2_raw = st.rsrc_f16(inst.src2, lane) if inst.src2 is not None else 0
# Handle opsel (which 16-bit halves to use for each source)
opsel = getattr(inst, 'opsel', 0)
opsel_hi = getattr(inst, 'opsel_hi', 3) # Default: use hi for hi result
opsel_hi2 = getattr(inst, 'opsel_hi2', 1) # Default for src2
# Handle neg modifiers for VOP3P
# neg applies to lo result inputs, neg_hi applies to hi result inputs
neg = getattr(inst, 'neg', 0)
neg_hi = getattr(inst, 'neg_hi', 0)
# Build "virtual" sources with halves arranged for pseudocode: lo half goes to [15:0], hi half goes to [31:16]
# opsel bit 0/1/2 selects which half of src0/1/2 goes to the LO result
# opsel_hi bit 0/1 selects which half of src0/1 goes to the HI result
s0_lo = (s0_raw >> 16) & 0xffff if (opsel & 1) else s0_raw & 0xffff
s1_lo = (s1_raw >> 16) & 0xffff if (opsel & 2) else s1_raw & 0xffff
s2_lo = (s2_raw >> 16) & 0xffff if (opsel & 4) else s2_raw & 0xffff
s0_hi = (s0_raw >> 16) & 0xffff if (opsel_hi & 1) else s0_raw & 0xffff
s1_hi = (s1_raw >> 16) & 0xffff if (opsel_hi & 2) else s1_raw & 0xffff
s2_hi = (s2_raw >> 16) & 0xffff if opsel_hi2 else s2_raw & 0xffff
# Apply neg to lo result inputs (toggle f16 sign bit)
if neg & 1: s0_lo ^= 0x8000
if neg & 2: s1_lo ^= 0x8000
if neg & 4: s2_lo ^= 0x8000
# Apply neg_hi to hi result inputs
if neg_hi & 1: s0_hi ^= 0x8000
if neg_hi & 2: s1_hi ^= 0x8000
if neg_hi & 4: s2_hi ^= 0x8000
# Pack into format expected by pseudocode: [31:16] = hi input, [15:0] = lo input
s0 = (s0_hi << 16) | s0_lo
s1 = (s1_hi << 16) | s1_lo
s2 = (s2_hi << 16) | s2_lo
op_cls, vdst = VOP3POp, inst.vdst
fn = compiled.get(op_cls, {}).get(op)
if fn is None: raise NotImplementedError(f"{op.name} not in pseudocode")
result = fn(s0, s1, s2, 0, st.scc, st.vcc, lane, st.exec_mask, st.literal, None, {})
st.vgpr[lane][vdst] = result['d0'] & 0xffffffff
return
else: raise NotImplementedError(f"Unknown vector type {inst_type}")
fn = compiled.get(op_cls, {}).get(op)
if fn is None: raise NotImplementedError(f"{op.name} not in pseudocode")
# Read sources (with VOP3 modifiers if applicable)
neg, abs_ = (getattr(inst, 'neg', 0), getattr(inst, 'abs', 0)) if inst_type is VOP3 else (0, 0)
opsel = getattr(inst, 'opsel', 0) if inst_type is VOP3 else 0
def mod_src(val: int, idx: int) -> int:
if (abs_ >> idx) & 1: val = _i32(abs(_f32(val)))
if (neg >> idx) & 1: val = _i32(-_f32(val))
return val
def mod_src64(val: int, idx: int) -> int:
if (abs_ >> idx) & 1: val = _i64(abs(_f64(val)))
if (neg >> idx) & 1: val = _i64(-_f64(val))
return val
# Determine if sources are 64-bit based on instruction type
# For 64-bit shift ops: src0 is 32-bit (shift amount), src1 is 64-bit (value to shift)
# For most other _B64/_I64/_U64/_F64 ops: all sources are 64-bit
is_64bit_op = op.name.endswith(('_B64', '_I64', '_U64', '_F64'))
# V_LDEXP_F64: src0 is 64-bit float, src1 is 32-bit integer exponent
is_ldexp_64 = op in (VOP3Op.V_LDEXP_F64,)
is_shift_64 = op in (VOP3Op.V_LSHLREV_B64, VOP3Op.V_LSHRREV_B64, VOP3Op.V_ASHRREV_I64)
# 16-bit source ops: use precomputed sets instead of string checks
has_16bit_type = op in _VOP3_16BIT_OPS or op in _VOP1_16BIT_OPS or op in _VOP2_16BIT_OPS
is_16bit_src = op_cls is VOP3Op and op in _VOP3_16BIT_OPS and op not in _CVT_32_64_SRC_OPS
# VOP2 16-bit ops use f16 inline constants for src0 (vsrc1 is always a VGPR, no inline constants)
is_vop2_16bit = op_cls is VOP2Op and op in _VOP2_16BIT_OPS
if is_shift_64:
s0 = mod_src(st.rsrc(src0, lane), 0) # shift amount is 32-bit
s1 = st.rsrc64(src1, lane) if src1 is not None else 0 # value to shift is 64-bit
s2 = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0
elif is_ldexp_64:
s0 = mod_src64(st.rsrc64(src0, lane), 0) # mantissa is 64-bit float
s1 = mod_src(st.rsrc(src1, lane), 1) if src1 is not None else 0 # exponent is 32-bit int
s2 = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0
elif is_64bit_op:
# 64-bit ops: apply neg/abs modifiers using f64 interpretation for float ops
s0 = mod_src64(st.rsrc64(src0, lane), 0)
s1 = mod_src64(st.rsrc64(src1, lane), 1) if src1 is not None else 0
s2 = mod_src64(st.rsrc64(src2, lane), 2) if src2 is not None else 0
elif is_16bit_src:
# For 16-bit source ops, opsel bits select which half to use
s0_raw = mod_src(st.rsrc(src0, lane), 0)
s1_raw = mod_src(st.rsrc(src1, lane), 1) if src1 is not None else 0
s2_raw = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0
# opsel[0] selects hi(1) or lo(0) for src0, opsel[1] for src1, opsel[2] for src2
s0 = ((s0_raw >> 16) & 0xffff) if (opsel & 1) else (s0_raw & 0xffff)
s1 = ((s1_raw >> 16) & 0xffff) if (opsel & 2) else (s1_raw & 0xffff)
s2 = ((s2_raw >> 16) & 0xffff) if (opsel & 4) else (s2_raw & 0xffff)
elif is_vop2_16bit:
# VOP2 16-bit ops: src0 can use f16 inline constants, vsrc1 is always a VGPR (no inline constants)
s0 = mod_src(st.rsrc_f16(src0, lane), 0)
s1 = mod_src(st.rsrc(src1, lane), 1) if src1 is not None else 0
s2 = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0
else:
s0 = mod_src(st.rsrc(src0, lane), 0)
s1 = mod_src(st.rsrc(src1, lane), 1) if src1 is not None else 0
s2 = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0
d0 = V[vdst] if not is_64bit_op else (V[vdst] | (V[vdst + 1] << 32))
# V_CNDMASK_B32: VOP3 encoding uses src2 as mask (not VCC); VOP2 uses VCC implicitly
# Pass the correct mask as vcc to the function so pseudocode VCC.u64[laneId] works correctly
vcc_for_fn = st.rsgpr64(src2) if op in (VOP3Op.V_CNDMASK_B32,) and inst_type is VOP3 and src2 is not None and src2 < 256 else st.vcc
# Execute compiled function - pass src0_idx and vdst_idx for lane instructions
# For VGPR access: src0 index is the VGPR number (src0 - 256 if VGPR, else src0 for SGPR)
src0_idx = (src0 - 256) if src0 is not None and src0 >= 256 else (src0 if src0 is not None else 0)
result = fn(s0, s1, s2, d0, st.scc, vcc_for_fn, lane, st.exec_mask, st.literal, st.vgpr, {}, src0_idx, vdst)
# Apply results
if 'vgpr_write' in result:
# Lane instruction wrote to VGPR: (lane, vgpr_idx, value)
wr_lane, wr_idx, wr_val = result['vgpr_write']
st.vgpr[wr_lane][wr_idx] = wr_val
if 'vcc_lane' in result:
# VOP2 carry instructions (V_ADD_CO_CI_U32, V_SUB_CO_CI_U32, V_SUBREV_CO_CI_U32) write carry to VCC implicitly
# VOPC and VOP3-encoded VOPC write to vdst (which is VCC_LO for VOPC, inst.sdst for VOP3)
vcc_dst = VCC_LO if op_cls is VOP2Op and op in (VOP2Op.V_ADD_CO_CI_U32, VOP2Op.V_SUB_CO_CI_U32, VOP2Op.V_SUBREV_CO_CI_U32) else vdst
st.pend_sgpr_lane(vcc_dst, lane, result['vcc_lane'])
if 'exec_lane' in result:
# V_CMPX instructions write to EXEC per-lane
st.pend_sgpr_lane(EXEC_LO, lane, result['exec_lane'])
if 'd0' in result and op_cls not in (VOPCOp,) and 'vgpr_write' not in result:
# V_READFIRSTLANE_B32 and V_READLANE_B32 write to SGPR, not VGPR
# V_WRITELANE_B32 uses vgpr_write for cross-lane writes, don't overwrite with d0
writes_to_sgpr = op in (VOP1Op.V_READFIRSTLANE_B32,) or \
(op_cls is VOP3Op and op in (VOP3Op.V_READFIRSTLANE_B32, VOP3Op.V_READLANE_B32))
# Check for 16-bit destination ops (opsel[3] controls hi/lo write)
is_16bit_dst = op in _VOP3_16BIT_DST_OPS or op in _VOP1_16BIT_DST_OPS
if writes_to_sgpr:
st.wsgpr(vdst, result['d0'] & 0xffffffff)
elif result.get('d0_64') or is_64bit_op:
V[vdst] = result['d0'] & 0xffffffff
V[vdst + 1] = (result['d0'] >> 32) & 0xffffffff
elif is_16bit_dst and inst_type is VOP3:
# VOP3 16-bit ops: opsel[3] (bit 3 of opsel field) controls hi/lo destination
if opsel & 8: # opsel[3] = 1: write to high 16 bits
V[vdst] = (V[vdst] & 0x0000ffff) | ((result['d0'] & 0xffff) << 16)
else: # opsel[3] = 0: write to low 16 bits
V[vdst] = (V[vdst] & 0xffff0000) | (result['d0'] & 0xffff)
else:
V[vdst] = result['d0'] & 0xffffffff
# ═══════════════════════════════════════════════════════════════════════════════
# WMMA (Wave Matrix Multiply-Accumulate)
# ═══════════════════════════════════════════════════════════════════════════════
def exec_wmma(st: WaveState, inst, op: VOP3POp) -> None:
"""Execute WMMA instruction - 16x16x16 matrix multiply across the wave."""
src0, src1, src2, vdst = inst.src0, inst.src1, inst.src2, inst.vdst
# Read matrix A (16x16 f16/bf16) from lanes 0-15, VGPRs src0 to src0+7 (2 f16 per VGPR = 16 values per lane)
# Layout: A[row][k] where row = lane (0-15), k comes from 8 VGPRs × 2 halves
mat_a = []
for lane in range(16):
for reg in range(8):
val = st.vgpr[lane][src0 - 256 + reg] if src0 >= 256 else st.rsgpr(src0 + reg)
mat_a.append(_f16(val & 0xffff))
mat_a.append(_f16((val >> 16) & 0xffff))
# Read matrix B (16x16 f16/bf16) - same layout, B[col][k] where col comes from lane
mat_b = []
for lane in range(16):
for reg in range(8):
val = st.vgpr[lane][src1 - 256 + reg] if src1 >= 256 else st.rsgpr(src1 + reg)
mat_b.append(_f16(val & 0xffff))
mat_b.append(_f16((val >> 16) & 0xffff))
# Read matrix C (16x16 f32) from lanes 0-31, VGPRs src2 to src2+7
# Layout: element i is at lane (i % 32), VGPR (i // 32) + src2
mat_c = []
for i in range(256):
lane, reg = i % 32, i // 32
val = st.vgpr[lane][src2 - 256 + reg] if src2 >= 256 else st.rsgpr(src2 + reg)
mat_c.append(_f32(val))
# Compute D = A × B + C (16x16 matrix multiply)
mat_d = [0.0] * 256
for row in range(16):
for col in range(16):
acc = 0.0
for k in range(16):
a_val = mat_a[row * 16 + k]
b_val = mat_b[col * 16 + k]
acc += a_val * b_val
mat_d[row * 16 + col] = acc + mat_c[row * 16 + col]
# Write result matrix D back - same layout as C
if op == VOP3POp.V_WMMA_F16_16X16X16_F16:
# Output is f16, pack 2 values per VGPR
for i in range(0, 256, 2):
lane, reg = (i // 2) % 32, (i // 2) // 32
lo = _i16(mat_d[i]) & 0xffff
hi = _i16(mat_d[i + 1]) & 0xffff
st.vgpr[lane][vdst + reg] = (hi << 16) | lo
else:
# Output is f32
for i in range(256):
lane, reg = i % 32, i // 32
st.vgpr[lane][vdst + reg] = _i32(mat_d[i])
# ═══════════════════════════════════════════════════════════════════════════════
# MAIN EXECUTION LOOP
# ═══════════════════════════════════════════════════════════════════════════════
SCALAR_TYPES = {SOP1, SOP2, SOPC, SOPK, SOPP, SMEM}
VECTOR_TYPES = {VOP1, VOP2, VOP3, VOP3SD, VOPC, FLAT, DS, VOPD, VOP3P}
def step_wave(program: Program, st: WaveState, lds: bytearray, n_lanes: int) -> int:
inst = program.get(st.pc)
if inst is None: return 1
inst_words, st.literal, inst_type = inst._words, getattr(inst, '_literal', None) or 0, type(inst)
if inst_type in SCALAR_TYPES:
delta = exec_scalar(st, inst)
if delta == -1: return -1 # endpgm
if delta == -2: st.pc += inst_words; return -2 # barrier
st.pc += inst_words + delta
else:
# V_READFIRSTLANE_B32 and V_READLANE_B32 write to SGPR, so they should only execute once per wave (lane 0)
is_readlane = (inst_type is VOP1 and inst.op == VOP1Op.V_READFIRSTLANE_B32) or \
(inst_type is VOP3 and inst.op in (VOP3Op.V_READFIRSTLANE_B32, VOP3Op.V_READLANE_B32))
if is_readlane:
exec_vector(st, inst, 0, lds) # Execute once with lane 0
else:
exec_mask = st.exec_mask
for lane in range(n_lanes):
if exec_mask & (1 << lane): exec_vector(st, inst, lane, lds)
st.commit_pends()
st.pc += inst_words
return 0
def exec_wave(program: Program, st: WaveState, lds: bytearray, n_lanes: int) -> int:
while st.pc in program:
result = step_wave(program, st, lds, n_lanes)
if result == -1: return 0
if result == -2: return -2
return 0
def exec_workgroup(program: Program, workgroup_id: tuple[int, int, int], local_size: tuple[int, int, int], args_ptr: int,
wg_id_sgpr_base: int, wg_id_enables: tuple[bool, bool, bool]) -> None:
lx, ly, lz = local_size
total_threads, lds = lx * ly * lz, bytearray(65536)
waves: list[tuple[WaveState, int, int]] = []
for wave_start in range(0, total_threads, WAVE_SIZE):
n_lanes, st = min(WAVE_SIZE, total_threads - wave_start), WaveState()
st.exec_mask = (1 << n_lanes) - 1
st.wsgpr64(0, args_ptr)
gx, gy, gz = workgroup_id
# Set workgroup IDs in SGPRs based on USER_SGPR_COUNT and enable flags from COMPUTE_PGM_RSRC2
sgpr_idx = wg_id_sgpr_base
if wg_id_enables[0]: st.sgpr[sgpr_idx] = gx; sgpr_idx += 1
if wg_id_enables[1]: st.sgpr[sgpr_idx] = gy; sgpr_idx += 1
if wg_id_enables[2]: st.sgpr[sgpr_idx] = gz
for i in range(n_lanes):
tid = wave_start + i
st.vgpr[i][0] = tid if local_size == (lx, 1, 1) else ((tid // (lx * ly)) << 20) | (((tid // lx) % ly) << 10) | (tid % lx)
waves.append((st, n_lanes, wave_start))
has_barrier = any(isinstance(inst, SOPP) and inst.op == SOPPOp.S_BARRIER for inst in program.values())
for _ in range(2 if has_barrier else 1):
for st, n_lanes, _ in waves: exec_wave(program, st, lds, n_lanes)
def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int, lz: int, args_ptr: int, rsrc2: int = 0x19c) -> int:
data = (ctypes.c_char * lib_sz).from_address(lib).raw
program = decode_program(data)
if not program: return -1
# Parse COMPUTE_PGM_RSRC2 for SGPR layout
user_sgpr_count = (rsrc2 >> 1) & 0x1f
enable_wg_id_x = bool((rsrc2 >> 7) & 1)
enable_wg_id_y = bool((rsrc2 >> 8) & 1)
enable_wg_id_z = bool((rsrc2 >> 9) & 1)
wg_id_enables = (enable_wg_id_x, enable_wg_id_y, enable_wg_id_z)
for gidz in range(gz):
for gidy in range(gy):
for gidx in range(gx): exec_workgroup(program, (gidx, gidy, gidz), (lx, ly, lz), args_ptr, user_sgpr_count, wg_id_enables)
return 0