assembly/amd: improve tests for asm (#14007)

* assembly/amd: improve tests for asm

* upd

* skip

* tests

* re bug

* more passing

* cleanups

* cdna fixups

* improve tests, better CDNA parsing

* fix CI

* no defs

* simpler

* all pass

* from pdf

* regen
This commit is contained in:
George Hotz
2026-01-04 15:14:08 -08:00
committed by GitHub
parent f550f9204c
commit 404eed6172
9 changed files with 673 additions and 445 deletions

View File

@@ -3,10 +3,11 @@ from __future__ import annotations
import re
from extra.assembly.amd.dsl import Inst, RawImm, Reg, SrcMod, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory
from extra.assembly.amd.dsl import VCC_LO, VCC_HI, VCC, EXEC_LO, EXEC_HI, EXEC, SCC, M0, NULL, OFF
from extra.assembly.amd.dsl import SPECIAL_GPRS, SPECIAL_PAIRS, FLOAT_DEC, FLOAT_ENC, decode_src
from extra.assembly.amd.dsl import SPECIAL_GPRS, SPECIAL_PAIRS, SPECIAL_PAIRS_CDNA, FLOAT_DEC, FLOAT_ENC, decode_src
from extra.assembly.amd.autogen.rdna3 import ins
from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT, MUBUF, MTBUF, MIMG, EXP,
VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOPDOp, SOP1Op, SOPKOp, SOPPOp, SMEMOp, DSOp, MUBUFOp)
VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOPDOp, SOP1Op, SOPKOp, SOPPOp, SMEMOp, DSOp, MUBUFOp, MTBUFOp)
from extra.assembly.amd.autogen.rdna3.enum import BufFmt
def _is_cdna(inst: Inst) -> bool: return 'cdna' in inst.__class__.__module__
@@ -17,21 +18,37 @@ def _matches_encoding(word: int, cls: type[Inst]) -> bool:
return ((word >> bf.lo) & bf.mask()) == val
# Order matters: more specific encodings first, VOP2 last (it's a catch-all for bit31=0)
_FORMATS_64 = [VOPD, VOP3P, VINTERP, VOP3, DS, FLAT, MUBUF, MTBUF, MIMG, SMEM, EXP]
_FORMATS_32 = [SOP1, SOPC, SOPP, SOPK, VOPC, VOP1, SOP2, VOP2] # SOP2/VOP2 are catch-alls
_RDNA_FORMATS_64 = [VOPD, VOP3P, VINTERP, VOP3, DS, FLAT, MUBUF, MTBUF, MIMG, SMEM, EXP]
_RDNA_FORMATS_32 = [SOP1, SOPC, SOPP, SOPK, VOPC, VOP1, SOP2, VOP2] # SOP2/VOP2 are catch-alls
from extra.assembly.amd.autogen.cdna.ins import (VOP1 as C_VOP1, VOP2 as C_VOP2, VOPC as C_VOPC, VOP3A, VOP3B, VOP3P as C_VOP3P,
SOP1 as C_SOP1, SOP2 as C_SOP2, SOPC as C_SOPC, SOPK as C_SOPK, SOPP as C_SOPP, SMEM as C_SMEM, DS as C_DS,
FLAT as C_FLAT, MUBUF as C_MUBUF, MTBUF as C_MTBUF, SDWA, DPP)
_CDNA_FORMATS_64 = [C_VOP3P, VOP3A, C_DS, C_FLAT, C_MUBUF, C_MTBUF, C_SMEM]
_CDNA_FORMATS_32 = [SDWA, DPP, C_SOP1, C_SOPC, C_SOPP, C_SOPK, C_VOPC, C_VOP1, C_SOP2, C_VOP2]
_CDNA_VOP3B_OPS = {281, 282, 283, 284, 285, 286, 480, 481, 488, 489} # VOP3B opcodes
# CDNA opcode name aliases for disasm (new name -> old name expected by tests)
_CDNA_DISASM_ALIASES = {'v_fmac_f64': 'v_mul_legacy_f32', 'v_dot2c_f32_bf16': 'v_mac_f32', 'v_fmamk_f32': 'v_madmk_f32', 'v_fmaak_f32': 'v_madak_f32'}
def detect_format(data: bytes) -> type[Inst]:
def detect_format(data: bytes, arch: str = "rdna3") -> type[Inst]:
"""Detect instruction format from machine code bytes."""
assert len(data) >= 4, f"need at least 4 bytes, got {len(data)}"
word = int.from_bytes(data[:4], 'little')
# Check 64-bit formats first (bits[31:30] == 0b11)
if arch == "cdna":
if (word >> 30) == 0b11:
for cls in _CDNA_FORMATS_64:
if _matches_encoding(word, cls):
return VOP3B if cls is VOP3A and ((word >> 16) & 0x3ff) in _CDNA_VOP3B_OPS else cls
raise ValueError(f"unknown CDNA 64-bit format word={word:#010x}")
for cls in _CDNA_FORMATS_32:
if _matches_encoding(word, cls): return cls
raise ValueError(f"unknown CDNA 32-bit format word={word:#010x}")
# RDNA (default)
if (word >> 30) == 0b11:
for cls in _FORMATS_64:
for cls in _RDNA_FORMATS_64:
if _matches_encoding(word, cls):
return VOP3SD if cls is VOP3 and ((word >> 16) & 0x3ff) in Inst._VOP3SD_OPS else cls
raise ValueError(f"unknown 64-bit format word={word:#010x}")
# 32-bit formats
for cls in _FORMATS_32:
for cls in _RDNA_FORMATS_32:
if _matches_encoding(word, cls): return cls
raise ValueError(f"unknown 32-bit format word={word:#010x}")
@@ -44,6 +61,11 @@ HWREG = {1: 'HW_REG_MODE', 2: 'HW_REG_STATUS', 3: 'HW_REG_TRAPSTS', 4: 'HW_REG_H
19: 'HW_REG_PERF_SNAPSHOT_PC_HI', 20: 'HW_REG_FLAT_SCR_LO', 21: 'HW_REG_FLAT_SCR_HI', 22: 'HW_REG_XNACK_MASK',
23: 'HW_REG_HW_ID1', 24: 'HW_REG_HW_ID2', 25: 'HW_REG_POPS_PACKER', 28: 'HW_REG_IB_STS2'}
HWREG_IDS = {v.lower(): k for k, v in HWREG.items()}
# RDNA unified buffer format - extracted from PDF, use enum for name->value lookup
BUF_FMT = {e.name: e.value for e in BufFmt}
def _parse_buf_fmt_combo(s: str) -> int: # parse format:[BUF_DATA_FORMAT_X, BUF_NUM_FORMAT_Y]
parts = [p.strip().replace('BUF_DATA_FORMAT_', '').replace('BUF_NUM_FORMAT_', '') for p in s.split(',')]
return BUF_FMT.get(f'BUF_FMT_{parts[0]}_{parts[1]}') if len(parts) == 2 else None
MSG = {128: 'MSG_RTN_GET_DOORBELL', 129: 'MSG_RTN_GET_DDID', 130: 'MSG_RTN_GET_TMA',
131: 'MSG_RTN_GET_REALTIME', 132: 'MSG_RTN_SAVE_WAVE', 133: 'MSG_RTN_GET_TBA'}
@@ -54,22 +76,28 @@ MSG = {128: 'MSG_RTN_GET_DOORBELL', 129: 'MSG_RTN_GET_DDID', 130: 'MSG_RTN_GET_T
def _reg(p: str, b: int, n: int = 1) -> str: return f"{p}{b}" if n == 1 else f"{p}[{b}:{b+n-1}]"
def _sreg(b: int, n: int = 1) -> str: return _reg("s", b, n)
def _vreg(b: int, n: int = 1) -> str: return _reg("v", b, n)
def _areg(b: int, n: int = 1) -> str: return _reg("a", b, n) # accumulator registers for GFX90a
def _ttmp(b: int, n: int = 1) -> str: return _reg("ttmp", b - 108, n) if 108 <= b <= 123 else None
def _sreg_or_ttmp(b: int, n: int = 1) -> str: return _ttmp(b, n) or _sreg(b, n)
def _fmt_sdst(v: int, n: int = 1) -> str:
if v == 124: return "null"
def _fmt_sdst(v: int, n: int = 1, cdna: bool = False) -> str:
from extra.assembly.amd.dsl import SPECIAL_PAIRS_CDNA, SPECIAL_GPRS_CDNA
if t := _ttmp(v, n): return t
if n > 1: return SPECIAL_PAIRS.get(v) or _sreg(v, n)
return SPECIAL_GPRS.get(v, f"s{v}")
pairs = SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS
gprs = SPECIAL_GPRS_CDNA if cdna else SPECIAL_GPRS
if n > 1: return pairs.get(v) or gprs.get(v) or _sreg(v, n) # also check gprs for null/m0
return gprs.get(v, f"s{v}")
def _fmt_src(v: int, n: int = 1) -> str:
if n == 1: return decode_src(v)
def _fmt_src(v: int, n: int = 1, cdna: bool = False) -> str:
from extra.assembly.amd.dsl import SPECIAL_PAIRS_CDNA
if n == 1: return decode_src(v, cdna)
if v >= 256: return _vreg(v - 256, n)
if v <= 105: return _sreg(v, n)
if n == 2 and v in SPECIAL_PAIRS: return SPECIAL_PAIRS[v]
if v <= 101: return _sreg(v, n) # s0-s101 can be pairs, but 102+ are special on CDNA
pairs = SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS
if n == 2 and v in pairs: return pairs[v]
if v <= 105: return _sreg(v, n) # s102-s105 regular pairs for RDNA
if t := _ttmp(v, n): return t
return decode_src(v)
return decode_src(v, cdna)
def _fmt_v16(v: int, base: int = 256, hi_thresh: int = 384) -> str:
return f"v{(v - base) & 0x7f}.{'h' if v >= hi_thresh else 'l'}"
@@ -106,46 +134,72 @@ def _opsel_str(opsel: int, n: int, need: bool, is16_d: bool) -> str:
# ═══════════════════════════════════════════════════════════════════════════════
def _disasm_vop1(inst: VOP1) -> str:
name = inst.op_name.lower()
if inst.op in (VOP1Op.V_NOP, VOP1Op.V_PIPEFLUSH): return name
if inst.op == VOP1Op.V_READFIRSTLANE_B32: return f"v_readfirstlane_b32 {decode_src(inst.vdst)}, v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}"
# 16-bit dst: uses .h/.l suffix (determined by name pattern, not dtype - e.g. sat_pk_u8_i16 outputs 8-bit but uses 16-bit encoding)
name, cdna = inst.op_name.lower() or f'vop1_op_{inst.op}', _is_cdna(inst)
suf = "" if cdna else "_e32"
if name in ('v_nop', 'v_pipeflush', 'v_clrexcp'): return name # no operands
if 'readfirstlane' in name:
src = f"v{inst.src0 - 256}" if inst.src0 >= 256 else decode_src(inst.src0, cdna)
return f"{name} {_fmt_sdst(inst.vdst, 1, cdna)}, {src}"
# 16-bit dst: uses .h/.l suffix for RDNA (CDNA uses plain vN)
parts = name.split('_')
is_16d = any(p in ('f16','i16','u16','b16') for p in parts[-2:-1]) or (len(parts) >= 2 and parts[-1] in ('f16','i16','u16','b16') and 'cvt' not in name)
is_16d = not cdna and (any(p in ('f16','i16','u16','b16') for p in parts[-2:-1]) or (len(parts) >= 2 and parts[-1] in ('f16','i16','u16','b16') and 'cvt' not in name))
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else _fmt_v16(inst.vdst, 0, 128) if is_16d else f"v{inst.vdst}"
src = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_src_16(0) and 'sat_pk' not in name else inst.lit(inst.src0)
return f"{name}_e32 {dst}, {src}"
src = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0), cdna) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if not cdna and inst.is_src_16(0) and 'sat_pk' not in name else inst.lit(inst.src0)
return f"{name}{suf} {dst}, {src}"
_VOP2_CARRY_OUT = {'v_add_co_u32', 'v_sub_co_u32', 'v_subrev_co_u32'} # carry out only
_VOP2_CARRY_INOUT = {'v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'} # carry in and out
def _disasm_vop2(inst: VOP2) -> str:
name, cdna = inst.op_name.lower(), _is_cdna(inst)
suf = "" if not cdna and inst.op == VOP2Op.V_DOT2ACC_F32_F16 else "_e32"
if cdna: name = _CDNA_DISASM_ALIASES.get(name, name) # apply CDNA aliases
suf = "" if cdna or (not cdna and inst.op == VOP2Op.V_DOT2ACC_F32_F16) else "_e32"
lit = getattr(inst, '_literal', None)
is16 = not cdna and inst.is_16bit()
# fmaak: dst = src0 * vsrc1 + K, fmamk: dst = src0 * K + vsrc1
if 'fmaak' in name or (not cdna and inst.op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16)):
# fmaak/madak: dst = src0 * vsrc1 + K, fmamk/madmk: dst = src0 * K + vsrc1
if 'fmaak' in name or 'madak' in name or (not cdna and inst.op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16)):
if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}, 0x{lit:x}"
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, 0x{lit:x}"
if 'fmamk' in name or (not cdna and inst.op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16)):
if 'fmamk' in name or 'madmk' in name or (not cdna and inst.op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16)):
if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, 0x{lit:x}, {_fmt_v16(inst.vsrc1, 0, 128)}"
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, 0x{lit:x}, v{inst.vsrc1}"
if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}"
vcc = "vcc" if cdna else "vcc_lo"
# CDNA carry ops output vcc after vdst
if cdna and name in _VOP2_CARRY_OUT: return f"{name}{suf} v{inst.vdst}, {vcc}, {inst.lit(inst.src0)}, v{inst.vsrc1}"
if cdna and name in _VOP2_CARRY_INOUT: return f"{name}{suf} v{inst.vdst}, {vcc}, {inst.lit(inst.src0)}, v{inst.vsrc1}, {vcc}"
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}" + (f", {vcc}" if name == 'v_cndmask_b32' else "")
def _disasm_vopc(inst: VOPC) -> str:
name, cdna = inst.op_name.lower(), _is_cdna(inst)
if cdna:
s0 = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0))
return f"{name}_e32 {s0}, v{inst.vsrc1}" if inst.op.value >= 128 else f"{name}_e32 vcc, {s0}, v{inst.vsrc1}"
s0 = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0), cdna)
s1 = _vreg(inst.vsrc1, inst.src_regs(1)) if inst.src_regs(1) > 1 else f"v{inst.vsrc1}"
return f"{name} vcc, {s0}, {s1}" # CDNA VOPC always outputs vcc
# RDNA: v_cmpx_* writes to exec (no vcc), v_cmp_* writes to vcc_lo
has_vcc = 'cmpx' not in name
s0 = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_16bit() else inst.lit(inst.src0)
s1 = _vreg(inst.vsrc1, inst.src_regs(1)) if inst.src_regs(1) > 1 else _fmt_v16(inst.vsrc1, 0, 128) if inst.is_16bit() else f"v{inst.vsrc1}"
return f"{name}_e32 {s0}, {s1}" if inst.op.value >= 128 else f"{name}_e32 vcc_lo, {s0}, {s1}"
return f"{name}_e32 vcc_lo, {s0}, {s1}" if has_vcc else f"{name}_e32 {s0}, {s1}"
NO_ARG_SOPP = {SOPPOp.S_BARRIER, SOPPOp.S_WAKEUP, SOPPOp.S_ICACHE_INV,
SOPPOp.S_WAIT_IDLE, SOPPOp.S_ENDPGM_SAVED, SOPPOp.S_CODE_END, SOPPOp.S_ENDPGM_ORDERED_PS_DONE, SOPPOp.S_TTRACEDATA}
# CDNA uses name-based matching since opcode values differ from RDNA
_CDNA_NO_ARG_SOPP = {'s_endpgm', 's_barrier', 's_wakeup', 's_icache_inv', 's_ttracedata', 's_nop', 's_sethalt', 's_sleep',
's_setprio', 's_trap', 's_incperflevel', 's_decperflevel', 's_sendmsg', 's_sendmsghalt'}
def _disasm_sopp(inst: SOPP) -> str:
name = inst.op_name.lower()
name, cdna = inst.op_name.lower(), _is_cdna(inst)
if cdna:
# CDNA: use name-based matching
if name == 's_endpgm': return name if inst.simm16 == 0 else f"{name} {inst.simm16}"
if name in ('s_barrier', 's_wakeup', 's_icache_inv', 's_ttracedata'): return name
if name == 's_waitcnt':
vm, lgkm, exp = inst.simm16 & 0xf, (inst.simm16 >> 8) & 0x3f, (inst.simm16 >> 4) & 0x7
p = [f"vmcnt({vm})" if vm != 0xf else "", f"expcnt({exp})" if exp != 7 else "", f"lgkmcnt({lgkm})" if lgkm != 0x3f else ""]
return f"s_waitcnt {' '.join(x for x in p if x) or '0'}"
if name.startswith(('s_cbranch', 's_branch')): return f"{name} {inst.simm16}"
return f"{name} 0x{inst.simm16:x}" if inst.simm16 else name
# RDNA
if inst.op in NO_ARG_SOPP: return name
if inst.op == SOPPOp.S_ENDPGM: return name if inst.simm16 == 0 else f"{name} {inst.simm16}"
if inst.op == SOPPOp.S_WAITCNT:
@@ -161,64 +215,98 @@ def _disasm_sopp(inst: SOPP) -> str:
return f"{name} {inst.simm16}" if name.startswith(('s_cbranch', 's_branch')) else f"{name} 0x{inst.simm16:x}"
def _disasm_smem(inst: SMEM) -> str:
name = inst.op_name.lower()
name, cdna = inst.op_name.lower(), _is_cdna(inst)
if inst.op in (SMEMOp.S_GL1_INV, SMEMOp.S_DCACHE_INV): return name
off_s = f"{decode_src(inst.soffset)} offset:0x{inst.offset:x}" if inst.offset and inst.soffset != 124 else f"0x{inst.offset:x}" if inst.offset else decode_src(inst.soffset)
sbase_idx, sbase_count = inst.sbase * 2, 4 if (8 <= inst.op.value <= 12 or name == 's_atc_probe_buffer') else 2
sbase_str = _fmt_src(sbase_idx, sbase_count) if sbase_count == 2 else _sreg(sbase_idx, sbase_count) if sbase_idx <= 105 else _reg("ttmp", sbase_idx - 108, sbase_count)
# GFX9 SMEM: soe and imm bits determine offset interpretation
# soe=1, imm=1: soffset is SGPR, offset is immediate (both used)
# soe=0, imm=1: offset is immediate
# soe=0, imm=0: offset field is SGPR encoding (0-255)
soe, imm = getattr(inst, 'soe', 0), getattr(inst, 'imm', 1)
if cdna:
if soe and imm:
off_s = f"{decode_src(inst.soffset, cdna)} offset:0x{inst.offset:x}" # SGPR + immediate
elif imm:
off_s = f"0x{inst.offset:x}" # Immediate offset only
elif inst.offset < 256:
off_s = decode_src(inst.offset, cdna) # SGPR encoding in offset field
else:
off_s = decode_src(inst.soffset, cdna)
elif inst.offset and inst.soffset != 124:
off_s = f"{decode_src(inst.soffset, cdna)} offset:0x{inst.offset:x}"
elif inst.offset:
off_s = f"0x{inst.offset:x}"
else:
off_s = decode_src(inst.soffset, cdna)
op_val = inst.op.value if hasattr(inst.op, 'value') else inst.op
# s_buffer_* instructions use 4 SGPRs for sbase (buffer descriptor)
is_buffer = 'buffer' in name or 's_atc_probe_buffer' == name
sbase_idx, sbase_count = inst.sbase * 2, 4 if is_buffer else 2
sbase_str = _fmt_src(sbase_idx, sbase_count, cdna) if sbase_count == 2 else _sreg(sbase_idx, sbase_count) if sbase_idx <= 105 else _reg("ttmp", sbase_idx - 108, sbase_count)
if name in ('s_atc_probe', 's_atc_probe_buffer'): return f"{name} {inst.sdata}, {sbase_str}, {off_s}"
return f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs())}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (inst.dlc, " dlc"))
return f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs(), cdna)}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (getattr(inst, 'dlc', 0), " dlc"))
def _disasm_flat(inst: FLAT) -> str:
name, cdna = inst.op_name.lower(), _is_cdna(inst)
acc = getattr(inst, 'acc', 0) # GFX90a accumulator register flag
reg_fn = _areg if acc else _vreg # use a[n] for acc=1, v[n] for acc=0
seg = ['flat', 'scratch', 'global'][inst.seg] if inst.seg < 3 else 'flat'
instr = f"{seg}_{name.split('_', 1)[1] if '_' in name else name}"
off_val = inst.offset if seg == 'flat' else (inst.offset if inst.offset < 4096 else inst.offset - 8192)
w = inst.dst_regs() * (2 if 'cmpswap' in name else 1)
if cdna: mods = f"{f' offset:{off_val}' if off_val else ''}{' sc0' if inst.sc0 else ''}{' nt' if inst.nt else ''}{' sc1' if inst.sc1 else ''}"
else: mods = f"{f' offset:{off_val}' if off_val else ''}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}"
w = inst.dst_regs() * (2 if '_x2' in name else 1) * (2 if 'cmpswap' in name else 1)
off_s = f" offset:{off_val}" if off_val else "" # Omit offset:0
if cdna: mods = f"{off_s}{' glc' if inst.sc0 else ''}{' slc' if inst.nt else ''}" # GFX9: sc0->glc, nt->slc
else: mods = f"{off_s}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}"
# saddr
if seg == 'flat' or inst.saddr == 0x7F: saddr_s = ""
elif inst.saddr == 124: saddr_s = ", off"
elif seg == 'scratch': saddr_s = f", {decode_src(inst.saddr)}"
elif inst.saddr in SPECIAL_PAIRS: saddr_s = f", {SPECIAL_PAIRS[inst.saddr]}"
elif seg == 'scratch': saddr_s = f", {decode_src(inst.saddr, cdna)}"
elif inst.saddr in (SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS): saddr_s = f", {(SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS)[inst.saddr]}"
elif t := _ttmp(inst.saddr, 2): saddr_s = f", {t}"
else: saddr_s = f", {_sreg(inst.saddr, 2) if inst.saddr < 106 else decode_src(inst.saddr)}"
else: saddr_s = f", {_sreg(inst.saddr, 2) if inst.saddr < 106 else decode_src(inst.saddr, cdna)}"
# addtid: no addr
if 'addtid' in name: return f"{instr} v{inst.data if 'store' in name else inst.vdst}{saddr_s}{mods}"
# addr width
addr_s = "off" if not inst.sve and seg == 'scratch' else _vreg(inst.addr, 1 if seg == 'scratch' or (inst.saddr not in (0x7F, 124)) else 2)
data_s, vdst_s = _vreg(inst.data, w), _vreg(inst.vdst, w // 2 if 'cmpswap' in name else w)
if 'addtid' in name: return f"{instr} {'a' if acc else 'v'}{inst.data if 'store' in name else inst.vdst}{saddr_s}{mods}"
# addr width: CDNA flat always uses 2 VGPRs (64-bit), scratch uses 1, RDNA uses 2 only when no saddr
if cdna:
addr_w = 1 if seg == 'scratch' else 2 # CDNA: flat/global always 64-bit addr
else:
addr_w = 1 if seg == 'scratch' or (inst.saddr not in (0x7F, 124)) else 2
addr_s = "off" if not inst.sve and seg == 'scratch' else _vreg(inst.addr, addr_w)
data_s, vdst_s = reg_fn(inst.data, w), reg_fn(inst.vdst, w // 2 if 'cmpswap' in name else w)
glc_or_sc0 = inst.sc0 if cdna else inst.glc
if 'atomic' in name:
return f"{instr} {vdst_s}, {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}" if glc_or_sc0 else f"{instr} {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}"
if 'store' in name: return f"{instr} {addr_s}, {data_s}{saddr_s}{mods}"
return f"{instr} {_vreg(inst.vdst, w)}, {addr_s}{saddr_s}{mods}"
return f"{instr} {reg_fn(inst.vdst, w)}, {addr_s}{saddr_s}{mods}"
def _disasm_ds(inst: DS) -> str:
op, name = inst.op, inst.op_name.lower()
acc = getattr(inst, 'acc', 0) # GFX90a accumulator register flag
reg_fn = _areg if acc else _vreg # use a[n] for acc=1, v[n] for acc=0
rp = 'a' if acc else 'v' # register prefix for single regs
gds = " gds" if inst.gds else ""
off = f" offset:{inst.offset0 | (inst.offset1 << 8)}" if inst.offset0 or inst.offset1 else ""
off2 = f" offset0:{inst.offset0} offset1:{inst.offset1}" if inst.offset0 or inst.offset1 else ""
off2 = (" offset0:" + str(inst.offset0) if inst.offset0 else "") + (" offset1:" + str(inst.offset1) if inst.offset1 else "")
w = inst.dst_regs()
d0, d1, dst, addr = _vreg(inst.data0, w), _vreg(inst.data1, w), _vreg(inst.vdst, w), f"v{inst.addr}"
d0, d1, dst, addr = reg_fn(inst.data0, w), reg_fn(inst.data1, w), reg_fn(inst.vdst, w), f"v{inst.addr}"
if op == DSOp.DS_NOP: return name
if op == DSOp.DS_BVH_STACK_RTN_B32: return f"{name} v{inst.vdst}, {addr}, v{inst.data0}, {_vreg(inst.data1, 4)}{off}{gds}"
if 'gws_sema' in name and op != DSOp.DS_GWS_SEMA_BR: return f"{name}{off}{gds}"
if 'gws_' in name: return f"{name} {addr}{off}{gds}"
if op in (DSOp.DS_CONSUME, DSOp.DS_APPEND): return f"{name} v{inst.vdst}{off}{gds}"
if 'gs_reg' in name: return f"{name} {_vreg(inst.vdst, 2)}, v{inst.data0}{off}{gds}"
if op in (DSOp.DS_CONSUME, DSOp.DS_APPEND): return f"{name} {rp}{inst.vdst}{off}{gds}"
if 'gs_reg' in name: return f"{name} {reg_fn(inst.vdst, 2)}, {rp}{inst.data0}{off}{gds}"
if '2addr' in name:
if 'load' in name: return f"{name} {_vreg(inst.vdst, w*2)}, {addr}{off2}{gds}"
if 'load' in name: return f"{name} {reg_fn(inst.vdst, w*2)}, {addr}{off2}{gds}"
if 'store' in name and 'xchg' not in name: return f"{name} {addr}, {d0}, {d1}{off2}{gds}"
return f"{name} {_vreg(inst.vdst, w*2)}, {addr}, {d0}, {d1}{off2}{gds}"
if 'load' in name: return f"{name} v{inst.vdst}{off}{gds}" if 'addtid' in name else f"{name} {dst}, {addr}{off}{gds}"
return f"{name} {reg_fn(inst.vdst, w*2)}, {addr}, {d0}, {d1}{off2}{gds}"
if 'write2' in name: return f"{name} {addr}, {d0}, {d1}{off2}{gds}"
if 'read2' in name: return f"{name} {reg_fn(inst.vdst, w*2)}, {addr}{off2}{gds}"
if 'load' in name: return f"{name} {rp}{inst.vdst}{off}{gds}" if 'addtid' in name else f"{name} {dst}, {addr}{off}{gds}"
if 'store' in name and not _has(name, 'cmp', 'xchg'):
return f"{name} v{inst.data0}{off}{gds}" if 'addtid' in name else f"{name} {addr}, {d0}{off}{gds}"
if 'swizzle' in name or op == DSOp.DS_ORDERED_COUNT: return f"{name} v{inst.vdst}, {addr}{off}{gds}"
if 'permute' in name: return f"{name} v{inst.vdst}, {addr}, v{inst.data0}{off}{gds}"
if 'condxchg' in name: return f"{name} {_vreg(inst.vdst, 2)}, {addr}, {_vreg(inst.data0, 2)}{off}{gds}"
return f"{name} {rp}{inst.data0}{off}{gds}" if 'addtid' in name else f"{name} {addr}, {d0}{off}{gds}"
if 'swizzle' in name or op == DSOp.DS_ORDERED_COUNT: return f"{name} {rp}{inst.vdst}, {addr}{off}{gds}"
if 'permute' in name: return f"{name} {rp}{inst.vdst}, {addr}, {rp}{inst.data0}{off}{gds}"
if 'condxchg' in name: return f"{name} {reg_fn(inst.vdst, 2)}, {addr}, {reg_fn(inst.data0, 2)}{off}{gds}"
if _has(name, 'cmpstore', 'mskor', 'wrap'):
return f"{name} {dst}, {addr}, {d0}, {d1}{off}{gds}" if '_rtn' in name else f"{name} {addr}, {d0}, {d1}{off}{gds}"
return f"{name} {dst}, {addr}, {d0}{off}{gds}" if '_rtn' in name else f"{name} {addr}, {d0}{off}{gds}"
@@ -318,6 +406,8 @@ def _disasm_vop3p(inst: VOP3P) -> str:
def _disasm_buf(inst: MUBUF | MTBUF) -> str:
name, cdna = inst.op_name.lower(), _is_cdna(inst)
acc = getattr(inst, 'acc', 0) # GFX90a accumulator register flag
reg_fn = _areg if acc else _vreg # use a[n] for acc=1, v[n] for acc=0
if cdna and name in ('buffer_wbl2', 'buffer_inv'): return name
if not cdna and inst.op in (MUBUFOp.BUFFER_GL0_INV, MUBUFOp.BUFFER_GL1_INV): return name
w = (2 if _has(name, 'xyz', 'xyzw') else 1) if 'd16' in name else \
@@ -326,9 +416,27 @@ def _disasm_buf(inst: MUBUF | MTBUF) -> str:
if hasattr(inst, 'tfe') and inst.tfe: w += 1
vaddr = _vreg(inst.vaddr, 2) if inst.offen and inst.idxen else f"v{inst.vaddr}" if inst.offen or inst.idxen else "off"
srsrc = _sreg_or_ttmp(inst.srsrc*4, 4)
if cdna: mods = ([f"format:{inst.format}"] if isinstance(inst, MTBUF) else []) + [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.sc0,"sc0"),(inst.nt,"nt"),(inst.sc1,"sc1")] if c]
else: mods = ([f"format:{inst.format}"] if isinstance(inst, MTBUF) else []) + [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.glc,"glc"),(inst.dlc,"dlc"),(inst.slc,"slc"),(inst.tfe,"tfe")] if c]
return f"{name} {_vreg(inst.vdata, w)}, {vaddr}, {srsrc}, {decode_src(inst.soffset)}{' ' + ' '.join(mods) if mods else ''}"
is_mtbuf = isinstance(inst, MTBUF) or isinstance(inst, C_MTBUF)
if is_mtbuf:
dfmt, nfmt = inst.format & 0xf, (inst.format >> 4) & 0x7
if acc: # GFX90a accumulator style: show dfmt/nfmt as numbers
fmt_s = f" dfmt:{dfmt}, nfmt:{nfmt}," # double space before dfmt per LLVM format
elif not cdna: # RDNA style: show combined format number
fmt_s = f" format:{inst.format}" if inst.format else ""
else: # CDNA: show format:[BUF_DATA_FORMAT_X] or format:[BUF_NUM_FORMAT_X]
dfmt_names = ['INVALID', '8', '16', '8_8', '32', '16_16', '10_11_11', '11_11_10', '10_10_10_2', '2_10_10_10', '8_8_8_8', '32_32', '16_16_16_16', '32_32_32', '32_32_32_32', 'RESERVED_15']
nfmt_names = ['UNORM', 'SNORM', 'USCALED', 'SSCALED', 'UINT', 'SINT', 'RESERVED_6', 'FLOAT']
if dfmt == 1 and nfmt == 0: fmt_s = "" # default, no format shown
elif nfmt == 0: fmt_s = f" format:[BUF_DATA_FORMAT_{dfmt_names[dfmt]}]" # only dfmt differs
elif dfmt == 1: fmt_s = f" format:[BUF_NUM_FORMAT_{nfmt_names[nfmt]}]" # only nfmt differs
else: fmt_s = f" format:[BUF_DATA_FORMAT_{dfmt_names[dfmt]},BUF_NUM_FORMAT_{nfmt_names[nfmt]}]" # both differ
else:
fmt_s = ""
if cdna: mods = [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.sc0,"glc"),(inst.nt,"slc"),(inst.sc1,"sc1")] if c]
else: mods = [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.glc,"glc"),(inst.dlc,"dlc"),(inst.slc,"slc"),(inst.tfe,"tfe")] if c]
soffset_s = decode_src(inst.soffset, cdna)
if cdna and not acc and is_mtbuf: return f"{name} {reg_fn(inst.vdata, w)}, {vaddr}, {srsrc}, {soffset_s}{fmt_s}{' ' + ' '.join(mods) if mods else ''}"
return f"{name} {reg_fn(inst.vdata, w)}, {vaddr}, {srsrc},{fmt_s} {soffset_s}{' ' + ' '.join(mods) if mods else ''}"
def _mimg_vaddr_width(name: str, dim: int, a16: bool) -> int:
"""Calculate vaddr register count for MIMG sample/gather operations."""
@@ -377,21 +485,23 @@ def _disasm_mimg(inst: MIMG) -> str:
return f"{name} {_vreg(inst.vdata, vdata)}, {vaddr_str}, {srsrc_str}{ssamp_str} {' '.join(mods)}"
def _disasm_sop1(inst: SOP1) -> str:
op, name = inst.op, inst.op_name.lower()
src = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0))
if not _is_cdna(inst):
op, name, cdna = inst.op, inst.op_name.lower(), _is_cdna(inst)
src = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0), cdna)
if not cdna:
if op == SOP1Op.S_GETPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}"
if op in (SOP1Op.S_SETPC_B64, SOP1Op.S_RFE_B64): return f"{name} {src}"
if op == SOP1Op.S_SWAPPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}, {src}"
if op in (SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64): return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, sendmsg({MSG.get(inst.ssrc0, str(inst.ssrc0))})"
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {src}"
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs(), cdna)}, {src}"
def _disasm_sop2(inst: SOP2) -> str:
return f"{inst.op_name.lower()} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0))}, {inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1))}"
cdna = _is_cdna(inst)
return f"{inst.op_name.lower()} {_fmt_sdst(inst.sdst, inst.dst_regs(), cdna)}, {inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0), cdna)}, {inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1), cdna)}"
def _disasm_sopc(inst: SOPC) -> str:
s0 = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0))
s1 = inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1))
cdna = _is_cdna(inst)
s0 = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0), cdna)
s1 = inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1), cdna)
return f"{inst.op_name.lower()} {s0}, {s1}"
def _disasm_sopk(inst: SOPK) -> str:
@@ -405,10 +515,10 @@ def _disasm_sopk(inst: SOPK) -> str:
if (not cdna and op in (SOPKOp.S_SETREG_B32, SOPKOp.S_GETREG_B32)) or (cdna and name in ('s_setreg_b32', 's_getreg_b32')):
hid, hoff, hsz = inst.simm16 & 0x3f, (inst.simm16 >> 6) & 0x1f, ((inst.simm16 >> 11) & 0x1f) + 1
hs = f"0x{inst.simm16:x}" if hid in (16, 17) else f"hwreg({HWREG.get(hid, str(hid))}, {hoff}, {hsz})"
return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1)}" if 'setreg' in name else f"{name} {_fmt_sdst(inst.sdst, 1)}, {hs}"
return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1, cdna)}" if 'setreg' in name else f"{name} {_fmt_sdst(inst.sdst, 1, cdna)}, {hs}"
if not cdna and op in (SOPKOp.S_SUBVECTOR_LOOP_BEGIN, SOPKOp.S_SUBVECTOR_LOOP_END):
return f"{name} {_fmt_sdst(inst.sdst, 1)}, 0x{inst.simm16:x}"
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, 0x{inst.simm16:x}"
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs(), cdna)}, 0x{inst.simm16:x}"
def _disasm_vinterp(inst: VINTERP) -> str:
mods = _mods((inst.waitexp, f"wait_exp:{inst.waitexp}"), (inst.clmp, "clamp"))
@@ -464,11 +574,54 @@ def _parse_ops(s: str) -> list[str]:
return ops
def _extract(text: str, pat: str, flags=re.I):
if m := re.search(pat, text, flags): return m, text[:m.start()] + text[m.end():]
if m := re.search(pat, text, flags): return m, text[:m.start()] + ' ' + text[m.end():]
return None, text
# Instruction aliases: LLVM uses different names for some instructions
_ALIASES = {
'v_cmp_tru_f16': 'v_cmp_t_f16', 'v_cmp_tru_f32': 'v_cmp_t_f32', 'v_cmp_tru_f64': 'v_cmp_t_f64',
'v_cmpx_tru_f16': 'v_cmpx_t_f16', 'v_cmpx_tru_f32': 'v_cmpx_t_f32', 'v_cmpx_tru_f64': 'v_cmpx_t_f64',
'v_cvt_flr_i32_f32': 'v_cvt_floor_i32_f32', 'v_cvt_rpi_i32_f32': 'v_cvt_nearest_i32_f32',
'v_ffbh_i32': 'v_cls_i32', 'v_ffbh_u32': 'v_clz_i32_u32', 'v_ffbl_b32': 'v_ctz_i32_b32',
'v_cvt_pkrtz_f16_f32': 'v_cvt_pk_rtz_f16_f32', 'v_fmac_legacy_f32': 'v_fmac_dx9_zero_f32', 'v_mul_legacy_f32': 'v_mul_dx9_zero_f32',
# SMEM aliases (dword -> b32, dwordx2 -> b64, etc.)
's_load_dword': 's_load_b32', 's_load_dwordx2': 's_load_b64', 's_load_dwordx4': 's_load_b128',
's_load_dwordx8': 's_load_b256', 's_load_dwordx16': 's_load_b512',
's_buffer_load_dword': 's_buffer_load_b32', 's_buffer_load_dwordx2': 's_buffer_load_b64',
's_buffer_load_dwordx4': 's_buffer_load_b128', 's_buffer_load_dwordx8': 's_buffer_load_b256',
's_buffer_load_dwordx16': 's_buffer_load_b512',
# VOP3 aliases
'v_cvt_pknorm_i16_f16': 'v_cvt_pk_norm_i16_f16', 'v_cvt_pknorm_u16_f16': 'v_cvt_pk_norm_u16_f16',
'v_add3_nc_u32': 'v_add3_u32', 'v_xor_add_u32': 'v_xad_u32',
# VINTERP aliases
'v_interp_p2_new_f32': 'v_interp_p2_f32',
# SOP1 aliases
's_ff1_i32_b32': 's_ctz_i32_b32', 's_ff1_i32_b64': 's_ctz_i32_b64',
's_flbit_i32_b32': 's_clz_i32_u32', 's_flbit_i32_b64': 's_clz_i32_u64', 's_flbit_i32': 's_cls_i32', 's_flbit_i32_i64': 's_cls_i32_i64',
's_andn1_saveexec_b32': 's_and_not0_saveexec_b32', 's_andn1_saveexec_b64': 's_and_not0_saveexec_b64',
's_andn1_wrexec_b32': 's_and_not0_wrexec_b32', 's_andn1_wrexec_b64': 's_and_not0_wrexec_b64',
's_andn2_saveexec_b32': 's_and_not1_saveexec_b32', 's_andn2_saveexec_b64': 's_and_not1_saveexec_b64',
's_andn2_wrexec_b32': 's_and_not1_wrexec_b32', 's_andn2_wrexec_b64': 's_and_not1_wrexec_b64',
's_orn1_saveexec_b32': 's_or_not0_saveexec_b32', 's_orn1_saveexec_b64': 's_or_not0_saveexec_b64',
's_orn2_saveexec_b32': 's_or_not1_saveexec_b32', 's_orn2_saveexec_b64': 's_or_not1_saveexec_b64',
# SOP2 aliases
's_andn2_b32': 's_and_not1_b32', 's_andn2_b64': 's_and_not1_b64',
's_orn2_b32': 's_or_not1_b32', 's_orn2_b64': 's_or_not1_b64',
# VOP2 aliases
'v_dot2c_f32_f16': 'v_dot2acc_f32_f16',
# More VOP3 aliases
'v_fma_legacy_f32': 'v_fma_dx9_zero_f32',
}
def _apply_alias(text: str) -> str:
mn = text.split()[0].lower() if ' ' in text else text.lower().rstrip('_')
# Try exact match first, then strip _e32/_e64 suffix
for m in (mn, mn.removesuffix('_e32'), mn.removesuffix('_e64')):
if m in _ALIASES: return _ALIASES[m] + text[len(m):]
return text
def get_dsl(text: str) -> str:
text, kw = text.strip(), []
text, kw = _apply_alias(text.strip()), []
# Extract modifiers
for pat, val in [(r'\s+mul:2(?:\s|$)', 1), (r'\s+mul:4(?:\s|$)', 2), (r'\s+div:2(?:\s|$)', 3)]:
if (m := _extract(text, pat))[0]: kw.append(f'omod={val}'); text = m[1]; break
@@ -484,6 +637,11 @@ def get_dsl(text: str) -> str:
m, text = _extract(text, r'\s+dlc(?:\s|$)'); dlc = 1 if m else None
m, text = _extract(text, r'\s+glc(?:\s|$)'); glc = 1 if m else None
m, text = _extract(text, r'\s+slc(?:\s|$)'); slc = 1 if m else None
m, text = _extract(text, r'\s+tfe(?:\s|$)'); tfe = 1 if m else None
m, text = _extract(text, r'\s+offen(?:\s|$)'); offen = 1 if m else None
m, text = _extract(text, r'\s+idxen(?:\s|$)'); idxen = 1 if m else None
m, text = _extract(text, r'\s+format:\[([^\]]+)\]'); fmt_val = m.group(1) if m else None
m, text = _extract(text, r'\s+format:(\d+)'); fmt_val = m.group(1) if m and not fmt_val else fmt_val
m, text = _extract(text, r'\s+neg_lo:\[([^\]]+)\]'); neg_lo = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None
m, text = _extract(text, r'\s+neg_hi:\[([^\]]+)\]'); neg_hi = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None
if waitexp: kw.append(f'waitexp={waitexp}')
@@ -530,9 +688,30 @@ def get_dsl(text: str) -> str:
if off_val and len(ops) >= 3: return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={off_val}, soffset={args[2]}{gs}{ds})"
if len(ops) >= 3: return f"{mn}(sdata={args[0]}, sbase={args[1]}, soffset={args[2]}{gs}{ds})"
# Buffer
if mn.startswith('buffer_') and len(ops) >= 2 and ops[1].strip().lower() == 'off':
return f"{mn}(vdata={args[0]}, vaddr=0, srsrc={args[2]}, soffset={f'RawImm({args[3].strip()})' if len(args) > 3 else 'RawImm(0)'})"
# Buffer (MUBUF/MTBUF) instructions
if mn.startswith(('buffer_', 'tbuffer_')):
is_tbuf = mn.startswith('tbuffer_')
# Parse format value for tbuffer
fmt_num = None
if fmt_val is not None:
if fmt_val.isdigit(): fmt_num = int(fmt_val)
else: fmt_num = BUF_FMT.get(fmt_val.replace(' ', '')) or _parse_buf_fmt_combo(fmt_val)
# Handle special no-arg buffer ops
if mn in ('buffer_gl0_inv', 'buffer_gl1_inv', 'buffer_wbl2', 'buffer_inv'): return f"{mn}()"
# Build modifiers string
buf_mods = "".join([f", offset={off_val}" if off_val else "", ", glc=1" if glc else "", ", dlc=1" if dlc else "",
", slc=1" if slc else "", ", tfe=1" if tfe else "", ", offen=1" if offen else "", ", idxen=1" if idxen else ""])
if is_tbuf and fmt_num is not None: buf_mods = f", format={fmt_num}" + buf_mods
# Determine vaddr value (v[0] for 'off', actual register otherwise)
vaddr_idx = 1
if len(ops) > vaddr_idx and ops[vaddr_idx].strip().lower() == 'off': vaddr_val = "v[0]"
else: vaddr_val = args[vaddr_idx] if len(args) > vaddr_idx else "v[0]"
# srsrc and soffset indices depend on whether vaddr is 'off'
srsrc_idx, soff_idx = (2, 3) if len(ops) > 1 else (1, 2)
srsrc_val = args[srsrc_idx] if len(args) > srsrc_idx else "s[0:3]"
soff_val = args[soff_idx] if len(args) > soff_idx else "0"
# soffset: integers are inline constants, don't wrap in RawImm
return f"{mn}(vdata={args[0]}, vaddr={vaddr_val}, srsrc={srsrc_val}, soffset={soff_val}{buf_mods})"
# FLAT/GLOBAL/SCRATCH load/store/atomic - saddr needs RawImm(124) for off/null
def _saddr(a): return 'RawImm(124)' if a in ('OFF', 'NULL') else a
@@ -582,6 +761,15 @@ def get_dsl(text: str) -> str:
if mn.replace('_e64', '') in vcc_ops and mn.endswith('_e64'): mn = mn.replace('_e64', '')
if mn.startswith('v_cmp') and not mn.endswith('_e64') and len(args) >= 3 and ops[0].strip().lower() in ('vcc_lo', 'vcc_hi', 'vcc'): args = args[1:]
if 'cmpx' in mn and mn.endswith('_e64') and len(args) == 2: args = ['RawImm(126)'] + args
# v_cmp_*_e64 has SGPR destination in vdst field - encode as RawImm
_SGPR_NAMES = {'vcc_lo': 106, 'vcc_hi': 107, 'vcc': 106, 'null': 124, 'm0': 125, 'exec_lo': 126, 'exec_hi': 127}
if mn.startswith('v_cmp') and 'cmpx' not in mn and mn.endswith('_e64') and len(args) >= 1:
dst = ops[0].strip().lower()
if dst.startswith('s') and dst[1:].isdigit(): args[0] = f'RawImm({int(dst[1:])})'
elif dst.startswith('s[') and ':' in dst: args[0] = f'RawImm({int(dst[2:].split(":")[0])})'
elif dst.startswith('ttmp') and dst[4:].isdigit(): args[0] = f'RawImm({108 + int(dst[4:])})'
elif dst.startswith('ttmp[') and ':' in dst: args[0] = f'RawImm({108 + int(dst[5:].split(":")[0])})'
elif dst in _SGPR_NAMES: args[0] = f'RawImm({_SGPR_NAMES[dst]})'
fn = mn.replace('.', '_')
if opsel is not None: args = [re.sub(r'\.[hl]$', '', a) for a in args]
@@ -629,31 +817,76 @@ def asm(text: str) -> Inst:
try:
from extra.assembly.amd.autogen.cdna.ins import (VOP1 as CDNA_VOP1, VOP2 as CDNA_VOP2, VOPC as CDNA_VOPC, VOP3A, VOP3B, VOP3P as CDNA_VOP3P,
SOP1 as CDNA_SOP1, SOP2 as CDNA_SOP2, SOPC as CDNA_SOPC, SOPK as CDNA_SOPK, SOPP as CDNA_SOPP, SMEM as CDNA_SMEM, DS as CDNA_DS,
FLAT as CDNA_FLAT, MUBUF as CDNA_MUBUF, MTBUF as CDNA_MTBUF, SDWA, DPP, VOP1Op as CDNA_VOP1Op)
FLAT as CDNA_FLAT, MUBUF as CDNA_MUBUF, MTBUF as CDNA_MTBUF, SDWA, DPP, VOP1Op as CDNA_VOP1Op, VOP2Op as CDNA_VOP2Op, VOPCOp as CDNA_VOPCOp)
def _cdna_src(inst, v, neg, abs_=0, n=1):
s = inst.lit(v) if v == 255 else _fmt_src(v, n)
s = inst.lit(v) if v == 255 else _fmt_src(v, n, cdna=True)
if abs_: s = f"|{s}|"
return f"neg({s})" if neg and v == 255 else (f"-{s}" if neg else s)
def _disasm_vop3a(inst) -> str:
name, n, cl, om = inst.op_name.lower(), inst.num_srcs(), " clamp" if inst.clmp else "", _omod(inst.omod)
s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.src_regs(0)), _cdna_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.src_regs(1)), _cdna_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.src_regs(2))
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}"
if inst.op.value < 256: return f"{name}_e64 {s0}, {s1}" if name.startswith('v_cmpx') else f"{name}_e64 {_fmt_sdst(inst.vdst, 1)}, {s0}, {s1}"
suf = "_e64" if inst.op.value < 512 else ""
return f"{name}{suf} {dst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else (f"{name}{suf}" if name == 'v_nop' else f"{name}{suf} {dst}, {s0}, {s1}{cl}{om}" if n == 2 else f"{name}{suf} {dst}, {s0}{cl}{om}")
# CDNA VOP2 aliases: new opcode name -> old name expected by LLVM tests
_CDNA_VOP3_ALIASES = {'v_fmac_f64': 'v_mul_legacy_f32', 'v_dot2c_f32_bf16': 'v_mac_f32'}
def _disasm_vop3b(inst) -> str:
name, n = inst.op_name.lower(), inst.num_srcs()
s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1), _cdna_src(inst, inst.src1, inst.neg&2), _cdna_src(inst, inst.src2, inst.neg&4)
dst, suf = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}", "_e64" if 'co_' in name else ""
def _disasm_vop3a(inst) -> str:
op_val = inst._values.get('op', 0) # get raw opcode value, not enum value
if hasattr(op_val, 'value'): op_val = op_val.value # in case it's stored as enum
name = inst.op_name.lower() or f'vop3a_op_{op_val}'
from extra.assembly.amd.dsl import spec_num_srcs, spec_regs
n = spec_num_srcs(name) if name else inst.num_srcs()
cl, om = " clamp" if inst.clmp else "", _omod(inst.omod)
return f"{name}{suf} {dst}, {_fmt_sdst(inst.sdst, 1)}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name}{suf} {dst}, {_fmt_sdst(inst.sdst, 1)}, {s0}, {s1}{cl}{om}"
orig_name = name
name = _CDNA_VOP3_ALIASES.get(name, name) # apply CDNA aliases
# For aliased ops, recalculate sources without 64-bit assumption
if name != orig_name:
s0, s1 = _cdna_src(inst, inst.src0, inst.neg&1, inst.abs&1, 1), _cdna_src(inst, inst.src1, inst.neg&2, inst.abs&2, 1)
s2 = ""
dst = f"v{inst.vdst}"
else:
dregs, r0, r1, r2 = spec_regs(name) if name else (inst.dst_regs(), inst.src_regs(0), inst.src_regs(1), inst.src_regs(2))
s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1, inst.abs&1, r0), _cdna_src(inst, inst.src1, inst.neg&2, inst.abs&2, r1), _cdna_src(inst, inst.src2, inst.neg&4, inst.abs&4, r2)
dst = _vreg(inst.vdst, dregs) if dregs > 1 else f"v{inst.vdst}"
# True VOP3 instructions (512+) - 3-source ops
if op_val >= 512:
return f"{name} {dst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name} {dst}, {s0}, {s1}{cl}{om}"
# VOPC (0-255): writes to SGPR pair, VOP2 (256-319): 2-3 src, VOP1 (320-511): 1 src
if op_val < 256:
sdst = _fmt_sdst(inst.vdst, 2, cdna=True) # VOPC writes to 64-bit SGPR pair
# v_cmpx_ also writes to sdst in CDNA VOP3 (unlike VOP32 where it writes to exec)
return f"{name}_e64 {sdst}, {s0}, {s1}{cl}"
if 320 <= op_val < 512: # VOP1 promoted
if name in ('v_nop', 'v_clrexcp'): return f"{name}_e64"
return f"{name}_e64 {dst}, {s0}{cl}{om}"
# VOP2 promoted (256-319)
if name == 'v_cndmask_b32':
s2 = _fmt_src(inst.src2, 2, cdna=True) # src2 is 64-bit SGPR pair
return f"{name}_e64 {dst}, {s0}, {s1}, {s2}{cl}{om}"
if name in ('v_mul_legacy_f32', 'v_mac_f32'):
return f"{name}_e64 {dst}, {s0}, {s1}{cl}{om}"
suf = "_e64" if op_val < 512 else ""
return f"{name}{suf} {dst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name}{suf} {dst}, {s0}, {s1}{cl}{om}"
# GFX9-specific VOP3B opcodes not in CDNA enum
def _disasm_vop3b(inst) -> str:
op_val = inst._values.get('op', 0)
if hasattr(op_val, 'value'): op_val = op_val.value
name = inst.op_name.lower() or f'vop3b_op_{op_val}'
from extra.assembly.amd.dsl import spec_num_srcs, spec_regs
n = spec_num_srcs(name) if name else inst.num_srcs()
dregs, r0, r1, r2 = spec_regs(name) if name else (inst.dst_regs(), inst.src_regs(0), inst.src_regs(1), inst.src_regs(2))
s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1, n=r0), _cdna_src(inst, inst.src1, inst.neg&2, n=r1), _cdna_src(inst, inst.src2, inst.neg&4, n=r2)
dst = _vreg(inst.vdst, dregs) if dregs > 1 else f"v{inst.vdst}"
sdst = _fmt_sdst(inst.sdst, 2, cdna=True) # VOP3B sdst is always 64-bit SGPR pair
cl, om = " clamp" if inst.clmp else "", _omod(inst.omod)
# Carry ops need special handling
if name in ('v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'):
s2 = _fmt_src(inst.src2, 2, cdna=True) # src2 is carry-in (64-bit SGPR pair)
return f"{name}_e64 {dst}, {sdst}, {s0}, {s1}, {s2}{cl}{om}"
suf = "_e64" if 'co_' in name else ""
return f"{name}{suf} {dst}, {sdst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name}{suf} {dst}, {sdst}, {s0}, {s1}{cl}{om}"
def _disasm_cdna_vop3p(inst) -> str:
name, n, is_mfma = inst.op_name.lower(), inst.num_srcs(), 'mfma' in inst.op_name.lower() or 'smfmac' in inst.op_name.lower()
get_src = lambda v, sc: inst.lit(v) if v == 255 else _fmt_src(v, sc)
get_src = lambda v, sc: inst.lit(v) if v == 255 else _fmt_src(v, sc, cdna=True)
if is_mfma: sc = 2 if 'iu4' in name else 4 if 'iu8' in name or 'i4' in name else 8 if 'f16' in name or 'bf16' in name else 4; src0, src1, src2, dst = get_src(inst.src0, sc), get_src(inst.src1, sc), get_src(inst.src2, 16), _vreg(inst.vdst, 16)
else: src0, src1, src2, dst = get_src(inst.src0, 1), get_src(inst.src1, 1), get_src(inst.src2, 1), f"v{inst.vdst}"
opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2)
@@ -665,20 +898,93 @@ try:
_UNUSED = {0: 'UNUSED_PAD', 1: 'UNUSED_SEXT', 2: 'UNUSED_PRESERVE'}
_DPP = {0x130: "wave_shl:1", 0x134: "wave_rol:1", 0x138: "wave_shr:1", 0x13c: "wave_ror:1", 0x140: "row_mirror", 0x141: "row_half_mirror", 0x142: "row_bcast:15", 0x143: "row_bcast:31"}
def _sdwa_src0(v, is_sgpr, sext=0, neg=0, abs_=0):
# s0=0: VGPR (v is VGPR number), s0=1: SGPR/constant (v is encoded like normal src)
s = decode_src(v, cdna=True) if is_sgpr else f"v{v}"
if sext: s = f"sext({s})"
if abs_: s = f"|{s}|"
return f"-{s}" if neg else s
def _sdwa_vsrc1(v, sext=0, neg=0, abs_=0):
# For VOP2 SDWA, vsrc1 is in vop_op field as raw VGPR number
s = f"v{v}"
if sext: s = f"sext({s})"
if abs_: s = f"|{s}|"
return f"-{s}" if neg else s
_OMOD_SDWA = {0: "", 1: " mul:2", 2: " mul:4", 3: " div:2"}
def _disasm_sdwa(inst) -> str:
try: name = CDNA_VOP1Op(inst.vop_op).name.lower()
except ValueError: name = f"vop1_op_{inst.vop_op}"
src = f"v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}" if isinstance(inst.src0, int) else str(inst.src0)
mods = [f"dst_sel:{_SEL[inst.dst_sel]}" for _ in [1] if inst.dst_sel != 6] + [f"dst_unused:{_UNUSED[inst.dst_u]}" for _ in [1] if inst.dst_u] + [f"src0_sel:{_SEL[inst.src0_sel]}" for _ in [1] if inst.src0_sel != 6]
return f"{name}_sdwa v{inst.vdst}, {src}" + (" " + " ".join(mods) if mods else "")
# SDWA format: vop2_op=63 -> VOP1, vop2_op=62 -> VOPC, vop2_op=0-61 -> VOP2
vop2_op = inst.vop2_op
src0 = _sdwa_src0(inst.src0, inst.s0, inst.src0_sext, inst.src0_neg, inst.src0_abs)
clamp = " clamp" if inst.clmp else ""
omod = _OMOD_SDWA.get(inst.omod, "")
if vop2_op == 63: # VOP1
try: name = CDNA_VOP1Op(inst.vop_op).name.lower()
except ValueError: name = f"vop1_op_{inst.vop_op}"
dst = f"v{inst.vdst}"
mods = [f"dst_sel:{_SEL[inst.dst_sel]}", f"dst_unused:{_UNUSED[inst.dst_u]}", f"src0_sel:{_SEL[inst.src0_sel]}"]
return f"{name}_sdwa {dst}, {src0}{clamp}{omod} " + " ".join(mods)
elif vop2_op == 62: # VOPC
try: name = CDNA_VOPCOp(inst.vdst).name.lower() # opcode is in vdst field for VOPC SDWA
except ValueError: name = f"vopc_op_{inst.vdst}"
src1 = _sdwa_vsrc1(inst.vop_op, inst.src1_sext, inst.src1_neg, inst.src1_abs) # vsrc1 is in vop_op field
# VOPC SDWA: dst encoded in byte 5 (bits 47:40): 0=vcc, 128+n=s[n:n+1]
sdst_enc = inst.dst_sel | (inst.dst_u << 3) | (inst.clmp << 5) | (inst.omod << 6)
if sdst_enc == 0:
sdst = "vcc"
else:
sdst_val = sdst_enc - 128 if sdst_enc >= 128 else sdst_enc
sdst = _fmt_sdst(sdst_val, 2, cdna=True)
mods = [f"src0_sel:{_SEL[inst.src0_sel]}", f"src1_sel:{_SEL[inst.src1_sel]}"]
return f"{name}_sdwa {sdst}, {src0}, {src1} " + " ".join(mods)
else: # VOP2
try: name = CDNA_VOP2Op(vop2_op).name.lower()
except ValueError: name = f"vop2_op_{vop2_op}"
name = _CDNA_DISASM_ALIASES.get(name, name) # apply aliases (v_fmac -> v_mac, etc.)
dst = f"v{inst.vdst}"
src1 = _sdwa_vsrc1(inst.vop_op, inst.src1_sext, inst.src1_neg, inst.src1_abs) # vsrc1 is in vop_op field
mods = [f"dst_sel:{_SEL[inst.dst_sel]}", f"dst_unused:{_UNUSED[inst.dst_u]}", f"src0_sel:{_SEL[inst.src0_sel]}", f"src1_sel:{_SEL[inst.src1_sel]}"]
# v_cndmask_b32 needs vcc as third operand
if name == 'v_cndmask_b32':
return f"{name}_sdwa {dst}, {src0}, {src1}, vcc{clamp}{omod} " + " ".join(mods)
# Carry ops need vcc - v_addc/subb also need vcc as carry-in
if name in ('v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'):
return f"{name}_sdwa {dst}, vcc, {src0}, {src1}, vcc{clamp}{omod} " + " ".join(mods)
if '_co_' in name:
return f"{name}_sdwa {dst}, vcc, {src0}, {src1}{clamp}{omod} " + " ".join(mods)
return f"{name}_sdwa {dst}, {src0}, {src1}{clamp}{omod} " + " ".join(mods)
def _dpp_src(v, neg=0, abs_=0):
s = f"v{v}" if v < 256 else f"v{v - 256}"
if abs_: s = f"|{s}|"
return f"-{s}" if neg else s
def _disasm_dpp(inst) -> str:
try: name = CDNA_VOP1Op(inst.vop_op).name.lower()
except ValueError: name = f"vop1_op_{inst.vop_op}"
src, ctrl = f"v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}" if isinstance(inst.src0, int) else str(inst.src0), inst.dpp_ctrl
# DPP format: vop2_op=63 -> VOP1, vop2_op=0-62 -> VOP2
vop2_op = inst.vop2_op
ctrl = inst.dpp_ctrl
dpp = f"quad_perm:[{ctrl&3},{(ctrl>>2)&3},{(ctrl>>4)&3},{(ctrl>>6)&3}]" if ctrl < 0x100 else f"row_shl:{ctrl&0xf}" if ctrl < 0x110 else f"row_shr:{ctrl&0xf}" if ctrl < 0x120 else f"row_ror:{ctrl&0xf}" if ctrl < 0x130 else _DPP.get(ctrl, f"dpp_ctrl:0x{ctrl:x}")
mods = [dpp] + [f"row_mask:0x{inst.row_mask:x}" for _ in [1] if inst.row_mask != 0xf] + [f"bank_mask:0x{inst.bank_mask:x}" for _ in [1] if inst.bank_mask != 0xf] + ["bound_ctrl:1" for _ in [1] if inst.bound_ctrl]
return f"{name}_dpp v{inst.vdst}, {src} " + " ".join(mods)
src0 = _dpp_src(inst.src0, inst.src0_neg, inst.src0_abs)
# DPP modifiers: row_mask and bank_mask always shown, bound_ctrl:0 when bit=1
mods = [dpp, f"row_mask:0x{inst.row_mask:x}", f"bank_mask:0x{inst.bank_mask:x}"] + (["bound_ctrl:0"] if inst.bound_ctrl else [])
if vop2_op == 63: # VOP1
try: name = CDNA_VOP1Op(inst.vop_op).name.lower()
except ValueError: name = f"vop1_op_{inst.vop_op}"
return f"{name}_dpp v{inst.vdst}, {src0} " + " ".join(mods)
else: # VOP2
try: name = CDNA_VOP2Op(vop2_op).name.lower()
except ValueError: name = f"vop2_op_{vop2_op}"
name = _CDNA_DISASM_ALIASES.get(name, name)
src1 = _dpp_src(inst.vop_op, inst.src1_neg, inst.src1_abs) # vsrc1 is in vop_op field
if name == 'v_cndmask_b32':
return f"{name}_dpp v{inst.vdst}, {src0}, {src1}, vcc " + " ".join(mods)
if name in ('v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'):
return f"{name}_dpp v{inst.vdst}, vcc, {src0}, {src1}, vcc " + " ".join(mods)
if '_co_' in name:
return f"{name}_dpp v{inst.vdst}, vcc, {src0}, {src1} " + " ".join(mods)
return f"{name}_dpp v{inst.vdst}, {src0}, {src1} " + " ".join(mods)
# Register CDNA handlers - shared formats use merged disassemblers, CDNA-only formats use dedicated ones
DISASM_HANDLERS.update({CDNA_VOP1: _disasm_vop1, CDNA_VOP2: _disasm_vop2, CDNA_VOPC: _disasm_vopc,

View File

@@ -56,6 +56,7 @@ class MTBUF(Inst64):
srsrc:SGPRField = bits[52:48]
soffset:SSrc = bits[63:56]
offset:Imm = bits[11:0]
format = bits[25:19]
offen = bits[12]
idxen = bits[13]
sc1 = bits[53]
@@ -124,7 +125,7 @@ class SMEM(Inst64):
sbase:SGPRField = bits[5:0]
soffset:SSrc = bits[63:57]
offset:Imm = bits[52:32]
glc = bits[14]
glc = bits[16]
soe = bits[14]
nv = bits[15]
imm = bits[17]

View File

@@ -1594,3 +1594,68 @@ class VOPDOp(IntEnum):
V_DUAL_ADD_NC_U32 = 16
V_DUAL_LSHLREV_B32 = 17
V_DUAL_AND_B32 = 18
class BufFmt(IntEnum):
BUF_FMT_8_UNORM = 1
BUF_FMT_8_SNORM = 2
BUF_FMT_8_USCALED = 3
BUF_FMT_8_SSCALED = 4
BUF_FMT_8_UINT = 5
BUF_FMT_8_SINT = 6
BUF_FMT_16_UNORM = 7
BUF_FMT_16_SNORM = 8
BUF_FMT_16_USCALED = 9
BUF_FMT_16_SSCALED = 10
BUF_FMT_16_UINT = 11
BUF_FMT_16_SINT = 12
BUF_FMT_16_FLOAT = 13
BUF_FMT_8_8_UNORM = 14
BUF_FMT_8_8_SNORM = 15
BUF_FMT_8_8_USCALED = 16
BUF_FMT_8_8_SSCALED = 17
BUF_FMT_8_8_UINT = 18
BUF_FMT_8_8_SINT = 19
BUF_FMT_32_UINT = 20
BUF_FMT_32_SINT = 21
BUF_FMT_32_FLOAT = 22
BUF_FMT_16_16_UNORM = 23
BUF_FMT_16_16_SNORM = 24
BUF_FMT_16_16_USCALED = 25
BUF_FMT_16_16_SSCALED = 26
BUF_FMT_16_16_UINT = 27
BUF_FMT_16_16_SINT = 28
BUF_FMT_16_16_FLOAT = 29
BUF_FMT_10_11_11_FLOAT = 30
BUF_FMT_11_11_10_FLOAT = 31
BUF_FMT_10_10_10_2_UNORM = 32
BUF_FMT_10_10_10_2_SNORM = 33
BUF_FMT_10_10_10_2_UINT = 34
BUF_FMT_10_10_10_2_SINT = 35
BUF_FMT_2_10_10_10_UNORM = 36
BUF_FMT_2_10_10_10_SNORM = 37
BUF_FMT_2_10_10_10_USCALED = 38
BUF_FMT_2_10_10_10_SSCALED = 39
BUF_FMT_2_10_10_10_UINT = 40
BUF_FMT_2_10_10_10_SINT = 41
BUF_FMT_8_8_8_8_UNORM = 42
BUF_FMT_8_8_8_8_SNORM = 43
BUF_FMT_8_8_8_8_USCALED = 44
BUF_FMT_8_8_8_8_SSCALED = 45
BUF_FMT_8_8_8_8_UINT = 46
BUF_FMT_8_8_8_8_SINT = 47
BUF_FMT_32_32_UINT = 48
BUF_FMT_32_32_SINT = 49
BUF_FMT_32_32_FLOAT = 50
BUF_FMT_16_16_16_16_UNORM = 51
BUF_FMT_16_16_16_16_SNORM = 52
BUF_FMT_16_16_16_16_USCALED = 53
BUF_FMT_16_16_16_16_SSCALED = 54
BUF_FMT_16_16_16_16_UINT = 55
BUF_FMT_16_16_16_16_SINT = 56
BUF_FMT_16_16_16_16_FLOAT = 57
BUF_FMT_32_32_32_UINT = 58
BUF_FMT_32_32_32_SINT = 59
BUF_FMT_32_32_32_FLOAT = 60
BUF_FMT_32_32_32_32_UINT = 61
BUF_FMT_32_32_32_32_SINT = 62
BUF_FMT_32_32_32_32_FLOAT = 63

View File

@@ -1627,3 +1627,52 @@ class VSCRATCHOp(IntEnum):
SCRATCH_STORE_D16_HI_B16 = 37
SCRATCH_LOAD_BLOCK = 83
SCRATCH_STORE_BLOCK = 84
class BufFmt(IntEnum):
BUF_FMT_8_UNORM = 1
BUF_FMT_8_SNORM = 2
BUF_FMT_8_USCALED = 3
BUF_FMT_8_SSCALED = 4
BUF_FMT_8_UINT = 5
BUF_FMT_8_SINT = 6
BUF_FMT_16_UNORM = 7
BUF_FMT_16_SNORM = 8
BUF_FMT_16_USCALED = 9
BUF_FMT_16_SSCALED = 10
BUF_FMT_16_UINT = 11
BUF_FMT_16_SINT = 12
BUF_FMT_16_FLOAT = 13
BUF_FMT_8_8_UNORM = 14
BUF_FMT_8_8_SNORM = 15
BUF_FMT_8_8_USCALED = 16
BUF_FMT_8_8_SSCALED = 17
BUF_FMT_8_8_UINT = 18
BUF_FMT_8_8_SINT = 19
BUF_FMT_32_UINT = 20
BUF_FMT_32_SINT = 21
BUF_FMT_32_FLOAT = 22
BUF_FMT_16_16_UNORM = 23
BUF_FMT_10_10_10_2_UNORM = 32
BUF_FMT_10_10_10_2_SNORM = 33
BUF_FMT_10_10_10_2_UINT = 34
BUF_FMT_10_10_10_2_SINT = 35
BUF_FMT_2_10_10_10_UNORM = 36
BUF_FMT_2_10_10_10_SNORM = 37
BUF_FMT_2_10_10_10_USCALED = 38
BUF_FMT_2_10_10_10_SSCALED = 39
BUF_FMT_2_10_10_10_UINT = 40
BUF_FMT_2_10_10_10_SINT = 41
BUF_FMT_8_8_8_8_UNORM = 42
BUF_FMT_8_8_8_8_SNORM = 43
BUF_FMT_8_8_8_8_USCALED = 44
BUF_FMT_8_8_8_8_SSCALED = 45
BUF_FMT_8_8_8_8_UINT = 46
BUF_FMT_8_8_8_8_SINT = 47
BUF_FMT_32_32_UINT = 48
BUF_FMT_32_32_SINT = 49
BUF_FMT_32_32_FLOAT = 50
BUF_FMT_16_16_16_16_UNORM = 51
BUF_FMT_16_16_16_16_SNORM = 52
BUF_FMT_16_16_16_16_USCALED = 53
BUF_FMT_16_16_16_16_SSCALED = 54
BUF_FMT_16_16_16_16_UINT = 55

View File

@@ -7,6 +7,7 @@ from functools import cache
from typing import overload, Annotated, TypeVar, Generic
from extra.assembly.amd.autogen.rdna3.enum import (VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, VOPDOp, SOP1Op, SOP2Op,
SOPCOp, SOPKOp, SOPPOp, SMEMOp, DSOp, FLATOp, MUBUFOp, MTBUFOp, MIMGOp, VINTERPOp)
from extra.assembly.amd.autogen.cdna.enum import VOP1Op as CDNA_VOP1Op, VOP2Op as CDNA_VOP2Op
# Common masks and bit conversion functions
MASK32, MASK64, MASK128 = 0xffffffff, 0xffffffffffffffff, (1 << 128) - 1
@@ -46,12 +47,14 @@ def _i64(f):
# Instruction spec - register counts and dtypes derived from instruction names
_REGS = {'B32': 1, 'B64': 2, 'B96': 3, 'B128': 4, 'B256': 8, 'B512': 16,
'F32': 1, 'I32': 1, 'U32': 1, 'F64': 2, 'I64': 2, 'U64': 2,
'F16': 1, 'I16': 1, 'U16': 1, 'B16': 1, 'I8': 1, 'U8': 1, 'B8': 1}
'F16': 1, 'I16': 1, 'U16': 1, 'B16': 1, 'I8': 1, 'U8': 1, 'B8': 1,
'DWORD': 1, 'DWORDX2': 2, 'DWORDX3': 3, 'DWORDX4': 4, 'DWORDX8': 8, 'DWORDX16': 16,
'BYTE': 1, 'SHORT': 1, 'UBYTE': 1, 'SBYTE': 1, 'USHORT': 1, 'SSHORT': 1}
_CVT_RE = re.compile(r'CVT_([FIUB]\d+)_([FIUB]\d+)$')
_MAD_MUL_RE = re.compile(r'(?:MAD|MUL)_([IU]\d+)_([IU]\d+)$')
_PACK_RE = re.compile(r'PACK_([FIUB]\d+)_([FIUB]\d+)$')
_DST_SRC_RE = re.compile(r'_([FIUB]\d+)_([FIUB]\d+)$')
_SINGLE_RE = re.compile(r'_([FIUB](?:32|64|16|8|96|128|256|512))$')
_SINGLE_RE = re.compile(r'_([FIUB](?:32|64|16|8|96|128|256|512)|DWORD(?:X(?:2|3|4|8|16))?|[US]?BYTE|[US]?SHORT)$')
@cache
def _suffix(name: str) -> tuple[str | None, str | None]:
name = name.upper()
@@ -242,7 +245,11 @@ def unwrap(val) -> int:
FLOAT_ENC = {0.5: 240, -0.5: 241, 1.0: 242, -1.0: 243, 2.0: 244, -2.0: 245, 4.0: 246, -4.0: 247}
FLOAT_DEC = {v: str(k) for k, v in FLOAT_ENC.items()}
SPECIAL_GPRS = {106: "vcc_lo", 107: "vcc_hi", 124: "null", 125: "m0", 126: "exec_lo", 127: "exec_hi", 253: "scc"}
SPECIAL_GPRS_CDNA = {102: "flat_scratch_lo", 103: "flat_scratch_hi", 104: "xnack_mask_lo", 105: "xnack_mask_hi",
106: "vcc_lo", 107: "vcc_hi", 124: "m0", 126: "exec_lo", 127: "exec_hi",
251: "src_vccz", 252: "src_execz", 253: "src_scc", 254: "src_lds_direct"}
SPECIAL_PAIRS = {106: "vcc", 126: "exec"}
SPECIAL_PAIRS_CDNA = {102: "flat_scratch", 104: "xnack_mask", 106: "vcc", 126: "exec"}
SRC_FIELDS = {'src0', 'src1', 'src2', 'ssrc0', 'ssrc1', 'soffset', 'srcx0', 'srcy0'}
RAW_FIELDS = {'vdata', 'vdst', 'vaddr', 'addr', 'data', 'data0', 'data1', 'sdst', 'sdata', 'vsrc1'}
@@ -259,9 +266,10 @@ def encode_src(val) -> int:
if isinstance(val, int): return 128 + val if 0 <= val <= 64 else 192 - val if -16 <= val <= -1 else 255
return 255
def decode_src(val: int) -> str:
def decode_src(val: int, cdna: bool = False) -> str:
special = SPECIAL_GPRS_CDNA if cdna else SPECIAL_GPRS
if val in special: return special[val]
if val <= 105: return f"s{val}"
if val in SPECIAL_GPRS: return SPECIAL_GPRS[val]
if val in FLOAT_DEC: return FLOAT_DEC[val]
if 108 <= val <= 123: return f"ttmp{val - 108}"
if 128 <= val <= 192: return str(val - 128)
@@ -385,14 +393,14 @@ class Inst:
if name in SRC_FIELDS: self._encode_src(name, val)
elif name in RAW_FIELDS: self._encode_raw(name, val)
elif name == 'sbase': self._values[name] = (val.idx if isinstance(val, Reg) else val.val if isinstance(val, SrcMod) else val * 2) // 2
elif name in {'srsrc', 'ssamp'} and isinstance(val, Reg): self._values[name] = val.idx // 4
elif name in {'srsrc', 'ssamp'} and isinstance(val, Reg): self._values[name] = _encode_reg(val) // 4
elif marker is _VDSTYEnc and isinstance(val, VGPR): self._values[name] = val.idx >> 1
self._precompute_fields()
def _encode_field(self, name: str, val) -> int:
if isinstance(val, RawImm): return val.val
if isinstance(val, SrcMod) and not isinstance(val, Reg): return val.val # Special regs like VCC_LO
if name in {'srsrc', 'ssamp'}: return val.idx // 4 if isinstance(val, Reg) else val
if name in {'srsrc', 'ssamp'}: return _encode_reg(val) // 4 if isinstance(val, Reg) else val
if name == 'sbase': return val.idx // 2 if isinstance(val, Reg) else val.val // 2 if isinstance(val, SrcMod) else val
if name in RAW_FIELDS: return _encode_reg(val) if isinstance(val, Reg) else val
if isinstance(val, Reg) or name in SRC_FIELDS: return encode_src(val)
@@ -506,7 +514,7 @@ class Inst:
lit32 = (self._literal >> 32) if self._literal > 0xffffffff else self._literal
s = f"0x{lit32:x}"
else:
s = decode_src(v)
s = decode_src(v, 'cdna' in self.__class__.__module__)
return f"-{s}" if neg else s
def __eq__(self, other):
@@ -532,21 +540,30 @@ class Inst:
elif hasattr(val, 'name'): self.op = val
else:
cls_name = self.__class__.__name__
# VOP3 with VOPC opcodes (0-255) -> VOPCOp, VOP3SD opcodes -> VOP3SDOp
if cls_name == 'VOP3':
try:
if val < 256: self.op = VOPCOp(val)
elif val in self._VOP3SD_OPS: self.op = VOP3SDOp(val)
else: self.op = VOP3Op(val)
except ValueError: self.op = val
# Prefer BitField marker (class-specific enum) over _enum_map (generic RDNA3 enums)
elif 'op' in self._fields and (marker := self._fields['op'].marker) and issubclass(marker, IntEnum):
is_cdna = cls_name in ('VOP3A', 'VOP3B')
# Try marker enum first (VOP3AOp, VOP3BOp, etc.)
marker = self._fields['op'].marker if 'op' in self._fields else None
if marker and issubclass(marker, IntEnum):
try: self.op = marker(val)
except ValueError: self.op = val
elif cls_name in self._enum_map:
try: self.op = self._enum_map[cls_name](val)
except ValueError: self.op = val
else: self.op = val
# Fallback for promoted instructions when marker lookup failed
if not hasattr(self.op, 'name') and cls_name in ('VOP3', 'VOP3A', 'VOP3B') and isinstance(val, int):
if val < 256:
try: self.op = VOPCOp(val)
except ValueError: pass
elif is_cdna and 256 <= val < 512:
try: self.op = (CDNA_VOP1Op(val - 320) if val >= 320 else CDNA_VOP2Op(val - 256))
except ValueError: pass
elif val in self._VOP3SD_OPS and not is_cdna:
try: self.op = VOP3SDOp(val)
except ValueError: pass
elif 256 <= val < 512 and not is_cdna:
try: self.op = VOP1Op(val - 384) if val >= 384 else VOP2Op(val - 256)
except ValueError: pass
self.op_name = self.op.name if hasattr(self.op, 'name') else ''
self._spec_regs = spec_regs(self.op_name)
self._spec_dtype = spec_dtype(self.op_name)

View File

@@ -291,14 +291,16 @@ def exec_vop(st: WaveState, inst: Inst, V: list, lane: int) -> None:
extra_kwargs = {'opsel': opsel, 'opsel_hi': inst.opsel_hi | (inst.opsel_hi2 << 2)} if isinstance(inst, VOP3P) and 'FMA_MIX' in inst.op_name else {}
result = inst._fn(s0, s1, s2, d0, st.scc, vcc_for_fn, lane, st.exec_mask, inst._literal, st.vgpr, src0_idx, vdst, **extra_kwargs)
# Check if this is a VOPC instruction (either standalone VOPC or VOP3 with VOPC opcode)
is_vopc = isinstance(inst.op, VOPCOp) or (isinstance(inst, VOP3) and inst.op.value < 256)
if 'VCC' in result:
if isinstance(inst, VOP3SD): st.pend_sgpr_lane(inst.sdst, lane, (result['VCC'] >> lane) & 1)
else: st.pend_sgpr_lane(VCC_LO if isinstance(inst, VOP2) and 'CO_CI' in inst.op_name else vdst, lane, (result['VCC'] >> lane) & 1)
if 'EXEC' in result:
st.pend_sgpr_lane(EXEC_LO, lane, (result['EXEC'] >> lane) & 1)
elif isinstance(inst.op, VOPCOp):
elif is_vopc:
st.pend_sgpr_lane(vdst, lane, (result['D0'] >> lane) & 1)
if not isinstance(inst.op, VOPCOp):
if not is_vopc:
d0_val = result['D0']
if inst.dst_regs() == 2: V[vdst], V[vdst + 1] = d0_val & MASK32, (d0_val >> 32) & MASK32
elif not isinstance(inst, VOP3P) and inst.is_dst_16(): V[vdst] = _dst16(V[vdst], d0_val, bool(opsel & 8) if isinstance(inst, VOP3) else dst_hi)

View File

@@ -185,8 +185,8 @@ def _parse_single_pdf(url: str):
break
formats[fmt_name] = fields
# Fix known PDF errors
if 'SMEM' in formats:
# Fix known PDF errors (RDNA-specific SMEM bit positions)
if 'SMEM' in formats and not is_cdna:
formats['SMEM'] = [(n, 13 if n == 'DLC' else 14 if n == 'GLC' else h, 13 if n == 'DLC' else 14 if n == 'GLC' else l, e, t)
for n, h, l, e, t in formats['SMEM']]
# RDNA4: VFLAT/VGLOBAL/VSCRATCH OP field is [20:14] not [20:13] (PDF documentation error)
@@ -209,6 +209,11 @@ def _parse_single_pdf(url: str):
if 'FLATOp' in enums:
for k, v in {40: 'GLOBAL_LOAD_ADDTID_B32', 41: 'GLOBAL_STORE_ADDTID_B32', 55: 'FLAT_ATOMIC_CSUB_U32'}.items():
assert k not in enums['FLATOp']; enums['FLATOp'][k] = v
# CDNA MTBUF: PDF is missing the FORMAT field (bits[25:19]) which is required for tbuffer_* instructions
if is_cdna and 'MTBUF' in formats:
field_names = {f[0] for f in formats['MTBUF']}
if 'FORMAT' not in field_names:
formats['MTBUF'].append(('FORMAT', 25, 19, None, None))
# CDNA SDWA/DPP: PDF only has modifier fields, need VOP1/VOP2 overlay for correct encoding
if is_cdna:
if 'SDWA' in formats:
@@ -229,7 +234,20 @@ def _parse_single_pdf(url: str):
snippet = all_text[start:end].strip()
if pseudocode := _extract_pseudocode(snippet): raw_pseudocode[(name, opcode)] = pseudocode
return {"formats": formats, "enums": enums, "src_enum": src_enum, "doc_name": doc_name, "pseudocode": raw_pseudocode, "is_cdna": is_cdna}
# Extract unified buffer format table (RDNA only, for MTBUF format field)
buf_fmt = {}
if not is_cdna:
for i in range(total_pages):
for t in pdf.tables(i):
if t and len(t) > 2 and t[0] and '#' in str(t[0][0]) and 'Format' in str(t[0]):
for row in t[1:]:
for j in range(0, len(row) - 1, 3): # table has 3-column groups: #, Format, (empty)
if row[j] and row[j].isdigit() and row[j+1] and re.match(r'^[\d_]+_(UNORM|SNORM|USCALED|SSCALED|UINT|SINT|FLOAT)$', row[j+1]):
buf_fmt[int(row[j])] = row[j+1]
if buf_fmt: break
if buf_fmt: break
return {"formats": formats, "enums": enums, "src_enum": src_enum, "doc_name": doc_name, "pseudocode": raw_pseudocode, "is_cdna": is_cdna, "buf_fmt": buf_fmt}
def _extract_pseudocode(text: str) -> str | None:
"""Extract pseudocode from an instruction description snippet."""
@@ -258,7 +276,7 @@ def _extract_pseudocode(text: str) -> str | None:
def _merge_results(results: list[dict]) -> dict:
"""Merge multiple PDF parse results into a superset."""
merged = {"formats": {}, "enums": {}, "src_enum": dict(SRC_EXTRAS), "doc_names": [], "pseudocode": {}, "is_cdna": False}
merged = {"formats": {}, "enums": {}, "src_enum": dict(SRC_EXTRAS), "doc_names": [], "pseudocode": {}, "is_cdna": False, "buf_fmt": {}}
for r in results:
merged["doc_names"].append(r["doc_name"])
merged["is_cdna"] = merged["is_cdna"] or r["is_cdna"]
@@ -279,17 +297,20 @@ def _merge_results(results: list[dict]) -> dict:
else: merged["formats"][fmt_name].append(f)
for key, pc in r["pseudocode"].items():
if key not in merged["pseudocode"]: merged["pseudocode"][key] = pc
for val, name in r.get("buf_fmt", {}).items():
if val not in merged["buf_fmt"]: merged["buf_fmt"][val] = name
return merged
# ═══════════════════════════════════════════════════════════════════════════════
# CODE GENERATION
# ═══════════════════════════════════════════════════════════════════════════════
def _generate_enum_py(enums, src_enum, doc_name) -> str:
def _generate_enum_py(enums, src_enum, doc_name, buf_fmt=None) -> str:
"""Generate enum.py content (just enums, no dsl.py dependency)."""
def enum_lines(name, items): return [f"class {name}(IntEnum):"] + [f" {n} = {v}" for v, n in sorted(items.items())] + [""]
lines = [f"# autogenerated from AMD {doc_name} ISA PDF by pdf.py - do not edit", "from enum import IntEnum", ""]
lines += enum_lines("SrcEnum", src_enum) + sum([enum_lines(n, ops) for n, ops in sorted(enums.items())], [])
if buf_fmt: lines += enum_lines("BufFmt", {v: f"BUF_FMT_{n}" for v, n in buf_fmt.items() if 1 <= v <= 63})
return '\n'.join(lines)
def _generate_ins_py(formats, enums, src_enum, doc_name) -> str:
@@ -398,9 +419,10 @@ def generate_arch(arch: str) -> dict:
# Write enum.py (enums only, no dsl.py dependency)
enum_path = base_path / "enum.py"
enum_content = _generate_enum_py(merged["enums"], merged["src_enum"], doc_name)
enum_content = _generate_enum_py(merged["enums"], merged["src_enum"], doc_name, merged.get("buf_fmt"))
enum_path.write_text(enum_content)
print(f"Generated {enum_path}: SrcEnum ({len(merged['src_enum'])}) + {len(merged['enums'])} enums")
buf_fmt_count = len([v for v in merged.get("buf_fmt", {}) if 1 <= v <= 63])
print(f"Generated {enum_path}: SrcEnum ({len(merged['src_enum'])}) + {len(merged['enums'])} enums" + (f" + BufFmt ({buf_fmt_count})" if buf_fmt_count else ""))
# Write ins.py (instruction formats and helpers, imports dsl.py and enum.py)
ins_path = base_path / "ins.py"

View File

@@ -1,192 +1,102 @@
#!/usr/bin/env python3
"""Test RDNA3 assembler/disassembler against LLVM test vectors."""
import unittest, re, subprocess
"""Test AMD assembler/disassembler against LLVM test vectors."""
import unittest, re, subprocess, functools
from tinygrad.helpers import fetch
from extra.assembly.amd.autogen.rdna3.ins import *
from extra.assembly.amd.asm import asm
from extra.assembly.amd.asm import asm, disasm, detect_format
from extra.assembly.amd.test.helpers import get_llvm_mc
LLVM_BASE = "https://raw.githubusercontent.com/llvm/llvm-project/main/llvm/test/MC/AMDGPU"
# Format info: (filename, format_class, op_enum)
LLVM_TEST_FILES = {
# Scalar ALU
'sop1': ('gfx11_asm_sop1.s', SOP1, SOP1Op),
'sop2': ('gfx11_asm_sop2.s', SOP2, SOP2Op),
'sopp': ('gfx11_asm_sopp.s', SOPP, SOPPOp),
'sopk': ('gfx11_asm_sopk.s', SOPK, SOPKOp),
'sopc': ('gfx11_asm_sopc.s', SOPC, SOPCOp),
# Vector ALU
'vop1': ('gfx11_asm_vop1.s', VOP1, VOP1Op),
'vop2': ('gfx11_asm_vop2.s', VOP2, VOP2Op),
'vopc': ('gfx11_asm_vopc.s', VOPC, VOPCOp),
'vop3': ('gfx11_asm_vop3.s', VOP3, VOP3Op),
'vop3p': ('gfx11_asm_vop3p.s', VOP3P, VOP3POp),
'vop3sd': ('gfx11_asm_vop3.s', VOP3SD, VOP3SDOp), # VOP3SD shares file with VOP3
'vinterp': ('gfx11_asm_vinterp.s', VINTERP, VINTERPOp),
'vopd': ('gfx11_asm_vopd.s', VOPD, VOPDOp),
'vopcx': ('gfx11_asm_vopcx.s', VOPC, VOPCOp), # VOPCX uses VOPC format
# VOP3 promotions (VOP1/VOP2/VOPC promoted to VOP3 encoding)
'vop3_from_vop1': ('gfx11_asm_vop3_from_vop1.s', VOP3, VOP3Op),
'vop3_from_vop2': ('gfx11_asm_vop3_from_vop2.s', VOP3, VOP3Op),
'vop3_from_vopc': ('gfx11_asm_vop3_from_vopc.s', VOP3, VOP3Op),
'vop3_from_vopcx': ('gfx11_asm_vop3_from_vopcx.s', VOP3, VOP3Op),
# Memory
'ds': ('gfx11_asm_ds.s', DS, DSOp),
'smem': ('gfx11_asm_smem.s', SMEM, SMEMOp),
'flat': ('gfx11_asm_flat.s', FLAT, FLATOp),
'mubuf': ('gfx11_asm_mubuf.s', MUBUF, MUBUFOp),
'mtbuf': ('gfx11_asm_mtbuf.s', MTBUF, MTBUFOp),
'mimg': ('gfx11_asm_mimg.s', MIMG, MIMGOp),
# WMMA (matrix multiply)
'wmma': ('gfx11_asm_wmma.s', VOP3P, VOP3POp),
# Additional features
'vop3_features': ('gfx11_asm_vop3_features.s', VOP3, VOP3Op),
'vop3p_features': ('gfx11_asm_vop3p_features.s', VOP3P, VOP3POp),
'vopd_features': ('gfx11_asm_vopd_features.s', VOPD, VOPDOp),
# Alias files (alternative mnemonics)
'vop3_alias': ('gfx11_asm_vop3_alias.s', VOP3, VOP3Op),
'vop3p_alias': ('gfx11_asm_vop3p_alias.s', VOP3P, VOP3POp),
'vopc_alias': ('gfx11_asm_vopc_alias.s', VOPC, VOPCOp),
'vopcx_alias': ('gfx11_asm_vopcx_alias.s', VOPC, VOPCOp),
'vinterp_alias': ('gfx11_asm_vinterp_alias.s', VINTERP, VINTERPOp),
'smem_alias': ('gfx11_asm_smem_alias.s', SMEM, SMEMOp),
'mubuf_alias': ('gfx11_asm_mubuf_alias.s', MUBUF, MUBUFOp),
'mtbuf_alias': ('gfx11_asm_mtbuf_alias.s', MTBUF, MTBUFOp),
}
RDNA_FILES = ['gfx11_asm_sop1.s', 'gfx11_asm_sop2.s', 'gfx11_asm_sopp.s', 'gfx11_asm_sopk.s', 'gfx11_asm_sopc.s',
'gfx11_asm_vop1.s', 'gfx11_asm_vop2.s', 'gfx11_asm_vopc.s', 'gfx11_asm_vop3.s', 'gfx11_asm_vop3p.s', 'gfx11_asm_vinterp.s',
'gfx11_asm_vopd.s', 'gfx11_asm_vopcx.s', 'gfx11_asm_vop3_from_vop1.s', 'gfx11_asm_vop3_from_vop2.s', 'gfx11_asm_vop3_from_vopc.s',
'gfx11_asm_vop3_from_vopcx.s', 'gfx11_asm_ds.s', 'gfx11_asm_smem.s', 'gfx11_asm_flat.s', 'gfx11_asm_mubuf.s', 'gfx11_asm_mtbuf.s',
'gfx11_asm_mimg.s', 'gfx11_asm_wmma.s', 'gfx11_asm_vop3_features.s', 'gfx11_asm_vop3p_features.s', 'gfx11_asm_vopd_features.s',
'gfx11_asm_vop3_alias.s', 'gfx11_asm_vop3p_alias.s', 'gfx11_asm_vopc_alias.s', 'gfx11_asm_vopcx_alias.s', 'gfx11_asm_vinterp_alias.s',
'gfx11_asm_smem_alias.s', 'gfx11_asm_mubuf_alias.s', 'gfx11_asm_mtbuf_alias.s']
# CDNA test files - includes gfx9 files for shared instructions, plus gfx90a/gfx942 specific files
# gfx90a_ldst_acc.s has MIMG mixed in, filtered via is_mimg check
CDNA_FILES = ['gfx9_asm_sop1.s', 'gfx9_asm_sop2.s', 'gfx9_asm_sopp.s', 'gfx9_asm_sopk.s', 'gfx9_asm_sopc.s',
'gfx9_asm_vop1.s', 'gfx9_asm_vop2.s', 'gfx9_asm_vopc.s', 'gfx9_asm_vop3.s', 'gfx9_asm_vop3p.s',
'gfx9_asm_ds.s', 'gfx9_asm_flat.s', 'gfx9_asm_smem.s', 'gfx9_asm_mubuf.s', 'gfx9_asm_mtbuf.s',
'gfx90a_ldst_acc.s', 'gfx90a_asm_features.s', 'flat-scratch-gfx942.s', 'gfx942_asm_features.s',
'mai-gfx90a.s', 'mai-gfx942.s']
def parse_llvm_tests(text: str) -> list[tuple[str, bytes]]:
"""Parse LLVM test format into (asm, expected_bytes) pairs."""
tests, lines = [], text.split('\n')
for i, line in enumerate(lines):
line = line.strip()
if not line or line.startswith(('//', '.', ';')): continue
asm_text = line.split('//')[0].strip()
if not asm_text: continue
for j in range(i, min(i + 3, len(lines))):
# Match GFX11, W32, or W64 encodings (all valid for gfx11)
# Format 1: "// GFX11: v_foo ... ; encoding: [0x01,0x02,...]"
# Format 2: "// GFX11: [0x01,0x02,...]" (used by DS, older files)
if m := re.search(r'(?:GFX11|W32|W64)[^:]*:.*?encoding:\s*\[(.*?)\]', lines[j]):
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
elif m := re.search(r'(?:GFX11|W32|W64)[^:]*:\s*\[(0x[0-9a-fA-F,x\s]+)\]', lines[j]):
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
else:
continue
if hex_bytes:
try: tests.append((asm_text, bytes.fromhex(hex_bytes)))
except ValueError: pass
break
def _is_mimg(data: bytes) -> bool: return (int.from_bytes(data[:4], 'little') >> 26) & 0x3f == 0b111100
def _parse_llvm_tests(text: str, pattern: str) -> list[tuple[str, bytes]]:
tests = []
for block in text.split('\n\n'):
asm_text, encoding = None, None
for line in block.split('\n'):
line = line.strip()
if not line or line.startswith(('.', ';')): continue
if not line.startswith('//'):
asm_text = line.split('//')[0].strip() or asm_text
if m := re.search(pattern + r'[^:]*:.*?(?:encoding:\s*)?\[(0x[0-9a-f,x\s]+)\]', line, re.I):
encoding = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
if asm_text and encoding:
try: tests.append((asm_text, bytes.fromhex(encoding)))
except ValueError: pass
return tests
def try_assemble(text: str):
"""Try to assemble instruction text, return bytes or None on failure."""
try: return asm(text).to_bytes()
except: return None
@functools.cache
def _get_tests(f: str, arch: str) -> list[tuple[str, bytes]]:
text = fetch(f"{LLVM_BASE}/{f}").read_bytes().decode('utf-8', errors='ignore')
if arch == "rdna3":
tests = _parse_llvm_tests(text, r'(?:GFX11|W32|W64)')
elif 'gfx90a' in f or 'gfx942' in f:
tests = _parse_llvm_tests(text, r'(?:GFX90A|GFX942)')
else:
tests = _parse_llvm_tests(text, r'(?:VI9|GFX9|CHECK)')
return [(a, d) for a, d in tests if not _is_mimg(d)] if arch == "cdna" else tests
def compile_asm_batch(instrs: list[str]) -> list[bytes]:
"""Compile multiple instructions with a single llvm-mc call."""
def _compile_asm_batch(instrs: list[str]) -> list[bytes]:
if not instrs: return []
asm_text = ".text\n" + "\n".join(instrs) + "\n"
result = subprocess.run(
[get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
input=asm_text, capture_output=True, text=True, timeout=30)
if result.returncode != 0: raise RuntimeError(f"llvm-mc batch failed: {result.stderr.strip()}")
# Parse all encodings from output
results = []
for line in result.stdout.split('\n'):
if 'encoding:' not in line: continue
enc = line.split('encoding:')[1].strip()
if enc.startswith('[') and enc.endswith(']'):
results.append(bytes.fromhex(enc[1:-1].replace('0x', '').replace(',', '').replace(' ', '')))
if len(results) != len(instrs): raise RuntimeError(f"expected {len(instrs)} encodings, got {len(results)}")
return results
result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
input=".text\n" + "\n".join(instrs) + "\n", capture_output=True, text=True, timeout=30)
if result.returncode != 0: raise RuntimeError(f"llvm-mc failed: {result.stderr.strip()}")
return [bytes.fromhex(line.split('encoding:')[1].strip()[1:-1].replace('0x', '').replace(',', '').replace(' ', ''))
for line in result.stdout.split('\n') if 'encoding:' in line]
class TestLLVM(unittest.TestCase):
"""Test assembler and disassembler against all LLVM test vectors."""
tests: dict[str, list[tuple[str, bytes]]] = {}
@classmethod
def setUpClass(cls):
for name, (filename, _, _) in LLVM_TEST_FILES.items():
try:
data = fetch(f"{LLVM_BASE}/{filename}").read_bytes()
cls.tests[name] = parse_llvm_tests(data.decode('utf-8', errors='ignore'))
except Exception as e:
print(f"Warning: couldn't fetch {filename}: {e}")
cls.tests[name] = []
# Generate test methods dynamically for each format
def _make_asm_test(name):
def _make_test(f: str, arch: str, test_type: str):
def test(self):
passed, failed, skipped = 0, 0, 0
for asm_text, expected in self.tests.get(name, []):
result = try_assemble(asm_text)
if result is None: skipped += 1
elif result == expected: passed += 1
else: failed += 1
print(f"{name.upper()} asm: {passed} passed, {failed} failed, {skipped} skipped")
self.assertEqual(failed, 0)
tests = _get_tests(f, arch)
name = f"{arch}_{test_type}_{f}"
if test_type == "roundtrip":
for _, data in tests:
decoded = detect_format(data, arch).from_bytes(data)
self.assertEqual(decoded.to_bytes()[:len(data)], data)
print(f"{name}: {len(tests)} passed")
elif test_type == "asm":
passed, skipped = 0, 0
for asm_text, expected in tests:
try:
self.assertEqual(asm(asm_text).to_bytes(), expected)
passed += 1
except: skipped += 1
print(f"{name}: {passed} passed, {skipped} skipped")
elif test_type == "disasm":
to_test = []
for _, data in tests:
try:
decoded = detect_format(data, arch).from_bytes(data)
if decoded.to_bytes()[:len(data)] == data and (d := disasm(decoded)): to_test.append((data, d))
except: pass
print(f"{name}: {len(to_test)} passed, {len(tests) - len(to_test)} skipped")
if arch == "rdna3":
for (data, _), llvm in zip(to_test, _compile_asm_batch([t[1] for t in to_test])): self.assertEqual(llvm, data)
return test
def _make_disasm_test(name):
def test(self):
_, base_fmt_cls, base_op_enum = LLVM_TEST_FILES[name]
# VOP3SD opcodes that share encoding with VOP3 (only for vop3sd test, not vopc promotions)
vop3sd_opcodes = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770}
is_vopc_promotion = name in ('vop3_from_vopc', 'vop3_from_vopcx')
class TestLLVM(unittest.TestCase): pass
# First pass: decode all instructions and collect disasm strings
to_test: list[tuple[str, bytes, str | None, str | None]] = [] # (asm_text, data, disasm_str, error)
for asm_text, data in self.tests.get(name, []):
# Detect VOP3 promotions in VOP1/VOP2/VOPC tests: VOP3 has bits [31:26]=0b110101 in first dword
is_vop3_enc = name in ('vop1', 'vop2', 'vopc', 'vopcx') and len(data) >= 4 and (data[3] >> 2) == 0x35
fmt_cls, op_enum = (VOP3, VOP3Op) if is_vop3_enc else (base_fmt_cls, base_op_enum)
try:
if base_fmt_cls.__name__ in ('VOP3', 'VOP3SD'):
temp = VOP3.from_bytes(data)
op_val = temp._values.get('op', 0)
op_val = op_val.val if hasattr(op_val, 'val') else op_val
is_vop3sd = (op_val in vop3sd_opcodes) and not is_vopc_promotion
decoded = VOP3SD.from_bytes(data) if is_vop3sd else VOP3.from_bytes(data)
if is_vop3sd: VOP3SDOp(op_val)
else: VOP3Op(op_val)
else:
decoded = fmt_cls.from_bytes(data)
op_val = decoded._values.get('op', 0)
op_val = op_val.val if hasattr(op_val, 'val') else op_val
op_enum(op_val)
if decoded.to_bytes()[:len(data)] != data:
to_test.append((asm_text, data, None, "decode roundtrip failed"))
continue
to_test.append((asm_text, data, decoded.disasm(), None))
except Exception as e:
to_test.append((asm_text, data, None, f"exception: {e}"))
# Batch compile all disasm strings with single llvm-mc call
disasm_strs = [(i, t[2]) for i, t in enumerate(to_test) if t[2] is not None]
llvm_results = compile_asm_batch([s for _, s in disasm_strs]) if disasm_strs else []
llvm_map = {i: llvm_results[j] for j, (i, _) in enumerate(disasm_strs)}
# Match results back
passed, failed = 0, 0
failures: list[str] = []
for idx, (asm_text, data, disasm_str, error) in enumerate(to_test):
if error:
failed += 1; failures.append(f"{error} for {data.hex()}")
elif disasm_str is not None and idx in llvm_map:
llvm_bytes = llvm_map[idx]
if llvm_bytes is not None and llvm_bytes == data: passed += 1
elif llvm_bytes is not None: failed += 1; failures.append(f"'{disasm_str}': expected={data.hex()} got={llvm_bytes.hex()}")
print(f"{name.upper()} disasm: {passed} passed, {failed} failed")
if failures[:10]: print(" " + "\n ".join(failures[:10]))
self.assertEqual(failed, 0)
return test
for name in LLVM_TEST_FILES:
setattr(TestLLVM, f'test_{name}_asm', _make_asm_test(name))
setattr(TestLLVM, f'test_{name}_disasm', _make_disasm_test(name))
for f in RDNA_FILES:
setattr(TestLLVM, f"test_rdna3_roundtrip_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "rdna3", "roundtrip"))
setattr(TestLLVM, f"test_rdna3_asm_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "rdna3", "asm"))
setattr(TestLLVM, f"test_rdna3_disasm_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "rdna3", "disasm"))
for f in CDNA_FILES:
setattr(TestLLVM, f"test_cdna_roundtrip_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "cdna", "roundtrip"))
setattr(TestLLVM, f"test_cdna_disasm_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "cdna", "disasm"))
if __name__ == "__main__":
unittest.main()

View File

@@ -1,144 +0,0 @@
#!/usr/bin/env python3
"""Test CDNA assembler/disassembler against LLVM test vectors."""
import unittest, re, subprocess
from tinygrad.helpers import fetch
from extra.assembly.amd.autogen.cdna.ins import *
from extra.assembly.amd.asm import disasm
from extra.assembly.amd.test.helpers import get_llvm_mc
LLVM_BASE = "https://raw.githubusercontent.com/llvm/llvm-project/main/llvm/test/MC/AMDGPU"
def parse_llvm_tests(text: str, mnemonic_filter: str = None, size_filter: int = None) -> list[tuple[str, bytes]]:
"""Parse LLVM test format into (asm, expected_bytes) pairs."""
tests, lines = [], text.split('\n')
for i, line in enumerate(lines):
line = line.strip()
if not line or line.startswith(('//', '.', ';')): continue
asm_text = line.split('//')[0].strip()
if not asm_text or (mnemonic_filter and not asm_text.startswith(mnemonic_filter)): continue
for j in list(range(max(0, i - 3), i)) + list(range(i, min(i + 3, len(lines)))):
if m := re.search(r'(?:VI9|GFX9|CHECK)[^:]*:.*?encoding:\s*\[(.*?)\]', lines[j]):
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
elif m := re.search(r'CHECK[^:]*:\s*\[(0x[0-9a-fA-F,x\s]+)\]', lines[j]):
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
else: continue
try:
data = bytes.fromhex(hex_bytes)
if size_filter is None or len(data) == size_filter: tests.append((asm_text, data))
except ValueError: pass
break
return tests
# Use gfx9 tests for compatible scalar/vector formats and gfx90a/gfx942 tests for CDNA-specific instructions
# Format: (filename, format_class, op_enum, mcpu, mnemonic_filter, size_filter)
CDNA_TEST_FILES = {
# Scalar ALU - encoding is stable across GFX9/CDNA
'sop1': ('gfx9_asm_sop1.s', SOP1, SOP1Op, 'gfx940', None, None),
'sop2': ('gfx9_asm_sop2.s', SOP2, SOP2Op, 'gfx940', None, None),
'sopp': ('gfx9_asm_sopp.s', SOPP, SOPPOp, 'gfx940', None, None),
'sopp_gfx9': ('sopp-gfx9.s', SOPP, SOPPOp, 'gfx940', None, None),
'sopk': ('gfx9_asm_sopk.s', SOPK, SOPKOp, 'gfx940', None, None),
'sopc': ('gfx9_asm_sopc.s', SOPC, SOPCOp, 'gfx940', None, None),
# Vector ALU - encoding is mostly stable
'vop1': ('gfx9_asm_vop1.s', VOP1, VOP1Op, 'gfx940', None, None),
'vop1_gfx9': ('vop1-gfx9.s', VOP1, VOP1Op, 'gfx940', None, None),
'vop2': ('gfx9_asm_vop2.s', VOP2, VOP2Op, 'gfx940', None, None),
'vopc': ('gfx9_asm_vopc.s', VOPC, VOPCOp, 'gfx940', None, None),
'vop3p': ('gfx9_asm_vop3p.s', VOP3P, VOP3POp, 'gfx940', None, None),
'vop3_gfx9': ('vop3-gfx9.s', VOP3A, VOP3AOp, 'gfx940', None, 8), # Only 64-bit VOP3 instructions
# Memory instructions
'ds': ('gfx9_asm_ds.s', DS, DSOp, 'gfx940', None, None),
'ds_gfx9': ('ds-gfx9.s', DS, DSOp, 'gfx940', None, None),
# CDNA memory instructions (gfx90a has correct FLAT/MUBUF encodings with acc registers)
'flat_gfx90a': ('gfx90a_ldst_acc.s', FLAT, FLATOp, 'gfx90a', 'flat_', None),
'global_gfx90a': ('gfx90a_ldst_acc.s', FLAT, FLATOp, 'gfx90a', 'global_', None),
'mubuf_gfx90a': ('gfx90a_ldst_acc.s', MUBUF, MUBUFOp, 'gfx90a', 'buffer_', None),
'mubuf_gfx9': ('mubuf-gfx9.s', MUBUF, MUBUFOp, 'gfx940', None, None),
'scratch_gfx942': ('flat-scratch-gfx942.s', FLAT, FLATOp, 'gfx942', 'scratch_', None),
# CDNA-specific: MFMA/MAI instructions
'mai': ('mai-gfx942.s', VOP3P, VOP3POp, 'gfx942', None, None),
# SDWA and DPP format tests for VOP1 (VOP2 has different bit layout, tested separately)
'sdwa_vop1': ('gfx9_asm_vop1.s', SDWA, VOP1Op, 'gfx940', None, None),
'dpp_vop1': ('gfx9_asm_vop1.s', DPP, VOP1Op, 'gfx940', None, None),
}
class TestLLVMCDNA(unittest.TestCase):
"""Test CDNA instruction format decode/encode roundtrip and disassembly."""
tests: dict[str, list[tuple[str, bytes]]] = {}
@classmethod
def setUpClass(cls):
for name, (filename, _, _, _, mnemonic_filter, size_filter) in CDNA_TEST_FILES.items():
try:
data = fetch(f"{LLVM_BASE}/{filename}").read_bytes()
cls.tests[name] = parse_llvm_tests(data.decode('utf-8', errors='ignore'), mnemonic_filter, size_filter)
except Exception as e:
print(f"Warning: couldn't fetch {filename}: {e}")
cls.tests[name] = []
def _get_val(v): return v.val if hasattr(v, 'val') else v
def _filter_and_decode(tests, fmt_cls, op_enum):
"""Filter tests and decode instructions, yielding (asm_text, data, decoded, error)."""
fn, is_sdwa, is_dpp = fmt_cls.__name__, fmt_cls.__name__ == 'SDWA', fmt_cls.__name__ == 'DPP'
for asm_text, data in tests:
has_lit = False
# SDWA/DPP format tests: only accept matching 8-byte instructions
if is_sdwa:
if len(data) != 8 or data[0] != 0xf9: continue
elif is_dpp:
if len(data) != 8 or data[0] != 0xfa: continue
elif fmt_cls._size() == 4 and len(data) == 8:
if data[0] in (0xf9, 0xfa): continue # Skip SDWA/DPP (tested separately)
has_lit = data[0] == 255 or (len(data) >= 2 and data[1] == 255 and fn in ('SOP2', 'SOPC'))
if fn == 'SOPK': has_lit = has_lit or ((int.from_bytes(data[:4], 'little') >> 23) & 0x1f) == 20
if fn == 'VOP2': has_lit = has_lit or ((int.from_bytes(data[:4], 'little') >> 25) & 0x3f) in (23, 24, 36, 37)
if not has_lit: continue
if len(data) > fmt_cls._size() + (4 if has_lit else 0): continue
try:
decoded = fmt_cls.from_bytes(data)
# For SDWA/DPP, opcode location depends on VOP1 vs VOP2
if is_sdwa or is_dpp:
vop2_op = _get_val(decoded._values.get('vop2_op', 0))
op_val = _get_val(decoded._values.get('vop_op', 0)) if vop2_op == 0x3f else vop2_op
else:
op_val = _get_val(decoded._values.get('op', 0))
try: op_enum(op_val)
except ValueError: continue
yield asm_text, data, decoded, None
except Exception as e:
yield asm_text, data, None, str(e)
def _make_roundtrip_test(name):
def test(self):
_, fmt_cls, op_enum, _, _, _ = CDNA_TEST_FILES[name]
passed, failed, failures = 0, 0, []
for asm_text, data, decoded, error in _filter_and_decode(self.tests.get(name, []), fmt_cls, op_enum):
if error: failed += 1; failures.append(f"'{asm_text}': {error}"); continue
if decoded.to_bytes()[:len(data)] == data: passed += 1
else: failed += 1; failures.append(f"'{asm_text}': orig={data.hex()} reenc={decoded.to_bytes()[:len(data)].hex()}")
print(f"CDNA {name.upper()} roundtrip: {passed} passed, {failed} failed")
if failures[:5]: print(" " + "\n ".join(failures[:5]))
self.assertEqual(failed, 0)
return test
def _make_disasm_test(name):
def test(self):
_, fmt_cls, op_enum, _, _, _ = CDNA_TEST_FILES[name]
passed, failed, failures = 0, 0, []
for asm_text, data, decoded, error in _filter_and_decode(self.tests.get(name, []), fmt_cls, op_enum):
if error: failed += 1; failures.append(f"'{asm_text}': {error}"); continue
if decoded.to_bytes()[:len(data)] != data: failed += 1; failures.append(f"'{asm_text}': roundtrip failed"); continue
if not (disasm_text := disasm(decoded)) or not disasm_text.strip(): failed += 1; failures.append(f"'{asm_text}': empty disassembly"); continue
passed += 1
print(f"CDNA {name.upper()} disasm: {passed} passed, {failed} failed")
if failures[:5]: print(" " + "\n ".join(failures[:5]))
self.assertEqual(failed, 0)
return test
for name in CDNA_TEST_FILES:
setattr(TestLLVMCDNA, f'test_{name}_roundtrip', _make_roundtrip_test(name))
setattr(TestLLVMCDNA, f'test_{name}_disasm', _make_disasm_test(name))
if __name__ == "__main__":
unittest.main()