Files
tinygrad/extra/assembly/amd/asm.py
George Hotz f2b11010e8 no skip
2026-01-04 20:51:01 -08:00

1511 lines
98 KiB
Python

# RDNA3/CDNA assembler and disassembler
from __future__ import annotations
import re
from extra.assembly.amd.dsl import Inst, RawImm, Reg, SrcMod, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory
from extra.assembly.amd.dsl import VCC_LO, VCC_HI, VCC, EXEC_LO, EXEC_HI, EXEC, SCC, M0, NULL, OFF
from extra.assembly.amd.dsl import SPECIAL_GPRS, SPECIAL_PAIRS, SPECIAL_PAIRS_CDNA, FLOAT_DEC, FLOAT_ENC, decode_src
from extra.assembly.amd.autogen.rdna3 import ins
from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT, MUBUF, MTBUF, MIMG, EXP,
VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOPDOp, SOP1Op, SOPKOp, SOPPOp, SMEMOp, DSOp, MUBUFOp, MTBUFOp)
from extra.assembly.amd.autogen.rdna3.enum import BufFmt
def _is_cdna(inst: Inst) -> bool: return 'cdna' in inst.__class__.__module__
def _matches_encoding(word: int, cls: type[Inst]) -> bool:
"""Check if word matches the encoding pattern of an instruction class."""
if cls._encoding is None: return False
bf, val = cls._encoding
return ((word >> bf.lo) & bf.mask()) == val
# Order matters: more specific encodings first, VOP2 last (it's a catch-all for bit31=0)
_RDNA_FORMATS_64 = [VOPD, VOP3P, VINTERP, VOP3, DS, FLAT, MUBUF, MTBUF, MIMG, SMEM, EXP]
_RDNA_FORMATS_32 = [SOP1, SOPC, SOPP, SOPK, VOPC, VOP1, SOP2, VOP2] # SOP2/VOP2 are catch-alls
from extra.assembly.amd.autogen.cdna.ins import (VOP1 as C_VOP1, VOP2 as C_VOP2, VOPC as C_VOPC, VOP3A, VOP3B, VOP3P as C_VOP3P,
SOP1 as C_SOP1, SOP2 as C_SOP2, SOPC as C_SOPC, SOPK as C_SOPK, SOPP as C_SOPP, SMEM as C_SMEM, DS as C_DS,
FLAT as C_FLAT, MUBUF as C_MUBUF, MTBUF as C_MTBUF, SDWA, DPP)
_CDNA_FORMATS_64 = [C_VOP3P, VOP3A, C_DS, C_FLAT, C_MUBUF, C_MTBUF, C_SMEM]
_CDNA_FORMATS_32 = [SDWA, DPP, C_SOP1, C_SOPC, C_SOPP, C_SOPK, C_VOPC, C_VOP1, C_SOP2, C_VOP2]
_CDNA_VOP3B_OPS = {281, 282, 283, 284, 285, 286, 480, 481, 488, 489} # VOP3B opcodes
# CDNA opcode name aliases for disasm (new name -> old name expected by tests)
_CDNA_DISASM_ALIASES = {'v_fmac_f64': 'v_mul_legacy_f32', 'v_dot2c_f32_bf16': 'v_mac_f32', 'v_fmamk_f32': 'v_madmk_f32', 'v_fmaak_f32': 'v_madak_f32'}
def detect_format(data: bytes, arch: str = "rdna3") -> type[Inst]:
"""Detect instruction format from machine code bytes."""
assert len(data) >= 4, f"need at least 4 bytes, got {len(data)}"
word = int.from_bytes(data[:4], 'little')
if arch == "cdna":
if (word >> 30) == 0b11:
for cls in _CDNA_FORMATS_64:
if _matches_encoding(word, cls):
return VOP3B if cls is VOP3A and ((word >> 16) & 0x3ff) in _CDNA_VOP3B_OPS else cls
raise ValueError(f"unknown CDNA 64-bit format word={word:#010x}")
for cls in _CDNA_FORMATS_32:
if _matches_encoding(word, cls): return cls
raise ValueError(f"unknown CDNA 32-bit format word={word:#010x}")
# RDNA (default)
if (word >> 30) == 0b11:
for cls in _RDNA_FORMATS_64:
if _matches_encoding(word, cls):
return VOP3SD if cls is VOP3 and ((word >> 16) & 0x3ff) in Inst._VOP3SD_OPS else cls
raise ValueError(f"unknown 64-bit format word={word:#010x}")
for cls in _RDNA_FORMATS_32:
if _matches_encoding(word, cls): return cls
raise ValueError(f"unknown 32-bit format word={word:#010x}")
# ═══════════════════════════════════════════════════════════════════════════════
# CONSTANTS
# ═══════════════════════════════════════════════════════════════════════════════
HWREG = {1: 'HW_REG_MODE', 2: 'HW_REG_STATUS', 3: 'HW_REG_TRAPSTS', 4: 'HW_REG_HW_ID', 5: 'HW_REG_GPR_ALLOC',
6: 'HW_REG_LDS_ALLOC', 7: 'HW_REG_IB_STS', 15: 'HW_REG_SH_MEM_BASES', 18: 'HW_REG_PERF_SNAPSHOT_PC_LO',
19: 'HW_REG_PERF_SNAPSHOT_PC_HI', 20: 'HW_REG_FLAT_SCR_LO', 21: 'HW_REG_FLAT_SCR_HI', 22: 'HW_REG_XNACK_MASK',
23: 'HW_REG_HW_ID1', 24: 'HW_REG_HW_ID2', 25: 'HW_REG_POPS_PACKER', 28: 'HW_REG_IB_STS2'}
HWREG_IDS = {v.lower(): k for k, v in HWREG.items()}
# RDNA unified buffer format - extracted from PDF, use enum for name->value lookup
BUF_FMT = {e.name: e.value for e in BufFmt}
def _parse_buf_fmt_combo(s: str) -> int: # parse format:[BUF_DATA_FORMAT_X, BUF_NUM_FORMAT_Y]
parts = [p.strip().replace('BUF_DATA_FORMAT_', '').replace('BUF_NUM_FORMAT_', '') for p in s.split(',')]
return BUF_FMT.get(f'BUF_FMT_{parts[0]}_{parts[1]}') if len(parts) == 2 else None
MSG = {128: 'MSG_RTN_GET_DOORBELL', 129: 'MSG_RTN_GET_DDID', 130: 'MSG_RTN_GET_TMA',
131: 'MSG_RTN_GET_REALTIME', 132: 'MSG_RTN_SAVE_WAVE', 133: 'MSG_RTN_GET_TBA'}
# ═══════════════════════════════════════════════════════════════════════════════
# HELPERS
# ═══════════════════════════════════════════════════════════════════════════════
def _reg(p: str, b: int, n: int = 1) -> str: return f"{p}{b}" if n == 1 else f"{p}[{b}:{b+n-1}]"
def _sreg(b: int, n: int = 1) -> str: return _reg("s", b, n)
def _vreg(b: int, n: int = 1) -> str: return _reg("v", b, n)
def _areg(b: int, n: int = 1) -> str: return _reg("a", b, n) # accumulator registers for GFX90a
def _ttmp(b: int, n: int = 1) -> str: return _reg("ttmp", b - 108, n) if 108 <= b <= 123 else None
def _sreg_or_ttmp(b: int, n: int = 1) -> str: return _ttmp(b, n) or _sreg(b, n)
def _fmt_sdst(v: int, n: int = 1, cdna: bool = False) -> str:
from extra.assembly.amd.dsl import SPECIAL_PAIRS_CDNA, SPECIAL_GPRS_CDNA
if t := _ttmp(v, n): return t
pairs = SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS
gprs = SPECIAL_GPRS_CDNA if cdna else SPECIAL_GPRS
if n > 1: return pairs.get(v) or gprs.get(v) or _sreg(v, n) # also check gprs for null/m0
return gprs.get(v, f"s{v}")
def _fmt_src(v: int, n: int = 1, cdna: bool = False) -> str:
from extra.assembly.amd.dsl import SPECIAL_PAIRS_CDNA
if n == 1: return decode_src(v, cdna)
if v >= 256: return _vreg(v - 256, n)
if v <= 101: return _sreg(v, n) # s0-s101 can be pairs, but 102+ are special on CDNA
pairs = SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS
if n == 2 and v in pairs: return pairs[v]
if v <= 105: return _sreg(v, n) # s102-s105 regular pairs for RDNA
if t := _ttmp(v, n): return t
return decode_src(v, cdna)
def _fmt_v16(v: int, base: int = 256, hi_thresh: int = 384) -> str:
return f"v{(v - base) & 0x7f}.{'h' if v >= hi_thresh else 'l'}"
def waitcnt(vmcnt: int = 0x3f, expcnt: int = 0x7, lgkmcnt: int = 0x3f) -> int:
return (expcnt & 0x7) | ((lgkmcnt & 0x3f) << 4) | ((vmcnt & 0x3f) << 10)
def _has(op: str, *subs) -> bool: return any(s in op for s in subs)
def _omod(v: int) -> str: return {1: " mul:2", 2: " mul:4", 3: " div:2"}.get(v, "")
def _src16(inst, v: int) -> str: return _fmt_v16(v) if v >= 256 else inst.lit(v) # format 16-bit src: vgpr.h/l or literal
def _mods(*pairs) -> str: return " ".join(m for c, m in pairs if c)
def _fmt_bits(label: str, val: int, count: int) -> str: return f"{label}:[{','.join(str((val >> i) & 1) for i in range(count))}]"
def _vop3_src(inst, v: int, neg: int, abs_: int, hi: int, n: int, f16: bool) -> str:
"""Format VOP3 source operand with modifiers."""
if v == 255: s = inst.lit(v) # literal constant takes priority
elif n > 1: s = _fmt_src(v, n)
elif f16 and v >= 256: s = f"v{v - 256}.h" if hi else f"v{v - 256}.l"
else: s = inst.lit(v)
if abs_: s = f"|{s}|"
return f"-{s}" if neg else s
def _opsel_str(opsel: int, n: int, need: bool, is16_d: bool) -> str:
"""Format op_sel modifier string."""
if not need: return ""
# For VOP1 (n=1): op_sel:[src0_hi, dst_hi], for VOP2 (n=2): op_sel:[src0_hi, src1_hi, dst_hi], for VOP3 (n=3): op_sel:[src0_hi, src1_hi, src2_hi, dst_hi]
dst_hi = (opsel >> 3) & 1
if n == 1: return f" op_sel:[{opsel & 1},{dst_hi}]"
if n == 2: return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{dst_hi}]"
return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{(opsel >> 2) & 1},{dst_hi}]"
# ═══════════════════════════════════════════════════════════════════════════════
# DISASSEMBLER
# ═══════════════════════════════════════════════════════════════════════════════
def _disasm_vop1(inst: VOP1) -> str:
name, cdna = inst.op_name.lower() or f'vop1_op_{inst.op}', _is_cdna(inst)
suf = "" if cdna else "_e32"
if name in ('v_nop', 'v_pipeflush', 'v_clrexcp'): return name # no operands
if 'readfirstlane' in name:
src = f"v{inst.src0 - 256}" if inst.src0 >= 256 else decode_src(inst.src0, cdna)
return f"{name} {_fmt_sdst(inst.vdst, 1, cdna)}, {src}"
# 16-bit dst: uses .h/.l suffix for RDNA (CDNA uses plain vN)
parts = name.split('_')
is_16d = not cdna and (any(p in ('f16','i16','u16','b16') for p in parts[-2:-1]) or (len(parts) >= 2 and parts[-1] in ('f16','i16','u16','b16') and 'cvt' not in name))
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else _fmt_v16(inst.vdst, 0, 128) if is_16d else f"v{inst.vdst}"
src = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0), cdna) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if not cdna and inst.is_src_16(0) and 'sat_pk' not in name else inst.lit(inst.src0)
return f"{name}{suf} {dst}, {src}"
_VOP2_CARRY_OUT = {'v_add_co_u32', 'v_sub_co_u32', 'v_subrev_co_u32'} # carry out only
_VOP2_CARRY_INOUT = {'v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'} # carry in and out
def _disasm_vop2(inst: VOP2) -> str:
name, cdna = inst.op_name.lower(), _is_cdna(inst)
if cdna: name = _CDNA_DISASM_ALIASES.get(name, name) # apply CDNA aliases
suf = "" if cdna or (not cdna and inst.op == VOP2Op.V_DOT2ACC_F32_F16) else "_e32"
lit = getattr(inst, '_literal', None)
is16 = not cdna and inst.is_16bit()
# fmaak/madak: dst = src0 * vsrc1 + K, fmamk/madmk: dst = src0 * K + vsrc1
if 'fmaak' in name or 'madak' in name or (not cdna and inst.op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16)):
if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}, 0x{lit:x}"
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, 0x{lit:x}"
if 'fmamk' in name or 'madmk' in name or (not cdna and inst.op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16)):
if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, 0x{lit:x}, {_fmt_v16(inst.vsrc1, 0, 128)}"
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, 0x{lit:x}, v{inst.vsrc1}"
if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}"
vcc = "vcc" if cdna else "vcc_lo"
# CDNA carry ops output vcc after vdst
if cdna and name in _VOP2_CARRY_OUT: return f"{name}{suf} v{inst.vdst}, {vcc}, {inst.lit(inst.src0)}, v{inst.vsrc1}"
if cdna and name in _VOP2_CARRY_INOUT: return f"{name}{suf} v{inst.vdst}, {vcc}, {inst.lit(inst.src0)}, v{inst.vsrc1}, {vcc}"
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}" + (f", {vcc}" if name == 'v_cndmask_b32' else "")
def _disasm_vopc(inst: VOPC) -> str:
name, cdna = inst.op_name.lower(), _is_cdna(inst)
if cdna:
s0 = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0), cdna)
s1 = _vreg(inst.vsrc1, inst.src_regs(1)) if inst.src_regs(1) > 1 else f"v{inst.vsrc1}"
return f"{name} vcc, {s0}, {s1}" # CDNA VOPC always outputs vcc
# RDNA: v_cmpx_* writes to exec (no vcc), v_cmp_* writes to vcc_lo
has_vcc = 'cmpx' not in name
s0 = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_16bit() else inst.lit(inst.src0)
s1 = _vreg(inst.vsrc1, inst.src_regs(1)) if inst.src_regs(1) > 1 else _fmt_v16(inst.vsrc1, 0, 128) if inst.is_16bit() else f"v{inst.vsrc1}"
return f"{name}_e32 vcc_lo, {s0}, {s1}" if has_vcc else f"{name}_e32 {s0}, {s1}"
NO_ARG_SOPP = {SOPPOp.S_BARRIER, SOPPOp.S_WAKEUP, SOPPOp.S_ICACHE_INV,
SOPPOp.S_WAIT_IDLE, SOPPOp.S_ENDPGM_SAVED, SOPPOp.S_CODE_END, SOPPOp.S_ENDPGM_ORDERED_PS_DONE, SOPPOp.S_TTRACEDATA}
# CDNA uses name-based matching since opcode values differ from RDNA
_CDNA_NO_ARG_SOPP = {'s_endpgm', 's_barrier', 's_wakeup', 's_icache_inv', 's_ttracedata', 's_nop', 's_sethalt', 's_sleep',
's_setprio', 's_trap', 's_incperflevel', 's_decperflevel', 's_sendmsg', 's_sendmsghalt'}
def _disasm_sopp(inst: SOPP) -> str:
name, cdna = inst.op_name.lower(), _is_cdna(inst)
if cdna:
# CDNA: use name-based matching
if name == 's_endpgm': return name if inst.simm16 == 0 else f"{name} {inst.simm16}"
if name in ('s_barrier', 's_wakeup', 's_icache_inv', 's_ttracedata'): return name
if name == 's_waitcnt':
vm, lgkm, exp = inst.simm16 & 0xf, (inst.simm16 >> 8) & 0x3f, (inst.simm16 >> 4) & 0x7
p = [f"vmcnt({vm})" if vm != 0xf else "", f"expcnt({exp})" if exp != 7 else "", f"lgkmcnt({lgkm})" if lgkm != 0x3f else ""]
return f"s_waitcnt {' '.join(x for x in p if x) or '0'}"
if name.startswith(('s_cbranch', 's_branch')): return f"{name} {inst.simm16}"
return f"{name} 0x{inst.simm16:x}" if inst.simm16 else name
# RDNA
if inst.op in NO_ARG_SOPP: return name
if inst.op == SOPPOp.S_ENDPGM: return name if inst.simm16 == 0 else f"{name} {inst.simm16}"
if inst.op == SOPPOp.S_WAITCNT:
vm, exp, lgkm = (inst.simm16 >> 10) & 0x3f, inst.simm16 & 0xf, (inst.simm16 >> 4) & 0x3f
p = [f"vmcnt({vm})" if vm != 0x3f else "", f"expcnt({exp})" if exp != 7 else "", f"lgkmcnt({lgkm})" if lgkm != 0x3f else ""]
return f"s_waitcnt {' '.join(x for x in p if x) or '0'}"
if inst.op == SOPPOp.S_DELAY_ALU:
deps, skips = ['VALU_DEP_1','VALU_DEP_2','VALU_DEP_3','VALU_DEP_4','TRANS32_DEP_1','TRANS32_DEP_2','TRANS32_DEP_3','FMA_ACCUM_CYCLE_1','SALU_CYCLE_1','SALU_CYCLE_2','SALU_CYCLE_3'], ['SAME','NEXT','SKIP_1','SKIP_2','SKIP_3','SKIP_4']
id0, skip, id1 = inst.simm16 & 0xf, (inst.simm16 >> 4) & 0x7, (inst.simm16 >> 7) & 0xf
dep = lambda v: deps[v-1] if 0 < v <= len(deps) else str(v)
p = [f"instid0({dep(id0)})" if id0 else "", f"instskip({skips[skip]})" if skip else "", f"instid1({dep(id1)})" if id1 else ""]
return f"s_delay_alu {' | '.join(x for x in p if x) or '0'}"
return f"{name} {inst.simm16}" if name.startswith(('s_cbranch', 's_branch')) else f"{name} 0x{inst.simm16:x}"
def _disasm_smem(inst: SMEM) -> str:
name, cdna = inst.op_name.lower(), _is_cdna(inst)
if inst.op in (SMEMOp.S_GL1_INV, SMEMOp.S_DCACHE_INV): return name
# GFX9 SMEM: soe and imm bits determine offset interpretation
# soe=1, imm=1: soffset is SGPR, offset is immediate (both used)
# soe=0, imm=1: offset is immediate
# soe=0, imm=0: offset field is SGPR encoding (0-255)
soe, imm = getattr(inst, 'soe', 0), getattr(inst, 'imm', 1)
if cdna:
if soe and imm:
off_s = f"{decode_src(inst.soffset, cdna)} offset:0x{inst.offset:x}" # SGPR + immediate
elif imm:
off_s = f"0x{inst.offset:x}" # Immediate offset only
elif inst.offset < 256:
off_s = decode_src(inst.offset, cdna) # SGPR encoding in offset field
else:
off_s = decode_src(inst.soffset, cdna)
elif inst.offset and inst.soffset != 124:
off_s = f"{decode_src(inst.soffset, cdna)} offset:0x{inst.offset:x}"
elif inst.offset:
off_s = f"0x{inst.offset:x}"
else:
off_s = decode_src(inst.soffset, cdna)
op_val = inst.op.value if hasattr(inst.op, 'value') else inst.op
# s_buffer_* instructions use 4 SGPRs for sbase (buffer descriptor)
is_buffer = 'buffer' in name or 's_atc_probe_buffer' == name
sbase_idx, sbase_count = inst.sbase * 2, 4 if is_buffer else 2
sbase_str = _fmt_src(sbase_idx, sbase_count, cdna) if sbase_count == 2 else _sreg(sbase_idx, sbase_count) if sbase_idx <= 105 else _reg("ttmp", sbase_idx - 108, sbase_count)
if name in ('s_atc_probe', 's_atc_probe_buffer'): return f"{name} {inst.sdata}, {sbase_str}, {off_s}"
return f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs(), cdna)}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (getattr(inst, 'dlc', 0), " dlc"))
def _disasm_flat(inst: FLAT) -> str:
name, cdna = inst.op_name.lower(), _is_cdna(inst)
acc = getattr(inst, 'acc', 0) # GFX90a accumulator register flag
reg_fn = _areg if acc else _vreg # use a[n] for acc=1, v[n] for acc=0
seg = ['flat', 'scratch', 'global'][inst.seg] if inst.seg < 3 else 'flat'
instr = f"{seg}_{name.split('_', 1)[1] if '_' in name else name}"
off_val = inst.offset if seg == 'flat' else (inst.offset if inst.offset < 4096 else inst.offset - 8192)
w = inst.dst_regs() * (2 if '_x2' in name else 1) * (2 if 'cmpswap' in name else 1)
off_s = f" offset:{off_val}" if off_val else "" # Omit offset:0
if cdna: mods = f"{off_s}{' glc' if inst.sc0 else ''}{' slc' if inst.nt else ''}" # GFX9: sc0->glc, nt->slc
else: mods = f"{off_s}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}"
# saddr
if seg == 'flat' or inst.saddr == 0x7F: saddr_s = ""
elif inst.saddr == 124: saddr_s = ", off"
elif seg == 'scratch': saddr_s = f", {decode_src(inst.saddr, cdna)}"
elif inst.saddr in (SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS): saddr_s = f", {(SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS)[inst.saddr]}"
elif t := _ttmp(inst.saddr, 2): saddr_s = f", {t}"
else: saddr_s = f", {_sreg(inst.saddr, 2) if inst.saddr < 106 else decode_src(inst.saddr, cdna)}"
# addtid: no addr
if 'addtid' in name: return f"{instr} {'a' if acc else 'v'}{inst.data if 'store' in name else inst.vdst}{saddr_s}{mods}"
# addr width: CDNA flat always uses 2 VGPRs (64-bit), scratch uses 1, RDNA uses 2 only when no saddr
if cdna:
addr_w = 1 if seg == 'scratch' else 2 # CDNA: flat/global always 64-bit addr
else:
addr_w = 1 if seg == 'scratch' or (inst.saddr not in (0x7F, 124)) else 2
addr_s = "off" if not inst.sve and seg == 'scratch' else _vreg(inst.addr, addr_w)
data_s, vdst_s = reg_fn(inst.data, w), reg_fn(inst.vdst, w // 2 if 'cmpswap' in name else w)
glc_or_sc0 = inst.sc0 if cdna else inst.glc
if 'atomic' in name:
return f"{instr} {vdst_s}, {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}" if glc_or_sc0 else f"{instr} {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}"
if 'store' in name: return f"{instr} {addr_s}, {data_s}{saddr_s}{mods}"
return f"{instr} {reg_fn(inst.vdst, w)}, {addr_s}{saddr_s}{mods}"
def _disasm_ds(inst: DS) -> str:
op, name = inst.op, inst.op_name.lower()
acc = getattr(inst, 'acc', 0) # GFX90a accumulator register flag
reg_fn = _areg if acc else _vreg # use a[n] for acc=1, v[n] for acc=0
rp = 'a' if acc else 'v' # register prefix for single regs
gds = " gds" if inst.gds else ""
off = f" offset:{inst.offset0 | (inst.offset1 << 8)}" if inst.offset0 or inst.offset1 else ""
off2 = (" offset0:" + str(inst.offset0) if inst.offset0 else "") + (" offset1:" + str(inst.offset1) if inst.offset1 else "")
w = inst.dst_regs()
d0, d1, dst, addr = reg_fn(inst.data0, w), reg_fn(inst.data1, w), reg_fn(inst.vdst, w), f"v{inst.addr}"
if op == DSOp.DS_NOP: return name
if op == DSOp.DS_BVH_STACK_RTN_B32: return f"{name} v{inst.vdst}, {addr}, v{inst.data0}, {_vreg(inst.data1, 4)}{off}{gds}"
if 'gws_sema' in name and op != DSOp.DS_GWS_SEMA_BR: return f"{name}{off}{gds}"
if 'gws_' in name: return f"{name} {addr}{off}{gds}"
if op in (DSOp.DS_CONSUME, DSOp.DS_APPEND): return f"{name} {rp}{inst.vdst}{off}{gds}"
if 'gs_reg' in name: return f"{name} {reg_fn(inst.vdst, 2)}, {rp}{inst.data0}{off}{gds}"
if '2addr' in name:
if 'load' in name: return f"{name} {reg_fn(inst.vdst, w*2)}, {addr}{off2}{gds}"
if 'store' in name and 'xchg' not in name: return f"{name} {addr}, {d0}, {d1}{off2}{gds}"
return f"{name} {reg_fn(inst.vdst, w*2)}, {addr}, {d0}, {d1}{off2}{gds}"
if 'write2' in name: return f"{name} {addr}, {d0}, {d1}{off2}{gds}"
if 'read2' in name: return f"{name} {reg_fn(inst.vdst, w*2)}, {addr}{off2}{gds}"
if 'load' in name: return f"{name} {rp}{inst.vdst}{off}{gds}" if 'addtid' in name else f"{name} {dst}, {addr}{off}{gds}"
if 'store' in name and not _has(name, 'cmp', 'xchg'):
return f"{name} {rp}{inst.data0}{off}{gds}" if 'addtid' in name else f"{name} {addr}, {d0}{off}{gds}"
if 'swizzle' in name or op == DSOp.DS_ORDERED_COUNT: return f"{name} {rp}{inst.vdst}, {addr}{off}{gds}"
if 'permute' in name: return f"{name} {rp}{inst.vdst}, {addr}, {rp}{inst.data0}{off}{gds}"
if 'condxchg' in name: return f"{name} {reg_fn(inst.vdst, 2)}, {addr}, {reg_fn(inst.data0, 2)}{off}{gds}"
if _has(name, 'cmpstore', 'mskor', 'wrap'):
return f"{name} {dst}, {addr}, {d0}, {d1}{off}{gds}" if '_rtn' in name else f"{name} {addr}, {d0}, {d1}{off}{gds}"
return f"{name} {dst}, {addr}, {d0}{off}{gds}" if '_rtn' in name else f"{name} {addr}, {d0}{off}{gds}"
def _disasm_vop3(inst: VOP3) -> str:
op, name = inst.op, inst.op_name.lower()
# VOP3SD (shared encoding)
if isinstance(op, VOP3SDOp):
sdst = (inst.clmp << 7) | (inst.opsel << 3) | inst.abs
def src(v, neg, n):
s = inst.lit(v) if v == 255 else (_fmt_src(v, n) if n > 1 else inst.lit(v))
return f"neg({s})" if neg and v == 255 else (f"-{s}" if neg else s)
s0, s1, s2 = src(inst.src0, inst.neg & 1, inst.src_regs(0)), src(inst.src1, inst.neg & 2, inst.src_regs(1)), src(inst.src2, inst.neg & 4, inst.src_regs(2))
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}"
srcs = f"{s0}, {s1}, {s2}" if inst.num_srcs() == 3 else f"{s0}, {s1}"
return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {srcs}" + _omod(inst.omod)
# Detect 16-bit operand sizes (for .h/.l suffix handling)
is16_d = is16_s = is16_s2 = False
if 'cvt_pk' in name: is16_s = name.endswith('16')
elif m := re.match(r'v_(?:cvt|frexp_exp)_([a-z0-9_]+)_([a-z0-9]+)', name):
is16_d, is16_s = _has(m.group(1), 'f16','i16','u16','b16'), _has(m.group(2), 'f16','i16','u16','b16')
is16_s2 = is16_s
elif re.match(r'v_mad_[iu]32_[iu]16', name): is16_s = True
elif 'pack_b32' in name: is16_s = is16_s2 = True
elif 'sat_pk' in name: is16_d = True # v_sat_pk_* writes to 16-bit dest but takes 32-bit src
else: is16_d = is16_s = is16_s2 = inst.is_16bit()
s0 = _vop3_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.opsel&1, inst.src_regs(0), is16_s)
s1 = _vop3_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.opsel&2, inst.src_regs(1), is16_s)
s2 = _vop3_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.opsel&4, inst.src_regs(2), is16_s2)
# Destination
dn = inst.dst_regs()
if op == VOP3Op.V_READLANE_B32: dst = _fmt_sdst(inst.vdst, 1)
elif dn > 1: dst = _vreg(inst.vdst, dn)
elif is16_d: dst = f"v{inst.vdst}.h" if (inst.opsel & 8) else f"v{inst.vdst}.l"
else: dst = f"v{inst.vdst}"
cl, om = " clamp" if inst.clmp else "", _omod(inst.omod)
nonvgpr_opsel = (inst.src0 < 256 and (inst.opsel & 1)) or (inst.src1 < 256 and (inst.opsel & 2)) or (inst.src2 < 256 and (inst.opsel & 4))
need_opsel = nonvgpr_opsel or (inst.opsel and not is16_s)
if inst.op < 256: # VOPC
return f"{name}_e64 {s0}, {s1}{cl}" if name.startswith('v_cmpx') else f"{name}_e64 {_fmt_sdst(inst.vdst, 1)}, {s0}, {s1}{cl}"
if inst.op < 384: # VOP2
n = inst.num_srcs()
os = _opsel_str(inst.opsel, n, need_opsel, is16_d)
return f"{name}_e64 {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if n == 3 else f"{name}_e64 {dst}, {s0}, {s1}{os}{cl}{om}"
if inst.op < 512: # VOP1
return f"{name}_e64" if op in (VOP3Op.V_NOP, VOP3Op.V_PIPEFLUSH) else f"{name}_e64 {dst}, {s0}{_opsel_str(inst.opsel, 1, need_opsel, is16_d)}{cl}{om}"
# Native VOP3
n = inst.num_srcs()
os = _opsel_str(inst.opsel, n, need_opsel, is16_d)
return f"{name} {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if n == 3 else f"{name} {dst}, {s0}, {s1}{os}{cl}{om}"
def _disasm_vop3sd(inst: VOP3SD) -> str:
name = inst.op_name.lower()
def src(v, neg, n):
s = inst.lit(v) if v == 255 else (_fmt_src(v, n) if n > 1 else inst.lit(v))
return f"neg({s})" if neg and v == 255 else (f"-{s}" if neg else s)
s0, s1, s2 = src(inst.src0, inst.neg & 1, inst.src_regs(0)), src(inst.src1, inst.neg & 2, inst.src_regs(1)), src(inst.src2, inst.neg & 4, inst.src_regs(2))
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}"
srcs = f"{s0}, {s1}, {s2}" if inst.num_srcs() == 3 else f"{s0}, {s1}"
suffix = "_e64" if name.startswith('v_') and 'co_' in name else ""
return f"{name}{suffix} {dst}, {_fmt_sdst(inst.sdst, 1)}, {srcs}{' clamp' if inst.clmp else ''}{_omod(inst.omod)}"
def _disasm_vopd(inst: VOPD) -> str:
lit = inst._literal or inst.literal
vdst_y, nx, ny = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1), VOPDOp(inst.opx).name.lower(), VOPDOp(inst.opy).name.lower()
def half(n, vd, s0, vs1):
if 'mov' in n: return f"{n} v{vd}, {inst.lit(s0)}"
# fmamk: dst = src0 * K + vsrc1, fmaak: dst = src0 * vsrc1 + K
if 'fmamk' in n and lit: return f"{n} v{vd}, {inst.lit(s0)}, 0x{lit:x}, v{vs1}"
if 'fmaak' in n and lit: return f"{n} v{vd}, {inst.lit(s0)}, v{vs1}, 0x{lit:x}"
return f"{n} v{vd}, {inst.lit(s0)}, v{vs1}"
return f"{half(nx, inst.vdstx, inst.srcx0, inst.vsrcx1)} :: {half(ny, vdst_y, inst.srcy0, inst.vsrcy1)}"
def _disasm_vop3p(inst: VOP3P) -> str:
name = inst.op_name.lower()
is_wmma, n, is_fma_mix = 'wmma' in name, inst.num_srcs(), 'fma_mix' in name
def get_src(v, sc): return inst.lit(v) if v == 255 else _fmt_src(v, sc)
if is_wmma:
sc = 2 if 'iu4' in name else 4 if 'iu8' in name else 8
src0, src1, src2, dst = get_src(inst.src0, sc), get_src(inst.src1, sc), get_src(inst.src2, 8), _vreg(inst.vdst, 8)
else: src0, src1, src2, dst = get_src(inst.src0, 1), get_src(inst.src1, 1), get_src(inst.src2, 1), f"v{inst.vdst}"
opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2)
if is_fma_mix:
def m(s, neg, abs_): return f"-{f'|{s}|' if abs_ else s}" if neg else (f"|{s}|" if abs_ else s)
src0, src1, src2 = m(src0, inst.neg & 1, inst.neg_hi & 1), m(src1, inst.neg & 2, inst.neg_hi & 2), m(src2, inst.neg & 4, inst.neg_hi & 4)
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi else []) + (["clamp"] if inst.clmp else [])
else:
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != (7 if n == 3 else 3) else []) + \
([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if inst.clmp else [])
return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if n == 3 else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}"
def _disasm_buf(inst: MUBUF | MTBUF) -> str:
name, cdna = inst.op_name.lower(), _is_cdna(inst)
acc = getattr(inst, 'acc', 0) # GFX90a accumulator register flag
reg_fn = _areg if acc else _vreg # use a[n] for acc=1, v[n] for acc=0
if cdna and name in ('buffer_wbl2', 'buffer_inv'): return name
if not cdna and inst.op in (MUBUFOp.BUFFER_GL0_INV, MUBUFOp.BUFFER_GL1_INV): return name
w = (2 if _has(name, 'xyz', 'xyzw') else 1) if 'd16' in name else \
((2 if _has(name, 'b64', 'u64', 'i64') else 1) * (2 if 'cmpswap' in name else 1)) if 'atomic' in name else \
{'b32':1,'b64':2,'b96':3,'b128':4,'b16':1,'x':1,'xy':2,'xyz':3,'xyzw':4}.get(name.split('_')[-1], 1)
if hasattr(inst, 'tfe') and inst.tfe: w += 1
vaddr = _vreg(inst.vaddr, 2) if inst.offen and inst.idxen else f"v{inst.vaddr}" if inst.offen or inst.idxen else "off"
srsrc = _sreg_or_ttmp(inst.srsrc*4, 4)
is_mtbuf = isinstance(inst, MTBUF) or isinstance(inst, C_MTBUF)
if is_mtbuf:
dfmt, nfmt = inst.format & 0xf, (inst.format >> 4) & 0x7
if acc: # GFX90a accumulator style: show dfmt/nfmt as numbers
fmt_s = f" dfmt:{dfmt}, nfmt:{nfmt}," # double space before dfmt per LLVM format
elif not cdna: # RDNA style: show combined format number
fmt_s = f" format:{inst.format}" if inst.format else ""
else: # CDNA: show format:[BUF_DATA_FORMAT_X] or format:[BUF_NUM_FORMAT_X]
dfmt_names = ['INVALID', '8', '16', '8_8', '32', '16_16', '10_11_11', '11_11_10', '10_10_10_2', '2_10_10_10', '8_8_8_8', '32_32', '16_16_16_16', '32_32_32', '32_32_32_32', 'RESERVED_15']
nfmt_names = ['UNORM', 'SNORM', 'USCALED', 'SSCALED', 'UINT', 'SINT', 'RESERVED_6', 'FLOAT']
if dfmt == 1 and nfmt == 0: fmt_s = "" # default, no format shown
elif nfmt == 0: fmt_s = f" format:[BUF_DATA_FORMAT_{dfmt_names[dfmt]}]" # only dfmt differs
elif dfmt == 1: fmt_s = f" format:[BUF_NUM_FORMAT_{nfmt_names[nfmt]}]" # only nfmt differs
else: fmt_s = f" format:[BUF_DATA_FORMAT_{dfmt_names[dfmt]},BUF_NUM_FORMAT_{nfmt_names[nfmt]}]" # both differ
else:
fmt_s = ""
if cdna: mods = [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.sc0,"glc"),(inst.nt,"slc"),(inst.sc1,"sc1")] if c]
else: mods = [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.glc,"glc"),(inst.dlc,"dlc"),(inst.slc,"slc"),(inst.tfe,"tfe")] if c]
soffset_s = decode_src(inst.soffset, cdna)
if cdna and not acc and is_mtbuf: return f"{name} {reg_fn(inst.vdata, w)}, {vaddr}, {srsrc}, {soffset_s}{fmt_s}{' ' + ' '.join(mods) if mods else ''}"
return f"{name} {reg_fn(inst.vdata, w)}, {vaddr}, {srsrc},{fmt_s} {soffset_s}{' ' + ' '.join(mods) if mods else ''}"
def _mimg_vaddr_width(name: str, dim: int, a16: bool) -> int:
"""Calculate vaddr register count for MIMG sample/gather operations."""
# 1d,2d,3d,cube,1d_arr,2d_arr,2d_msaa,2d_msaa_arr
base = [1, 2, 3, 3, 2, 3, 3, 4][dim] # address coords
grad = [1, 2, 3, 2, 1, 2, 2, 2][dim] # gradient coords (for derivatives)
if 'get_resinfo' in name: return 1 # only mip level
packed, unpacked = 0, 0
if '_mip' in name: packed += 1
elif 'sample' in name or 'gather' in name:
if '_o' in name: unpacked += 1 # offset
if re.search(r'_c(_|$)', name): unpacked += 1 # compare (not _cl)
if '_d' in name: unpacked += (grad + 1) & ~1 if '_g16' in name else grad*2 # derivatives
if '_b' in name: unpacked += 1 # bias
if '_l' in name and '_cl' not in name and '_lz' not in name: packed += 1 # LOD
if '_cl' in name: packed += 1 # clamp
return (base + packed + 1) // 2 + unpacked if a16 else base + packed + unpacked
def _disasm_mimg(inst: MIMG) -> str:
name = inst.op_name.lower()
srsrc_base = inst.srsrc * 4
srsrc_str = _sreg_or_ttmp(srsrc_base, 8)
# BVH intersect ray: special case with 4 SGPR srsrc
if 'bvh' in name:
vaddr = (9 if '64' in name else 8) if inst.a16 else (12 if '64' in name else 11)
return f"{name} {_vreg(inst.vdata, 4)}, {_vreg(inst.vaddr, vaddr)}, {_sreg_or_ttmp(srsrc_base, 4)}{' a16' if inst.a16 else ''}"
# vdata width from dmask (gather4/msaa_load always 4), d16 packs, tfe adds 1
vdata = 4 if 'gather4' in name or 'msaa_load' in name else (bin(inst.dmask).count('1') or 1)
if inst.d16: vdata = (vdata + 1) // 2
if inst.tfe: vdata += 1
# vaddr width
dim_names = ['1d', '2d', '3d', 'cube', '1d_array', '2d_array', '2d_msaa', '2d_msaa_array']
dim = dim_names[inst.dim] if inst.dim < len(dim_names) else f"dim_{inst.dim}"
vaddr = _mimg_vaddr_width(name, inst.dim, inst.a16)
vaddr_str = f"v{inst.vaddr}" if vaddr == 1 else _vreg(inst.vaddr, vaddr)
# modifiers
mods = [f"dmask:0x{inst.dmask:x}"] if inst.dmask and (inst.dmask != 15 or 'atomic' in name) else []
mods.append(f"dim:SQ_RSRC_IMG_{dim.upper()}")
for flag, mod in [(inst.unrm,"unorm"),(inst.glc,"glc"),(inst.slc,"slc"),(inst.dlc,"dlc"),(inst.r128,"r128"),
(inst.a16,"a16"),(inst.tfe,"tfe"),(inst.lwe,"lwe"),(inst.d16,"d16")]:
if flag: mods.append(mod)
# ssamp for sample/gather/get_lod
ssamp_str = ""
if 'sample' in name or 'gather' in name or 'get_lod' in name:
ssamp_str = ", " + _sreg_or_ttmp(inst.ssamp * 4, 4)
return f"{name} {_vreg(inst.vdata, vdata)}, {vaddr_str}, {srsrc_str}{ssamp_str} {' '.join(mods)}"
def _disasm_sop1(inst: SOP1) -> str:
op, name, cdna = inst.op, inst.op_name.lower(), _is_cdna(inst)
src = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0), cdna)
if not cdna:
if op == SOP1Op.S_GETPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}"
if op in (SOP1Op.S_SETPC_B64, SOP1Op.S_RFE_B64): return f"{name} {src}"
if op == SOP1Op.S_SWAPPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}, {src}"
if op in (SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64): return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, sendmsg({MSG.get(inst.ssrc0, str(inst.ssrc0))})"
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs(), cdna)}, {src}"
def _disasm_sop2(inst: SOP2) -> str:
cdna = _is_cdna(inst)
return f"{inst.op_name.lower()} {_fmt_sdst(inst.sdst, inst.dst_regs(), cdna)}, {inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0), cdna)}, {inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1), cdna)}"
def _disasm_sopc(inst: SOPC) -> str:
cdna = _is_cdna(inst)
s0 = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0), cdna)
s1 = inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1), cdna)
return f"{inst.op_name.lower()} {s0}, {s1}"
def _disasm_sopk(inst: SOPK) -> str:
op, name, cdna = inst.op, inst.op_name.lower(), _is_cdna(inst)
# s_setreg_imm32_b32 has a 32-bit literal value
if name == 's_setreg_imm32_b32' or (not cdna and op == SOPKOp.S_SETREG_IMM32_B32):
hid, hoff, hsz = inst.simm16 & 0x3f, (inst.simm16 >> 6) & 0x1f, ((inst.simm16 >> 11) & 0x1f) + 1
hs = f"0x{inst.simm16:x}" if hid in (16, 17) else f"hwreg({HWREG.get(hid, str(hid))}, {hoff}, {hsz})"
return f"{name} {hs}, 0x{inst._literal:x}"
if not cdna and op == SOPKOp.S_VERSION: return f"{name} 0x{inst.simm16:x}"
if (not cdna and op in (SOPKOp.S_SETREG_B32, SOPKOp.S_GETREG_B32)) or (cdna and name in ('s_setreg_b32', 's_getreg_b32')):
hid, hoff, hsz = inst.simm16 & 0x3f, (inst.simm16 >> 6) & 0x1f, ((inst.simm16 >> 11) & 0x1f) + 1
hs = f"0x{inst.simm16:x}" if hid in (16, 17) else f"hwreg({HWREG.get(hid, str(hid))}, {hoff}, {hsz})"
return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1, cdna)}" if 'setreg' in name else f"{name} {_fmt_sdst(inst.sdst, 1, cdna)}, {hs}"
if not cdna and op in (SOPKOp.S_SUBVECTOR_LOOP_BEGIN, SOPKOp.S_SUBVECTOR_LOOP_END):
return f"{name} {_fmt_sdst(inst.sdst, 1)}, 0x{inst.simm16:x}"
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs(), cdna)}, 0x{inst.simm16:x}"
def _disasm_vinterp(inst: VINTERP) -> str:
mods = _mods((inst.waitexp, f"wait_exp:{inst.waitexp}"), (inst.clmp, "clamp"))
return f"{inst.op_name.lower()} v{inst.vdst}, {inst.lit(inst.src0, inst.neg & 1)}, {inst.lit(inst.src1, inst.neg & 2)}, {inst.lit(inst.src2, inst.neg & 4)}" + (" " + mods if mods else "")
DISASM_HANDLERS = {VOP1: _disasm_vop1, VOP2: _disasm_vop2, VOPC: _disasm_vopc, VOP3: _disasm_vop3, VOP3SD: _disasm_vop3sd, VOPD: _disasm_vopd, VOP3P: _disasm_vop3p,
VINTERP: _disasm_vinterp, SOPP: _disasm_sopp, SMEM: _disasm_smem, DS: _disasm_ds, FLAT: _disasm_flat, MUBUF: _disasm_buf, MTBUF: _disasm_buf,
MIMG: _disasm_mimg, SOP1: _disasm_sop1, SOP2: _disasm_sop2, SOPC: _disasm_sopc, SOPK: _disasm_sopk}
def disasm(inst: Inst) -> str: return DISASM_HANDLERS[type(inst)](inst)
# ═══════════════════════════════════════════════════════════════════════════════
# ASSEMBLER
# ═══════════════════════════════════════════════════════════════════════════════
SPEC_REGS = {'vcc_lo': RawImm(106), 'vcc_hi': RawImm(107), 'vcc': RawImm(106), 'null': RawImm(124), 'off': RawImm(124), 'm0': RawImm(125),
'exec_lo': RawImm(126), 'exec_hi': RawImm(127), 'exec': RawImm(126), 'scc': RawImm(253), 'src_scc': RawImm(253)}
FLOATS = {str(k): k for k in FLOAT_ENC} # Valid float literal strings: '0.5', '-0.5', '1.0', etc.
REG_MAP: dict[str, _RegFactory] = {'s': s, 'v': v, 't': ttmp, 'ttmp': ttmp}
SMEM_OPS = {'s_load_b32', 's_load_b64', 's_load_b128', 's_load_b256', 's_load_b512',
's_buffer_load_b32', 's_buffer_load_b64', 's_buffer_load_b128', 's_buffer_load_b256', 's_buffer_load_b512',
's_scratch_load_dword', 's_scratch_load_dwordx2', 's_scratch_load_dwordx4',
's_scratch_store_dword', 's_scratch_store_dwordx2', 's_scratch_store_dwordx4',
's_store_dword', 's_store_dwordx2', 's_store_dwordx4',
's_buffer_store_dword', 's_buffer_store_dwordx2', 's_buffer_store_dwordx4',
's_atc_probe', 's_atc_probe_buffer'}
SPEC_DSL = {'vcc_lo': 'VCC_LO', 'vcc_hi': 'VCC_HI', 'vcc': 'VCC_LO', 'null': 'NULL', 'off': 'OFF', 'm0': 'M0',
'exec_lo': 'EXEC_LO', 'exec_hi': 'EXEC_HI', 'exec': 'EXEC_LO', 'scc': 'SCC', 'src_scc': 'SCC'}
SPEC_DSL_CDNA = {'vcc_lo': 'VCC_LO', 'vcc_hi': 'VCC_HI', 'vcc': 'VCC_LO', 'null': 'NULL', 'off': 'OFF', 'm0': 'M0',
'exec_lo': 'EXEC_LO', 'exec_hi': 'EXEC_HI', 'exec': 'EXEC_LO', 'scc': 'SCC', 'src_scc': 'SRC_SCC',
'flat_scratch_lo': 'FLAT_SCRATCH_LO', 'flat_scratch_hi': 'FLAT_SCRATCH_HI', 'flat_scratch': 'FLAT_SCRATCH',
'xnack_mask_lo': 'XNACK_MASK_LO', 'xnack_mask_hi': 'XNACK_MASK_HI', 'xnack_mask': 'XNACK_MASK',
'src_vccz': 'SRC_VCCZ', 'src_execz': 'SRC_EXECZ', 'vccz': 'SRC_VCCZ', 'execz': 'SRC_EXECZ',
'src_lds_direct': 'SRC_LDS_DIRECT', 'lds_direct': 'SRC_LDS_DIRECT'}
def _op2dsl(op: str, arch: str = "rdna3") -> str:
op = op.strip()
neg = op.startswith('-') and not (op[1:2].isdigit() or (len(op) > 2 and op[1] == '0' and op[2] in 'xX'))
if neg: op = op[1:]
abs_ = (op.startswith('|') and op.endswith('|')) or (op.startswith('abs(') and op.endswith(')'))
if abs_: op = op[1:-1] if op.startswith('|') else op[4:-1]
hi = ".h" if op.endswith('.h') else ".l" if op.endswith('.l') else ""
if hi: op = op[:-2]
lo = op.lower()
spec_dsl = SPEC_DSL_CDNA if arch == "cdna" else SPEC_DSL
def wrap(b): return f"{'-' if neg else ''}abs({b}){hi}" if abs_ else f"-{b}{hi}" if neg else f"{b}{hi}"
if lo in spec_dsl: return wrap(spec_dsl[lo])
if op in FLOATS: return wrap(op)
rp = {'s': 's', 'v': 'v', 't': 'ttmp', 'ttmp': 'ttmp'}
if m := re.match(r'^([asvt](?:tmp)?)\[(\d+):(\d+)\]$', lo): return wrap(f"{rp.get(m.group(1), m.group(1))}[{m.group(2)}:{m.group(3)}]")
if m := re.match(r'^([asvt](?:tmp)?)(\d+)$', lo): return wrap(f"{rp.get(m.group(1), m.group(1))}[{m.group(2)}]")
if re.match(r'^-?\d+$|^-?0x[0-9a-fA-F]+$', op): return f"SrcMod({op}, neg={neg}, abs_={abs_})" if neg or abs_ else op
return wrap(op)
def _parse_ops(s: str) -> list[str]:
ops, cur, depth, pipe = [], "", 0, False
for c in s:
if c in '[(': depth += 1
elif c in '])': depth -= 1
elif c == '|': pipe = not pipe
if c == ',' and depth == 0 and not pipe: ops.append(cur.strip()); cur = ""
else: cur += c
if cur.strip(): ops.append(cur.strip())
return ops
def _extract(text: str, pat: str, flags=re.I):
if m := re.search(pat, text, flags): return m, text[:m.start()] + ' ' + text[m.end():]
return None, text
# Instruction aliases: LLVM uses different names for some instructions
_ALIASES = {
'v_cmp_tru_f16': 'v_cmp_t_f16', 'v_cmp_tru_f32': 'v_cmp_t_f32', 'v_cmp_tru_f64': 'v_cmp_t_f64',
'v_cmpx_tru_f16': 'v_cmpx_t_f16', 'v_cmpx_tru_f32': 'v_cmpx_t_f32', 'v_cmpx_tru_f64': 'v_cmpx_t_f64',
'v_cvt_flr_i32_f32': 'v_cvt_floor_i32_f32', 'v_cvt_rpi_i32_f32': 'v_cvt_nearest_i32_f32',
'v_ffbh_i32': 'v_cls_i32', 'v_ffbh_u32': 'v_clz_i32_u32', 'v_ffbl_b32': 'v_ctz_i32_b32',
'v_cvt_pkrtz_f16_f32': 'v_cvt_pk_rtz_f16_f32', 'v_fmac_legacy_f32': 'v_fmac_dx9_zero_f32', 'v_mul_legacy_f32': 'v_mul_dx9_zero_f32',
# SMEM aliases (dword -> b32, dwordx2 -> b64, etc.)
's_load_dword': 's_load_b32', 's_load_dwordx2': 's_load_b64', 's_load_dwordx4': 's_load_b128',
's_load_dwordx8': 's_load_b256', 's_load_dwordx16': 's_load_b512',
's_buffer_load_dword': 's_buffer_load_b32', 's_buffer_load_dwordx2': 's_buffer_load_b64',
's_buffer_load_dwordx4': 's_buffer_load_b128', 's_buffer_load_dwordx8': 's_buffer_load_b256',
's_buffer_load_dwordx16': 's_buffer_load_b512',
# VOP3 aliases
'v_cvt_pknorm_i16_f16': 'v_cvt_pk_norm_i16_f16', 'v_cvt_pknorm_u16_f16': 'v_cvt_pk_norm_u16_f16',
'v_add3_nc_u32': 'v_add3_u32', 'v_xor_add_u32': 'v_xad_u32',
# VINTERP aliases
'v_interp_p2_new_f32': 'v_interp_p2_f32',
# SOP1 aliases
's_ff1_i32_b32': 's_ctz_i32_b32', 's_ff1_i32_b64': 's_ctz_i32_b64',
's_flbit_i32_b32': 's_clz_i32_u32', 's_flbit_i32_b64': 's_clz_i32_u64', 's_flbit_i32': 's_cls_i32', 's_flbit_i32_i64': 's_cls_i32_i64',
's_andn1_saveexec_b32': 's_and_not0_saveexec_b32', 's_andn1_saveexec_b64': 's_and_not0_saveexec_b64',
's_andn1_wrexec_b32': 's_and_not0_wrexec_b32', 's_andn1_wrexec_b64': 's_and_not0_wrexec_b64',
's_andn2_saveexec_b32': 's_and_not1_saveexec_b32', 's_andn2_saveexec_b64': 's_and_not1_saveexec_b64',
's_andn2_wrexec_b32': 's_and_not1_wrexec_b32', 's_andn2_wrexec_b64': 's_and_not1_wrexec_b64',
's_orn1_saveexec_b32': 's_or_not0_saveexec_b32', 's_orn1_saveexec_b64': 's_or_not0_saveexec_b64',
's_orn2_saveexec_b32': 's_or_not1_saveexec_b32', 's_orn2_saveexec_b64': 's_or_not1_saveexec_b64',
# SOP2 aliases
's_andn2_b32': 's_and_not1_b32', 's_andn2_b64': 's_and_not1_b64',
's_orn2_b32': 's_or_not1_b32', 's_orn2_b64': 's_or_not1_b64',
# VOP2 aliases
'v_dot2c_f32_f16': 'v_dot2acc_f32_f16',
# More VOP3 aliases
'v_fma_legacy_f32': 'v_fma_dx9_zero_f32',
}
# RDNA3-only aliases (should NOT be applied to CDNA)
_RDNA3_ONLY_ALIASES = {'v_mul_legacy_f32', 'v_fmac_legacy_f32', 'v_fma_legacy_f32',
# SMEM: RDNA3 uses b32/b64, CDNA uses dword/dwordx2
's_load_dword', 's_load_dwordx2', 's_load_dwordx4', 's_load_dwordx8', 's_load_dwordx16',
's_buffer_load_dword', 's_buffer_load_dwordx2', 's_buffer_load_dwordx4', 's_buffer_load_dwordx8', 's_buffer_load_dwordx16'}
# CDNA-specific aliases (GFX9 uses different names for some instructions)
# CDNA-specific aliases - CDNA uses dword naming, not b32
_CDNA_ALIASES = {
# VOP aliases (inverse of _CDNA_DISASM_ALIASES)
'v_cvt_pkrtz_f16_f32': 'v_cvt_pk_rtz_f16_f32',
'v_mul_legacy_f32': 'v_fmac_f64', 'v_mac_f32': 'v_dot2c_f32_bf16', 'v_madmk_f32': 'v_fmamk_f32', 'v_madak_f32': 'v_fmaak_f32',
# VOPC: v_cmp_t_fXX -> v_cmp_tru_fXX for CDNA
'v_cmp_t_f16': 'v_cmp_tru_f16', 'v_cmp_t_f32': 'v_cmp_tru_f32', 'v_cmp_t_f64': 'v_cmp_tru_f64',
'v_cmpx_t_f16': 'v_cmpx_tru_f16', 'v_cmpx_t_f32': 'v_cmpx_tru_f32', 'v_cmpx_t_f64': 'v_cmpx_tru_f64',
# VOP1: flr/rpi -> floor/nearest for CDNA
'v_cvt_flr_i32_f32': 'v_cvt_floor_i32_f32', 'v_cvt_rpi_i32_f32': 'v_cvt_nearest_i32_f32',
}
def _apply_alias(text: str, arch: str = "rdna3") -> str:
mn = text.split()[0].lower() if ' ' in text else text.lower().rstrip('_')
aliases = _CDNA_ALIASES if arch == "cdna" else _ALIASES
# Try exact match first, then strip _e32/_e64 suffix
for m in (mn, mn.removesuffix('_e32'), mn.removesuffix('_e64')):
if m in aliases: return aliases[m] + text[len(m):]
# Also check common aliases, but skip RDNA3-only ones for CDNA
if m in _ALIASES and not (arch == "cdna" and m in _RDNA3_ONLY_ALIASES): return _ALIASES[m] + text[len(m):]
return text
def get_dsl(text: str, arch: str = "rdna3") -> str:
text, kw = _apply_alias(text.strip(), arch), []
# Extract modifiers
for pat, val in [(r'\s+mul:2(?:\s|$)', 1), (r'\s+mul:4(?:\s|$)', 2), (r'\s+div:2(?:\s|$)', 3)]:
if (m := _extract(text, pat))[0]: kw.append(f'omod={val}'); text = m[1]; break
if (m := _extract(text, r'\s+clamp(?:\s|$)'))[0]: kw.append('clmp=1'); text = m[1]
opsel, m, text = None, *_extract(text, r'\s+op_sel:\[([^\]]+)\]')
if m:
bits, mn = [int(x.strip()) for x in m.group(1).split(',')], text.split()[0].lower()
is3p = mn.startswith(('v_pk_', 'v_wmma_', 'v_dot', 'v_mad_mix', 'v_fma_mix'))
opsel = (bits[0] | (bits[1] << 1) | (bits[2] << 2)) if len(bits) == 3 and is3p else \
(bits[0] | (bits[1] << 1) | (bits[2] << 3)) if len(bits) == 3 else sum(b << i for i, b in enumerate(bits))
m, text = _extract(text, r'\s+wait_exp:(\d+)'); waitexp = m.group(1) if m else None
m, text = _extract(text, r'\s+offset:(0x[0-9a-fA-F]+|-?\d+)'); off_val = m.group(1) if m else None
m, text = _extract(text, r'\s+dlc(?:\s|$)'); dlc = 1 if m else None
m, text = _extract(text, r'\s+glc(?:\s|$)'); glc = 1 if m else None
m, text = _extract(text, r'\s+slc(?:\s|$)'); slc = 1 if m else None
m, text = _extract(text, r'\s+tfe(?:\s|$)'); tfe = 1 if m else None
m, text = _extract(text, r'\s+offen(?:\s|$)'); offen = 1 if m else None
m, text = _extract(text, r'\s+idxen(?:\s|$)'); idxen = 1 if m else None
m, text = _extract(text, r'\s+format:\[([^\]]+)\]'); fmt_val = m.group(1) if m else None
m, text = _extract(text, r'\s+format:(\d+)'); fmt_val = m.group(1) if m and not fmt_val else fmt_val
m, text = _extract(text, r'\s+neg_lo:\[([^\]]+)\]'); neg_lo = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None
m, text = _extract(text, r'\s+neg_hi:\[([^\]]+)\]'); neg_hi = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None
m, text = _extract(text, r'\s+op_sel_hi:\[([^\]]+)\]')
opsel_hi, opsel_hi_count = (sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))), len(m.group(1).split(','))) if m else (None, 0)
m, text = _extract(text, r'\s+gds(?:\s|$)'); gds = 1 if m else None
m, text = _extract(text, r'\s+offset0:(\d+)'); offset0 = m.group(1) if m else None
m, text = _extract(text, r'\s+offset1:(\d+)'); offset1 = m.group(1) if m else None
m, text = _extract(text, r'\s+lds(?:\s|$)'); lds = 1 if m else None
# SDWA modifiers
_SDWA_SEL = {'BYTE_0': 0, 'BYTE_1': 1, 'BYTE_2': 2, 'BYTE_3': 3, 'WORD_0': 4, 'WORD_1': 5, 'DWORD': 6}
_SDWA_DST_UNUSED = {'UNUSED_PAD': 0, 'UNUSED_SEXT': 1, 'UNUSED_PRESERVE': 2}
m, text = _extract(text, r'\s+dst_sel:(\w+)'); sdwa_dst_sel = _SDWA_SEL.get(m.group(1), 6) if m else None
m, text = _extract(text, r'\s+dst_unused:(\w+)'); sdwa_dst_unused = _SDWA_DST_UNUSED.get(m.group(1), 0) if m else None
m, text = _extract(text, r'\s+src0_sel:(\w+)'); sdwa_src0_sel = _SDWA_SEL.get(m.group(1), 6) if m else None
m, text = _extract(text, r'\s+src1_sel:(\w+)'); sdwa_src1_sel = _SDWA_SEL.get(m.group(1), 6) if m else None
m, text = _extract(text, r'\s+sext\(src0\)'); sdwa_src0_sext = 1 if m else None
m, text = _extract(text, r'\s+sext\(src1\)'); sdwa_src1_sext = 1 if m else None
# DPP modifiers
m, text = _extract(text, r'\s+quad_perm:\[(\d+),(\d+),(\d+),(\d+)\]')
dpp_ctrl = int(m.group(1)) | (int(m.group(2)) << 2) | (int(m.group(3)) << 4) | (int(m.group(4)) << 6) if m else None
m, text = _extract(text, r'\s+row_shl:(\d+)'); dpp_ctrl = 0x100 | int(m.group(1)) if m else dpp_ctrl
m, text = _extract(text, r'\s+row_shr:(\d+)'); dpp_ctrl = 0x110 | int(m.group(1)) if m else dpp_ctrl
m, text = _extract(text, r'\s+row_ror:(\d+)'); dpp_ctrl = 0x120 | int(m.group(1)) if m else dpp_ctrl
m, text = _extract(text, r'\s+wave_shl:1'); dpp_ctrl = 0x130 if m else dpp_ctrl
m, text = _extract(text, r'\s+wave_rol:1'); dpp_ctrl = 0x134 if m else dpp_ctrl
m, text = _extract(text, r'\s+wave_shr:1'); dpp_ctrl = 0x138 if m else dpp_ctrl
m, text = _extract(text, r'\s+wave_ror:1'); dpp_ctrl = 0x13c if m else dpp_ctrl
m, text = _extract(text, r'\s+row_mirror(?:\s|$)'); dpp_ctrl = 0x140 if m else dpp_ctrl
m, text = _extract(text, r'\s+row_half_mirror(?:\s|$)'); dpp_ctrl = 0x141 if m else dpp_ctrl
m, text = _extract(text, r'\s+row_bcast:15(?:\s|$)'); dpp_ctrl = 0x142 if m else dpp_ctrl
m, text = _extract(text, r'\s+row_bcast:31(?:\s|$)'); dpp_ctrl = 0x143 if m else dpp_ctrl
m, text = _extract(text, r'\s+row_mask:(0x[0-9a-fA-F]+|\d+)'); dpp_row_mask = int(m.group(1), 0) if m else None; dpp_row_mask_specified = m is not None
m, text = _extract(text, r'\s+bank_mask:(0x[0-9a-fA-F]+|\d+)'); dpp_bank_mask = int(m.group(1), 0) if m else None; dpp_bank_mask_specified = m is not None
m, text = _extract(text, r'\s+bound_ctrl:([01])'); dpp_bound_ctrl = 1 if m else None # bound_ctrl:0 or bound_ctrl:1 both set bit to 1
if waitexp: kw.append(f'waitexp={waitexp}')
parts = text.replace(',', ' ').split()
if not parts: raise ValueError("empty instruction")
mn, op_str = parts[0].lower(), text[len(parts[0]):].strip()
ops, args = _parse_ops(op_str), [_op2dsl(o, arch) for o in _parse_ops(op_str)]
# s_waitcnt
if mn == 's_waitcnt':
vm, exp, lgkm = 0x3f, 0x7, 0x3f
for p in op_str.replace(',', ' ').split():
if m := re.match(r'vmcnt\((\d+)\)', p): vm = int(m.group(1))
elif m := re.match(r'expcnt\((\d+)\)', p): exp = int(m.group(1))
elif m := re.match(r'lgkmcnt\((\d+)\)', p): lgkm = int(m.group(1))
elif re.match(r'^0x[0-9a-f]+$|^\d+$', p): return f"s_waitcnt(simm16={int(p, 0)})"
return f"s_waitcnt(simm16={waitcnt(vm, exp, lgkm)})"
# SDWA instructions (CDNA)
if mn.endswith('_sdwa') and arch == "cdna":
base_mn = mn[:-5] # strip _sdwa
# Get VOP1/VOP2/VOPC opcode
from extra.assembly.amd.autogen.cdna.ins import VOP1Op, VOP2Op, VOPCOp, SDWA
vop1_op = getattr(VOP1Op, base_mn.upper(), None)
vop2_op = getattr(VOP2Op, base_mn.upper(), None)
vopc_op = getattr(VOPCOp, base_mn.upper(), None)
if vop1_op is None and vop2_op is None and vopc_op is None: raise ValueError(f"unknown SDWA instruction: {mn}")
# Parse operands: vdst, [vcc,] src0[, vsrc1]
# For carry-out ops (v_add_co_u32, etc.), vcc is at ops[1], src0 is at ops[2], vsrc1 is at ops[3]
vdst = args[0] # keep as v[N] for VGPRField
carry_out_ops = {'v_add_co_u32', 'v_sub_co_u32', 'v_subrev_co_u32', 'v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'}
has_carry = base_mn in carry_out_ops
src0_idx = 2 if has_carry else 1
src1_idx = 3 if has_carry else 2
src0_raw = ops[src0_idx].strip().lower() if len(ops) > src0_idx else 'v0'
src0 = args[1] if len(args) > 1 else 'v[0]'
# Parse neg/abs/sext modifiers from src0_raw
src0_neg_mod = src0_raw.startswith('-') and not src0_raw[1:2].isdigit() and src0_raw[1:3] != '0.'
src0_abs_mod = src0_raw.startswith('|') and src0_raw.endswith('|')
src0_sext_mod = src0_raw.startswith('sext(') and src0_raw.endswith(')')
if src0_neg_mod: src0_raw = src0_raw[1:]
if src0_abs_mod: src0_raw = src0_raw[1:-1]
if src0_sext_mod: src0_raw = src0_raw[5:-1]
# Extract src0 register number for RawImm
_SDWA_SGPR_MAP = {'flat_scratch_lo': 102, 'flat_scratch_hi': 103, 'xnack_mask_lo': 104, 'xnack_mask_hi': 105,
'vcc_lo': 106, 'vcc_hi': 107, 'vcc': 106, 'ttmp0': 108, 'ttmp1': 109, 'ttmp2': 110, 'ttmp3': 111,
'ttmp4': 112, 'ttmp5': 113, 'ttmp6': 114, 'ttmp7': 115, 'ttmp8': 116, 'ttmp9': 117,
'ttmp10': 118, 'ttmp11': 119, 'ttmp12': 120, 'ttmp13': 121, 'ttmp14': 122, 'ttmp15': 123,
'm0': 124, 'exec_lo': 126, 'exec_hi': 127,
'src_vccz': 251, 'src_execz': 252, 'src_scc': 253}
# Inline constant encoding for SDWA src0
_SDWA_INLINE_CONST = {'0': 128, '0.0': 128, '1': 129, '1.0': 242, '2': 130, '3': 131, '4': 132, '-1': 193, '-2': 194, '-3': 195, '-4': 196,
'0.5': 240, '-0.5': 241, '-1.0': 243, '2.0': 244, '-2.0': 245, '4.0': 246, '-4.0': 247}
if src0_raw.startswith('v') and (src0_raw[1:].isdigit() or src0_raw[1] == '['):
src0_val = int(src0_raw[1:].split('[')[0]) if src0_raw[1:].isdigit() else int(src0_raw.split('[')[1].split(']')[0])
s0 = 0
elif src0_raw.startswith('s') and (src0_raw[1:].isdigit() or src0_raw[1] == '['):
src0_val = int(src0_raw[1:].split('[')[0]) if src0_raw[1:].isdigit() else int(src0_raw.split('[')[1].split(':')[0])
s0 = 1
elif src0_raw in _SDWA_SGPR_MAP: src0_val, s0 = _SDWA_SGPR_MAP[src0_raw], 1
elif src0_raw.startswith('ttmp') and src0_raw[4:].isdigit(): src0_val, s0 = 108 + int(src0_raw[4:]), 1
elif src0_raw in _SDWA_INLINE_CONST: src0_val, s0 = _SDWA_INLINE_CONST[src0_raw], 1
elif src0_raw.lstrip('-').replace('.', '', 1).isdigit():
# Integer or float inline constant
if '.' in src0_raw:
src0_val, s0 = _SDWA_INLINE_CONST.get(src0_raw, (0, 0))
if src0_val == 0 and src0_raw != '0.0': s0 = 0
else: s0 = 1
else:
ival = int(src0_raw)
if 0 <= ival <= 64: src0_val, s0 = 128 + ival, 1
elif -16 <= ival < 0: src0_val, s0 = 192 + (-ival), 1
else: src0_val, s0 = 0, 0 # Not an inline constant
else: src0_val, s0 = 0, 0
# For VOP2, parse vsrc1 and its modifiers
vsrc1_val, src1_neg_mod, src1_abs_mod, src1_sext_mod, s1 = 0, False, False, False, 0
if vop2_op is not None and len(ops) > src1_idx:
src1_raw = ops[src1_idx].strip().lower()
# Parse neg/abs/sext modifiers
src1_neg_mod = src1_raw.startswith('-') and not src1_raw[1:2].isdigit() and src1_raw[1:3] != '0.'
if src1_neg_mod: src1_raw = src1_raw[1:]
src1_abs_mod = src1_raw.startswith('|') and src1_raw.endswith('|')
if src1_abs_mod: src1_raw = src1_raw[1:-1]
src1_sext_mod = src1_raw.startswith('sext(') and src1_raw.endswith(')')
if src1_sext_mod: src1_raw = src1_raw[5:-1]
# Extract vsrc1 register number
if src1_raw.startswith('v') and (src1_raw[1:].isdigit() or src1_raw[1] == '['):
vsrc1_val = int(src1_raw[1:].split('[')[0]) if src1_raw[1:].isdigit() else int(src1_raw.split('[')[1].split(']')[0])
s1 = 0
elif src1_raw in _SDWA_SGPR_MAP: vsrc1_val, s1 = _SDWA_SGPR_MAP[src1_raw], 1
elif src1_raw in _SDWA_INLINE_CONST: vsrc1_val, s1 = _SDWA_INLINE_CONST[src1_raw], 1
# Build SDWA kwargs
# VOP1 SDWA: vop_op = VOP1 opcode, vop2_op = 0x3f (63)
# VOP2 SDWA: vop_op = vsrc1, vop2_op = VOP2 opcode
# VOPC SDWA: vop_op = src1, vop2_op = 0x3e (62), vdst = VOPC opcode, dst_sel/dst_u/clmp/omod = sdst encoding
sdwa_kw = []
if vopc_op is not None:
# VOPC SDWA: opcode goes in vdst field, vop2_op=62
# Parse sdst from first operand (e.g., vcc, s[n:n+1], flat_scratch, ttmp[n:n+1])
_SDWA_SDST_MAP = {'vcc': 0, 'vcc_lo': 0, 'flat_scratch': 128+102, 'flat_scratch_lo': 128+102,
'ttmp0': 128+108, 'ttmp2': 128+110, 'ttmp4': 128+112, 'ttmp6': 128+114,
'ttmp8': 128+116, 'ttmp10': 128+118, 'ttmp12': 128+120, 'ttmp14': 128+122}
sdst_raw = ops[0].strip().lower()
if sdst_raw in _SDWA_SDST_MAP: sdst_enc = _SDWA_SDST_MAP[sdst_raw]
elif sdst_raw.startswith('s[') and ':' in sdst_raw: sdst_enc = 128 + int(sdst_raw[2:].split(':')[0])
elif sdst_raw.startswith('s') and sdst_raw[1:].isdigit(): sdst_enc = 128 + int(sdst_raw[1:])
elif sdst_raw.startswith('ttmp[') and ':' in sdst_raw: sdst_enc = 128 + 108 + int(sdst_raw[5:].split(':')[0])
else: sdst_enc = 0 # Default: vcc
# For VOPC SDWA, src0 is ops[1], src1 is ops[2]
src0_raw = ops[1].strip().lower() if len(ops) > 1 else 'v0'
src1_raw = ops[2].strip().lower() if len(ops) > 2 else 'v0'
# Parse src0 with modifiers
src0_neg_mod = src0_raw.startswith('-') and not src0_raw[1:2].isdigit()
if src0_neg_mod: src0_raw = src0_raw[1:]
src0_abs_mod = src0_raw.startswith('|') and src0_raw.endswith('|')
if src0_abs_mod: src0_raw = src0_raw[1:-1]
src0_sext_mod = src0_raw.startswith('sext(') and src0_raw.endswith(')')
if src0_sext_mod: src0_raw = src0_raw[5:-1]
# Extract src0 value and type
if src0_raw.startswith('v') and (src0_raw[1:].isdigit() or src0_raw[1] == '['):
src0_val = int(src0_raw[1:].split('[')[0]) if src0_raw[1:].isdigit() else int(src0_raw.split('[')[1].split(']')[0])
s0 = 0
elif src0_raw.startswith('s') and (src0_raw[1:].isdigit() or src0_raw[1] == '['):
src0_val = int(src0_raw[1:].split('[')[0]) if src0_raw[1:].isdigit() else int(src0_raw.split('[')[1].split(':')[0])
s0 = 1
elif src0_raw in _SDWA_SGPR_MAP:
src0_val, s0 = _SDWA_SGPR_MAP[src0_raw], 1
elif src0_raw in _SDWA_INLINE_CONST:
src0_val, s0 = _SDWA_INLINE_CONST[src0_raw], 1
elif src0_raw.lstrip('-').replace('.', '', 1).isdigit():
# Integer or float inline constant
if '.' in src0_raw:
src0_val = _SDWA_INLINE_CONST.get(src0_raw, 128)
s0 = 1
else:
ival = int(src0_raw)
if 0 <= ival <= 64: src0_val, s0 = 128 + ival, 1
elif -16 <= ival < 0: src0_val, s0 = 192 + (-ival), 1
else: src0_val, s0 = 0, 0
else: src0_val, s0 = 0, 0
# Parse src1 with modifiers
src1_neg_mod = src1_raw.startswith('-') and not src1_raw[1:2].isdigit()
if src1_neg_mod: src1_raw = src1_raw[1:]
src1_abs_mod = src1_raw.startswith('|') and src1_raw.endswith('|')
if src1_abs_mod: src1_raw = src1_raw[1:-1]
src1_sext_mod = src1_raw.startswith('sext(') and src1_raw.endswith(')')
if src1_sext_mod: src1_raw = src1_raw[5:-1]
# Extract src1 value and type
if src1_raw.startswith('v') and (src1_raw[1:].isdigit() or src1_raw[1] == '['):
vsrc1_val = int(src1_raw[1:].split('[')[0]) if src1_raw[1:].isdigit() else int(src1_raw.split('[')[1].split(']')[0])
s1 = 0
else: vsrc1_val, s1 = 0, 0
sdwa_kw.append(f'vop_op={vsrc1_val}')
sdwa_kw.append('vop2_op=62') # 0x3e indicates VOPC mode
sdwa_kw.append(f'vdst=RawImm({vopc_op.value})') # VOPC opcode in vdst
sdwa_kw.append(f'src0=RawImm({src0_val})')
# Encode sdst in dst_sel/dst_u/clmp/omod fields
sdwa_kw.append(f'dst_sel={sdst_enc & 7}')
sdwa_kw.append(f'dst_u={(sdst_enc >> 3) & 3}')
sdwa_kw.append(f'clmp={(sdst_enc >> 5) & 1}')
sdwa_kw.append(f'omod={(sdst_enc >> 6) & 3}')
sdwa_kw.append(f'src0_sel={sdwa_src0_sel if sdwa_src0_sel is not None else 6}')
sdwa_kw.append(f'src1_sel={sdwa_src1_sel if sdwa_src1_sel is not None else 6}')
if src0_sext_mod or sdwa_src0_sext: sdwa_kw.append('src0_sext=1')
if src0_neg_mod: sdwa_kw.append('src0_neg=1')
if src0_abs_mod: sdwa_kw.append('src0_abs=1')
if s0: sdwa_kw.append('s0=1')
if src1_sext_mod or sdwa_src1_sext: sdwa_kw.append('src1_sext=1')
if src1_neg_mod: sdwa_kw.append('src1_neg=1')
if src1_abs_mod: sdwa_kw.append('src1_abs=1')
if s1: sdwa_kw.append('s1=1')
return f"SDWA({', '.join(sdwa_kw)})"
elif vop1_op is not None:
sdwa_kw.append(f'vop_op={vop1_op.value}')
sdwa_kw.append('vop2_op=63') # 0x3f indicates VOP1 mode
else:
sdwa_kw.append(f'vop_op={vsrc1_val}') # vsrc1 goes in vop_op for VOP2 SDWA
sdwa_kw.append(f'vop2_op={vop2_op.value}')
sdwa_kw.append(f'vdst={vdst}')
sdwa_kw.append(f'src0=RawImm({src0_val})')
# Defaults: dst_sel=6 (DWORD), dst_unused=2 (UNUSED_PRESERVE), src0_sel=6 (DWORD), src1_sel=6 (DWORD)
sdwa_kw.append(f'dst_sel={sdwa_dst_sel if sdwa_dst_sel is not None else 6}')
sdwa_kw.append(f'dst_u={sdwa_dst_unused if sdwa_dst_unused is not None else 2}')
sdwa_kw.append(f'src0_sel={sdwa_src0_sel if sdwa_src0_sel is not None else 6}')
if sdwa_src0_sext or src0_sext_mod: sdwa_kw.append('src0_sext=1')
if src0_neg_mod: sdwa_kw.append('src0_neg=1')
if src0_abs_mod: sdwa_kw.append('src0_abs=1')
if s0: sdwa_kw.append('s0=1')
# VOP2 SDWA src1 modifiers and defaults
if vop2_op is not None:
sdwa_kw.append(f'src1_sel={sdwa_src1_sel if sdwa_src1_sel is not None else 6}')
if sdwa_src1_sext or src1_sext_mod: sdwa_kw.append('src1_sext=1')
if src1_neg_mod: sdwa_kw.append('src1_neg=1')
if src1_abs_mod: sdwa_kw.append('src1_abs=1')
if s1: sdwa_kw.append('s1=1')
# Add clamp/omod from kw if present
for k in kw:
if k.startswith('clmp='): sdwa_kw.append(k)
elif k.startswith('omod='): sdwa_kw.append(k)
return f"SDWA({', '.join(sdwa_kw)})"
# DPP instructions (CDNA)
if mn.endswith('_dpp') and arch == "cdna" and dpp_ctrl is not None:
base_mn = mn[:-4] # strip _dpp
from extra.assembly.amd.autogen.cdna.ins import VOP1Op, VOP2Op, DPP
vop1_op = getattr(VOP1Op, base_mn.upper(), None)
vop2_op = getattr(VOP2Op, base_mn.upper(), None)
if vop1_op is None and vop2_op is None: raise ValueError(f"unknown DPP instruction: {mn}")
# Parse operands: vdst, [vcc,] src0[, vsrc1]
# For carry-out ops (v_add_co_u32, etc.), vcc is at ops[1], src0 is at ops[2], vsrc1 is at ops[3]
vdst = args[0]
carry_out_ops = {'v_add_co_u32', 'v_sub_co_u32', 'v_subrev_co_u32', 'v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'}
has_carry = base_mn in carry_out_ops
src0_idx = 2 if has_carry else 1
src1_idx = 3 if has_carry else 2
src0_raw = ops[src0_idx].strip().lower() if len(ops) > src0_idx else 'v0'
# Parse neg/abs modifiers for src0 (neg before abs for -|v1| case)
src0_neg_mod = src0_raw.startswith('-') and not src0_raw[1:2].isdigit()
if src0_neg_mod: src0_raw = src0_raw[1:]
src0_abs_mod = src0_raw.startswith('|') and src0_raw.endswith('|')
if src0_abs_mod: src0_raw = src0_raw[1:-1]
# Extract src0 VGPR number
if src0_raw.startswith('v') and src0_raw[1:].isdigit(): src0_val = int(src0_raw[1:])
elif 'v[' in src0_raw: src0_val = int(src0_raw.split('[')[1].split(']')[0])
else: src0_val = 0
# For VOP2, parse vsrc1 and its modifiers
vsrc1_val, src1_neg_mod, src1_abs_mod = 0, False, False
if vop2_op is not None and len(ops) > src1_idx:
src1_raw = ops[src1_idx].strip().lower()
src1_neg_mod = src1_raw.startswith('-') and not src1_raw[1:2].isdigit()
if src1_neg_mod: src1_raw = src1_raw[1:]
src1_abs_mod = src1_raw.startswith('|') and src1_raw.endswith('|')
if src1_abs_mod: src1_raw = src1_raw[1:-1]
if src1_raw.startswith('v') and src1_raw[1:].isdigit(): vsrc1_val = int(src1_raw[1:])
elif 'v[' in src1_raw: vsrc1_val = int(src1_raw.split('[')[1].split(']')[0])
# Build DPP kwargs
# VOP1 DPP: vop_op = VOP1 opcode, vop2_op = 0x3f
# VOP2 DPP: vop_op = vsrc1, vop2_op = VOP2 opcode
dpp_kw = []
if vop1_op is not None:
dpp_kw.append(f'vop_op={vop1_op.value}')
dpp_kw.append('vop2_op=63') # 0x3f indicates VOP1 mode
else:
dpp_kw.append(f'vop_op={vsrc1_val}') # vsrc1 goes in vop_op for VOP2 DPP
dpp_kw.append(f'vop2_op={vop2_op.value}')
dpp_kw.append(f'vdst={vdst}')
dpp_kw.append(f'src0=RawImm({src0_val})')
dpp_kw.append(f'dpp_ctrl={dpp_ctrl}')
if dpp_bound_ctrl: dpp_kw.append('bound_ctrl=1')
if src0_neg_mod: dpp_kw.append('src0_neg=1')
if src0_abs_mod: dpp_kw.append('src0_abs=1')
if src1_neg_mod: dpp_kw.append('src1_neg=1')
if src1_abs_mod: dpp_kw.append('src1_abs=1')
# Default masks: if one is specified but not the other, the other defaults to 0xf
if dpp_bank_mask_specified or dpp_row_mask_specified:
dpp_kw.append(f'bank_mask={dpp_bank_mask if dpp_bank_mask is not None else 0xf}')
dpp_kw.append(f'row_mask={dpp_row_mask if dpp_row_mask is not None else 0xf}')
return f"DPP({', '.join(dpp_kw)})"
# VOPD (RDNA3 only)
if '::' in text:
xp, yp = text.split('::')
xps, yps = xp.strip().replace(',', ' ').split(), yp.strip().replace(',', ' ').split()
xo, yo = [_op2dsl(p, arch) for p in xps[1:]], [_op2dsl(p, arch) for p in yps[1:]]
vdx, sx0, vsx1 = xo[0], xo[1] if len(xo) > 1 else '0', xo[2] if len(xo) > 2 else 'v[0]'
vdy, sy0, vsy1 = yo[0], yo[1] if len(yo) > 1 else '0', yo[2] if len(yo) > 2 else 'v[0]'
lit = xo[3] if 'fmaak' in xps[0].lower() and len(xo) > 3 else yo[3] if 'fmaak' in yps[0].lower() and len(yo) > 3 else None
if 'fmamk' in xps[0].lower() and len(xo) > 3: lit, vsx1 = xo[2], xo[3]
elif 'fmamk' in yps[0].lower() and len(yo) > 3: lit, vsy1 = yo[2], yo[3]
return f"VOPD(VOPDOp.{xps[0].upper()}, VOPDOp.{yps[0].upper()}, vdstx={vdx}, vdsty={vdy}, srcx0={sx0}, vsrcx1={vsx1}, srcy0={sy0}, vsrcy1={vsy1}{f', literal={lit}' if lit else ''})"
# Special instructions
if mn == 's_setreg_imm32_b32': raise ValueError(f"unsupported: {mn}")
# v_readfirstlane_b32 has SGPR dest but encoded in vdst field - use RawImm
if mn == 'v_readfirstlane_b32' and len(args) >= 2:
dst = ops[0].strip().lower()
if dst.startswith('s') and dst[1:].isdigit(): dst_val = int(dst[1:])
elif dst.startswith('ttmp') and dst[4:].isdigit(): dst_val = 108 + int(dst[4:])
else:
sgpr_map = {'vcc_lo': 106, 'vcc_hi': 107, 'm0': 124, 'exec_lo': 126, 'exec_hi': 127,
'flat_scratch_lo': 102, 'flat_scratch_hi': 103, 'xnack_mask_lo': 104, 'xnack_mask_hi': 105,
'null': 124} # null register for RDNA3
dst_val = sgpr_map.get(dst, int(dst) if dst.isdigit() else 0)
return f"v_readfirstlane_b32_e32(vdst=RawImm({dst_val}), src0={args[1]})"
if mn in ('s_setpc_b64', 's_rfe_b64'): return f"{mn}(ssrc0={args[0]})"
if mn in ('s_cbranch_join', 's_set_gpr_idx_idx'): return f"{mn}(ssrc0={args[0]}, sdst=RawImm(0))" # No destination, only source
if mn == 's_cbranch_g_fork': return f"{mn}(ssrc0={args[0]}, ssrc1={args[1]}, sdst=RawImm(0))" # Two sources, no dest
if mn == 's_set_gpr_idx_on': return f"{mn}(ssrc0={args[0]}, ssrc1=RawImm({int(args[1], 0)}))" # Mode bits as raw value
if mn in ('s_sendmsg_rtn_b32', 's_sendmsg_rtn_b64'): return f"{mn}(sdst={args[0]}, ssrc0=RawImm({args[1].strip()}))"
if mn == 's_version': return f"{mn}(simm16={args[0]})"
if mn == 's_setreg_b32': return f"{mn}(simm16={args[0]}, sdst={args[1]})"
# SMEM: s_dcache_discard has swapped operand layout (saddr→sbase, soffset→sdata)
if arch == "cdna" and mn.startswith('s_dcache_discard'):
gs = ", glc=1" if glc else ""
# Syntax: s_dcache_discard saddr, soffset [offset:imm]
if off_val and len(ops) >= 2:
# SGPR + immediate offset: soe=1, imm=1, soffset=SGPR, offset=imm
return f"{mn}(sbase={args[0]}, sdata=RawImm(0), offset={off_val}, soffset={args[1]}, soe=1, imm=1{gs})"
if len(ops) >= 2 and re.match(r'^-?[0-9]|^-?0x', ops[1].strip().lower()):
# Immediate offset only: imm=1
return f"{mn}(sbase={args[0]}, sdata=RawImm(0), offset={args[1]}, soffset=RawImm(0), imm=1{gs})"
# SGPR offset only: imm=0, offset=SGPR
return f"{mn}(sbase={args[0]}, sdata=RawImm(0), offset={args[1]}, soffset=RawImm(0){gs})"
# SMEM: s_atomic_*/s_buffer_atomic_* uses offset field for SGPR (imm=0), not soffset
if arch == "cdna" and (mn.startswith('s_buffer_atomic') or (mn.startswith('s_atomic') and not mn.startswith('s_atc'))):
gs = ", glc=1" if glc else ""
if len(ops) >= 3:
# Syntax: s_atomic_* sdata, sbase, soffset [offset:imm]
if off_val:
# SGPR + immediate offset: soe=1, imm=1
return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={off_val}, soffset={args[2]}, soe=1, imm=1{gs})"
if re.match(r'^-?[0-9]|^-?0x', ops[2].strip().lower()):
# Immediate offset only: imm=1
return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={args[2]}, soffset=RawImm(0), imm=1{gs})"
# SGPR offset only: imm=0, offset=SGPR
return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={args[2]}, soffset=RawImm(0){gs})"
# SMEM
if mn in SMEM_OPS or (arch == "cdna" and mn.startswith(('s_load_dword', 's_buffer_load_dword'))):
gs, ds = ", glc=1" if glc else "", ", dlc=1" if dlc else ""
if arch == "cdna":
# CDNA SMEM encoding: imm=1 for immediate, soe=1 for sgpr+offset combo
if len(ops) >= 3 and re.match(r'^-?[0-9]|^-?0x', ops[2].strip().lower()):
# Immediate offset only
return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={args[2]}, soffset=RawImm(0), imm=1{gs}{ds})"
if off_val and len(ops) >= 3:
# SGPR + immediate offset: soe=1, soffset=SGPR, offset=imm
return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={off_val}, soffset={args[2]}, soe=1, imm=1{gs}{ds})"
if len(ops) >= 3:
# SGPR offset only: offset=SGPR index, soffset=0
return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={args[2]}, soffset=RawImm(0){gs}{ds})"
if len(ops) == 2:
# No offset specified: imm=1, offset=0
return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset=0, soffset=RawImm(0), imm=1{gs}{ds})"
else:
# RDNA3 encoding
if len(ops) >= 3 and re.match(r'^-?[0-9]|^-?0x', ops[2].strip().lower()):
return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={args[2]}, soffset=RawImm(124){gs}{ds})"
if off_val and len(ops) >= 3: return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={off_val}, soffset={args[2]}{gs}{ds})"
if len(ops) >= 3: return f"{mn}(sdata={args[0]}, sbase={args[1]}, soffset={args[2]}{gs}{ds})"
# Buffer (MUBUF/MTBUF) instructions
if mn.startswith(('buffer_', 'tbuffer_')):
is_tbuf = mn.startswith('tbuffer_')
# Parse format value for tbuffer
fmt_num = None
if fmt_val is not None:
if fmt_val.isdigit(): fmt_num = int(fmt_val)
else:
fmt_num = BUF_FMT.get(fmt_val.replace(' ', '')) or _parse_buf_fmt_combo(fmt_val)
# CDNA-style: BUF_DATA_FORMAT_X or BUF_NUM_FORMAT_X (or comma-separated pair)
if fmt_num is None and arch == "cdna":
_dfmt = {'INVALID': 0, '8': 1, '16': 2, '8_8': 3, '32': 4, '16_16': 5, '10_11_11': 6, '11_11_10': 7,
'10_10_10_2': 8, '2_10_10_10': 9, '8_8_8_8': 10, '32_32': 11, '16_16_16_16': 12,
'32_32_32': 13, '32_32_32_32': 14, 'RESERVED_15': 15}
_nfmt = {'UNORM': 0, 'SNORM': 1, 'USCALED': 2, 'SSCALED': 3, 'UINT': 4, 'SINT': 5, 'RESERVED_6': 6, 'FLOAT': 7}
parts = [p.strip() for p in fmt_val.split(',')]
dfmt, nfmt = 1, 0 # defaults
for p in parts:
if p.startswith('BUF_DATA_FORMAT_'): dfmt = _dfmt.get(p[16:], 1)
elif p.startswith('BUF_NUM_FORMAT_'): nfmt = _nfmt.get(p[15:], 0)
fmt_num = dfmt | (nfmt << 4)
# Handle special no-arg buffer ops
if mn in ('buffer_gl0_inv', 'buffer_gl1_inv', 'buffer_wbl2', 'buffer_inv'): return f"{mn}()"
# Build modifiers string - CDNA uses sc0/nt for glc/slc
if arch == "cdna":
buf_mods = "".join([f", offset={off_val}" if off_val else "", ", sc0=1" if glc else "", ", nt=1" if slc else "",
", offen=1" if offen else "", ", idxen=1" if idxen else "", ", lds=1" if lds else ""])
else:
buf_mods = "".join([f", offset={off_val}" if off_val else "", ", glc=1" if glc else "", ", dlc=1" if dlc else "",
", slc=1" if slc else "", ", tfe=1" if tfe else "", ", offen=1" if offen else "", ", idxen=1" if idxen else ""])
# Default format for tbuffer is dfmt=1, nfmt=0 (format=8 after encoding as (nfmt<<4)|dfmt becomes just dfmt=1)
# Actually format is (dfmt | (nfmt << 4)), so dfmt=1, nfmt=0 -> format=1
if is_tbuf: buf_mods = f", format={fmt_num if fmt_num is not None else 1}" + buf_mods
# Handle LDS mode: first operand is 'off' meaning no vdata, it goes to LDS
if len(ops) >= 1 and ops[0].strip().lower() == 'off':
# LDS mode: buffer_load_format_x off, srsrc, soffset -> no vdata, just vaddr=off
srsrc_val = args[1] if len(args) > 1 else "s[0:3]"
soff_val = args[2] if len(args) > 2 else "0"
return f"{mn}(vdata=v[0], vaddr=v[0], srsrc={srsrc_val}, soffset={soff_val}{buf_mods})"
# Determine vaddr value (v[0] for 'off', actual register otherwise)
vaddr_idx = 1
if len(ops) > vaddr_idx and ops[vaddr_idx].strip().lower() == 'off': vaddr_val = "v[0]"
else: vaddr_val = args[vaddr_idx] if len(args) > vaddr_idx else "v[0]"
# srsrc and soffset indices depend on whether vaddr is 'off'
srsrc_idx, soff_idx = (2, 3) if len(ops) > 1 else (1, 2)
srsrc_val = args[srsrc_idx] if len(args) > srsrc_idx else "s[0:3]"
soff_val = args[soff_idx] if len(args) > soff_idx else "0"
# soffset: integers are inline constants, don't wrap in RawImm
return f"{mn}(vdata={args[0]}, vaddr={vaddr_val}, srsrc={srsrc_val}, soffset={soff_val}{buf_mods})"
# FLAT/GLOBAL/SCRATCH load/store/atomic - saddr needs RawImm for off/null
# CDNA: flat uses saddr=0 for off, global/scratch use saddr=0x7F (127) for off
# RDNA: uses saddr=124 (NULL)
# CDNA: uses sc0/sc1 for glc/slc
def _saddr_off(seg): return 'RawImm(0)' if arch == 'cdna' and seg == 'flat' else ('RawImm(127)' if arch == 'cdna' else 'RawImm(124)')
def _saddr(a, seg='global'): return _saddr_off(seg) if a in ('OFF', 'NULL') else a
if arch == "cdna":
flat_mods = f"{f', offset={off_val}' if off_val else ''}{', sc0=1' if glc else ''}{', nt=1' if slc else ''}{', lds=1' if lds else ''}"
else:
flat_mods = f"{f', offset={off_val}' if off_val else ''}{', glc=1' if glc else ''}{', slc=1' if slc else ''}{', dlc=1' if dlc else ''}{', lds=1' if lds else ''}"
for pre, flds in [('flat_load','vdst,addr,saddr'), ('global_load','vdst,addr,saddr'), ('scratch_load','vdst,addr,saddr'),
('flat_store','addr,data,saddr'), ('global_store','addr,data,saddr'), ('scratch_store','addr,data,saddr')]:
if mn.startswith(pre) and len(args) >= 2:
f0, f1, f2 = flds.split(',')
seg = pre.split('_')[0] # 'flat', 'global', or 'scratch'
# LDS mode: args=[addr, saddr], vdst=0, data goes to LDS
if lds and 'load' in pre:
addr_val = 'v[0]' if seg == 'scratch' and args[0] == 'OFF' else args[0]
saddr_val = _saddr(args[1], seg) if len(args) >= 2 else _saddr_off(seg)
return f"{mn}(vdst=v[0], addr={addr_val}, saddr={saddr_val}{flat_mods})"
# For scratch, 'off' as vaddr means vaddr=0 (no offset), not null register
# For load: args=[vdst, addr, saddr], for store: args=[addr, data, saddr]
# For RDNA3 scratch with 'off' as vaddr, set sve=0 (no VGPR address)
if 'store' in pre:
addr_off = seg == 'scratch' and args[0] == 'OFF'
addr_val = 'v[0]' if addr_off else args[0]
sve_mod = ', sve=0' if addr_off and arch == 'rdna3' else ''
return f"{mn}({f0}={addr_val}, {f1}={args[1]}{f', {f2}={_saddr(args[2], seg)}' if len(args) >= 3 else f', saddr={_saddr_off(seg)}'}{sve_mod}{flat_mods})"
else:
addr_off = seg == 'scratch' and args[1] == 'OFF'
addr_val = 'v[0]' if addr_off else args[1]
sve_mod = ', sve=0' if addr_off and arch == 'rdna3' else ''
return f"{mn}({f0}={args[0]}, {f1}={addr_val}{f', {f2}={_saddr(args[2], seg)}' if len(args) >= 3 else f', saddr={_saddr_off(seg)}'}{sve_mod}{flat_mods})"
for pre in ('flat_atomic', 'global_atomic', 'scratch_atomic'):
if mn.startswith(pre):
seg = pre.split('_')[0] # 'flat', 'global', or 'scratch'
if glc and len(args) >= 3: return f"{mn}(vdst={args[0]}, addr={args[1]}, data={args[2]}{f', saddr={_saddr(args[3], seg)}' if len(args) >= 4 else f', saddr={_saddr_off(seg)}'}{flat_mods})"
if len(args) >= 2: return f"{mn}(addr={args[0]}, data={args[1]}{f', saddr={_saddr(args[2], seg)}' if len(args) >= 3 else f', saddr={_saddr_off(seg)}'}{flat_mods})"
# DS instructions
if mn.startswith('ds_'):
# Handle offset formats: offset:N (combined), offset0:N offset1:N (separate), or none
if offset0 is not None or offset1 is not None:
off0, off1 = offset0 or "0", offset1 or "0"
elif off_val:
off0, off1 = str(int(off_val, 0) & 0xff), str((int(off_val, 0) >> 8) & 0xff)
else:
off0, off1 = "0", "0"
gds_s = ", gds=1" if gds else ""
off_kw = f", offset0={off0}, offset1={off1}{gds_s}"
if mn == 'ds_nop' or mn in ('ds_gws_sema_v', 'ds_gws_sema_p', 'ds_gws_sema_release_all'): return f"{mn}({off_kw.lstrip(', ')})"
if 'gws_' in mn: return f"{mn}(addr={args[0]}{off_kw})"
if 'consume' in mn or 'append' in mn: return f"{mn}(vdst={args[0]}{off_kw})"
if 'gs_reg' in mn: return f"{mn}(vdst={args[0]}, data0={args[1]}{off_kw})"
if '2addr' in mn:
if 'load' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}{off_kw})"
if 'store' in mn and 'xchg' not in mn: return f"{mn}(addr={args[0]}, data0={args[1]}, data1={args[2]}{off_kw})"
return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}, data1={args[3]}{off_kw})"
if 'load' in mn or ('read' in mn and 'read2' not in mn): return f"{mn}(vdst={args[0]}{off_kw})" if 'addtid' in mn else f"{mn}(vdst={args[0]}, addr={args[1]}{off_kw})"
if 'read2' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}{off_kw})"
if 'write2' in mn: return f"{mn}(addr={args[0]}, data0={args[1]}, data1={args[2]}{off_kw})"
if 'xchg2' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}, data1={args[3]}{off_kw})" if '_rtn' in mn else f"{mn}(addr={args[0]}, data0={args[1]}, data1={args[2]}{off_kw})"
if 'store' in mn and not _has(mn, 'cmp', 'xchg'):
return f"{mn}(data0={args[0]}{off_kw})" if 'addtid' in mn else f"{mn}(addr={args[0]}, data0={args[1]}{off_kw})"
if 'swizzle' in mn or 'ordered_count' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}{off_kw})"
if 'permute' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}{off_kw})"
if 'bvh' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}, data1={args[3]}{off_kw})"
if 'condxchg' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}{off_kw})"
if _has(mn, 'cmpst', 'mskor', 'wrap'):
return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}, data1={args[3]}{off_kw})" if '_rtn' in mn else f"{mn}(addr={args[0]}, data0={args[1]}, data1={args[2]}{off_kw})"
return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}{off_kw})" if '_rtn' in mn else f"{mn}(addr={args[0]}, data0={args[1]}{off_kw})"
# v_fmaak/v_fmamk literal handling
# RDNA3: use literal= keyword arg; CDNA: keep literal in positional args for _e32 variant
# v_fmamk_e32(vdst, src0, K, vsrc1); v_fmaak_e32(vdst, src0, vsrc1, K)
lit_s = ""
if arch == "cdna":
# For CDNA, reorder args to match _e32 signature: fmamk(vdst, src0, K, vsrc1), fmaak(vdst, src0, vsrc1, K)
if mn in ('v_fmamk_f32', 'v_fmamk_f16') and len(args) == 4: args = [args[0], args[1], args[2], args[3]] # already correct order
elif mn in ('v_fmaak_f32', 'v_fmaak_f16') and len(args) == 4: args = [args[0], args[1], args[2], args[3]] # already correct order
else:
if mn in ('v_fmaak_f32', 'v_fmaak_f16') and len(args) == 4: lit_s, args = f", literal={args[3].strip()}", args[:3]
elif mn in ('v_fmamk_f32', 'v_fmamk_f16') and len(args) == 4: lit_s, args = f", literal={args[2].strip()}", [args[0], args[1], args[3]]
# VCC ops cleanup
vcc_ops = {'v_add_co_ci_u32', 'v_sub_co_ci_u32', 'v_subrev_co_ci_u32'}
if mn.replace('_e32', '') in vcc_ops and len(args) >= 5: mn, args = mn.replace('_e32', '') + '_e32', [args[0], args[2], args[3]]
if mn.replace('_e64', '') in vcc_ops and mn.endswith('_e64'): mn = mn.replace('_e64', '')
if mn.startswith('v_cmp') and not mn.endswith('_e64') and len(args) >= 3 and ops[0].strip().lower() in ('vcc_lo', 'vcc_hi', 'vcc'): args = args[1:]
# For RDNA3 v_cmpx, destination is implicitly exec (126)
if 'cmpx' in mn and mn.endswith('_e64') and len(args) == 2 and arch == 'rdna3': args = ['RawImm(126)'] + args
# v_cmp_*_e64 and v_cmpx_*_e64 have SGPR destination in vdst field - encode as RawImm
# For CDNA, v_cmpx also writes to SGPR pair (first operand)
_SGPR_NAMES = {'vcc_lo': 106, 'vcc_hi': 107, 'vcc': 106, 'null': 124, 'm0': 125, 'exec_lo': 126, 'exec_hi': 127}
if mn.startswith('v_cmp') and mn.endswith('_e64') and len(args) >= 1:
# For CDNA v_cmpx with 3 operands (sdst, src0, src1), convert sdst to RawImm
# For RDNA3, v_cmpx only has 2 operands (src0, src1) - already handled above
is_cmpx = 'cmpx' in mn
if not is_cmpx or arch == 'cdna':
dst = ops[0].strip().lower()
if dst.startswith('s') and dst[1:].isdigit(): args[0] = f'RawImm({int(dst[1:])})'
elif dst.startswith('s[') and ':' in dst: args[0] = f'RawImm({int(dst[2:].split(":")[0])})'
elif dst.startswith('ttmp') and dst[4:].isdigit(): args[0] = f'RawImm({108 + int(dst[4:])})'
elif dst.startswith('ttmp[') and ':' in dst: args[0] = f'RawImm({108 + int(dst[5:].split(":")[0])})'
elif dst in _SGPR_NAMES: args[0] = f'RawImm({_SGPR_NAMES[dst]})'
fn = mn.replace('.', '_')
if opsel is not None: args = [re.sub(r'\.[hl]$', '', a) for a in args]
# v_fma_mix*: extract inline neg/abs modifiers
if 'fma_mix' in mn and neg_lo is None and neg_hi is None:
inline_neg, inline_abs, clean_args = 0, 0, [args[0]]
for i, op in enumerate(ops[1:4]):
op = op.strip()
neg = op.startswith('-') and not (op[1:2].isdigit() or (len(op) > 2 and op[1] == '0' and op[2] in 'xX'))
if neg: op = op[1:]
abs_ = op.startswith('|') and op.endswith('|')
if abs_: op = op[1:-1]
if neg: inline_neg |= (1 << i)
if abs_: inline_abs |= (1 << i)
clean_args.append(_op2dsl(op, arch))
args = clean_args + args[4:]
if inline_neg: neg_lo = inline_neg
if inline_abs: neg_hi = inline_abs
all_kw = list(kw)
if lit_s: all_kw.append(lit_s.lstrip(', '))
if opsel is not None: all_kw.append(f'opsel={opsel}')
if opsel_hi is not None:
all_kw.append(f'opsel_hi={opsel_hi & 3}')
if opsel_hi_count >= 3: all_kw.append(f'opsel_hi2={(opsel_hi >> 2) & 1}') # only set opsel_hi2 if 3 elements specified
if neg_lo is not None: all_kw.append(f'neg={neg_lo}')
if neg_hi is not None: all_kw.append(f'neg_hi={neg_hi}')
if 'bvh' in mn and 'intersect_ray' in mn: all_kw.extend(['dmask=15', 'unrm=1', 'r128=1'])
# For CDNA _e64 VOP instructions: use keyword args (VOP3 layout)
# Pattern: v_xxx_e64 dst, src0[, src1[, src2]] -> VOP3A with promoted opcode
# VOP1 to VOP3 promotion: VOP3 op = 384 + (VOP1_op - 64) for VOP1_op >= 64, else 256 + VOP1_op
if fn.endswith('_e64') and fn.startswith('v_') and arch == "cdna":
fn_base = fn[:-4].upper() # strip _e64 and uppercase for enum lookup
from extra.assembly.amd.autogen.cdna.ins import VOP1Op, VOP2Op, VOP3AOp, VOP3BOp
# Check if this is a VOP3B instruction (has sdst for carry-out)
vop3b_op = getattr(VOP3BOp, fn_base, None)
if vop3b_op is not None:
# VOP3B: v_xxx_e64 vdst, sdst, src0, src1[, src2]
vop3_args = []
if len(args) >= 1: vop3_args.append(f'vdst={args[0]}')
if len(args) >= 2: vop3_args.append(f'sdst={args[1]}')
if len(args) >= 3: vop3_args.append(f'src0={args[2]}')
if len(args) >= 4: vop3_args.append(f'src1={args[3]}')
if len(args) >= 5: vop3_args.append(f'src2={args[4]}')
a_str = ', '.join(vop3_args + all_kw)
return f"{fn[:-4]}({a_str})"
# Check if this is a VOP1 instruction that needs promotion
vop1_op = getattr(VOP1Op, fn_base, None)
vop2_op = getattr(VOP2Op, fn_base, None)
vop3a_op = getattr(VOP3AOp, fn_base, None)
if vop1_op is not None and vop3a_op is None:
# VOP1 -> VOP3 promotion: calculate promoted opcode
promoted_op = 384 + (vop1_op.value - 64) if vop1_op.value >= 64 else 256 + vop1_op.value
vop3_args = [f'op={promoted_op}']
if len(args) >= 1: vop3_args.append(f'vdst={args[0]}')
if len(args) >= 2: vop3_args.append(f'src0={args[1]}')
if len(args) >= 3: vop3_args.append(f'src1={args[2]}')
if len(args) >= 4: vop3_args.append(f'src2={args[3]}')
return f"VOP3A({', '.join(vop3_args + all_kw)})"
# Otherwise try normal VOP3 lookup
vop3_args = ['_vop3=True'] # marker for asm() to force VOP3
if len(args) >= 1: vop3_args.append(f'vdst={args[0]}')
if len(args) >= 2: vop3_args.append(f'src0={args[1]}')
if len(args) >= 3: vop3_args.append(f'src1={args[2]}')
if len(args) >= 4: vop3_args.append(f'src2={args[3]}')
a_str = ', '.join(vop3_args + all_kw)
return f"{fn[:-4]}({a_str})"
a_str, kw_str = ', '.join(args), ', '.join(all_kw)
return f"{fn}({a_str}, {kw_str})" if kw_str and a_str else f"{fn}({kw_str})" if kw_str else f"{fn}({a_str})"
def asm(text: str, arch: str = "rdna3") -> Inst:
dsl = get_dsl(text, arch)
if arch == "cdna":
from extra.assembly.amd.autogen.cdna import ins as cdna_ins
ns = {n: getattr(cdna_ins, n) for n in dir(cdna_ins) if not n.startswith('_')}
# CDNA special registers: m0=124, flat_scratch=102-103, xnack_mask=104-105, no NULL (use m0 for off)
ns.update({'s': s, 'v': v, 'ttmp': ttmp, 'abs': abs, 'RawImm': RawImm, 'SrcMod': SrcMod, 'VGPR': VGPR, 'SGPR': SGPR, 'TTMP': TTMP,
'VCC_LO': RawImm(106), 'VCC_HI': RawImm(107), 'VCC': RawImm(106), 'EXEC_LO': RawImm(126), 'EXEC_HI': RawImm(127), 'EXEC': RawImm(126),
'SCC': RawImm(253), 'M0': RawImm(124), 'NULL': RawImm(124), 'OFF': RawImm(124),
'FLAT_SCRATCH_LO': RawImm(102), 'FLAT_SCRATCH_HI': RawImm(103), 'FLAT_SCRATCH': RawImm(102),
'XNACK_MASK_LO': RawImm(104), 'XNACK_MASK_HI': RawImm(105), 'XNACK_MASK': RawImm(104),
'SRC_VCCZ': RawImm(251), 'SRC_EXECZ': RawImm(252), 'SRC_SCC': RawImm(253), 'SRC_LDS_DIRECT': RawImm(254)})
else:
ns = {n: getattr(ins, n) for n in dir(ins) if not n.startswith('_')}
ns.update({'s': s, 'v': v, 'ttmp': ttmp, 'abs': abs, 'RawImm': RawImm, 'SrcMod': SrcMod, 'VGPR': VGPR, 'SGPR': SGPR, 'TTMP': TTMP,
'VCC_LO': VCC_LO, 'VCC_HI': VCC_HI, 'VCC': VCC, 'EXEC_LO': EXEC_LO, 'EXEC_HI': EXEC_HI, 'EXEC': EXEC, 'SCC': SCC, 'M0': M0, 'NULL': NULL, 'OFF': OFF})
try:
# For CDNA, prefer _e32 variants for VOP1/VOP2 when available (bare names map to VOP3)
# But skip if:
# - already has _e64 suffix (explicit VOP3 request)
# - uses keyword args like vdst=/src0= (VOP3 layout from _e64 instructions)
# - has _vop3=True marker (from _e64 instructions without operands)
uses_vop3_kwargs = 'vdst=' in dsl or 'src0=' in dsl or '_vop3=True' in dsl
if arch == "cdna" and (m := re.match(r'^(v_\w+)(\(.*\))$', dsl)) and not m.group(1).endswith('_e64') and not uses_vop3_kwargs:
fn_name, args_str = m.group(1), m.group(2)
e32_name = f"{fn_name}_e32"
# VOP2 carry ops: v_add_co_u32(vdst, vcc, src0, vsrc1) -> v_add_co_u32_e32(vdst, src0, vsrc1)
# Strip VCC argument (2nd arg) for VOP2 carry operations when using _e32
if e32_name in ns and fn_name in _VOP2_CARRY_OUT | _VOP2_CARRY_INOUT:
args_match = re.match(r'\(([^,]+),\s*[^,]+,\s*(.+)\)$', args_str)
if args_match: args_str = f"({args_match.group(1)}, {args_match.group(2)})"
if e32_name in ns: return eval(f"{e32_name}{args_str}", ns)
# For CDNA, _e64 suffix maps to base name (VOP3)
if arch == "cdna" and (m := re.match(r'^(v_\w+)_e64(\(.*\))$', dsl)):
base_name = m.group(1)
if base_name in ns: return eval(f"{base_name}{m.group(2)}", ns)
# Strip _vop3=True marker before eval
eval_dsl = dsl.replace('_vop3=True, ', '').replace('_vop3=True', '')
return eval(eval_dsl, ns)
except NameError:
# For CDNA, try stripping _e64 to get VOP3 base name
if arch == "cdna" and (m := re.match(r'^(v_\w+)_e64(\(.*\))$', dsl)):
return eval(f"{m.group(1)}{m.group(2)}", ns)
# Don't try _e32 if already _e64
if (m := re.match(r'^(v_\w+)(\(.*\))$', dsl)) and not m.group(1).endswith('_e64'):
return eval(f"{m.group(1)}_e32{m.group(2)}", ns)
raise
# ═══════════════════════════════════════════════════════════════════════════════
# CDNA DISASSEMBLER SUPPORT
# ═══════════════════════════════════════════════════════════════════════════════
try:
from extra.assembly.amd.autogen.cdna.ins import (VOP1 as CDNA_VOP1, VOP2 as CDNA_VOP2, VOPC as CDNA_VOPC, VOP3A, VOP3B, VOP3P as CDNA_VOP3P,
SOP1 as CDNA_SOP1, SOP2 as CDNA_SOP2, SOPC as CDNA_SOPC, SOPK as CDNA_SOPK, SOPP as CDNA_SOPP, SMEM as CDNA_SMEM, DS as CDNA_DS,
FLAT as CDNA_FLAT, MUBUF as CDNA_MUBUF, MTBUF as CDNA_MTBUF, SDWA, DPP, VOP1Op as CDNA_VOP1Op, VOP2Op as CDNA_VOP2Op, VOPCOp as CDNA_VOPCOp)
def _cdna_src(inst, v, neg, abs_=0, n=1):
s = inst.lit(v) if v == 255 else _fmt_src(v, n, cdna=True)
if abs_: s = f"|{s}|"
return f"neg({s})" if neg and v == 255 else (f"-{s}" if neg else s)
# CDNA VOP2 aliases: new opcode name -> old name expected by LLVM tests
_CDNA_VOP3_ALIASES = {'v_fmac_f64': 'v_mul_legacy_f32', 'v_dot2c_f32_bf16': 'v_mac_f32'}
def _disasm_vop3a(inst) -> str:
op_val = inst._values.get('op', 0) # get raw opcode value, not enum value
if hasattr(op_val, 'value'): op_val = op_val.value # in case it's stored as enum
name = inst.op_name.lower() or f'vop3a_op_{op_val}'
from extra.assembly.amd.dsl import spec_num_srcs, spec_regs
n = spec_num_srcs(name) if name else inst.num_srcs()
cl, om = " clamp" if inst.clmp else "", _omod(inst.omod)
orig_name = name
name = _CDNA_VOP3_ALIASES.get(name, name) # apply CDNA aliases
# For aliased ops, recalculate sources without 64-bit assumption
if name != orig_name:
s0, s1 = _cdna_src(inst, inst.src0, inst.neg&1, inst.abs&1, 1), _cdna_src(inst, inst.src1, inst.neg&2, inst.abs&2, 1)
s2 = ""
dst = f"v{inst.vdst}"
else:
dregs, r0, r1, r2 = spec_regs(name) if name else (inst.dst_regs(), inst.src_regs(0), inst.src_regs(1), inst.src_regs(2))
s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1, inst.abs&1, r0), _cdna_src(inst, inst.src1, inst.neg&2, inst.abs&2, r1), _cdna_src(inst, inst.src2, inst.neg&4, inst.abs&4, r2)
dst = _vreg(inst.vdst, dregs) if dregs > 1 else f"v{inst.vdst}"
# True VOP3 instructions (512+) - 3-source ops
if op_val >= 512:
return f"{name} {dst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name} {dst}, {s0}, {s1}{cl}{om}"
# VOPC (0-255): writes to SGPR pair, VOP2 (256-319): 2-3 src, VOP1 (320-511): 1 src
if op_val < 256:
sdst = _fmt_sdst(inst.vdst, 2, cdna=True) # VOPC writes to 64-bit SGPR pair
# v_cmpx_ also writes to sdst in CDNA VOP3 (unlike VOP32 where it writes to exec)
return f"{name}_e64 {sdst}, {s0}, {s1}{cl}"
if 320 <= op_val < 512: # VOP1 promoted
if name in ('v_nop', 'v_clrexcp'): return f"{name}_e64"
return f"{name}_e64 {dst}, {s0}{cl}{om}"
# VOP2 promoted (256-319)
if name == 'v_cndmask_b32':
s2 = _fmt_src(inst.src2, 2, cdna=True) # src2 is 64-bit SGPR pair
return f"{name}_e64 {dst}, {s0}, {s1}, {s2}{cl}{om}"
if name in ('v_mul_legacy_f32', 'v_mac_f32'):
return f"{name}_e64 {dst}, {s0}, {s1}{cl}{om}"
suf = "_e64" if op_val < 512 else ""
return f"{name}{suf} {dst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name}{suf} {dst}, {s0}, {s1}{cl}{om}"
# GFX9-specific VOP3B opcodes not in CDNA enum
def _disasm_vop3b(inst) -> str:
op_val = inst._values.get('op', 0)
if hasattr(op_val, 'value'): op_val = op_val.value
name = inst.op_name.lower() or f'vop3b_op_{op_val}'
from extra.assembly.amd.dsl import spec_num_srcs, spec_regs
n = spec_num_srcs(name) if name else inst.num_srcs()
dregs, r0, r1, r2 = spec_regs(name) if name else (inst.dst_regs(), inst.src_regs(0), inst.src_regs(1), inst.src_regs(2))
s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1, n=r0), _cdna_src(inst, inst.src1, inst.neg&2, n=r1), _cdna_src(inst, inst.src2, inst.neg&4, n=r2)
dst = _vreg(inst.vdst, dregs) if dregs > 1 else f"v{inst.vdst}"
sdst = _fmt_sdst(inst.sdst, 2, cdna=True) # VOP3B sdst is always 64-bit SGPR pair
cl, om = " clamp" if inst.clmp else "", _omod(inst.omod)
# Carry ops need special handling
if name in ('v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'):
s2 = _fmt_src(inst.src2, 2, cdna=True) # src2 is carry-in (64-bit SGPR pair)
return f"{name}_e64 {dst}, {sdst}, {s0}, {s1}, {s2}{cl}{om}"
suf = "_e64" if 'co_' in name else ""
return f"{name}{suf} {dst}, {sdst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name}{suf} {dst}, {sdst}, {s0}, {s1}{cl}{om}"
def _disasm_cdna_vop3p(inst) -> str:
name, n, is_mfma = inst.op_name.lower(), inst.num_srcs(), 'mfma' in inst.op_name.lower() or 'smfmac' in inst.op_name.lower()
get_src = lambda v, sc: inst.lit(v) if v == 255 else _fmt_src(v, sc, cdna=True)
if is_mfma: sc = 2 if 'iu4' in name else 4 if 'iu8' in name or 'i4' in name else 8 if 'f16' in name or 'bf16' in name else 4; src0, src1, src2, dst = get_src(inst.src0, sc), get_src(inst.src1, sc), get_src(inst.src2, 16), _vreg(inst.vdst, 16)
else: src0, src1, src2, dst = get_src(inst.src0, 1), get_src(inst.src1, 1), get_src(inst.src2, 1), f"v{inst.vdst}"
opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2)
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != (7 if n == 3 else 3) else []) + \
([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if inst.clmp else [])
return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if n == 3 else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}"
_SEL = {0: 'BYTE_0', 1: 'BYTE_1', 2: 'BYTE_2', 3: 'BYTE_3', 4: 'WORD_0', 5: 'WORD_1', 6: 'DWORD'}
_UNUSED = {0: 'UNUSED_PAD', 1: 'UNUSED_SEXT', 2: 'UNUSED_PRESERVE'}
_DPP = {0x130: "wave_shl:1", 0x134: "wave_rol:1", 0x138: "wave_shr:1", 0x13c: "wave_ror:1", 0x140: "row_mirror", 0x141: "row_half_mirror", 0x142: "row_bcast:15", 0x143: "row_bcast:31"}
def _sdwa_src0(v, is_sgpr, sext=0, neg=0, abs_=0):
# s0=0: VGPR (v is VGPR number), s0=1: SGPR/constant (v is encoded like normal src)
s = decode_src(v, cdna=True) if is_sgpr else f"v{v}"
if sext: s = f"sext({s})"
if abs_: s = f"|{s}|"
return f"-{s}" if neg else s
def _sdwa_vsrc1(v, sext=0, neg=0, abs_=0):
# For VOP2 SDWA, vsrc1 is in vop_op field as raw VGPR number
s = f"v{v}"
if sext: s = f"sext({s})"
if abs_: s = f"|{s}|"
return f"-{s}" if neg else s
_OMOD_SDWA = {0: "", 1: " mul:2", 2: " mul:4", 3: " div:2"}
def _disasm_sdwa(inst) -> str:
# SDWA format: vop2_op=63 -> VOP1, vop2_op=62 -> VOPC, vop2_op=0-61 -> VOP2
vop2_op = inst.vop2_op
src0 = _sdwa_src0(inst.src0, inst.s0, inst.src0_sext, inst.src0_neg, inst.src0_abs)
clamp = " clamp" if inst.clmp else ""
omod = _OMOD_SDWA.get(inst.omod, "")
if vop2_op == 63: # VOP1
try: name = CDNA_VOP1Op(inst.vop_op).name.lower()
except ValueError: name = f"vop1_op_{inst.vop_op}"
dst = f"v{inst.vdst}"
mods = [f"dst_sel:{_SEL[inst.dst_sel]}", f"dst_unused:{_UNUSED[inst.dst_u]}", f"src0_sel:{_SEL[inst.src0_sel]}"]
return f"{name}_sdwa {dst}, {src0}{clamp}{omod} " + " ".join(mods)
elif vop2_op == 62: # VOPC
try: name = CDNA_VOPCOp(inst.vdst).name.lower() # opcode is in vdst field for VOPC SDWA
except ValueError: name = f"vopc_op_{inst.vdst}"
src1 = _sdwa_vsrc1(inst.vop_op, inst.src1_sext, inst.src1_neg, inst.src1_abs) # vsrc1 is in vop_op field
# VOPC SDWA: dst encoded in byte 5 (bits 47:40): 0=vcc, 128+n=s[n:n+1]
sdst_enc = inst.dst_sel | (inst.dst_u << 3) | (inst.clmp << 5) | (inst.omod << 6)
if sdst_enc == 0:
sdst = "vcc"
else:
sdst_val = sdst_enc - 128 if sdst_enc >= 128 else sdst_enc
sdst = _fmt_sdst(sdst_val, 2, cdna=True)
mods = [f"src0_sel:{_SEL[inst.src0_sel]}", f"src1_sel:{_SEL[inst.src1_sel]}"]
return f"{name}_sdwa {sdst}, {src0}, {src1} " + " ".join(mods)
else: # VOP2
try: name = CDNA_VOP2Op(vop2_op).name.lower()
except ValueError: name = f"vop2_op_{vop2_op}"
name = _CDNA_DISASM_ALIASES.get(name, name) # apply aliases (v_fmac -> v_mac, etc.)
dst = f"v{inst.vdst}"
src1 = _sdwa_vsrc1(inst.vop_op, inst.src1_sext, inst.src1_neg, inst.src1_abs) # vsrc1 is in vop_op field
mods = [f"dst_sel:{_SEL[inst.dst_sel]}", f"dst_unused:{_UNUSED[inst.dst_u]}", f"src0_sel:{_SEL[inst.src0_sel]}", f"src1_sel:{_SEL[inst.src1_sel]}"]
# v_cndmask_b32 needs vcc as third operand
if name == 'v_cndmask_b32':
return f"{name}_sdwa {dst}, {src0}, {src1}, vcc{clamp}{omod} " + " ".join(mods)
# Carry ops need vcc - v_addc/subb also need vcc as carry-in
if name in ('v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'):
return f"{name}_sdwa {dst}, vcc, {src0}, {src1}, vcc{clamp}{omod} " + " ".join(mods)
if '_co_' in name:
return f"{name}_sdwa {dst}, vcc, {src0}, {src1}{clamp}{omod} " + " ".join(mods)
return f"{name}_sdwa {dst}, {src0}, {src1}{clamp}{omod} " + " ".join(mods)
def _dpp_src(v, neg=0, abs_=0):
s = f"v{v}" if v < 256 else f"v{v - 256}"
if abs_: s = f"|{s}|"
return f"-{s}" if neg else s
def _disasm_dpp(inst) -> str:
# DPP format: vop2_op=63 -> VOP1, vop2_op=0-62 -> VOP2
vop2_op = inst.vop2_op
ctrl = inst.dpp_ctrl
dpp = f"quad_perm:[{ctrl&3},{(ctrl>>2)&3},{(ctrl>>4)&3},{(ctrl>>6)&3}]" if ctrl < 0x100 else f"row_shl:{ctrl&0xf}" if ctrl < 0x110 else f"row_shr:{ctrl&0xf}" if ctrl < 0x120 else f"row_ror:{ctrl&0xf}" if ctrl < 0x130 else _DPP.get(ctrl, f"dpp_ctrl:0x{ctrl:x}")
src0 = _dpp_src(inst.src0, inst.src0_neg, inst.src0_abs)
# DPP modifiers: row_mask and bank_mask always shown, bound_ctrl:0 when bit=1
mods = [dpp, f"row_mask:0x{inst.row_mask:x}", f"bank_mask:0x{inst.bank_mask:x}"] + (["bound_ctrl:0"] if inst.bound_ctrl else [])
if vop2_op == 63: # VOP1
try: name = CDNA_VOP1Op(inst.vop_op).name.lower()
except ValueError: name = f"vop1_op_{inst.vop_op}"
return f"{name}_dpp v{inst.vdst}, {src0} " + " ".join(mods)
else: # VOP2
try: name = CDNA_VOP2Op(vop2_op).name.lower()
except ValueError: name = f"vop2_op_{vop2_op}"
name = _CDNA_DISASM_ALIASES.get(name, name)
src1 = _dpp_src(inst.vop_op, inst.src1_neg, inst.src1_abs) # vsrc1 is in vop_op field
if name == 'v_cndmask_b32':
return f"{name}_dpp v{inst.vdst}, {src0}, {src1}, vcc " + " ".join(mods)
if name in ('v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'):
return f"{name}_dpp v{inst.vdst}, vcc, {src0}, {src1}, vcc " + " ".join(mods)
if '_co_' in name:
return f"{name}_dpp v{inst.vdst}, vcc, {src0}, {src1} " + " ".join(mods)
return f"{name}_dpp v{inst.vdst}, {src0}, {src1} " + " ".join(mods)
# Register CDNA handlers - shared formats use merged disassemblers, CDNA-only formats use dedicated ones
DISASM_HANDLERS.update({CDNA_VOP1: _disasm_vop1, CDNA_VOP2: _disasm_vop2, CDNA_VOPC: _disasm_vopc,
CDNA_SOP1: _disasm_sop1, CDNA_SOP2: _disasm_sop2, CDNA_SOPC: _disasm_sopc, CDNA_SOPK: _disasm_sopk, CDNA_SOPP: _disasm_sopp,
CDNA_SMEM: _disasm_smem, CDNA_DS: _disasm_ds, CDNA_FLAT: _disasm_flat, CDNA_MUBUF: _disasm_buf, CDNA_MTBUF: _disasm_buf,
VOP3A: _disasm_vop3a, VOP3B: _disasm_vop3b, CDNA_VOP3P: _disasm_cdna_vop3p, SDWA: _disasm_sdwa, DPP: _disasm_dpp})
except ImportError:
pass