assembly/amd: cleanups to asm and emu (#13912)

* a bunch of cleanups

* ops are back

* bug fixes

* cleanups

* a lil simpler

* more refactors

* _disasm_vop1

* sops

* more

* continue

* more

* num_srcs

* simpler

* no _is16

* op cleanups

* isinstnace
This commit is contained in:
George Hotz
2025-12-31 12:46:11 -05:00
committed by GitHub
parent ba9aa5cd6f
commit 29402034a1
4 changed files with 479 additions and 313 deletions

View File

@@ -1,15 +1,12 @@
# RDNA3 assembler and disassembler
from __future__ import annotations
import re
from extra.assembly.amd.dsl import Inst, RawImm, Reg, SrcMod, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory, SRC_FIELDS, unwrap
from extra.assembly.amd.dsl import Inst, RawImm, Reg, SrcMod, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory
from extra.assembly.amd.dsl import VCC_LO, VCC_HI, VCC, EXEC_LO, EXEC_HI, EXEC, SCC, M0, NULL, OFF
from extra.assembly.amd.dsl import SPECIAL_GPRS, SPECIAL_PAIRS, FLOAT_DEC, FLOAT_ENC, decode_src
from extra.assembly.amd.autogen.rdna3 import ins
from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT, MUBUF, MTBUF, MIMG, EXP,
VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, VOPDOp, VINTERPOp, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, DSOp, FLATOp, MUBUFOp, MTBUFOp, MIMGOp)
# VOP3SD opcodes that share VOP3 encoding
VOP3SD_OPS = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770}
VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOPDOp, SOP1Op, SOPKOp, SOPPOp, SMEMOp, DSOp, MUBUFOp)
def _matches_encoding(word: int, cls: type[Inst]) -> bool:
"""Check if word matches the encoding pattern of an instruction class."""
@@ -29,7 +26,7 @@ def detect_format(data: bytes) -> type[Inst]:
if (word >> 30) == 0b11:
for cls in _FORMATS_64:
if _matches_encoding(word, cls):
return VOP3SD if cls is VOP3 and ((word >> 16) & 0x3ff) in VOP3SD_OPS else cls
return VOP3SD if cls is VOP3 and ((word >> 16) & 0x3ff) in Inst._VOP3SD_OPS else cls
raise ValueError(f"unknown 64-bit format word={word:#010x}")
# 32-bit formats
for cls in _FORMATS_32:
@@ -79,8 +76,6 @@ def waitcnt(vmcnt: int = 0x3f, expcnt: int = 0x7, lgkmcnt: int = 0x3f) -> int:
return (expcnt & 0x7) | ((lgkmcnt & 0x3f) << 4) | ((vmcnt & 0x3f) << 10)
def _has(op: str, *subs) -> bool: return any(s in op for s in subs)
def _is16(op: str) -> bool: return _has(op, 'f16', 'i16', 'u16', 'b16') and not _has(op, '_f32', '_i32')
def _is64(op: str) -> bool: return _has(op, 'f64', 'i64', 'u64', 'b64')
def _omod(v: int) -> str: return {1: " mul:2", 2: " mul:4", 3: " div:2"}.get(v, "")
def _src16(inst, v: int) -> str: return _fmt_v16(v) if v >= 256 else inst.lit(v) # format 16-bit src: vgpr.h/l or literal
def _mods(*pairs) -> str: return " ".join(m for c, m in pairs if c)
@@ -105,50 +100,43 @@ def _opsel_str(opsel: int, n: int, need: bool, is16_d: bool) -> str:
# DISASSEMBLER
# ═══════════════════════════════════════════════════════════════════════════════
_VOP1_F64 = {VOP1Op.V_CEIL_F64, VOP1Op.V_FLOOR_F64, VOP1Op.V_FRACT_F64, VOP1Op.V_FREXP_MANT_F64, VOP1Op.V_RCP_F64, VOP1Op.V_RNDNE_F64, VOP1Op.V_RSQ_F64, VOP1Op.V_SQRT_F64, VOP1Op.V_TRUNC_F64}
def _disasm_vop1(inst: VOP1) -> str:
op, name = VOP1Op(inst.op), VOP1Op(inst.op).name.lower()
if op in (VOP1Op.V_NOP, VOP1Op.V_PIPEFLUSH): return name
if op == VOP1Op.V_READFIRSTLANE_B32: return f"v_readfirstlane_b32 {decode_src(inst.vdst)}, v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}"
parts, is_f64_d = name.split('_'), op in _VOP1_F64 or op in (VOP1Op.V_CVT_F64_F32, VOP1Op.V_CVT_F64_I32, VOP1Op.V_CVT_F64_U32)
is_f64_s = op in _VOP1_F64 or op in (VOP1Op.V_CVT_F32_F64, VOP1Op.V_CVT_I32_F64, VOP1Op.V_CVT_U32_F64, VOP1Op.V_FREXP_EXP_I32_F64)
name = inst.op_name.lower()
if inst.op in (VOP1Op.V_NOP, VOP1Op.V_PIPEFLUSH): return name
if inst.op == VOP1Op.V_READFIRSTLANE_B32: return f"v_readfirstlane_b32 {decode_src(inst.vdst)}, v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}"
# 16-bit dst: uses .h/.l suffix (determined by name pattern, not dtype - e.g. sat_pk_u8_i16 outputs 8-bit but uses 16-bit encoding)
parts = name.split('_')
is_16d = any(p in ('f16','i16','u16','b16') for p in parts[-2:-1]) or (len(parts) >= 2 and parts[-1] in ('f16','i16','u16','b16') and 'cvt' not in name)
is_16s = parts[-1] in ('f16','i16','u16','b16') and 'sat_pk' not in name
dst = _vreg(inst.vdst, 2) if is_f64_d else _fmt_v16(inst.vdst, 0, 128) if is_16d else f"v{inst.vdst}"
src = _fmt_src(inst.src0, 2) if is_f64_s else _src16(inst, inst.src0) if is_16s else inst.lit(inst.src0)
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else _fmt_v16(inst.vdst, 0, 128) if is_16d else f"v{inst.vdst}"
src = _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_src_16(0) and 'sat_pk' not in name else inst.lit(inst.src0)
return f"{name}_e32 {dst}, {src}"
def _disasm_vop2(inst: VOP2) -> str:
op, name = VOP2Op(inst.op), VOP2Op(inst.op).name.lower()
suf, is16 = "" if op == VOP2Op.V_DOT2ACC_F32_F16 else "_e32", _is16(name) and 'pk_' not in name
name = inst.op_name.lower()
suf = "" if inst.op == VOP2Op.V_DOT2ACC_F32_F16 else "_e32"
# fmaak: dst = src0 * vsrc1 + K, fmamk: dst = src0 * K + vsrc1
if op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, 0x{inst._literal:x}"
if op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, 0x{inst._literal:x}, v{inst.vsrc1}"
if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}"
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}" + (", vcc_lo" if op == VOP2Op.V_CNDMASK_B32 else "")
VOPC_CLASS = {VOPCOp.V_CMP_CLASS_F16, VOPCOp.V_CMP_CLASS_F32, VOPCOp.V_CMP_CLASS_F64,
VOPCOp.V_CMPX_CLASS_F16, VOPCOp.V_CMPX_CLASS_F32, VOPCOp.V_CMPX_CLASS_F64}
if inst.op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, 0x{inst._literal:x}"
if inst.op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, 0x{inst._literal:x}, v{inst.vsrc1}"
if inst.is_16bit(): return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}"
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}" + (", vcc_lo" if inst.op == VOP2Op.V_CNDMASK_B32 else "")
def _disasm_vopc(inst: VOPC) -> str:
op, name = VOPCOp(inst.op), VOPCOp(inst.op).name.lower()
is64, is16 = _is64(name), _is16(name)
s0 = _fmt_src(inst.src0, 2) if is64 else _src16(inst, inst.src0) if is16 else inst.lit(inst.src0)
s1 = _vreg(inst.vsrc1, 2) if is64 and op not in VOPC_CLASS else _fmt_v16(inst.vsrc1, 0, 128) if is16 else f"v{inst.vsrc1}"
return f"{name}_e32 {s0}, {s1}" if op.value >= 128 else f"{name}_e32 vcc_lo, {s0}, {s1}"
name = inst.op_name.lower()
s0 = _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_16bit() else inst.lit(inst.src0)
s1 = _vreg(inst.vsrc1, inst.src_regs(1)) if inst.src_regs(1) > 1 else _fmt_v16(inst.vsrc1, 0, 128) if inst.is_16bit() else f"v{inst.vsrc1}"
return f"{name}_e32 {s0}, {s1}" if inst.op.value >= 128 else f"{name}_e32 vcc_lo, {s0}, {s1}"
NO_ARG_SOPP = {SOPPOp.S_ENDPGM, SOPPOp.S_BARRIER, SOPPOp.S_WAKEUP, SOPPOp.S_ICACHE_INV,
SOPPOp.S_WAIT_IDLE, SOPPOp.S_ENDPGM_SAVED, SOPPOp.S_CODE_END, SOPPOp.S_ENDPGM_ORDERED_PS_DONE}
def _disasm_sopp(inst: SOPP) -> str:
op, name = SOPPOp(inst.op), SOPPOp(inst.op).name.lower()
if op in NO_ARG_SOPP: return name
if op == SOPPOp.S_WAITCNT:
name = inst.op_name.lower()
if inst.op in NO_ARG_SOPP: return name
if inst.op == SOPPOp.S_WAITCNT:
vm, exp, lgkm = (inst.simm16 >> 10) & 0x3f, inst.simm16 & 0xf, (inst.simm16 >> 4) & 0x3f
p = [f"vmcnt({vm})" if vm != 0x3f else "", f"expcnt({exp})" if exp != 7 else "", f"lgkmcnt({lgkm})" if lgkm != 0x3f else ""]
return f"s_waitcnt {' '.join(x for x in p if x) or '0'}"
if op == SOPPOp.S_DELAY_ALU:
if inst.op == SOPPOp.S_DELAY_ALU:
deps, skips = ['VALU_DEP_1','VALU_DEP_2','VALU_DEP_3','VALU_DEP_4','TRANS32_DEP_1','TRANS32_DEP_2','TRANS32_DEP_3','FMA_ACCUM_CYCLE_1','SALU_CYCLE_1','SALU_CYCLE_2','SALU_CYCLE_3'], ['SAME','NEXT','SKIP_1','SKIP_2','SKIP_3','SKIP_4']
id0, skip, id1 = inst.simm16 & 0xf, (inst.simm16 >> 4) & 0x7, (inst.simm16 >> 7) & 0xf
dep = lambda v: deps[v-1] if 0 < v <= len(deps) else str(v)
@@ -157,25 +145,20 @@ def _disasm_sopp(inst: SOPP) -> str:
return f"{name} {inst.simm16}" if name.startswith(('s_cbranch', 's_branch')) else f"{name} 0x{inst.simm16:x}"
def _disasm_smem(inst: SMEM) -> str:
op = SMEMOp(inst.op)
name = op.name.lower()
if op in (SMEMOp.S_GL1_INV, SMEMOp.S_DCACHE_INV): return name
name = inst.op_name.lower()
if inst.op in (SMEMOp.S_GL1_INV, SMEMOp.S_DCACHE_INV): return name
off_s = f"{decode_src(inst.soffset)} offset:0x{inst.offset:x}" if inst.offset and inst.soffset != 124 else f"0x{inst.offset:x}" if inst.offset else decode_src(inst.soffset)
sbase_idx, sbase_count = inst.sbase * 2, 4 if (8 <= inst.op <= 12 or name == 's_atc_probe_buffer') else 2
sbase_idx, sbase_count = inst.sbase * 2, 4 if (8 <= inst.op.value <= 12 or name == 's_atc_probe_buffer') else 2
sbase_str = _fmt_src(sbase_idx, sbase_count) if sbase_count == 2 else _sreg(sbase_idx, sbase_count) if sbase_idx <= 105 else _reg("ttmp", sbase_idx - 108, sbase_count)
if name in ('s_atc_probe', 's_atc_probe_buffer'): return f"{name} {inst.sdata}, {sbase_str}, {off_s}"
width = {0:1, 1:2, 2:4, 3:8, 4:16, 8:1, 9:2, 10:4, 11:8, 12:16}.get(inst.op, 1)
return f"{name} {_fmt_sdst(inst.sdata, width)}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (inst.dlc, " dlc"))
return f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs())}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (inst.dlc, " dlc"))
def _disasm_flat(inst: FLAT) -> str:
name = FLATOp(inst.op).name.lower()
name = inst.op_name.lower()
seg = ['flat', 'scratch', 'global'][inst.seg] if inst.seg < 3 else 'flat'
instr = f"{seg}_{name.split('_', 1)[1] if '_' in name else name}"
off_val = inst.offset if seg == 'flat' else (inst.offset if inst.offset < 4096 else inst.offset - 8192)
suffix = name.split('_')[-1]
w = {'b32':1,'b64':2,'b96':3,'b128':4,'u8':1,'i8':1,'u16':1,'i16':1,'u32':1,'i32':1,'u64':2,'i64':2,'f32':1,'f64':2}.get(suffix, 1)
if 'cmpswap' in name: w *= 2
if name.endswith('_x2') or 'x2' in suffix: w = max(w, 2)
w = inst.dst_regs() * (2 if 'cmpswap' in name else 1)
mods = f"{f' offset:{off_val}' if off_val else ''}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}"
# saddr
if seg == 'flat' or inst.saddr == 0x7F: saddr_s = ""
@@ -195,11 +178,11 @@ def _disasm_flat(inst: FLAT) -> str:
return f"{instr} {_vreg(inst.vdst, w)}, {addr_s}{saddr_s}{mods}"
def _disasm_ds(inst: DS) -> str:
op, name = DSOp(inst.op), DSOp(inst.op).name.lower()
op, name = inst.op, inst.op_name.lower()
gds = " gds" if inst.gds else ""
off = f" offset:{inst.offset0 | (inst.offset1 << 8)}" if inst.offset0 or inst.offset1 else ""
off2 = f" offset0:{inst.offset0} offset1:{inst.offset1}" if inst.offset0 or inst.offset1 else ""
w = 4 if '128' in name else 3 if '96' in name else 2 if (name.endswith('64') or 'gs_reg' in name) else 1
w = inst.dst_regs()
d0, d1, dst, addr = _vreg(inst.data0, w), _vreg(inst.data1, w), _vreg(inst.vdst, w), f"v{inst.addr}"
if op == DSOp.DS_NOP: return name
@@ -223,49 +206,34 @@ def _disasm_ds(inst: DS) -> str:
return f"{name} {dst}, {addr}, {d0}{off}{gds}" if '_rtn' in name else f"{name} {addr}, {d0}{off}{gds}"
def _disasm_vop3(inst: VOP3) -> str:
op = VOP3SDOp(inst.op) if inst.op in VOP3SD_OPS else VOP3Op(inst.op)
name = op.name.lower()
op, name = inst.op, inst.op_name.lower()
# VOP3SD (shared encoding)
if inst.op in VOP3SD_OPS:
if isinstance(op, VOP3SDOp):
sdst = (inst.clmp << 7) | (inst.opsel << 3) | inst.abs
is64, mad64 = 'f64' in name, _has(name, 'mad_i64_i32', 'mad_u64_u32')
def src(v, neg, ext=False): s = _fmt_src(v, 2) if ext or is64 else inst.lit(v); return f"-{s}" if neg else s
s0, s1, s2 = src(inst.src0, inst.neg & 1), src(inst.src1, inst.neg & 2), src(inst.src2, inst.neg & 4, mad64)
dst = _vreg(inst.vdst, 2) if is64 or mad64 else f"v{inst.vdst}"
if op in (VOP3SDOp.V_ADD_CO_U32, VOP3SDOp.V_SUB_CO_U32, VOP3SDOp.V_SUBREV_CO_U32): return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {s0}, {s1}"
if op in (VOP3SDOp.V_ADD_CO_CI_U32, VOP3SDOp.V_SUB_CO_CI_U32, VOP3SDOp.V_SUBREV_CO_CI_U32): return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {s0}, {s1}, {s2}"
return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {s0}, {s1}, {s2}" + _omod(inst.omod)
def src(v, neg, n): s = _fmt_src(v, n) if n > 1 else inst.lit(v); return f"-{s}" if neg else s
s0, s1, s2 = src(inst.src0, inst.neg & 1, inst.src_regs(0)), src(inst.src1, inst.neg & 2, inst.src_regs(1)), src(inst.src2, inst.neg & 4, inst.src_regs(2))
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}"
srcs = f"{s0}, {s1}, {s2}" if inst.num_srcs() == 3 else f"{s0}, {s1}"
return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {srcs}" + _omod(inst.omod)
# Detect operand sizes
is64 = _is64(name)
is64_src, is64_dst = False, False
# Detect 16-bit operand sizes (for .h/.l suffix handling)
is16_d = is16_s = is16_s2 = False
if 'cvt_pk' in name: is16_s = name.endswith('16')
elif m := re.match(r'v_(?:cvt|frexp_exp)_([a-z0-9_]+)_([a-z0-9]+)', name):
is16_d, is16_s = _has(m.group(1), 'f16','i16','u16','b16'), _has(m.group(2), 'f16','i16','u16','b16')
is64_src, is64_dst = '64' in m.group(2), '64' in m.group(1)
is16_s2, is64 = is16_s, False
is16_s2 = is16_s
elif re.match(r'v_mad_[iu]32_[iu]16', name): is16_s = True
elif 'pack_b32' in name: is16_s = is16_s2 = True
else: is16_d = is16_s = is16_s2 = _is16(name) and not _has(name, 'dot2', 'pk_', 'sad', 'msad', 'qsad', 'mqsad')
# Source counts
shift64 = 'rev' in name and '64' in name and name.startswith('v_')
ldexp64 = op == VOP3Op.V_LDEXP_F64
trig = op == VOP3Op.V_TRIG_PREOP_F64
sad64, mqsad = _has(name, 'qsad_pk', 'mqsad_pk'), 'mqsad_u32' in name
s0n = 2 if ((is64 and not shift64) or sad64 or mqsad or is64_src) else 1
s1n = 2 if (is64 and not _has(name, 'class') and not ldexp64 and not trig) else 1
s2n = 4 if mqsad else 2 if (is64 or sad64) else 1
else: is16_d = is16_s = is16_s2 = inst.is_16bit()
any_hi = inst.opsel != 0
s0 = _vop3_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.opsel&1, s0n, is16_s, any_hi)
s1 = _vop3_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.opsel&2, s1n, is16_s, any_hi)
s2 = _vop3_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.opsel&4, s2n, is16_s2, any_hi)
s0 = _vop3_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.opsel&1, inst.src_regs(0), is16_s, any_hi)
s1 = _vop3_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.opsel&2, inst.src_regs(1), is16_s, any_hi)
s2 = _vop3_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.opsel&4, inst.src_regs(2), is16_s2, any_hi)
# Destination
dn = 4 if mqsad else 2 if (is64 or sad64 or is64_dst) else 1
dn = inst.dst_regs()
if op == VOP3Op.V_READLANE_B32: dst = _fmt_sdst(inst.vdst, 1)
elif dn > 1: dst = _vreg(inst.vdst, dn)
elif is16_d: dst = f"v{inst.vdst}.h" if (inst.opsel & 8) else f"v{inst.vdst}.l" if any_hi else f"v{inst.vdst}"
@@ -278,24 +246,24 @@ def _disasm_vop3(inst: VOP3) -> str:
if inst.op < 256: # VOPC
return f"{name}_e64 {s0}, {s1}" if name.startswith('v_cmpx') else f"{name}_e64 {_fmt_sdst(inst.vdst, 1)}, {s0}, {s1}"
if inst.op < 384: # VOP2
os = _opsel_str(inst.opsel, 3, need_opsel, is16_d) if 'cndmask' in name else _opsel_str(inst.opsel, 2, need_opsel, is16_d)
return f"{name}_e64 {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if 'cndmask' in name else f"{name}_e64 {dst}, {s0}, {s1}{os}{cl}{om}"
n = inst.num_srcs()
os = _opsel_str(inst.opsel, n, need_opsel, is16_d)
return f"{name}_e64 {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if n == 3 else f"{name}_e64 {dst}, {s0}, {s1}{os}{cl}{om}"
if inst.op < 512: # VOP1
return f"{name}_e64" if op in (VOP3Op.V_NOP, VOP3Op.V_PIPEFLUSH) else f"{name}_e64 {dst}, {s0}{_opsel_str(inst.opsel, 1, need_opsel, is16_d)}{cl}{om}"
# Native VOP3
is3 = _has(name, 'fma', 'mad', 'min3', 'max3', 'med3', 'div_fix', 'div_fmas', 'sad', 'lerp', 'align', 'cube', 'bfe', 'bfi',
'perm_b32', 'permlane', 'cndmask', 'xor3', 'or3', 'add3', 'lshl_or', 'and_or', 'lshl_add', 'add_lshl', 'xad', 'maxmin', 'minmax', 'dot2', 'cvt_pk_u8', 'mullit')
os = _opsel_str(inst.opsel, 3 if is3 else 2, need_opsel, is16_d)
return f"{name} {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if is3 else f"{name} {dst}, {s0}, {s1}{os}{cl}{om}"
n = inst.num_srcs()
os = _opsel_str(inst.opsel, n, need_opsel, is16_d)
return f"{name} {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if n == 3 else f"{name} {dst}, {s0}, {s1}{os}{cl}{om}"
def _disasm_vop3sd(inst: VOP3SD) -> str:
op, name = VOP3SDOp(inst.op), VOP3SDOp(inst.op).name.lower()
is64, mad64 = 'f64' in name, _has(name, 'mad_i64_i32', 'mad_u64_u32')
def src(v, neg, ext=False): s = _fmt_src(v, 2) if ext or is64 else inst.lit(v); return f"-{s}" if neg else s
s0, s1, s2 = src(inst.src0, inst.neg & 1), src(inst.src1, inst.neg & 2), src(inst.src2, inst.neg & 4, mad64)
dst, is2src = _vreg(inst.vdst, 2) if is64 or mad64 else f"v{inst.vdst}", op in (VOP3SDOp.V_ADD_CO_U32, VOP3SDOp.V_SUB_CO_U32, VOP3SDOp.V_SUBREV_CO_U32)
name = inst.op_name.lower()
def src(v, neg, n): s = _fmt_src(v, n) if n > 1 else inst.lit(v); return f"-{s}" if neg else s
s0, s1, s2 = src(inst.src0, inst.neg & 1, inst.src_regs(0)), src(inst.src1, inst.neg & 2, inst.src_regs(1)), src(inst.src2, inst.neg & 4, inst.src_regs(2))
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}"
srcs = f"{s0}, {s1}, {s2}" if inst.num_srcs() == 3 else f"{s0}, {s1}"
suffix = "_e64" if name.startswith('v_') and 'co_' in name else ""
return f"{name}{suffix} {dst}, {_fmt_sdst(inst.sdst, 1)}, {s0}, {s1}{'' if is2src else f', {s2}'}{' clamp' if inst.clmp else ''}{_omod(inst.omod)}"
return f"{name}{suffix} {dst}, {_fmt_sdst(inst.sdst, 1)}, {srcs}{' clamp' if inst.clmp else ''}{_omod(inst.omod)}"
def _disasm_vopd(inst: VOPD) -> str:
lit = inst._literal or inst.literal
@@ -304,26 +272,25 @@ def _disasm_vopd(inst: VOPD) -> str:
return f"{half(nx, inst.vdstx, inst.srcx0, inst.vsrcx1)} :: {half(ny, vdst_y, inst.srcy0, inst.vsrcy1)}"
def _disasm_vop3p(inst: VOP3P) -> str:
name = VOP3POp(inst.op).name.lower()
is_wmma, is_3src, is_fma_mix = 'wmma' in name, _has(name, 'fma', 'mad', 'dot', 'wmma'), 'fma_mix' in name
name = inst.op_name.lower()
is_wmma, n, is_fma_mix = 'wmma' in name, inst.num_srcs(), 'fma_mix' in name
if is_wmma:
sc = 2 if 'iu4' in name else 4 if 'iu8' in name else 8
src0, src1, src2, dst = _fmt_src(inst.src0, sc), _fmt_src(inst.src1, sc), _fmt_src(inst.src2, 8), _vreg(inst.vdst, 8)
else: src0, src1, src2, dst = _fmt_src(inst.src0, 1), _fmt_src(inst.src1, 1), _fmt_src(inst.src2, 1), f"v{inst.vdst}"
n, opsel_hi = 3 if is_3src else 2, inst.opsel_hi | (inst.opsel_hi2 << 2)
opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2)
if is_fma_mix:
def m(s, neg, abs_): return f"-{f'|{s}|' if abs_ else s}" if neg else (f"|{s}|" if abs_ else s)
src0, src1, src2 = m(src0, inst.neg & 1, inst.neg_hi & 1), m(src1, inst.neg & 2, inst.neg_hi & 2), m(src2, inst.neg & 4, inst.neg_hi & 4)
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi else []) + (["clamp"] if inst.clmp else [])
else:
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != (7 if is_3src else 3) else []) + \
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != (7 if n == 3 else 3) else []) + \
([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if inst.clmp else [])
return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if is_3src else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}"
return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if n == 3 else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}"
def _disasm_buf(inst: MUBUF | MTBUF) -> str:
op = MTBUFOp(inst.op) if isinstance(inst, MTBUF) else MUBUFOp(inst.op)
name = op.name.lower()
if op in (MUBUFOp.BUFFER_GL0_INV, MUBUFOp.BUFFER_GL1_INV): return name
name = inst.op_name.lower()
if inst.op in (MUBUFOp.BUFFER_GL0_INV, MUBUFOp.BUFFER_GL1_INV): return name
w = (2 if _has(name, 'xyz', 'xyzw') else 1) if 'd16' in name else \
((2 if _has(name, 'b64', 'u64', 'i64') else 1) * (2 if 'cmpswap' in name else 1)) if 'atomic' in name else \
{'b32':1,'b64':2,'b96':3,'b128':4,'b16':1,'x':1,'xy':2,'xyz':3,'xyzw':4}.get(name.split('_')[-1], 1)
@@ -351,7 +318,7 @@ def _mimg_vaddr_width(name: str, dim: int, a16: bool) -> int:
return (base + packed + 1) // 2 + unpacked if a16 else base + packed + unpacked
def _disasm_mimg(inst: MIMG) -> str:
name = MIMGOp(inst.op).name.lower()
name = inst.op_name.lower()
srsrc_base = inst.srsrc * 4
srsrc_str = _sreg_or_ttmp(srsrc_base, 8)
# BVH intersect ray: special case with 4 SGPR srsrc
@@ -379,66 +346,38 @@ def _disasm_mimg(inst: MIMG) -> str:
ssamp_str = ", " + _sreg_or_ttmp(inst.ssamp * 4, 4)
return f"{name} {_vreg(inst.vdata, vdata)}, {vaddr_str}, {srsrc_str}{ssamp_str} {' '.join(mods)}"
def _sop_widths(name: str) -> tuple[int, int, int]:
"""Return (dst_width, src0_width, src1_width) in register count for SOP instructions."""
if name in ('s_bitset0_b64', 's_bitset1_b64', 's_bfm_b64'): return 2, 1, 1
if name in ('s_lshl_b64', 's_lshr_b64', 's_ashr_i64', 's_bfe_u64', 's_bfe_i64'): return 2, 2, 1
if name in ('s_bitcmp0_b64', 's_bitcmp1_b64'): return 1, 2, 1
if m := re.search(r'_(b|i|u)(32|64)_(b|i|u)(32|64)$', name): return 2 if m.group(2) == '64' else 1, 2 if m.group(4) == '64' else 1, 1
if m := re.search(r'_(b|i|u)(32|64)$', name): sz = 2 if m.group(2) == '64' else 1; return sz, sz, sz
return 1, 1, 1
def _disasm_sop1(inst: SOP1) -> str:
op, name = SOP1Op(inst.op), SOP1Op(inst.op).name.lower()
op, name = inst.op, inst.op_name.lower()
if op == SOP1Op.S_GETPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}"
if op in (SOP1Op.S_SETPC_B64, SOP1Op.S_RFE_B64): return f"{name} {_fmt_src(inst.ssrc0, 2)}"
if op == SOP1Op.S_SWAPPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}, {_fmt_src(inst.ssrc0, 2)}"
if op in (SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64): return f"{name} {_fmt_sdst(inst.sdst, 2 if 'b64' in name else 1)}, sendmsg({MSG.get(inst.ssrc0, str(inst.ssrc0))})"
dn, s0n, _ = _sop_widths(name)
return f"{name} {_fmt_sdst(inst.sdst, dn)}, {inst.lit(inst.ssrc0) if s0n == 1 else _fmt_src(inst.ssrc0, s0n)}"
if op in (SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64): return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, sendmsg({MSG.get(inst.ssrc0, str(inst.ssrc0))})"
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {inst.lit(inst.ssrc0) if inst.src_regs(0) == 1 else _fmt_src(inst.ssrc0, inst.src_regs(0))}"
def _disasm_sop2(inst: SOP2) -> str:
name = SOP2Op(inst.op).name.lower()
dn, s0n, s1n = _sop_widths(name)
return f"{name} {_fmt_sdst(inst.sdst, dn)}, {inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, s0n)}, {inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, s1n)}"
return f"{inst.op_name.lower()} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0))}, {inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1))}"
def _disasm_sopc(inst: SOPC) -> str:
name = SOPCOp(inst.op).name.lower()
_, s0n, s1n = _sop_widths(name)
return f"{name} {_fmt_src(inst.ssrc0, s0n)}, {_fmt_src(inst.ssrc1, s1n)}"
return f"{inst.op_name.lower()} {_fmt_src(inst.ssrc0, inst.src_regs(0))}, {_fmt_src(inst.ssrc1, inst.src_regs(1))}"
def _disasm_sopk(inst: SOPK) -> str:
op, name = SOPKOp(inst.op), SOPKOp(inst.op).name.lower()
op, name = inst.op, inst.op_name.lower()
if op == SOPKOp.S_VERSION: return f"{name} 0x{inst.simm16:x}"
if op in (SOPKOp.S_SETREG_B32, SOPKOp.S_GETREG_B32):
hid, hoff, hsz = inst.simm16 & 0x3f, (inst.simm16 >> 6) & 0x1f, ((inst.simm16 >> 11) & 0x1f) + 1
hs = f"0x{inst.simm16:x}" if hid in (16, 17) else f"hwreg({HWREG.get(hid, str(hid))}, {hoff}, {hsz})"
return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1)}" if op == SOPKOp.S_SETREG_B32 else f"{name} {_fmt_sdst(inst.sdst, 1)}, {hs}"
dn, _, _ = _sop_widths(name)
return f"{name} {_fmt_sdst(inst.sdst, dn)}, 0x{inst.simm16:x}"
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, 0x{inst.simm16:x}"
def _disasm_vinterp(inst: VINTERP) -> str:
name = VINTERPOp(inst.op).name.lower()
src0 = f"-{inst.lit(inst.src0)}" if inst.neg & 1 else inst.lit(inst.src0)
src1 = f"-{inst.lit(inst.src1)}" if inst.neg & 2 else inst.lit(inst.src1)
src2 = f"-{inst.lit(inst.src2)}" if inst.neg & 4 else inst.lit(inst.src2)
mods = _mods((inst.waitexp, f"wait_exp:{inst.waitexp}"), (inst.clmp, "clamp"))
return f"{name} v{inst.vdst}, {src0}, {src1}, {src2}" + (" " + mods if mods else "")
def _disasm_generic(inst: Inst) -> str:
name = f"op_{inst.op}"
def format_field(field_name, val):
val = unwrap(val)
if field_name in SRC_FIELDS: return inst.lit(val) if val != 255 else "0xff"
return f"{'s' if field_name == 'sdst' else 'v'}{val}" if field_name in ('sdst', 'vdst') else f"v{val}" if field_name == 'vsrc1' else f"0x{val:x}" if field_name == 'simm16' else str(val)
operands = [format_field(field_name, inst._values.get(field_name, 0)) for field_name in inst._fields if field_name not in ('encoding', 'op')]
return f"{name} {', '.join(operands)}" if operands else name
return f"{inst.op_name.lower()} v{inst.vdst}, {inst.lit(inst.src0, inst.neg & 1)}, {inst.lit(inst.src1, inst.neg & 2)}, {inst.lit(inst.src2, inst.neg & 4)}" + (" " + mods if mods else "")
DISASM_HANDLERS = {VOP1: _disasm_vop1, VOP2: _disasm_vop2, VOPC: _disasm_vopc, VOP3: _disasm_vop3, VOP3SD: _disasm_vop3sd, VOPD: _disasm_vopd, VOP3P: _disasm_vop3p,
VINTERP: _disasm_vinterp, SOPP: _disasm_sopp, SMEM: _disasm_smem, DS: _disasm_ds, FLAT: _disasm_flat, MUBUF: _disasm_buf, MTBUF: _disasm_buf,
MIMG: _disasm_mimg, SOP1: _disasm_sop1, SOP2: _disasm_sop2, SOPC: _disasm_sopc, SOPK: _disasm_sopk}
def disasm(inst: Inst) -> str: return DISASM_HANDLERS.get(type(inst), _disasm_generic)(inst)
def disasm(inst: Inst) -> str: return DISASM_HANDLERS[type(inst)](inst)
# ═══════════════════════════════════════════════════════════════════════════════
# ASSEMBLER

View File

@@ -4,6 +4,8 @@ from __future__ import annotations
import struct, math
from enum import IntEnum
from typing import overload, Annotated, TypeVar, Generic
from extra.assembly.amd.autogen.rdna3.enum import (VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, VOPDOp, SOP1Op, SOP2Op,
SOPCOp, SOPKOp, SOPPOp, SMEMOp, DSOp, FLATOp, MUBUFOp, MTBUFOp, MIMGOp, VINTERPOp)
# Common masks and bit conversion functions
MASK32, MASK64 = 0xffffffff, 0xffffffffffffffff
@@ -28,6 +30,79 @@ def _i64(f):
try: return struct.unpack("<Q", struct.pack("<d", f))[0]
except (OverflowError, struct.error): return 0x7ff0000000000000 if f > 0 else 0xfff0000000000000
# Instruction spec - register counts and dtypes derived from instruction names
import re
_REGS = {'B32': 1, 'B64': 2, 'B96': 3, 'B128': 4, 'B256': 8, 'B512': 16,
'F32': 1, 'I32': 1, 'U32': 1, 'F64': 2, 'I64': 2, 'U64': 2,
'F16': 1, 'I16': 1, 'U16': 1, 'B16': 1, 'I8': 1, 'U8': 1, 'B8': 1}
def _suffix(name: str) -> tuple[str | None, str | None]:
name = name.upper()
if m := re.search(r'CVT_([FIUB]\d+)_([FIUB]\d+)$', name): return m.group(1), m.group(2)
if m := re.search(r'(?:MAD|MUL)_([IU]\d+)_([IU]\d+)$', name): return m.group(1), m.group(2)
if m := re.search(r'PACK_([FIUB]\d+)_([FIUB]\d+)$', name): return m.group(1), m.group(2)
# Generic dst_src pattern: S_BCNT0_I32_B64, S_BITREPLICATE_B64_B32, V_FREXP_EXP_I32_F64, etc.
if m := re.search(r'_([FIUB]\d+)_([FIUB]\d+)$', name): return m.group(1), m.group(2)
if m := re.search(r'_([FIUB](?:32|64|16|8|96|128|256|512))$', name): return m.group(1), m.group(1)
return None, None
_SPECIAL_REGS = {
'V_LSHLREV_B64': (2, 1, 2, 1), 'V_LSHRREV_B64': (2, 1, 2, 1), 'V_ASHRREV_I64': (2, 1, 2, 1),
'S_LSHL_B64': (2, 2, 1, 1), 'S_LSHR_B64': (2, 2, 1, 1), 'S_ASHR_I64': (2, 2, 1, 1),
'S_BFE_U64': (2, 2, 1, 1), 'S_BFE_I64': (2, 2, 1, 1), 'S_BFM_B64': (2, 1, 1, 1),
'S_BITSET0_B64': (2, 1, 1, 1), 'S_BITSET1_B64': (2, 1, 1, 1),
'S_BITCMP0_B64': (1, 2, 1, 1), 'S_BITCMP1_B64': (1, 2, 1, 1),
'V_LDEXP_F64': (2, 2, 1, 1), 'V_TRIG_PREOP_F64': (2, 2, 1, 1),
'V_CMP_CLASS_F64': (1, 2, 1, 1), 'V_CMPX_CLASS_F64': (1, 2, 1, 1),
'V_CMP_CLASS_F32': (1, 1, 1, 1), 'V_CMPX_CLASS_F32': (1, 1, 1, 1),
'V_CMP_CLASS_F16': (1, 1, 1, 1), 'V_CMPX_CLASS_F16': (1, 1, 1, 1),
'V_MAD_U64_U32': (2, 1, 1, 2), 'V_MAD_I64_I32': (2, 1, 1, 2),
'V_QSAD_PK_U16_U8': (2, 2, 1, 2), 'V_MQSAD_PK_U16_U8': (2, 2, 1, 2), 'V_MQSAD_U32_U8': (4, 2, 1, 4),
}
_SPECIAL_DTYPE = {
'V_LSHLREV_B64': ('B64', 'U32', 'B64', None), 'V_LSHRREV_B64': ('B64', 'U32', 'B64', None), 'V_ASHRREV_I64': ('I64', 'U32', 'I64', None),
'S_LSHL_B64': ('B64', 'B64', 'U32', None), 'S_LSHR_B64': ('B64', 'B64', 'U32', None), 'S_ASHR_I64': ('I64', 'I64', 'U32', None),
'S_BFE_U64': ('U64', 'U64', 'U32', None), 'S_BFE_I64': ('I64', 'I64', 'U32', None),
'S_BFM_B64': ('B64', 'U32', 'U32', None), 'S_BITSET0_B64': ('B64', 'U32', None, None), 'S_BITSET1_B64': ('B64', 'U32', None, None),
'S_BITCMP0_B64': ('SCC', 'B64', 'U32', None), 'S_BITCMP1_B64': ('SCC', 'B64', 'U32', None),
'V_LDEXP_F64': ('F64', 'F64', 'I32', None), 'V_TRIG_PREOP_F64': ('F64', 'F64', 'U32', None),
'V_CMP_CLASS_F64': ('VCC', 'F64', 'U32', None), 'V_CMPX_CLASS_F64': ('EXEC', 'F64', 'U32', None),
'V_CMP_CLASS_F32': ('VCC', 'F32', 'U32', None), 'V_CMPX_CLASS_F32': ('EXEC', 'F32', 'U32', None),
'V_CMP_CLASS_F16': ('VCC', 'F16', 'U32', None), 'V_CMPX_CLASS_F16': ('EXEC', 'F16', 'U32', None),
'V_MAD_U64_U32': ('U64', 'U32', 'U32', 'U64'), 'V_MAD_I64_I32': ('I64', 'I32', 'I32', 'I64'),
'V_QSAD_PK_U16_U8': ('B64', 'B64', 'B64', 'B64'), 'V_MQSAD_PK_U16_U8': ('B64', 'B64', 'B64', 'B64'),
'V_MQSAD_U32_U8': ('B128', 'B64', 'B64', 'B128'),
}
def spec_regs(name: str) -> tuple[int, int, int, int]:
name = name.upper()
if name in _SPECIAL_REGS: return _SPECIAL_REGS[name]
if 'SAD' in name and 'U8' in name and 'QSAD' not in name and 'MQSAD' not in name: return 1, 1, 1, 1
dst_suf, src_suf = _suffix(name)
return _REGS.get(dst_suf, 1), _REGS.get(src_suf, 1), _REGS.get(src_suf, 1), _REGS.get(src_suf, 1)
def spec_dtype(name: str) -> tuple[str | None, str | None, str | None, str | None]:
name = name.upper()
if name in _SPECIAL_DTYPE: return _SPECIAL_DTYPE[name]
if 'SAD' in name and ('U8' in name or 'U16' in name) and 'QSAD' not in name and 'MQSAD' not in name: return 'U32', 'U32', 'U32', 'U32'
if '_CMP_' in name or '_CMPX_' in name:
dst_suf, src_suf = _suffix(name)
return 'EXEC' if '_CMPX_' in name else 'VCC', src_suf, src_suf, None
dst_suf, src_suf = _suffix(name)
return dst_suf, src_suf, src_suf, src_suf
def spec_is_16bit(name: str) -> bool:
name = name.upper()
if 'SAD' in name or 'PACK' in name or '_PK_' in name or 'SAT_PK' in name or 'DOT2' in name: return False
if '_F32' in name or '_I32' in name or '_U32' in name or '_B32' in name: return False # mixed ops like V_DOT2ACC_F32_F16
return bool(re.search(r'_[FIUB]16(?:_|$)', name))
def spec_is_64bit(name: str) -> bool: return bool(re.search(r'_[FIUB]64(?:_|$)', name.upper()))
_3SRC = {'FMA', 'MAD', 'MIN3', 'MAX3', 'MED3', 'DIV_FIX', 'DIV_FMAS', 'DIV_SCALE', 'SAD', 'LERP', 'ALIGN', 'CUBE', 'BFE', 'BFI',
'PERM_B32', 'PERMLANE', 'CNDMASK', 'XOR3', 'OR3', 'ADD3', 'LSHL_OR', 'AND_OR', 'LSHL_ADD', 'ADD_LSHL', 'XAD', 'MAXMIN',
'MINMAX', 'DOT2', 'DOT4', 'DOT8', 'WMMA', 'CVT_PK_U8', 'MULLIT', 'CO_CI'}
_2SRC = {'FMAC'} # FMAC uses dst as implicit accumulator, so only 2 explicit sources
def spec_num_srcs(name: str) -> int:
name = name.upper()
if any(k in name for k in _2SRC): return 2
return 3 if any(k in name for k in _3SRC) else 2
def is_dtype_16(dt: str | None) -> bool: return dt is not None and '16' in dt
def is_dtype_64(dt: str | None) -> bool: return dt is not None and '64' in dt
# Bit field DSL
class BitField:
def __init__(self, hi: int, lo: int, name: str | None = None): self.hi, self.lo, self.name, self._marker = hi, lo, name, None
@@ -54,6 +129,10 @@ class BitField:
val = unwrap(obj._values.get(self.name, 0))
# Convert to IntEnum if marker is an IntEnum subclass
if self.marker and isinstance(self.marker, type) and issubclass(self.marker, IntEnum):
# VOP3 with VOPC opcodes (0-255) -> VOPCOp, VOP3SD opcodes -> VOP3SDOp
if self.marker is VOP3Op:
if val < 256: return VOPCOp(val)
if val in Inst._VOP3SD_OPS: return VOP3SDOp(val)
try: return self.marker(val)
except ValueError: pass
return val
@@ -201,10 +280,13 @@ class Inst:
# Track literal value if needed
if encoded == 255 and self._literal is None:
import struct
is_64 = self._is_64bit_op()
# Check if THIS source uses 64-bit encoding (not just src0)
src_idx = {'src0': 0, 'src1': 1, 'src2': 2, 'ssrc0': 0, 'ssrc1': 1}.get(name, 0)
src_regs = self.src_regs(src_idx)
is_64 = src_regs == 2
if isinstance(val, SrcMod) and not isinstance(val, Reg): lit32 = val.val & MASK32
elif isinstance(val, int) and not isinstance(val, IntEnum): lit32 = val & MASK32
elif isinstance(val, float): lit32 = _i32(val)
elif isinstance(val, float): lit32 = (_i64(val) >> 32) if is_64 else _i32(val) # f64: high 32 bits of f64 repr
else: return
self._literal = (lit32 << 32) if is_64 else lit32
@@ -235,11 +317,21 @@ class Inst:
raise ValueError(f"SOP1 {orig_args['op'].name} expects {expected} register(s) for {fld}, got {orig_args[fld].count}")
def __init__(self, *args, literal: int | None = None, **kwargs):
self._values, self._literal = dict(self._defaults), literal
self._values, self._literal = dict(self._defaults), None
field_names = [n for n in self._fields if n != 'encoding']
orig_args = dict(zip(field_names, args)) | kwargs
self._values.update(orig_args)
self._validate(orig_args)
# Pre-shift literal for 64-bit sources (literal param is always raw 32-bit value from user)
if literal is not None:
# Find which source uses the literal (255) and check its register count
for n, idx in [('src0', 0), ('src1', 1), ('src2', 2), ('ssrc0', 0), ('ssrc1', 1)]:
v = orig_args.get(n)
if (isinstance(v, RawImm) and v.val == 255) or (isinstance(v, int) and v == 255):
self._literal = (literal << 32) if self.src_regs(idx) == 2 else literal
break
else:
self._literal = literal # fallback if no literal source found
cls_name = self.__class__.__name__
# Format-specific setup
@@ -297,11 +389,9 @@ class Inst:
op_name = op.name if hasattr(op, 'name') else None
# Look up op name from int if needed (happens in from_bytes path)
if op_name is None and self.__class__.__name__ == 'VOP3':
from extra.assembly.amd.autogen.rdna3.ins import VOP3Op
try: op_name = VOP3Op(op).name
except ValueError: pass
if op_name is None and self.__class__.__name__ == 'VOPC':
from extra.assembly.amd.autogen.rdna3.ins import VOPCOp
try: op_name = VOPCOp(op).name
except ValueError: pass
if op_name is None: return False
@@ -312,8 +402,16 @@ class Inst:
result = self.to_int().to_bytes(self._size(), 'little')
lit = self._get_literal() or getattr(self, '_literal', None)
if lit is None: return result
# For 64-bit ops, literal is stored in high 32 bits internally, but encoded as 4 bytes
lit32 = (lit >> 32) if self._is_64bit_op() else lit
# For 64-bit sources, literal is stored in high 32 bits internally, but encoded as 4 bytes
# Find which source uses the literal (255) and check its register count
lit_src_is_64 = False
for n, idx in [('src0', 0), ('src1', 1), ('src2', 2), ('ssrc0', 0), ('ssrc1', 1)]:
if n not in self._values: continue
v = self._values[n]
if (isinstance(v, RawImm) and v.val == 255) or (isinstance(v, int) and v == 255):
lit_src_is_64 = self.is_src_64(idx)
break
lit32 = (lit >> 32) if lit_src_is_64 else lit
return result + (lit32 & MASK32).to_bytes(4, 'little')
@classmethod
@@ -343,9 +441,16 @@ class Inst:
if has_literal:
# For 64-bit ops, the literal is 32 bits placed in the HIGH 32 bits of the 64-bit value
# (low 32 bits are zero). This is how AMD hardware interprets 32-bit literals for 64-bit ops.
# Check which source uses the literal and whether THAT source is 64-bit
if len(data) >= cls._size() + 4:
lit32 = int.from_bytes(data[cls._size():cls._size()+4], 'little')
inst._literal = (lit32 << 32) if inst._is_64bit_op() else lit32
# Find which source has literal (255) and check its register count
lit_src_is_64 = False
for n, idx in [('src0', 0), ('src1', 1), ('src2', 2)]:
if n in inst._values and isinstance(inst._values[n], RawImm) and inst._values[n].val == 255:
lit_src_is_64 = inst.src_regs(idx) == 2
break
inst._literal = (lit32 << 32) if lit_src_is_64 else lit32
return inst
def __repr__(self):
@@ -360,7 +465,9 @@ class Inst:
if name.startswith('_'): raise AttributeError(name)
return unwrap(self._values.get(name, 0))
def lit(self, v: int) -> str: return f"0x{self._literal:x}" if v == 255 and self._literal else decode_src(v)
def lit(self, v: int, neg: bool = False) -> str:
s = f"0x{self._literal:x}" if v == 255 and self._literal else decode_src(v)
return f"-{s}" if neg else s
def __eq__(self, other):
if not isinstance(other, Inst): return NotImplemented
@@ -372,5 +479,37 @@ class Inst:
from extra.assembly.amd.asm import disasm
return disasm(self)
_enum_map = {'VOP1': VOP1Op, 'VOP2': VOP2Op, 'VOP3': VOP3Op, 'VOP3SD': VOP3SDOp, 'VOP3P': VOP3POp, 'VOPC': VOPCOp,
'SOP1': SOP1Op, 'SOP2': SOP2Op, 'SOPC': SOPCOp, 'SOPK': SOPKOp, 'SOPP': SOPPOp,
'SMEM': SMEMOp, 'DS': DSOp, 'FLAT': FLATOp, 'MUBUF': MUBUFOp, 'MTBUF': MTBUFOp, 'MIMG': MIMGOp,
'VOPD': VOPDOp, 'VINTERP': VINTERPOp}
_VOP3SD_OPS = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770}
@property
def op(self):
"""Return the op as an enum (e.g., VOP1Op.V_MOV_B32). VOP3 returns VOPCOp/VOP3SDOp for those op ranges."""
val = self._values.get('op')
if val is None: return None
if hasattr(val, 'name'): return val # already an enum
cls_name = self.__class__.__name__
assert cls_name in self._enum_map, f"no enum map for {cls_name}"
return self._enum_map[cls_name](val)
@property
def op_name(self) -> str:
op = self.op
return op.name if hasattr(op, 'name') else ''
def dst_regs(self) -> int: return spec_regs(self.op_name)[0]
def src_regs(self, n: int) -> int: return spec_regs(self.op_name)[n + 1]
def num_srcs(self) -> int: return spec_num_srcs(self.op_name)
def dst_dtype(self) -> str | None: return spec_dtype(self.op_name)[0]
def src_dtype(self, n: int) -> str | None: return spec_dtype(self.op_name)[n + 1]
def is_src_16(self, n: int) -> bool: return self.src_regs(n) == 1 and is_dtype_16(self.src_dtype(n))
def is_src_64(self, n: int) -> bool: return self.src_regs(n) == 2
def is_16bit(self) -> bool: return spec_is_16bit(self.op_name)
def is_64bit(self) -> bool: return spec_is_64bit(self.op_name)
def is_dst_16(self) -> bool: return self.dst_regs() == 1 and is_dtype_16(self.dst_dtype())
class Inst32(Inst): pass
class Inst64(Inst): pass

View File

@@ -12,34 +12,6 @@ Program = dict[int, Inst]
WAVE_SIZE, SGPR_COUNT, VGPR_COUNT = 32, 128, 256
VCC_LO, VCC_HI, NULL, EXEC_LO, EXEC_HI, SCC = SrcEnum.VCC_LO, SrcEnum.VCC_HI, SrcEnum.NULL, SrcEnum.EXEC_LO, SrcEnum.EXEC_HI, SrcEnum.SCC
# Op classification helpers - build sets from op name patterns
def _ops_matching(enum, *patterns, exclude=()): return {op for op in enum if any(p in op.name for p in patterns) and not any(e in op.name for e in exclude)}
def _ops_ending(enum, *suffixes): return {op for op in enum if op.name.endswith(suffixes)}
# 64-bit ops (for literal handling)
_VOP3_64BIT_OPS = {op.value for op in _ops_ending(VOP3Op, '_F64', '_B64', '_I64', '_U64')}
_VOPC_64BIT_OPS = {op.value for op in _ops_ending(VOPCOp, '_F64', '_B64', '_I64', '_U64')}
_VOP3_64BIT_OPS_32BIT_SRC1 = {VOP3Op.V_LDEXP_F64.value} # src1 is 32-bit exponent
# 16-bit ops (SAD/MSAD excluded - they use 32-bit packed sources)
_VOP3_16BIT_OPS = _ops_matching(VOP3Op, '_F16', '_B16', '_I16', '_U16', exclude=('SAD',))
_VOP1_16BIT_OPS = _ops_matching(VOP1Op, '_F16', '_B16', '_I16', '_U16')
_VOP2_16BIT_OPS = _ops_matching(VOP2Op, '_F16', '_B16', '_I16', '_U16')
_VOPC_16BIT_OPS = _ops_matching(VOPCOp, '_F16', '_B16', '_I16', '_U16')
# CVT ops with 32/64-bit source (despite 16-bit in name) - must end with the type suffix
_CVT_32_64_SRC_OPS = {op for op in _ops_ending(VOP3Op, '_F32', '_I32', '_U32', '_F64', '_I64', '_U64') if op.name.startswith('V_CVT_')} | \
{op for op in _ops_ending(VOP1Op, '_F32', '_I32', '_U32', '_F64', '_I64', '_U64') if op.name.startswith('V_CVT_')}
# CVT ops with 32-bit destination (FROM 16-bit TO 32-bit) - match patterns like F32_F16 in name
_CVT_32_DST_OPS = _ops_matching(VOP3Op, 'F32_F16', 'I32_I16', 'U32_U16', 'I32_F16', 'U32_F16') | \
_ops_matching(VOP1Op, 'F32_F16', 'I32_I16', 'U32_U16', 'I32_F16', 'U32_F16')
# 16-bit dst ops (PACK has 32-bit dst, CVT to 32-bit has 32-bit dst)
_VOP3_16BIT_DST_OPS = {op for op in _VOP3_16BIT_OPS if 'PACK' not in op.name} - _CVT_32_DST_OPS
_VOP1_16BIT_DST_OPS = {op for op in _VOP1_16BIT_OPS if 'PACK' not in op.name} - _CVT_32_DST_OPS
_VOP1_16BIT_SRC_OPS = _VOP1_16BIT_OPS - _CVT_32_64_SRC_OPS
# Inline constants for src operands 128-254. Build tables for f32, f16, and f64 formats.
# Inline constants for src operands 128-254. Build tables for f32, f16, and f64 formats.
_FLOAT_CONSTS = {v: k for k, v in FLOAT_ENC.items()} | {248: 0.15915494309189535} # INV_2PI
def _build_inline_consts(mask, to_bits):
@@ -134,7 +106,7 @@ class WaveState:
def rsrc_f16(self, v: int, lane: int) -> int: return self._rsrc_base(v, lane, _INLINE_CONSTS_F16)
def rsrc64(self, v: int, lane: int) -> int:
if 128 <= v < 255: return _INLINE_CONSTS_F64[v - 128]
if v == 255: return self.literal
if v == 255: return self.literal # literal is already shifted in from_bytes for 64-bit ops
return self.rsrc(v, lane) | ((self.rsrc(v+1, lane) if v < VCC_LO or 256 <= v <= 511 else 0) << 32)
def pend_sgpr_lane(self, reg: int, lane: int, val: int):
@@ -155,20 +127,8 @@ def decode_program(data: bytes) -> Program:
base_size = inst_class._size()
# Pass enough data for potential 64-bit literal (base + 8 bytes max)
inst = inst_class.from_bytes(data[i:i+base_size+8])
for name, val in inst._values.items(): setattr(inst, name, unwrap(val))
# from_bytes already handles literal reading - only need fallback for cases it doesn't handle
if inst._literal is None:
has_literal = any(getattr(inst, fld, None) == 255 for fld in ('src0', 'src1', 'src2', 'ssrc0', 'ssrc1', 'srcx0', 'srcy0')) or \
(inst_class == VOP2 and inst.op in (44, 45, 55, 56)) or \
(inst_class == VOPD and (inst.opx in (1, 2) or inst.opy in (1, 2))) or \
(inst_class == SOP2 and inst.op in (69, 70))
if has_literal:
# For 64-bit ops, the 32-bit literal is placed in HIGH 32 bits (low 32 bits = 0)
op_val = getattr(inst._values.get('op'), 'value', inst._values.get('op'))
is_64bit = ((inst_class is VOP3 and op_val in _VOP3_64BIT_OPS) or (inst_class is VOPC and op_val in _VOPC_64BIT_OPS)) and \
not (op_val in _VOP3_64BIT_OPS_32BIT_SRC1 and getattr(inst, 'src1', None) == 255)
lit32 = int.from_bytes(data[i+base_size:i+base_size+4], 'little')
inst._literal = (lit32 << 32) if is_64bit else lit32
for name, val in inst._values.items():
if name != 'op': setattr(inst, name, unwrap(val)) # skip op to preserve property access
inst._words = inst.size() // 4
result[i // 4] = inst
i += inst._words * 4
@@ -181,16 +141,14 @@ def decode_program(data: bytes) -> Program:
def exec_scalar(st: WaveState, inst: Inst) -> int:
"""Execute scalar instruction. Returns PC delta or negative for special cases."""
compiled = _get_compiled()
inst_type = type(inst)
# SOPP: special cases for control flow that has no pseudocode
if inst_type is SOPP:
op = inst.op
if op == SOPPOp.S_ENDPGM: return -1
if op == SOPPOp.S_BARRIER: return -2
if isinstance(inst, SOPP):
if inst.op == SOPPOp.S_ENDPGM: return -1
if inst.op == SOPPOp.S_BARRIER: return -2
# SMEM: memory loads (not ALU)
if inst_type is SMEM:
if isinstance(inst, SMEM):
addr = st.rsgpr64(inst.sbase * 2) + _sext(inst.offset, 21)
if inst.soffset not in (NULL, 0x7f): addr += st.rsrc(inst.soffset, 0)
if (cnt := SMEM_LOAD.get(inst.op)) is None: raise NotImplementedError(f"SMEM op {inst.op}")
@@ -198,34 +156,30 @@ def exec_scalar(st: WaveState, inst: Inst) -> int:
return 0
# Get op enum and lookup compiled function
if inst_type is SOP1: op_cls, ssrc0, sdst = SOP1Op, inst.ssrc0, inst.sdst
elif inst_type is SOP2: op_cls, ssrc0, sdst = SOP2Op, inst.ssrc0, inst.sdst
elif inst_type is SOPC: op_cls, ssrc0, sdst = SOPCOp, inst.ssrc0, None
elif inst_type is SOPK: op_cls, ssrc0, sdst = SOPKOp, inst.sdst, inst.sdst # sdst is both src and dst
elif inst_type is SOPP: op_cls, ssrc0, sdst = SOPPOp, None, None
else: raise NotImplementedError(f"Unknown scalar type {inst_type}")
if isinstance(inst, SOP1): ssrc0, sdst = inst.ssrc0, inst.sdst
elif isinstance(inst, SOP2): ssrc0, sdst = inst.ssrc0, inst.sdst
elif isinstance(inst, SOPC): ssrc0, sdst = inst.ssrc0, None
elif isinstance(inst, SOPK): ssrc0, sdst = inst.sdst, inst.sdst # sdst is both src and dst
elif isinstance(inst, SOPP): ssrc0, sdst = None, None
else: raise NotImplementedError(f"Unknown scalar type {type(inst)}")
# SOPP has gaps in the opcode enum - treat unknown opcodes as no-ops
try: op = op_cls(inst.op)
try: op = inst.op
except ValueError:
if inst_type is SOPP: return 0
if isinstance(inst, SOPP): return 0
raise
fn = compiled.get(op_cls, {}).get(op)
fn = compiled.get(type(op), {}).get(op)
if fn is None:
# SOPP instructions without pseudocode (waits, hints, nops) are no-ops
if inst_type is SOPP: return 0
if isinstance(inst, SOPP): return 0
raise NotImplementedError(f"{op.name} not in pseudocode")
# Build context - handle 64-bit ops that need 64-bit source reads
# 64-bit source ops: name ends with _B64, _I64, _U64 or contains _U64, _I64 before last underscore
is_64bit_s0 = op.name.endswith(('_B64', '_I64', '_U64')) or '_U64_' in op.name or '_I64_' in op.name
is_64bit_s0s1 = op_cls is SOPCOp and op in (SOPCOp.S_CMP_EQ_U64, SOPCOp.S_CMP_LG_U64)
s0 = st.rsrc64(ssrc0, 0) if is_64bit_s0 or is_64bit_s0s1 else (st.rsrc(ssrc0, 0) if inst_type not in (SOPK, SOPP) else (st.rsgpr(inst.sdst) if inst_type is SOPK else 0))
is_64bit_sop2 = is_64bit_s0 and inst_type is SOP2
s1 = st.rsrc64(inst.ssrc1, 0) if (is_64bit_sop2 or is_64bit_s0s1) else (st.rsrc(inst.ssrc1, 0) if inst_type in (SOP2, SOPC) else inst.simm16 if inst_type is SOPK else 0)
d0 = st.rsgpr64(sdst) if (is_64bit_s0 or is_64bit_s0s1) and sdst is not None else (st.rsgpr(sdst) if sdst is not None else 0)
# Build context - use inst methods to determine operand sizes
s0 = st.rsrc64(ssrc0, 0) if inst.is_src_64(0) else (st.rsrc(ssrc0, 0) if not isinstance(inst, (SOPK, SOPP)) else (st.rsgpr(inst.sdst) if isinstance(inst, SOPK) else 0))
s1 = st.rsrc64(inst.ssrc1, 0) if inst.is_src_64(1) else (st.rsrc(inst.ssrc1, 0) if isinstance(inst, (SOP2, SOPC)) else inst.simm16 if isinstance(inst, SOPK) else 0)
d0 = st.rsgpr64(sdst) if inst.dst_regs() == 2 and sdst is not None else (st.rsgpr(sdst) if sdst is not None else 0)
exec_mask = st.exec_mask
literal = inst.simm16 if inst_type in (SOPK, SOPP) else st.literal
literal = inst.simm16 if isinstance(inst, (SOPK, SOPP)) else st.literal
# Execute compiled function - pass PC in bytes for instructions that need it
# For wave32, mask VCC and EXEC to 32 bits since only the lower 32 bits are relevant
@@ -248,10 +202,10 @@ def exec_scalar(st: WaveState, inst: Inst) -> int:
def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = None) -> None:
"""Execute vector instruction for one lane."""
compiled = _get_compiled()
inst_type, V = type(inst), st.vgpr[lane]
V = st.vgpr[lane]
# Memory ops (not ALU pseudocode)
if inst_type is FLAT:
if isinstance(inst, FLAT):
op, addr_reg, data_reg, vdst, offset, saddr = inst.op, inst.addr, inst.data, inst.vdst, _sext(inst.offset, 13), inst.saddr
addr = V[addr_reg] | (V[addr_reg+1] << 32)
addr = (st.rsgpr64(saddr) + V[addr_reg] + offset) & MASK64 if saddr not in (NULL, 0x7f) else (addr + offset) & MASK64
@@ -272,7 +226,7 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
else: raise NotImplementedError(f"FLAT op {op}")
return
if inst_type is DS:
if isinstance(inst, DS):
op, addr0, vdst = inst.op, (V[inst.addr] + inst.offset0) & 0xffff, inst.vdst
if op in DS_LOAD:
cnt, sz, sign = DS_LOAD[op]
@@ -302,7 +256,7 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
return
# VOPD: dual-issue, execute two ops simultaneously (read all inputs before writes)
if inst_type is VOPD:
if isinstance(inst, VOPD):
vdsty = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1)
inputs = [(inst.opx, st.rsrc(inst.srcx0, lane), V[inst.vsrcx1], V[inst.vdstx], inst.vdstx),
(inst.opy, st.rsrc(inst.srcy0, lane), V[inst.vsrcy1], V[vdsty], vdsty)]
@@ -312,18 +266,15 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
return
# VOP3SD: has extra scalar dest for carry output
if inst_type is VOP3SD:
op = VOP3SDOp(inst.op)
fn = compiled.get(VOP3SDOp, {}).get(op)
if fn is None: raise NotImplementedError(f"{op.name} not in pseudocode")
# Source sizes vary: DIV_SCALE=all 64-bit, MAD64=32/32/64, others=32-bit
r64 = op == VOP3SDOp.V_DIV_SCALE_F64
s0, s1 = (st.rsrc64 if r64 else st.rsrc)(inst.src0, lane), (st.rsrc64 if r64 else st.rsrc)(inst.src1, lane)
mad64 = op in (VOP3SDOp.V_MAD_U64_U32, VOP3SDOp.V_MAD_I64_I32)
s2 = st.rsrc64(inst.src2, lane) if r64 else ((V[inst.src2-256]|(V[inst.src2-255]<<32)) if inst.src2>=256 else st.rsgpr64(inst.src2)) if mad64 else st.rsrc(inst.src2, lane)
if isinstance(inst, VOP3SD):
fn = compiled.get(VOP3SDOp, {}).get(inst.op)
if fn is None: raise NotImplementedError(f"{inst.op.name} not in pseudocode")
# Read sources based on register counts from inst properties
def rsrc_n(src, regs): return st.rsrc64(src, lane) if regs == 2 else st.rsrc(src, lane)
s0, s1, s2 = rsrc_n(inst.src0, inst.src_regs(0)), rsrc_n(inst.src1, inst.src_regs(1)), rsrc_n(inst.src2, inst.src_regs(2))
# Carry-in ops use src2 as carry bitmask instead of VCC
carry_ops = (VOP3SDOp.V_ADD_CO_CI_U32, VOP3SDOp.V_SUB_CO_CI_U32, VOP3SDOp.V_SUBREV_CO_CI_U32)
result = fn(s0, s1, s2, V[inst.vdst], st.scc, st.rsgpr64(inst.src2) if op in carry_ops else st.vcc, lane, st.exec_mask, st.literal, None, {})
vcc = st.rsgpr64(inst.src2) if 'CO_CI' in inst.op_name else st.vcc
result = fn(s0, s1, s2, V[inst.vdst], st.scc, vcc, lane, st.exec_mask, st.literal, None, {})
V[inst.vdst] = result['d0'] & MASK32
if result.get('d0_64'): V[inst.vdst + 1] = (result['d0'] >> 32) & MASK32
if result.get('vcc_lane') is not None: st.pend_sgpr_lane(inst.sdst, lane, result['vcc_lane'])
@@ -332,35 +283,31 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
# Get op enum and sources (None means "no source" for that operand)
# dst_hi: for VOP1/VOP2 16-bit dst ops, bit 7 of vdst indicates .h (high 16-bit) destination
dst_hi = False
if inst_type is VOP1:
if isinstance(inst, VOP1):
if inst.op == VOP1Op.V_NOP: return
op_cls, op, src0, src1, src2 = VOP1Op, VOP1Op(inst.op), inst.src0, None, None
dst_hi, vdst = (inst.vdst & 0x80) != 0 and op in _VOP1_16BIT_DST_OPS, inst.vdst & 0x7f if op in _VOP1_16BIT_DST_OPS else inst.vdst
elif inst_type is VOP2:
op_cls, op, src0, src1, src2 = VOP2Op, VOP2Op(inst.op), inst.src0, inst.vsrc1 + 256, None
dst_hi, vdst = (inst.vdst & 0x80) != 0 and op in _VOP2_16BIT_OPS, inst.vdst & 0x7f if op in _VOP2_16BIT_OPS else inst.vdst
elif inst_type is VOP3:
# VOP3 ops 0-255 are VOPC comparisons encoded as VOP3 (use VOPCOp pseudocode)
if inst.op < 256:
op_cls, op, src0, src1, src2, vdst = VOPCOp, VOPCOp(inst.op), inst.src0, inst.src1, None, inst.vdst
else:
op_cls, op, src0, src1, src2, vdst = VOP3Op, VOP3Op(inst.op), inst.src0, inst.src1, inst.src2, inst.vdst
elif inst_type is VOPC:
op = VOPCOp(inst.op)
src0, src1, src2 = inst.src0, None, None
dst_hi = (inst.vdst & 0x80) != 0 and inst.is_dst_16()
vdst = inst.vdst & 0x7f if inst.is_dst_16() else inst.vdst
elif isinstance(inst, VOP2):
src0, src1, src2 = inst.src0, inst.vsrc1 + 256, None
dst_hi = (inst.vdst & 0x80) != 0 and inst.is_dst_16()
vdst = inst.vdst & 0x7f if inst.is_dst_16() else inst.vdst
elif isinstance(inst, VOP3):
# VOP3 ops 0-255 are VOPC comparisons encoded as VOP3 - inst.op returns VOPCOp for these
src0, src1, src2, vdst = inst.src0, inst.src1, (None if inst.op.value < 256 else inst.src2), inst.vdst
elif isinstance(inst, VOPC):
# For 16-bit VOPC, vsrc1 uses same encoding as VOP2 16-bit: bit 7 selects hi(1) or lo(0) half
# vsrc1 field is 8 bits: [6:0] = VGPR index, [7] = hi flag
src1 = inst.vsrc1 + 256 # convert to standard VGPR encoding (256 + vgpr_idx)
op_cls, src0, src2, vdst = VOPCOp, inst.src0, None, VCC_LO
elif inst_type is VOP3P:
src0, src1, src2, vdst = inst.src0, inst.vsrc1 + 256, None, VCC_LO
elif isinstance(inst, VOP3P):
# VOP3P: Packed 16-bit operations using compiled functions
op = VOP3POp(inst.op)
# WMMA: wave-level matrix multiply-accumulate (special handling - needs cross-lane access)
if op in (VOP3POp.V_WMMA_F32_16X16X16_F16, VOP3POp.V_WMMA_F32_16X16X16_BF16, VOP3POp.V_WMMA_F16_16X16X16_F16):
if 'WMMA' in inst.op_name:
if lane == 0: # Only execute once per wave, write results for all lanes
exec_wmma(st, inst, op)
exec_wmma(st, inst, inst.op)
return
# V_FMA_MIX: Mixed precision FMA - opsel_hi controls f32(0) vs f16(1), opsel selects which f16 half
if op in (VOP3POp.V_FMA_MIX_F32, VOP3POp.V_FMA_MIXLO_F16, VOP3POp.V_FMA_MIXHI_F16):
if 'FMA_MIX' in inst.op_name:
opsel, opsel_hi, opsel_hi2 = getattr(inst, 'opsel', 0), getattr(inst, 'opsel_hi', 0), getattr(inst, 'opsel_hi2', 0)
neg, abs_ = getattr(inst, 'neg', 0), getattr(inst, 'neg_hi', 0) # neg_hi reused as abs
raws = [st.rsrc(inst.src0, lane), st.rsrc(inst.src1, lane), st.rsrc(inst.src2, lane) if inst.src2 is not None else 0]
@@ -371,7 +318,7 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
if neg & (1<<i): srcs[i] = -srcs[i]
result = srcs[0] * srcs[1] + srcs[2]
V = st.vgpr[lane]
V[inst.vdst] = _i32(result) if op == VOP3POp.V_FMA_MIX_F32 else _dst16(V[inst.vdst], _i16(result), op == VOP3POp.V_FMA_MIXHI_F16)
V[inst.vdst] = _i32(result) if inst.op == VOP3POp.V_FMA_MIX_F32 else _dst16(V[inst.vdst], _i16(result), inst.op == VOP3POp.V_FMA_MIXHI_F16)
return
# VOP3P packed ops: opsel selects halves for lo, opsel_hi for hi; neg toggles f16 sign
raws = [st.rsrc_f16(inst.src0, lane), st.rsrc_f16(inst.src1, lane), st.rsrc_f16(inst.src2, lane) if inst.src2 is not None else 0]
@@ -380,80 +327,52 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
hi_sels = [opsel_hi & 1, opsel_hi & 2, opsel_hi2]
srcs = [((_src16(raws[i], hi_sels[i]) ^ (0x8000 if neg_hi & (1<<i) else 0)) << 16) |
(_src16(raws[i], opsel & (1<<i)) ^ (0x8000 if neg & (1<<i) else 0)) for i in range(3)]
fn = compiled.get(VOP3POp, {}).get(op)
if fn is None: raise NotImplementedError(f"{op.name} not in pseudocode")
fn = compiled.get(VOP3POp, {}).get(inst.op)
if fn is None: raise NotImplementedError(f"{inst.op.name} not in pseudocode")
st.vgpr[lane][inst.vdst] = fn(srcs[0], srcs[1], srcs[2], 0, st.scc, st.vcc, lane, st.exec_mask, st.literal, None, {})['d0'] & MASK32
return
else: raise NotImplementedError(f"Unknown vector type {inst_type}")
else: raise NotImplementedError(f"Unknown vector type {type(inst)}")
fn = compiled.get(op_cls, {}).get(op)
if fn is None: raise NotImplementedError(f"{op.name} not in pseudocode")
op_cls = type(inst.op)
fn = compiled.get(op_cls, {}).get(inst.op)
if fn is None: raise NotImplementedError(f"{inst.op_name} not in pseudocode")
# Read sources (with VOP3 modifiers if applicable)
neg, abs_ = (getattr(inst, 'neg', 0), getattr(inst, 'abs', 0)) if inst_type is VOP3 else (0, 0)
opsel = getattr(inst, 'opsel', 0) if inst_type is VOP3 else 0
neg, abs_ = (getattr(inst, 'neg', 0), getattr(inst, 'abs', 0)) if isinstance(inst, VOP3) else (0, 0)
opsel = getattr(inst, 'opsel', 0) if isinstance(inst, VOP3) else 0
def mod_src(val: int, idx: int, is64=False) -> int:
to_f, to_i = (_f64, _i64) if is64 else (_f32, _i32)
if (abs_ >> idx) & 1: val = to_i(abs(to_f(val)))
if (neg >> idx) & 1: val = to_i(-to_f(val))
return val
# Determine if sources are 64-bit based on instruction type
# For 64-bit shift ops: src0 is 32-bit (shift amount), src1 is 64-bit (value to shift)
# For most other _B64/_I64/_U64/_F64 ops: all sources are 64-bit
is_64bit_op = op.name.endswith(('_B64', '_I64', '_U64', '_F64'))
# V_LDEXP_F64, V_TRIG_PREOP_F64, V_CMP_CLASS_F64, V_CMPX_CLASS_F64: src0 is 64-bit, src1 is 32-bit
is_ldexp_64 = op in (VOP3Op.V_LDEXP_F64, VOP3Op.V_TRIG_PREOP_F64, VOP3Op.V_CMP_CLASS_F64, VOP3Op.V_CMPX_CLASS_F64,
VOPCOp.V_CMP_CLASS_F64, VOPCOp.V_CMPX_CLASS_F64)
is_shift_64 = op in (VOP3Op.V_LSHLREV_B64, VOP3Op.V_LSHRREV_B64, VOP3Op.V_ASHRREV_I64)
# 16-bit source ops: use precomputed sets instead of string checks
# Note: must check op_cls to avoid cross-enum value collisions
# VOP3-encoded VOPC 16-bit ops also use opsel (not VGPR bit 7 like non-VOP3 VOPC)
is_16bit_src = (op_cls is VOP3Op and op in _VOP3_16BIT_OPS and op not in _CVT_32_64_SRC_OPS) or \
(inst_type is VOP3 and op_cls is VOPCOp and op in _VOPC_16BIT_OPS)
# VOP2 16-bit ops use f16 inline constants for src0 (vsrc1 is always a VGPR, no inline constants)
is_vop2_16bit = op_cls is VOP2Op and op in _VOP2_16BIT_OPS
# Use inst methods to determine operand sizes (inst.is_src_16, inst.is_src_64, etc.)
is_vop2_16bit = isinstance(inst, VOP2) and inst.is_16bit()
if is_shift_64:
s0, s1 = mod_src(st.rsrc(src0, lane), 0), st.rsrc64(src1, lane) if src1 else 0
s2 = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0
elif is_ldexp_64:
s0 = mod_src(st.rsrc64(src0, lane), 0, is64=True)
s1_raw = st.rsrc(src1, lane) if src1 is not None else 0
is_class_op = op in (VOP3Op.V_CMP_CLASS_F64, VOP3Op.V_CMPX_CLASS_F64, VOPCOp.V_CMP_CLASS_F64, VOPCOp.V_CMPX_CLASS_F64)
s1, s2 = mod_src((s1_raw >> 32) if src1 == 255 and is_class_op else s1_raw, 1), mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0
elif is_64bit_op:
s0, s1 = mod_src(st.rsrc64(src0, lane), 0, is64=True), mod_src(st.rsrc64(src1, lane), 1, is64=True) if src1 is not None else 0
s2 = mod_src(st.rsrc64(src2, lane), 2, is64=True) if src2 is not None else 0
elif is_16bit_src:
# VOP3 16-bit ops: opsel bits select which half, abs/neg as f16 bit ops
def rsrc_16bit(src, idx):
if src is None: return 0
# Read sources based on register counts and dtypes from inst properties
def read_src(src, idx, regs, is_src_16):
if src is None: return 0
if regs == 2: return mod_src(st.rsrc64(src, lane), idx, is64=True)
if is_src_16 and isinstance(inst, VOP3):
raw = st.rsrc_f16(src, lane) if 128 <= src < 255 else st.rsrc(src, lane)
val = _src16(raw, bool(opsel & (1 << idx)))
if abs_ & (1 << idx): val &= 0x7fff
if neg & (1 << idx): val ^= 0x8000
return val
s0, s1, s2 = rsrc_16bit(src0, 0), rsrc_16bit(src1, 1), rsrc_16bit(src2, 2)
elif is_vop2_16bit or (op_cls is VOP1Op and op in _VOP1_16BIT_SRC_OPS) or (op_cls is VOPCOp and op in _VOPC_16BIT_OPS):
# VOP1/VOP2/VOPC 16-bit ops: VGPRs use bit 7 for hi/lo, non-VGPRs use f16 inline consts
# Special case: VOPC V_CMP_CLASS uses full 32-bit mask for src1 when non-VGPR
def rsrc16_vgpr(src, idx, full32=False):
if src is None: return 0
if is_src_16 and isinstance(inst, (VOP1, VOP2, VOPC)):
if src >= 256: return _src16(mod_src(st.rsrc(_vgpr_masked(src), lane), idx), _vgpr_hi(src))
return mod_src(st.rsrc(src, lane) if full32 else st.rsrc_f16(src, lane), idx) & (0xffffffff if full32 else 0xffff)
s0, s1 = rsrc16_vgpr(src0, 0), rsrc16_vgpr(src1, 1, full32=op_cls is VOPCOp)
s2 = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0
else:
s0 = mod_src(st.rsrc(src0, lane), 0)
s1 = mod_src(st.rsrc(src1, lane), 1) if src1 is not None else 0
s2 = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0
return mod_src(st.rsrc_f16(src, lane), idx) & 0xffff
return mod_src(st.rsrc(src, lane), idx)
s0 = read_src(src0, 0, inst.src_regs(0), inst.is_src_16(0))
s1 = read_src(src1, 1, inst.src_regs(1), inst.is_src_16(1)) if src1 is not None else 0
s2 = read_src(src2, 2, inst.src_regs(2), inst.is_src_16(2)) if src2 is not None else 0
# Read destination (accumulator for VOP2 f16, 64-bit for 64-bit ops)
d0 = _src16(V[vdst], dst_hi) if is_vop2_16bit else (V[vdst] | (V[vdst + 1] << 32)) if is_64bit_op else V[vdst]
d0 = _src16(V[vdst], dst_hi) if is_vop2_16bit else (V[vdst] | (V[vdst + 1] << 32)) if inst.dst_regs() == 2 else V[vdst]
# V_CNDMASK_B32/B16: VOP3 encoding uses src2 as mask (not VCC); VOP2 uses VCC implicitly
# Pass the correct mask as vcc to the function so pseudocode VCC.u64[laneId] works correctly
vcc_for_fn = st.rsgpr64(src2) if op in (VOP3Op.V_CNDMASK_B32, VOP3Op.V_CNDMASK_B16) and inst_type is VOP3 and src2 is not None and src2 < 256 else st.vcc
vcc_for_fn = st.rsgpr64(src2) if inst.op in (VOP3Op.V_CNDMASK_B32, VOP3Op.V_CNDMASK_B16) and isinstance(inst, VOP3) and src2 is not None and src2 < 256 else st.vcc
# Execute compiled function - pass src0_idx and vdst_idx for lane instructions
# For VGPR access: src0 index is the VGPR number (src0 - 256 if VGPR, else src0 for SGPR)
@@ -467,17 +386,16 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
st.vgpr[wr_lane][wr_idx] = wr_val
if 'vcc_lane' in result:
# VOP2 carry ops write to VCC implicitly; VOPC/VOP3 write to vdst
st.pend_sgpr_lane(VCC_LO if op_cls is VOP2Op and 'CO_CI' in op.name else vdst, lane, result['vcc_lane'])
st.pend_sgpr_lane(VCC_LO if isinstance(inst, VOP2) and 'CO_CI' in inst.op_name else vdst, lane, result['vcc_lane'])
if 'exec_lane' in result:
# V_CMPX instructions write to EXEC per-lane
st.pend_sgpr_lane(EXEC_LO, lane, result['exec_lane'])
if 'd0' in result and op_cls not in (VOPCOp,) and 'vgpr_write' not in result:
writes_to_sgpr = op in (VOP1Op.V_READFIRSTLANE_B32,) or (op_cls is VOP3Op and op in (VOP3Op.V_READFIRSTLANE_B32, VOP3Op.V_READLANE_B32))
is_16bit_dst = (op_cls is VOP3Op and op in _VOP3_16BIT_DST_OPS) or (op_cls is VOP1Op and op in _VOP1_16BIT_DST_OPS) or is_vop2_16bit
if 'd0' in result and op_cls is not VOPCOp and 'vgpr_write' not in result:
writes_to_sgpr = 'READFIRSTLANE' in inst.op_name or 'READLANE' in inst.op_name
d0_val = result['d0']
if writes_to_sgpr: st.wsgpr(vdst, d0_val & MASK32)
elif result.get('d0_64'): V[vdst], V[vdst + 1] = d0_val & MASK32, (d0_val >> 32) & MASK32
elif is_16bit_dst: V[vdst] = _dst16(V[vdst], d0_val, bool(opsel & 8) if inst_type is VOP3 else dst_hi)
elif inst.is_dst_16(): V[vdst] = _dst16(V[vdst], d0_val, bool(opsel & 8) if isinstance(inst, VOP3) else dst_hi)
else: V[vdst] = d0_val & MASK32
# ═══════════════════════════════════════════════════════════════════════════════
@@ -506,22 +424,19 @@ def exec_wmma(st: WaveState, inst, op: VOP3POp) -> None:
# MAIN EXECUTION LOOP
# ═══════════════════════════════════════════════════════════════════════════════
SCALAR_TYPES = {SOP1, SOP2, SOPC, SOPK, SOPP, SMEM}
VECTOR_TYPES = {VOP1, VOP2, VOP3, VOP3SD, VOPC, FLAT, DS, VOPD, VOP3P}
def step_wave(program: Program, st: WaveState, lds: bytearray, n_lanes: int) -> int:
inst = program.get(st.pc)
if inst is None: return 1
inst_words, st.literal, inst_type = inst._words, getattr(inst, '_literal', None) or 0, type(inst)
inst_words, st.literal = inst._words, getattr(inst, '_literal', None) or 0
if inst_type in SCALAR_TYPES:
if isinstance(inst, (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM)):
delta = exec_scalar(st, inst)
if delta == -1: return -1 # endpgm
if delta == -2: st.pc += inst_words; return -2 # barrier
st.pc += inst_words + delta
else:
# V_READFIRSTLANE/V_READLANE write to SGPR, execute once; others execute per-lane with exec_mask
is_readlane = inst_type in (VOP1, VOP3) and hasattr(inst.op, 'name') and 'READLANE' in inst.op.name
is_readlane = isinstance(inst, (VOP1, VOP3)) and ('READFIRSTLANE' in inst.op_name or 'READLANE' in inst.op_name)
exec_mask = 1 if is_readlane else st.exec_mask
for lane in range(1 if is_readlane else n_lanes):
if exec_mask & (1 << lane): exec_vector(st, inst, lane, lds)

View File

@@ -3982,6 +3982,179 @@ class TestVOP3VOPC16Bit(unittest.TestCase):
self.assertEqual(st.sgpr[0] & 1, 1, "hi>hi should be true: 0x9999>0x1234")
class Test64BitLiteralSources(unittest.TestCase):
"""Regression tests for 64-bit instruction literal source handling.
For f64 operations, a 32-bit literal in the instruction stream represents the
HIGH 32 bits of the 64-bit value (low 32 bits are implicitly 0).
Bug: rsrc64() was returning the 32-bit literal as-is instead of shifting it
left by 32 bits. This caused V_FMA_F64 and V_LDEXP_F64 to use wrong values
when their source is a literal, breaking the f64->i64 conversion sequence.
The f64->i64 conversion sequence is:
v_trunc_f64 -> v_ldexp_f64 (by -32) -> v_floor_f64 -> v_fma_f64 (by -2^32)
-> v_cvt_u32_f64 (low bits) -> v_cvt_i32_f64 (high bits)
The V_FMA_F64 uses literal 0xC1F00000 which is the high 32 bits of f64 -2^32.
"""
def test_v_fma_f64_literal_neg_2pow32(self):
"""V_FMA_F64 with literal encoding of -2^32.
The f64 value -2^32 (-4294967296.0) has bits 0xC1F0000000000000.
The compiler encodes only the high 32 bits (0xC1F00000) as a literal.
The emulator must interpret this as 0xC1F00000_00000000.
"""
# v[0:1] = -41.0 (trunc), v[2:3] = -1.0 (floor of -41/2^32)
# FMA: result = (-2^32) * (-1.0) + (-41.0) = 4294967296 - 41 = 4294967255.0
val_41 = f2i64(-41.0)
val_m1 = f2i64(-1.0)
# Literal 0xC1F00000 is high 32 bits of f64 -2^32
lit = 0xC1F00000
instructions = [
s_mov_b32(s[0], val_41 & 0xffffffff),
s_mov_b32(s[1], (val_41 >> 32) & 0xffffffff),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
s_mov_b32(s[2], val_m1 & 0xffffffff),
s_mov_b32(s[3], (val_m1 >> 32) & 0xffffffff),
v_mov_b32_e32(v[2], s[2]),
v_mov_b32_e32(v[3], s[3]),
# V_FMA_F64 v[4:5], literal, v[2:3], v[0:1]
# = (-2^32) * (-1.0) + (-41.0) = 4294967255.0
VOP3(VOP3Op.V_FMA_F64, vdst=v[4], src0=RawImm(255), src1=v[2], src2=v[0], literal=lit),
]
st = run_program(instructions, n_lanes=1)
result = i642f(st.vgpr[0][4] | (st.vgpr[0][5] << 32))
expected = 4294967255.0 # 2^32 - 41
self.assertAlmostEqual(result, expected, places=0, msg=f"Expected {expected}, got {result}")
def test_v_ldexp_f64_literal_neg32(self):
"""V_LDEXP_F64 with literal -32 for exponent.
V_LDEXP_F64 computes src0 * 2^src1 where src1 is an integer exponent.
The literal 0xFFFFFFE0 represents -32 as a 32-bit signed integer.
For V_LDEXP_F64, src1 is 32-bit (not 64-bit), so this is correct as-is.
"""
val = f2i64(-41.0)
expected = -41.0 * (2.0 ** -32) # -9.5367431640625e-09
instructions = [
s_mov_b32(s[0], val & 0xffffffff),
s_mov_b32(s[1], (val >> 32) & 0xffffffff),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
# V_LDEXP_F64 v[2:3], v[0:1], -32
v_ldexp_f64(v[2:4], v[0:2], 0xFFFFFFE0),
]
st = run_program(instructions, n_lanes=1)
result = i642f(st.vgpr[0][2] | (st.vgpr[0][3] << 32))
self.assertAlmostEqual(result, expected, places=15, msg=f"Expected {expected}, got {result}")
def test_f64_to_i64_full_sequence(self):
"""Full f64->i64 conversion sequence with negative value.
This is the exact sequence generated by the compiler for (long)(-41.0):
v_trunc_f64 v[0:1], v[0:1]
v_ldexp_f64 v[2:3], v[0:1], -32
v_floor_f64 v[2:3], v[2:3]
v_fma_f64 v[0:1], 0xc1f00000, v[2:3], v[0:1] # -2^32
v_cvt_u32_f64 v0, v[0:1]
v_cvt_i32_f64 v1, v[2:3]
Result: v1:v0 = 0xFFFFFFFF:0xFFFFFFD7 = -41 as i64
"""
val = f2i64(-41.0)
lit = 0xC1F00000 # high 32 bits of f64 -2^32
instructions = [
s_mov_b32(s[0], val & 0xffffffff),
s_mov_b32(s[1], (val >> 32) & 0xffffffff),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_trunc_f64_e32(v[0:2], v[0:2]),
v_ldexp_f64(v[2:4], v[0:2], 0xFFFFFFE0), # -32
v_floor_f64_e32(v[2:4], v[2:4]),
VOP3(VOP3Op.V_FMA_F64, vdst=v[0], src0=RawImm(255), src1=v[2], src2=v[0], literal=lit),
v_cvt_u32_f64_e32(v[4], v[0:2]),
v_cvt_i32_f64_e32(v[5], v[2:4]),
]
st = run_program(instructions, n_lanes=1)
lo = st.vgpr[0][4]
hi = st.vgpr[0][5]
result = struct.unpack('<q', struct.pack('<II', lo, hi))[0]
self.assertEqual(result, -41, f"Expected -41, got {result} (lo=0x{lo:08x}, hi=0x{hi:08x})")
def test_f64_to_i64_large_negative(self):
"""f64->i64 conversion with larger negative value (-1000000).
Tests that the conversion sequence works for values that span both
high and low 32-bit parts of the result.
"""
val = f2i64(-1000000.0)
lit = 0xC1F00000
instructions = [
s_mov_b32(s[0], val & 0xffffffff),
s_mov_b32(s[1], (val >> 32) & 0xffffffff),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_trunc_f64_e32(v[0:2], v[0:2]),
v_ldexp_f64(v[2:4], v[0:2], 0xFFFFFFE0),
v_floor_f64_e32(v[2:4], v[2:4]),
VOP3(VOP3Op.V_FMA_F64, vdst=v[0], src0=RawImm(255), src1=v[2], src2=v[0], literal=lit),
v_cvt_u32_f64_e32(v[4], v[0:2]),
v_cvt_i32_f64_e32(v[5], v[2:4]),
]
st = run_program(instructions, n_lanes=1)
lo = st.vgpr[0][4]
hi = st.vgpr[0][5]
result = struct.unpack('<q', struct.pack('<II', lo, hi))[0]
self.assertEqual(result, -1000000, f"Expected -1000000, got {result}")
def test_f64_to_i64_positive(self):
"""f64->i64 conversion with positive value (1000000)."""
val = f2i64(1000000.0)
lit = 0xC1F00000
instructions = [
s_mov_b32(s[0], val & 0xffffffff),
s_mov_b32(s[1], (val >> 32) & 0xffffffff),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_trunc_f64_e32(v[0:2], v[0:2]),
v_ldexp_f64(v[2:4], v[0:2], 0xFFFFFFE0),
v_floor_f64_e32(v[2:4], v[2:4]),
VOP3(VOP3Op.V_FMA_F64, vdst=v[0], src0=RawImm(255), src1=v[2], src2=v[0], literal=lit),
v_cvt_u32_f64_e32(v[4], v[0:2]),
v_cvt_i32_f64_e32(v[5], v[2:4]),
]
st = run_program(instructions, n_lanes=1)
lo = st.vgpr[0][4]
hi = st.vgpr[0][5]
result = struct.unpack('<q', struct.pack('<II', lo, hi))[0]
self.assertEqual(result, 1000000, f"Expected 1000000, got {result}")
def test_f64_to_i64_large_positive(self):
"""f64->i64 conversion with value > 2^32 (requires 64-bit result)."""
val = f2i64(5000000000.0) # 5 billion, > 2^32
lit = 0xC1F00000
instructions = [
s_mov_b32(s[0], val & 0xffffffff),
s_mov_b32(s[1], (val >> 32) & 0xffffffff),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
v_trunc_f64_e32(v[0:2], v[0:2]),
v_ldexp_f64(v[2:4], v[0:2], 0xFFFFFFE0),
v_floor_f64_e32(v[2:4], v[2:4]),
VOP3(VOP3Op.V_FMA_F64, vdst=v[0], src0=RawImm(255), src1=v[2], src2=v[0], literal=lit),
v_cvt_u32_f64_e32(v[4], v[0:2]),
v_cvt_i32_f64_e32(v[5], v[2:4]),
]
st = run_program(instructions, n_lanes=1)
lo = st.vgpr[0][4]
hi = st.vgpr[0][5]
result = struct.unpack('<q', struct.pack('<II', lo, hi))[0]
self.assertEqual(result, 5000000000, f"Expected 5000000000, got {result}")
class TestDS2Addr(unittest.TestCase):
"""Regression tests for DS_LOAD_2ADDR and DS_STORE_2ADDR instructions.
These ops use offset scaling: offset * sizeof(data) for address calculation.