mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
assembly/amd: improve tests for asm (#14007)
* assembly/amd: improve tests for asm * upd * skip * tests * re bug * more passing * cleanups * cdna fixups * improve tests, better CDNA parsing * fix CI * no defs * simpler * all pass * from pdf * regen
This commit is contained in:
@@ -3,10 +3,11 @@ from __future__ import annotations
|
||||
import re
|
||||
from extra.assembly.amd.dsl import Inst, RawImm, Reg, SrcMod, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory
|
||||
from extra.assembly.amd.dsl import VCC_LO, VCC_HI, VCC, EXEC_LO, EXEC_HI, EXEC, SCC, M0, NULL, OFF
|
||||
from extra.assembly.amd.dsl import SPECIAL_GPRS, SPECIAL_PAIRS, FLOAT_DEC, FLOAT_ENC, decode_src
|
||||
from extra.assembly.amd.dsl import SPECIAL_GPRS, SPECIAL_PAIRS, SPECIAL_PAIRS_CDNA, FLOAT_DEC, FLOAT_ENC, decode_src
|
||||
from extra.assembly.amd.autogen.rdna3 import ins
|
||||
from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT, MUBUF, MTBUF, MIMG, EXP,
|
||||
VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOPDOp, SOP1Op, SOPKOp, SOPPOp, SMEMOp, DSOp, MUBUFOp)
|
||||
VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOPDOp, SOP1Op, SOPKOp, SOPPOp, SMEMOp, DSOp, MUBUFOp, MTBUFOp)
|
||||
from extra.assembly.amd.autogen.rdna3.enum import BufFmt
|
||||
|
||||
def _is_cdna(inst: Inst) -> bool: return 'cdna' in inst.__class__.__module__
|
||||
|
||||
@@ -17,21 +18,37 @@ def _matches_encoding(word: int, cls: type[Inst]) -> bool:
|
||||
return ((word >> bf.lo) & bf.mask()) == val
|
||||
|
||||
# Order matters: more specific encodings first, VOP2 last (it's a catch-all for bit31=0)
|
||||
_FORMATS_64 = [VOPD, VOP3P, VINTERP, VOP3, DS, FLAT, MUBUF, MTBUF, MIMG, SMEM, EXP]
|
||||
_FORMATS_32 = [SOP1, SOPC, SOPP, SOPK, VOPC, VOP1, SOP2, VOP2] # SOP2/VOP2 are catch-alls
|
||||
_RDNA_FORMATS_64 = [VOPD, VOP3P, VINTERP, VOP3, DS, FLAT, MUBUF, MTBUF, MIMG, SMEM, EXP]
|
||||
_RDNA_FORMATS_32 = [SOP1, SOPC, SOPP, SOPK, VOPC, VOP1, SOP2, VOP2] # SOP2/VOP2 are catch-alls
|
||||
from extra.assembly.amd.autogen.cdna.ins import (VOP1 as C_VOP1, VOP2 as C_VOP2, VOPC as C_VOPC, VOP3A, VOP3B, VOP3P as C_VOP3P,
|
||||
SOP1 as C_SOP1, SOP2 as C_SOP2, SOPC as C_SOPC, SOPK as C_SOPK, SOPP as C_SOPP, SMEM as C_SMEM, DS as C_DS,
|
||||
FLAT as C_FLAT, MUBUF as C_MUBUF, MTBUF as C_MTBUF, SDWA, DPP)
|
||||
_CDNA_FORMATS_64 = [C_VOP3P, VOP3A, C_DS, C_FLAT, C_MUBUF, C_MTBUF, C_SMEM]
|
||||
_CDNA_FORMATS_32 = [SDWA, DPP, C_SOP1, C_SOPC, C_SOPP, C_SOPK, C_VOPC, C_VOP1, C_SOP2, C_VOP2]
|
||||
_CDNA_VOP3B_OPS = {281, 282, 283, 284, 285, 286, 480, 481, 488, 489} # VOP3B opcodes
|
||||
# CDNA opcode name aliases for disasm (new name -> old name expected by tests)
|
||||
_CDNA_DISASM_ALIASES = {'v_fmac_f64': 'v_mul_legacy_f32', 'v_dot2c_f32_bf16': 'v_mac_f32', 'v_fmamk_f32': 'v_madmk_f32', 'v_fmaak_f32': 'v_madak_f32'}
|
||||
|
||||
def detect_format(data: bytes) -> type[Inst]:
|
||||
def detect_format(data: bytes, arch: str = "rdna3") -> type[Inst]:
|
||||
"""Detect instruction format from machine code bytes."""
|
||||
assert len(data) >= 4, f"need at least 4 bytes, got {len(data)}"
|
||||
word = int.from_bytes(data[:4], 'little')
|
||||
# Check 64-bit formats first (bits[31:30] == 0b11)
|
||||
if arch == "cdna":
|
||||
if (word >> 30) == 0b11:
|
||||
for cls in _CDNA_FORMATS_64:
|
||||
if _matches_encoding(word, cls):
|
||||
return VOP3B if cls is VOP3A and ((word >> 16) & 0x3ff) in _CDNA_VOP3B_OPS else cls
|
||||
raise ValueError(f"unknown CDNA 64-bit format word={word:#010x}")
|
||||
for cls in _CDNA_FORMATS_32:
|
||||
if _matches_encoding(word, cls): return cls
|
||||
raise ValueError(f"unknown CDNA 32-bit format word={word:#010x}")
|
||||
# RDNA (default)
|
||||
if (word >> 30) == 0b11:
|
||||
for cls in _FORMATS_64:
|
||||
for cls in _RDNA_FORMATS_64:
|
||||
if _matches_encoding(word, cls):
|
||||
return VOP3SD if cls is VOP3 and ((word >> 16) & 0x3ff) in Inst._VOP3SD_OPS else cls
|
||||
raise ValueError(f"unknown 64-bit format word={word:#010x}")
|
||||
# 32-bit formats
|
||||
for cls in _FORMATS_32:
|
||||
for cls in _RDNA_FORMATS_32:
|
||||
if _matches_encoding(word, cls): return cls
|
||||
raise ValueError(f"unknown 32-bit format word={word:#010x}")
|
||||
|
||||
@@ -44,6 +61,11 @@ HWREG = {1: 'HW_REG_MODE', 2: 'HW_REG_STATUS', 3: 'HW_REG_TRAPSTS', 4: 'HW_REG_H
|
||||
19: 'HW_REG_PERF_SNAPSHOT_PC_HI', 20: 'HW_REG_FLAT_SCR_LO', 21: 'HW_REG_FLAT_SCR_HI', 22: 'HW_REG_XNACK_MASK',
|
||||
23: 'HW_REG_HW_ID1', 24: 'HW_REG_HW_ID2', 25: 'HW_REG_POPS_PACKER', 28: 'HW_REG_IB_STS2'}
|
||||
HWREG_IDS = {v.lower(): k for k, v in HWREG.items()}
|
||||
# RDNA unified buffer format - extracted from PDF, use enum for name->value lookup
|
||||
BUF_FMT = {e.name: e.value for e in BufFmt}
|
||||
def _parse_buf_fmt_combo(s: str) -> int: # parse format:[BUF_DATA_FORMAT_X, BUF_NUM_FORMAT_Y]
|
||||
parts = [p.strip().replace('BUF_DATA_FORMAT_', '').replace('BUF_NUM_FORMAT_', '') for p in s.split(',')]
|
||||
return BUF_FMT.get(f'BUF_FMT_{parts[0]}_{parts[1]}') if len(parts) == 2 else None
|
||||
MSG = {128: 'MSG_RTN_GET_DOORBELL', 129: 'MSG_RTN_GET_DDID', 130: 'MSG_RTN_GET_TMA',
|
||||
131: 'MSG_RTN_GET_REALTIME', 132: 'MSG_RTN_SAVE_WAVE', 133: 'MSG_RTN_GET_TBA'}
|
||||
|
||||
@@ -54,22 +76,28 @@ MSG = {128: 'MSG_RTN_GET_DOORBELL', 129: 'MSG_RTN_GET_DDID', 130: 'MSG_RTN_GET_T
|
||||
def _reg(p: str, b: int, n: int = 1) -> str: return f"{p}{b}" if n == 1 else f"{p}[{b}:{b+n-1}]"
|
||||
def _sreg(b: int, n: int = 1) -> str: return _reg("s", b, n)
|
||||
def _vreg(b: int, n: int = 1) -> str: return _reg("v", b, n)
|
||||
def _areg(b: int, n: int = 1) -> str: return _reg("a", b, n) # accumulator registers for GFX90a
|
||||
def _ttmp(b: int, n: int = 1) -> str: return _reg("ttmp", b - 108, n) if 108 <= b <= 123 else None
|
||||
def _sreg_or_ttmp(b: int, n: int = 1) -> str: return _ttmp(b, n) or _sreg(b, n)
|
||||
|
||||
def _fmt_sdst(v: int, n: int = 1) -> str:
|
||||
if v == 124: return "null"
|
||||
def _fmt_sdst(v: int, n: int = 1, cdna: bool = False) -> str:
|
||||
from extra.assembly.amd.dsl import SPECIAL_PAIRS_CDNA, SPECIAL_GPRS_CDNA
|
||||
if t := _ttmp(v, n): return t
|
||||
if n > 1: return SPECIAL_PAIRS.get(v) or _sreg(v, n)
|
||||
return SPECIAL_GPRS.get(v, f"s{v}")
|
||||
pairs = SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS
|
||||
gprs = SPECIAL_GPRS_CDNA if cdna else SPECIAL_GPRS
|
||||
if n > 1: return pairs.get(v) or gprs.get(v) or _sreg(v, n) # also check gprs for null/m0
|
||||
return gprs.get(v, f"s{v}")
|
||||
|
||||
def _fmt_src(v: int, n: int = 1) -> str:
|
||||
if n == 1: return decode_src(v)
|
||||
def _fmt_src(v: int, n: int = 1, cdna: bool = False) -> str:
|
||||
from extra.assembly.amd.dsl import SPECIAL_PAIRS_CDNA
|
||||
if n == 1: return decode_src(v, cdna)
|
||||
if v >= 256: return _vreg(v - 256, n)
|
||||
if v <= 105: return _sreg(v, n)
|
||||
if n == 2 and v in SPECIAL_PAIRS: return SPECIAL_PAIRS[v]
|
||||
if v <= 101: return _sreg(v, n) # s0-s101 can be pairs, but 102+ are special on CDNA
|
||||
pairs = SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS
|
||||
if n == 2 and v in pairs: return pairs[v]
|
||||
if v <= 105: return _sreg(v, n) # s102-s105 regular pairs for RDNA
|
||||
if t := _ttmp(v, n): return t
|
||||
return decode_src(v)
|
||||
return decode_src(v, cdna)
|
||||
|
||||
def _fmt_v16(v: int, base: int = 256, hi_thresh: int = 384) -> str:
|
||||
return f"v{(v - base) & 0x7f}.{'h' if v >= hi_thresh else 'l'}"
|
||||
@@ -106,46 +134,72 @@ def _opsel_str(opsel: int, n: int, need: bool, is16_d: bool) -> str:
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def _disasm_vop1(inst: VOP1) -> str:
|
||||
name = inst.op_name.lower()
|
||||
if inst.op in (VOP1Op.V_NOP, VOP1Op.V_PIPEFLUSH): return name
|
||||
if inst.op == VOP1Op.V_READFIRSTLANE_B32: return f"v_readfirstlane_b32 {decode_src(inst.vdst)}, v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}"
|
||||
# 16-bit dst: uses .h/.l suffix (determined by name pattern, not dtype - e.g. sat_pk_u8_i16 outputs 8-bit but uses 16-bit encoding)
|
||||
name, cdna = inst.op_name.lower() or f'vop1_op_{inst.op}', _is_cdna(inst)
|
||||
suf = "" if cdna else "_e32"
|
||||
if name in ('v_nop', 'v_pipeflush', 'v_clrexcp'): return name # no operands
|
||||
if 'readfirstlane' in name:
|
||||
src = f"v{inst.src0 - 256}" if inst.src0 >= 256 else decode_src(inst.src0, cdna)
|
||||
return f"{name} {_fmt_sdst(inst.vdst, 1, cdna)}, {src}"
|
||||
# 16-bit dst: uses .h/.l suffix for RDNA (CDNA uses plain vN)
|
||||
parts = name.split('_')
|
||||
is_16d = any(p in ('f16','i16','u16','b16') for p in parts[-2:-1]) or (len(parts) >= 2 and parts[-1] in ('f16','i16','u16','b16') and 'cvt' not in name)
|
||||
is_16d = not cdna and (any(p in ('f16','i16','u16','b16') for p in parts[-2:-1]) or (len(parts) >= 2 and parts[-1] in ('f16','i16','u16','b16') and 'cvt' not in name))
|
||||
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else _fmt_v16(inst.vdst, 0, 128) if is_16d else f"v{inst.vdst}"
|
||||
src = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_src_16(0) and 'sat_pk' not in name else inst.lit(inst.src0)
|
||||
return f"{name}_e32 {dst}, {src}"
|
||||
src = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0), cdna) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if not cdna and inst.is_src_16(0) and 'sat_pk' not in name else inst.lit(inst.src0)
|
||||
return f"{name}{suf} {dst}, {src}"
|
||||
|
||||
_VOP2_CARRY_OUT = {'v_add_co_u32', 'v_sub_co_u32', 'v_subrev_co_u32'} # carry out only
|
||||
_VOP2_CARRY_INOUT = {'v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'} # carry in and out
|
||||
def _disasm_vop2(inst: VOP2) -> str:
|
||||
name, cdna = inst.op_name.lower(), _is_cdna(inst)
|
||||
suf = "" if not cdna and inst.op == VOP2Op.V_DOT2ACC_F32_F16 else "_e32"
|
||||
if cdna: name = _CDNA_DISASM_ALIASES.get(name, name) # apply CDNA aliases
|
||||
suf = "" if cdna or (not cdna and inst.op == VOP2Op.V_DOT2ACC_F32_F16) else "_e32"
|
||||
lit = getattr(inst, '_literal', None)
|
||||
is16 = not cdna and inst.is_16bit()
|
||||
# fmaak: dst = src0 * vsrc1 + K, fmamk: dst = src0 * K + vsrc1
|
||||
if 'fmaak' in name or (not cdna and inst.op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16)):
|
||||
# fmaak/madak: dst = src0 * vsrc1 + K, fmamk/madmk: dst = src0 * K + vsrc1
|
||||
if 'fmaak' in name or 'madak' in name or (not cdna and inst.op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16)):
|
||||
if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}, 0x{lit:x}"
|
||||
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, 0x{lit:x}"
|
||||
if 'fmamk' in name or (not cdna and inst.op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16)):
|
||||
if 'fmamk' in name or 'madmk' in name or (not cdna and inst.op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16)):
|
||||
if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, 0x{lit:x}, {_fmt_v16(inst.vsrc1, 0, 128)}"
|
||||
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, 0x{lit:x}, v{inst.vsrc1}"
|
||||
if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}"
|
||||
vcc = "vcc" if cdna else "vcc_lo"
|
||||
# CDNA carry ops output vcc after vdst
|
||||
if cdna and name in _VOP2_CARRY_OUT: return f"{name}{suf} v{inst.vdst}, {vcc}, {inst.lit(inst.src0)}, v{inst.vsrc1}"
|
||||
if cdna and name in _VOP2_CARRY_INOUT: return f"{name}{suf} v{inst.vdst}, {vcc}, {inst.lit(inst.src0)}, v{inst.vsrc1}, {vcc}"
|
||||
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}" + (f", {vcc}" if name == 'v_cndmask_b32' else "")
|
||||
|
||||
def _disasm_vopc(inst: VOPC) -> str:
|
||||
name, cdna = inst.op_name.lower(), _is_cdna(inst)
|
||||
if cdna:
|
||||
s0 = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0))
|
||||
return f"{name}_e32 {s0}, v{inst.vsrc1}" if inst.op.value >= 128 else f"{name}_e32 vcc, {s0}, v{inst.vsrc1}"
|
||||
s0 = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0), cdna)
|
||||
s1 = _vreg(inst.vsrc1, inst.src_regs(1)) if inst.src_regs(1) > 1 else f"v{inst.vsrc1}"
|
||||
return f"{name} vcc, {s0}, {s1}" # CDNA VOPC always outputs vcc
|
||||
# RDNA: v_cmpx_* writes to exec (no vcc), v_cmp_* writes to vcc_lo
|
||||
has_vcc = 'cmpx' not in name
|
||||
s0 = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_16bit() else inst.lit(inst.src0)
|
||||
s1 = _vreg(inst.vsrc1, inst.src_regs(1)) if inst.src_regs(1) > 1 else _fmt_v16(inst.vsrc1, 0, 128) if inst.is_16bit() else f"v{inst.vsrc1}"
|
||||
return f"{name}_e32 {s0}, {s1}" if inst.op.value >= 128 else f"{name}_e32 vcc_lo, {s0}, {s1}"
|
||||
return f"{name}_e32 vcc_lo, {s0}, {s1}" if has_vcc else f"{name}_e32 {s0}, {s1}"
|
||||
|
||||
NO_ARG_SOPP = {SOPPOp.S_BARRIER, SOPPOp.S_WAKEUP, SOPPOp.S_ICACHE_INV,
|
||||
SOPPOp.S_WAIT_IDLE, SOPPOp.S_ENDPGM_SAVED, SOPPOp.S_CODE_END, SOPPOp.S_ENDPGM_ORDERED_PS_DONE, SOPPOp.S_TTRACEDATA}
|
||||
# CDNA uses name-based matching since opcode values differ from RDNA
|
||||
_CDNA_NO_ARG_SOPP = {'s_endpgm', 's_barrier', 's_wakeup', 's_icache_inv', 's_ttracedata', 's_nop', 's_sethalt', 's_sleep',
|
||||
's_setprio', 's_trap', 's_incperflevel', 's_decperflevel', 's_sendmsg', 's_sendmsghalt'}
|
||||
|
||||
def _disasm_sopp(inst: SOPP) -> str:
|
||||
name = inst.op_name.lower()
|
||||
name, cdna = inst.op_name.lower(), _is_cdna(inst)
|
||||
if cdna:
|
||||
# CDNA: use name-based matching
|
||||
if name == 's_endpgm': return name if inst.simm16 == 0 else f"{name} {inst.simm16}"
|
||||
if name in ('s_barrier', 's_wakeup', 's_icache_inv', 's_ttracedata'): return name
|
||||
if name == 's_waitcnt':
|
||||
vm, lgkm, exp = inst.simm16 & 0xf, (inst.simm16 >> 8) & 0x3f, (inst.simm16 >> 4) & 0x7
|
||||
p = [f"vmcnt({vm})" if vm != 0xf else "", f"expcnt({exp})" if exp != 7 else "", f"lgkmcnt({lgkm})" if lgkm != 0x3f else ""]
|
||||
return f"s_waitcnt {' '.join(x for x in p if x) or '0'}"
|
||||
if name.startswith(('s_cbranch', 's_branch')): return f"{name} {inst.simm16}"
|
||||
return f"{name} 0x{inst.simm16:x}" if inst.simm16 else name
|
||||
# RDNA
|
||||
if inst.op in NO_ARG_SOPP: return name
|
||||
if inst.op == SOPPOp.S_ENDPGM: return name if inst.simm16 == 0 else f"{name} {inst.simm16}"
|
||||
if inst.op == SOPPOp.S_WAITCNT:
|
||||
@@ -161,64 +215,98 @@ def _disasm_sopp(inst: SOPP) -> str:
|
||||
return f"{name} {inst.simm16}" if name.startswith(('s_cbranch', 's_branch')) else f"{name} 0x{inst.simm16:x}"
|
||||
|
||||
def _disasm_smem(inst: SMEM) -> str:
|
||||
name = inst.op_name.lower()
|
||||
name, cdna = inst.op_name.lower(), _is_cdna(inst)
|
||||
if inst.op in (SMEMOp.S_GL1_INV, SMEMOp.S_DCACHE_INV): return name
|
||||
off_s = f"{decode_src(inst.soffset)} offset:0x{inst.offset:x}" if inst.offset and inst.soffset != 124 else f"0x{inst.offset:x}" if inst.offset else decode_src(inst.soffset)
|
||||
sbase_idx, sbase_count = inst.sbase * 2, 4 if (8 <= inst.op.value <= 12 or name == 's_atc_probe_buffer') else 2
|
||||
sbase_str = _fmt_src(sbase_idx, sbase_count) if sbase_count == 2 else _sreg(sbase_idx, sbase_count) if sbase_idx <= 105 else _reg("ttmp", sbase_idx - 108, sbase_count)
|
||||
# GFX9 SMEM: soe and imm bits determine offset interpretation
|
||||
# soe=1, imm=1: soffset is SGPR, offset is immediate (both used)
|
||||
# soe=0, imm=1: offset is immediate
|
||||
# soe=0, imm=0: offset field is SGPR encoding (0-255)
|
||||
soe, imm = getattr(inst, 'soe', 0), getattr(inst, 'imm', 1)
|
||||
if cdna:
|
||||
if soe and imm:
|
||||
off_s = f"{decode_src(inst.soffset, cdna)} offset:0x{inst.offset:x}" # SGPR + immediate
|
||||
elif imm:
|
||||
off_s = f"0x{inst.offset:x}" # Immediate offset only
|
||||
elif inst.offset < 256:
|
||||
off_s = decode_src(inst.offset, cdna) # SGPR encoding in offset field
|
||||
else:
|
||||
off_s = decode_src(inst.soffset, cdna)
|
||||
elif inst.offset and inst.soffset != 124:
|
||||
off_s = f"{decode_src(inst.soffset, cdna)} offset:0x{inst.offset:x}"
|
||||
elif inst.offset:
|
||||
off_s = f"0x{inst.offset:x}"
|
||||
else:
|
||||
off_s = decode_src(inst.soffset, cdna)
|
||||
op_val = inst.op.value if hasattr(inst.op, 'value') else inst.op
|
||||
# s_buffer_* instructions use 4 SGPRs for sbase (buffer descriptor)
|
||||
is_buffer = 'buffer' in name or 's_atc_probe_buffer' == name
|
||||
sbase_idx, sbase_count = inst.sbase * 2, 4 if is_buffer else 2
|
||||
sbase_str = _fmt_src(sbase_idx, sbase_count, cdna) if sbase_count == 2 else _sreg(sbase_idx, sbase_count) if sbase_idx <= 105 else _reg("ttmp", sbase_idx - 108, sbase_count)
|
||||
if name in ('s_atc_probe', 's_atc_probe_buffer'): return f"{name} {inst.sdata}, {sbase_str}, {off_s}"
|
||||
return f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs())}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (inst.dlc, " dlc"))
|
||||
return f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs(), cdna)}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (getattr(inst, 'dlc', 0), " dlc"))
|
||||
|
||||
def _disasm_flat(inst: FLAT) -> str:
|
||||
name, cdna = inst.op_name.lower(), _is_cdna(inst)
|
||||
acc = getattr(inst, 'acc', 0) # GFX90a accumulator register flag
|
||||
reg_fn = _areg if acc else _vreg # use a[n] for acc=1, v[n] for acc=0
|
||||
seg = ['flat', 'scratch', 'global'][inst.seg] if inst.seg < 3 else 'flat'
|
||||
instr = f"{seg}_{name.split('_', 1)[1] if '_' in name else name}"
|
||||
off_val = inst.offset if seg == 'flat' else (inst.offset if inst.offset < 4096 else inst.offset - 8192)
|
||||
w = inst.dst_regs() * (2 if 'cmpswap' in name else 1)
|
||||
if cdna: mods = f"{f' offset:{off_val}' if off_val else ''}{' sc0' if inst.sc0 else ''}{' nt' if inst.nt else ''}{' sc1' if inst.sc1 else ''}"
|
||||
else: mods = f"{f' offset:{off_val}' if off_val else ''}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}"
|
||||
w = inst.dst_regs() * (2 if '_x2' in name else 1) * (2 if 'cmpswap' in name else 1)
|
||||
off_s = f" offset:{off_val}" if off_val else "" # Omit offset:0
|
||||
if cdna: mods = f"{off_s}{' glc' if inst.sc0 else ''}{' slc' if inst.nt else ''}" # GFX9: sc0->glc, nt->slc
|
||||
else: mods = f"{off_s}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}"
|
||||
# saddr
|
||||
if seg == 'flat' or inst.saddr == 0x7F: saddr_s = ""
|
||||
elif inst.saddr == 124: saddr_s = ", off"
|
||||
elif seg == 'scratch': saddr_s = f", {decode_src(inst.saddr)}"
|
||||
elif inst.saddr in SPECIAL_PAIRS: saddr_s = f", {SPECIAL_PAIRS[inst.saddr]}"
|
||||
elif seg == 'scratch': saddr_s = f", {decode_src(inst.saddr, cdna)}"
|
||||
elif inst.saddr in (SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS): saddr_s = f", {(SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS)[inst.saddr]}"
|
||||
elif t := _ttmp(inst.saddr, 2): saddr_s = f", {t}"
|
||||
else: saddr_s = f", {_sreg(inst.saddr, 2) if inst.saddr < 106 else decode_src(inst.saddr)}"
|
||||
else: saddr_s = f", {_sreg(inst.saddr, 2) if inst.saddr < 106 else decode_src(inst.saddr, cdna)}"
|
||||
# addtid: no addr
|
||||
if 'addtid' in name: return f"{instr} v{inst.data if 'store' in name else inst.vdst}{saddr_s}{mods}"
|
||||
# addr width
|
||||
addr_s = "off" if not inst.sve and seg == 'scratch' else _vreg(inst.addr, 1 if seg == 'scratch' or (inst.saddr not in (0x7F, 124)) else 2)
|
||||
data_s, vdst_s = _vreg(inst.data, w), _vreg(inst.vdst, w // 2 if 'cmpswap' in name else w)
|
||||
if 'addtid' in name: return f"{instr} {'a' if acc else 'v'}{inst.data if 'store' in name else inst.vdst}{saddr_s}{mods}"
|
||||
# addr width: CDNA flat always uses 2 VGPRs (64-bit), scratch uses 1, RDNA uses 2 only when no saddr
|
||||
if cdna:
|
||||
addr_w = 1 if seg == 'scratch' else 2 # CDNA: flat/global always 64-bit addr
|
||||
else:
|
||||
addr_w = 1 if seg == 'scratch' or (inst.saddr not in (0x7F, 124)) else 2
|
||||
addr_s = "off" if not inst.sve and seg == 'scratch' else _vreg(inst.addr, addr_w)
|
||||
data_s, vdst_s = reg_fn(inst.data, w), reg_fn(inst.vdst, w // 2 if 'cmpswap' in name else w)
|
||||
glc_or_sc0 = inst.sc0 if cdna else inst.glc
|
||||
if 'atomic' in name:
|
||||
return f"{instr} {vdst_s}, {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}" if glc_or_sc0 else f"{instr} {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}"
|
||||
if 'store' in name: return f"{instr} {addr_s}, {data_s}{saddr_s}{mods}"
|
||||
return f"{instr} {_vreg(inst.vdst, w)}, {addr_s}{saddr_s}{mods}"
|
||||
return f"{instr} {reg_fn(inst.vdst, w)}, {addr_s}{saddr_s}{mods}"
|
||||
|
||||
def _disasm_ds(inst: DS) -> str:
|
||||
op, name = inst.op, inst.op_name.lower()
|
||||
acc = getattr(inst, 'acc', 0) # GFX90a accumulator register flag
|
||||
reg_fn = _areg if acc else _vreg # use a[n] for acc=1, v[n] for acc=0
|
||||
rp = 'a' if acc else 'v' # register prefix for single regs
|
||||
gds = " gds" if inst.gds else ""
|
||||
off = f" offset:{inst.offset0 | (inst.offset1 << 8)}" if inst.offset0 or inst.offset1 else ""
|
||||
off2 = f" offset0:{inst.offset0} offset1:{inst.offset1}" if inst.offset0 or inst.offset1 else ""
|
||||
off2 = (" offset0:" + str(inst.offset0) if inst.offset0 else "") + (" offset1:" + str(inst.offset1) if inst.offset1 else "")
|
||||
w = inst.dst_regs()
|
||||
d0, d1, dst, addr = _vreg(inst.data0, w), _vreg(inst.data1, w), _vreg(inst.vdst, w), f"v{inst.addr}"
|
||||
d0, d1, dst, addr = reg_fn(inst.data0, w), reg_fn(inst.data1, w), reg_fn(inst.vdst, w), f"v{inst.addr}"
|
||||
|
||||
if op == DSOp.DS_NOP: return name
|
||||
if op == DSOp.DS_BVH_STACK_RTN_B32: return f"{name} v{inst.vdst}, {addr}, v{inst.data0}, {_vreg(inst.data1, 4)}{off}{gds}"
|
||||
if 'gws_sema' in name and op != DSOp.DS_GWS_SEMA_BR: return f"{name}{off}{gds}"
|
||||
if 'gws_' in name: return f"{name} {addr}{off}{gds}"
|
||||
if op in (DSOp.DS_CONSUME, DSOp.DS_APPEND): return f"{name} v{inst.vdst}{off}{gds}"
|
||||
if 'gs_reg' in name: return f"{name} {_vreg(inst.vdst, 2)}, v{inst.data0}{off}{gds}"
|
||||
if op in (DSOp.DS_CONSUME, DSOp.DS_APPEND): return f"{name} {rp}{inst.vdst}{off}{gds}"
|
||||
if 'gs_reg' in name: return f"{name} {reg_fn(inst.vdst, 2)}, {rp}{inst.data0}{off}{gds}"
|
||||
if '2addr' in name:
|
||||
if 'load' in name: return f"{name} {_vreg(inst.vdst, w*2)}, {addr}{off2}{gds}"
|
||||
if 'load' in name: return f"{name} {reg_fn(inst.vdst, w*2)}, {addr}{off2}{gds}"
|
||||
if 'store' in name and 'xchg' not in name: return f"{name} {addr}, {d0}, {d1}{off2}{gds}"
|
||||
return f"{name} {_vreg(inst.vdst, w*2)}, {addr}, {d0}, {d1}{off2}{gds}"
|
||||
if 'load' in name: return f"{name} v{inst.vdst}{off}{gds}" if 'addtid' in name else f"{name} {dst}, {addr}{off}{gds}"
|
||||
return f"{name} {reg_fn(inst.vdst, w*2)}, {addr}, {d0}, {d1}{off2}{gds}"
|
||||
if 'write2' in name: return f"{name} {addr}, {d0}, {d1}{off2}{gds}"
|
||||
if 'read2' in name: return f"{name} {reg_fn(inst.vdst, w*2)}, {addr}{off2}{gds}"
|
||||
if 'load' in name: return f"{name} {rp}{inst.vdst}{off}{gds}" if 'addtid' in name else f"{name} {dst}, {addr}{off}{gds}"
|
||||
if 'store' in name and not _has(name, 'cmp', 'xchg'):
|
||||
return f"{name} v{inst.data0}{off}{gds}" if 'addtid' in name else f"{name} {addr}, {d0}{off}{gds}"
|
||||
if 'swizzle' in name or op == DSOp.DS_ORDERED_COUNT: return f"{name} v{inst.vdst}, {addr}{off}{gds}"
|
||||
if 'permute' in name: return f"{name} v{inst.vdst}, {addr}, v{inst.data0}{off}{gds}"
|
||||
if 'condxchg' in name: return f"{name} {_vreg(inst.vdst, 2)}, {addr}, {_vreg(inst.data0, 2)}{off}{gds}"
|
||||
return f"{name} {rp}{inst.data0}{off}{gds}" if 'addtid' in name else f"{name} {addr}, {d0}{off}{gds}"
|
||||
if 'swizzle' in name or op == DSOp.DS_ORDERED_COUNT: return f"{name} {rp}{inst.vdst}, {addr}{off}{gds}"
|
||||
if 'permute' in name: return f"{name} {rp}{inst.vdst}, {addr}, {rp}{inst.data0}{off}{gds}"
|
||||
if 'condxchg' in name: return f"{name} {reg_fn(inst.vdst, 2)}, {addr}, {reg_fn(inst.data0, 2)}{off}{gds}"
|
||||
if _has(name, 'cmpstore', 'mskor', 'wrap'):
|
||||
return f"{name} {dst}, {addr}, {d0}, {d1}{off}{gds}" if '_rtn' in name else f"{name} {addr}, {d0}, {d1}{off}{gds}"
|
||||
return f"{name} {dst}, {addr}, {d0}{off}{gds}" if '_rtn' in name else f"{name} {addr}, {d0}{off}{gds}"
|
||||
@@ -318,6 +406,8 @@ def _disasm_vop3p(inst: VOP3P) -> str:
|
||||
|
||||
def _disasm_buf(inst: MUBUF | MTBUF) -> str:
|
||||
name, cdna = inst.op_name.lower(), _is_cdna(inst)
|
||||
acc = getattr(inst, 'acc', 0) # GFX90a accumulator register flag
|
||||
reg_fn = _areg if acc else _vreg # use a[n] for acc=1, v[n] for acc=0
|
||||
if cdna and name in ('buffer_wbl2', 'buffer_inv'): return name
|
||||
if not cdna and inst.op in (MUBUFOp.BUFFER_GL0_INV, MUBUFOp.BUFFER_GL1_INV): return name
|
||||
w = (2 if _has(name, 'xyz', 'xyzw') else 1) if 'd16' in name else \
|
||||
@@ -326,9 +416,27 @@ def _disasm_buf(inst: MUBUF | MTBUF) -> str:
|
||||
if hasattr(inst, 'tfe') and inst.tfe: w += 1
|
||||
vaddr = _vreg(inst.vaddr, 2) if inst.offen and inst.idxen else f"v{inst.vaddr}" if inst.offen or inst.idxen else "off"
|
||||
srsrc = _sreg_or_ttmp(inst.srsrc*4, 4)
|
||||
if cdna: mods = ([f"format:{inst.format}"] if isinstance(inst, MTBUF) else []) + [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.sc0,"sc0"),(inst.nt,"nt"),(inst.sc1,"sc1")] if c]
|
||||
else: mods = ([f"format:{inst.format}"] if isinstance(inst, MTBUF) else []) + [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.glc,"glc"),(inst.dlc,"dlc"),(inst.slc,"slc"),(inst.tfe,"tfe")] if c]
|
||||
return f"{name} {_vreg(inst.vdata, w)}, {vaddr}, {srsrc}, {decode_src(inst.soffset)}{' ' + ' '.join(mods) if mods else ''}"
|
||||
is_mtbuf = isinstance(inst, MTBUF) or isinstance(inst, C_MTBUF)
|
||||
if is_mtbuf:
|
||||
dfmt, nfmt = inst.format & 0xf, (inst.format >> 4) & 0x7
|
||||
if acc: # GFX90a accumulator style: show dfmt/nfmt as numbers
|
||||
fmt_s = f" dfmt:{dfmt}, nfmt:{nfmt}," # double space before dfmt per LLVM format
|
||||
elif not cdna: # RDNA style: show combined format number
|
||||
fmt_s = f" format:{inst.format}" if inst.format else ""
|
||||
else: # CDNA: show format:[BUF_DATA_FORMAT_X] or format:[BUF_NUM_FORMAT_X]
|
||||
dfmt_names = ['INVALID', '8', '16', '8_8', '32', '16_16', '10_11_11', '11_11_10', '10_10_10_2', '2_10_10_10', '8_8_8_8', '32_32', '16_16_16_16', '32_32_32', '32_32_32_32', 'RESERVED_15']
|
||||
nfmt_names = ['UNORM', 'SNORM', 'USCALED', 'SSCALED', 'UINT', 'SINT', 'RESERVED_6', 'FLOAT']
|
||||
if dfmt == 1 and nfmt == 0: fmt_s = "" # default, no format shown
|
||||
elif nfmt == 0: fmt_s = f" format:[BUF_DATA_FORMAT_{dfmt_names[dfmt]}]" # only dfmt differs
|
||||
elif dfmt == 1: fmt_s = f" format:[BUF_NUM_FORMAT_{nfmt_names[nfmt]}]" # only nfmt differs
|
||||
else: fmt_s = f" format:[BUF_DATA_FORMAT_{dfmt_names[dfmt]},BUF_NUM_FORMAT_{nfmt_names[nfmt]}]" # both differ
|
||||
else:
|
||||
fmt_s = ""
|
||||
if cdna: mods = [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.sc0,"glc"),(inst.nt,"slc"),(inst.sc1,"sc1")] if c]
|
||||
else: mods = [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.glc,"glc"),(inst.dlc,"dlc"),(inst.slc,"slc"),(inst.tfe,"tfe")] if c]
|
||||
soffset_s = decode_src(inst.soffset, cdna)
|
||||
if cdna and not acc and is_mtbuf: return f"{name} {reg_fn(inst.vdata, w)}, {vaddr}, {srsrc}, {soffset_s}{fmt_s}{' ' + ' '.join(mods) if mods else ''}"
|
||||
return f"{name} {reg_fn(inst.vdata, w)}, {vaddr}, {srsrc},{fmt_s} {soffset_s}{' ' + ' '.join(mods) if mods else ''}"
|
||||
|
||||
def _mimg_vaddr_width(name: str, dim: int, a16: bool) -> int:
|
||||
"""Calculate vaddr register count for MIMG sample/gather operations."""
|
||||
@@ -377,21 +485,23 @@ def _disasm_mimg(inst: MIMG) -> str:
|
||||
return f"{name} {_vreg(inst.vdata, vdata)}, {vaddr_str}, {srsrc_str}{ssamp_str} {' '.join(mods)}"
|
||||
|
||||
def _disasm_sop1(inst: SOP1) -> str:
|
||||
op, name = inst.op, inst.op_name.lower()
|
||||
src = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0))
|
||||
if not _is_cdna(inst):
|
||||
op, name, cdna = inst.op, inst.op_name.lower(), _is_cdna(inst)
|
||||
src = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0), cdna)
|
||||
if not cdna:
|
||||
if op == SOP1Op.S_GETPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}"
|
||||
if op in (SOP1Op.S_SETPC_B64, SOP1Op.S_RFE_B64): return f"{name} {src}"
|
||||
if op == SOP1Op.S_SWAPPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}, {src}"
|
||||
if op in (SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64): return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, sendmsg({MSG.get(inst.ssrc0, str(inst.ssrc0))})"
|
||||
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {src}"
|
||||
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs(), cdna)}, {src}"
|
||||
|
||||
def _disasm_sop2(inst: SOP2) -> str:
|
||||
return f"{inst.op_name.lower()} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0))}, {inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1))}"
|
||||
cdna = _is_cdna(inst)
|
||||
return f"{inst.op_name.lower()} {_fmt_sdst(inst.sdst, inst.dst_regs(), cdna)}, {inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0), cdna)}, {inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1), cdna)}"
|
||||
|
||||
def _disasm_sopc(inst: SOPC) -> str:
|
||||
s0 = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0))
|
||||
s1 = inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1))
|
||||
cdna = _is_cdna(inst)
|
||||
s0 = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0), cdna)
|
||||
s1 = inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1), cdna)
|
||||
return f"{inst.op_name.lower()} {s0}, {s1}"
|
||||
|
||||
def _disasm_sopk(inst: SOPK) -> str:
|
||||
@@ -405,10 +515,10 @@ def _disasm_sopk(inst: SOPK) -> str:
|
||||
if (not cdna and op in (SOPKOp.S_SETREG_B32, SOPKOp.S_GETREG_B32)) or (cdna and name in ('s_setreg_b32', 's_getreg_b32')):
|
||||
hid, hoff, hsz = inst.simm16 & 0x3f, (inst.simm16 >> 6) & 0x1f, ((inst.simm16 >> 11) & 0x1f) + 1
|
||||
hs = f"0x{inst.simm16:x}" if hid in (16, 17) else f"hwreg({HWREG.get(hid, str(hid))}, {hoff}, {hsz})"
|
||||
return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1)}" if 'setreg' in name else f"{name} {_fmt_sdst(inst.sdst, 1)}, {hs}"
|
||||
return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1, cdna)}" if 'setreg' in name else f"{name} {_fmt_sdst(inst.sdst, 1, cdna)}, {hs}"
|
||||
if not cdna and op in (SOPKOp.S_SUBVECTOR_LOOP_BEGIN, SOPKOp.S_SUBVECTOR_LOOP_END):
|
||||
return f"{name} {_fmt_sdst(inst.sdst, 1)}, 0x{inst.simm16:x}"
|
||||
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, 0x{inst.simm16:x}"
|
||||
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs(), cdna)}, 0x{inst.simm16:x}"
|
||||
|
||||
def _disasm_vinterp(inst: VINTERP) -> str:
|
||||
mods = _mods((inst.waitexp, f"wait_exp:{inst.waitexp}"), (inst.clmp, "clamp"))
|
||||
@@ -464,11 +574,54 @@ def _parse_ops(s: str) -> list[str]:
|
||||
return ops
|
||||
|
||||
def _extract(text: str, pat: str, flags=re.I):
|
||||
if m := re.search(pat, text, flags): return m, text[:m.start()] + text[m.end():]
|
||||
if m := re.search(pat, text, flags): return m, text[:m.start()] + ' ' + text[m.end():]
|
||||
return None, text
|
||||
|
||||
# Instruction aliases: LLVM uses different names for some instructions
|
||||
_ALIASES = {
|
||||
'v_cmp_tru_f16': 'v_cmp_t_f16', 'v_cmp_tru_f32': 'v_cmp_t_f32', 'v_cmp_tru_f64': 'v_cmp_t_f64',
|
||||
'v_cmpx_tru_f16': 'v_cmpx_t_f16', 'v_cmpx_tru_f32': 'v_cmpx_t_f32', 'v_cmpx_tru_f64': 'v_cmpx_t_f64',
|
||||
'v_cvt_flr_i32_f32': 'v_cvt_floor_i32_f32', 'v_cvt_rpi_i32_f32': 'v_cvt_nearest_i32_f32',
|
||||
'v_ffbh_i32': 'v_cls_i32', 'v_ffbh_u32': 'v_clz_i32_u32', 'v_ffbl_b32': 'v_ctz_i32_b32',
|
||||
'v_cvt_pkrtz_f16_f32': 'v_cvt_pk_rtz_f16_f32', 'v_fmac_legacy_f32': 'v_fmac_dx9_zero_f32', 'v_mul_legacy_f32': 'v_mul_dx9_zero_f32',
|
||||
# SMEM aliases (dword -> b32, dwordx2 -> b64, etc.)
|
||||
's_load_dword': 's_load_b32', 's_load_dwordx2': 's_load_b64', 's_load_dwordx4': 's_load_b128',
|
||||
's_load_dwordx8': 's_load_b256', 's_load_dwordx16': 's_load_b512',
|
||||
's_buffer_load_dword': 's_buffer_load_b32', 's_buffer_load_dwordx2': 's_buffer_load_b64',
|
||||
's_buffer_load_dwordx4': 's_buffer_load_b128', 's_buffer_load_dwordx8': 's_buffer_load_b256',
|
||||
's_buffer_load_dwordx16': 's_buffer_load_b512',
|
||||
# VOP3 aliases
|
||||
'v_cvt_pknorm_i16_f16': 'v_cvt_pk_norm_i16_f16', 'v_cvt_pknorm_u16_f16': 'v_cvt_pk_norm_u16_f16',
|
||||
'v_add3_nc_u32': 'v_add3_u32', 'v_xor_add_u32': 'v_xad_u32',
|
||||
# VINTERP aliases
|
||||
'v_interp_p2_new_f32': 'v_interp_p2_f32',
|
||||
# SOP1 aliases
|
||||
's_ff1_i32_b32': 's_ctz_i32_b32', 's_ff1_i32_b64': 's_ctz_i32_b64',
|
||||
's_flbit_i32_b32': 's_clz_i32_u32', 's_flbit_i32_b64': 's_clz_i32_u64', 's_flbit_i32': 's_cls_i32', 's_flbit_i32_i64': 's_cls_i32_i64',
|
||||
's_andn1_saveexec_b32': 's_and_not0_saveexec_b32', 's_andn1_saveexec_b64': 's_and_not0_saveexec_b64',
|
||||
's_andn1_wrexec_b32': 's_and_not0_wrexec_b32', 's_andn1_wrexec_b64': 's_and_not0_wrexec_b64',
|
||||
's_andn2_saveexec_b32': 's_and_not1_saveexec_b32', 's_andn2_saveexec_b64': 's_and_not1_saveexec_b64',
|
||||
's_andn2_wrexec_b32': 's_and_not1_wrexec_b32', 's_andn2_wrexec_b64': 's_and_not1_wrexec_b64',
|
||||
's_orn1_saveexec_b32': 's_or_not0_saveexec_b32', 's_orn1_saveexec_b64': 's_or_not0_saveexec_b64',
|
||||
's_orn2_saveexec_b32': 's_or_not1_saveexec_b32', 's_orn2_saveexec_b64': 's_or_not1_saveexec_b64',
|
||||
# SOP2 aliases
|
||||
's_andn2_b32': 's_and_not1_b32', 's_andn2_b64': 's_and_not1_b64',
|
||||
's_orn2_b32': 's_or_not1_b32', 's_orn2_b64': 's_or_not1_b64',
|
||||
# VOP2 aliases
|
||||
'v_dot2c_f32_f16': 'v_dot2acc_f32_f16',
|
||||
# More VOP3 aliases
|
||||
'v_fma_legacy_f32': 'v_fma_dx9_zero_f32',
|
||||
}
|
||||
|
||||
def _apply_alias(text: str) -> str:
|
||||
mn = text.split()[0].lower() if ' ' in text else text.lower().rstrip('_')
|
||||
# Try exact match first, then strip _e32/_e64 suffix
|
||||
for m in (mn, mn.removesuffix('_e32'), mn.removesuffix('_e64')):
|
||||
if m in _ALIASES: return _ALIASES[m] + text[len(m):]
|
||||
return text
|
||||
|
||||
def get_dsl(text: str) -> str:
|
||||
text, kw = text.strip(), []
|
||||
text, kw = _apply_alias(text.strip()), []
|
||||
# Extract modifiers
|
||||
for pat, val in [(r'\s+mul:2(?:\s|$)', 1), (r'\s+mul:4(?:\s|$)', 2), (r'\s+div:2(?:\s|$)', 3)]:
|
||||
if (m := _extract(text, pat))[0]: kw.append(f'omod={val}'); text = m[1]; break
|
||||
@@ -484,6 +637,11 @@ def get_dsl(text: str) -> str:
|
||||
m, text = _extract(text, r'\s+dlc(?:\s|$)'); dlc = 1 if m else None
|
||||
m, text = _extract(text, r'\s+glc(?:\s|$)'); glc = 1 if m else None
|
||||
m, text = _extract(text, r'\s+slc(?:\s|$)'); slc = 1 if m else None
|
||||
m, text = _extract(text, r'\s+tfe(?:\s|$)'); tfe = 1 if m else None
|
||||
m, text = _extract(text, r'\s+offen(?:\s|$)'); offen = 1 if m else None
|
||||
m, text = _extract(text, r'\s+idxen(?:\s|$)'); idxen = 1 if m else None
|
||||
m, text = _extract(text, r'\s+format:\[([^\]]+)\]'); fmt_val = m.group(1) if m else None
|
||||
m, text = _extract(text, r'\s+format:(\d+)'); fmt_val = m.group(1) if m and not fmt_val else fmt_val
|
||||
m, text = _extract(text, r'\s+neg_lo:\[([^\]]+)\]'); neg_lo = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None
|
||||
m, text = _extract(text, r'\s+neg_hi:\[([^\]]+)\]'); neg_hi = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None
|
||||
if waitexp: kw.append(f'waitexp={waitexp}')
|
||||
@@ -530,9 +688,30 @@ def get_dsl(text: str) -> str:
|
||||
if off_val and len(ops) >= 3: return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={off_val}, soffset={args[2]}{gs}{ds})"
|
||||
if len(ops) >= 3: return f"{mn}(sdata={args[0]}, sbase={args[1]}, soffset={args[2]}{gs}{ds})"
|
||||
|
||||
# Buffer
|
||||
if mn.startswith('buffer_') and len(ops) >= 2 and ops[1].strip().lower() == 'off':
|
||||
return f"{mn}(vdata={args[0]}, vaddr=0, srsrc={args[2]}, soffset={f'RawImm({args[3].strip()})' if len(args) > 3 else 'RawImm(0)'})"
|
||||
# Buffer (MUBUF/MTBUF) instructions
|
||||
if mn.startswith(('buffer_', 'tbuffer_')):
|
||||
is_tbuf = mn.startswith('tbuffer_')
|
||||
# Parse format value for tbuffer
|
||||
fmt_num = None
|
||||
if fmt_val is not None:
|
||||
if fmt_val.isdigit(): fmt_num = int(fmt_val)
|
||||
else: fmt_num = BUF_FMT.get(fmt_val.replace(' ', '')) or _parse_buf_fmt_combo(fmt_val)
|
||||
# Handle special no-arg buffer ops
|
||||
if mn in ('buffer_gl0_inv', 'buffer_gl1_inv', 'buffer_wbl2', 'buffer_inv'): return f"{mn}()"
|
||||
# Build modifiers string
|
||||
buf_mods = "".join([f", offset={off_val}" if off_val else "", ", glc=1" if glc else "", ", dlc=1" if dlc else "",
|
||||
", slc=1" if slc else "", ", tfe=1" if tfe else "", ", offen=1" if offen else "", ", idxen=1" if idxen else ""])
|
||||
if is_tbuf and fmt_num is not None: buf_mods = f", format={fmt_num}" + buf_mods
|
||||
# Determine vaddr value (v[0] for 'off', actual register otherwise)
|
||||
vaddr_idx = 1
|
||||
if len(ops) > vaddr_idx and ops[vaddr_idx].strip().lower() == 'off': vaddr_val = "v[0]"
|
||||
else: vaddr_val = args[vaddr_idx] if len(args) > vaddr_idx else "v[0]"
|
||||
# srsrc and soffset indices depend on whether vaddr is 'off'
|
||||
srsrc_idx, soff_idx = (2, 3) if len(ops) > 1 else (1, 2)
|
||||
srsrc_val = args[srsrc_idx] if len(args) > srsrc_idx else "s[0:3]"
|
||||
soff_val = args[soff_idx] if len(args) > soff_idx else "0"
|
||||
# soffset: integers are inline constants, don't wrap in RawImm
|
||||
return f"{mn}(vdata={args[0]}, vaddr={vaddr_val}, srsrc={srsrc_val}, soffset={soff_val}{buf_mods})"
|
||||
|
||||
# FLAT/GLOBAL/SCRATCH load/store/atomic - saddr needs RawImm(124) for off/null
|
||||
def _saddr(a): return 'RawImm(124)' if a in ('OFF', 'NULL') else a
|
||||
@@ -582,6 +761,15 @@ def get_dsl(text: str) -> str:
|
||||
if mn.replace('_e64', '') in vcc_ops and mn.endswith('_e64'): mn = mn.replace('_e64', '')
|
||||
if mn.startswith('v_cmp') and not mn.endswith('_e64') and len(args) >= 3 and ops[0].strip().lower() in ('vcc_lo', 'vcc_hi', 'vcc'): args = args[1:]
|
||||
if 'cmpx' in mn and mn.endswith('_e64') and len(args) == 2: args = ['RawImm(126)'] + args
|
||||
# v_cmp_*_e64 has SGPR destination in vdst field - encode as RawImm
|
||||
_SGPR_NAMES = {'vcc_lo': 106, 'vcc_hi': 107, 'vcc': 106, 'null': 124, 'm0': 125, 'exec_lo': 126, 'exec_hi': 127}
|
||||
if mn.startswith('v_cmp') and 'cmpx' not in mn and mn.endswith('_e64') and len(args) >= 1:
|
||||
dst = ops[0].strip().lower()
|
||||
if dst.startswith('s') and dst[1:].isdigit(): args[0] = f'RawImm({int(dst[1:])})'
|
||||
elif dst.startswith('s[') and ':' in dst: args[0] = f'RawImm({int(dst[2:].split(":")[0])})'
|
||||
elif dst.startswith('ttmp') and dst[4:].isdigit(): args[0] = f'RawImm({108 + int(dst[4:])})'
|
||||
elif dst.startswith('ttmp[') and ':' in dst: args[0] = f'RawImm({108 + int(dst[5:].split(":")[0])})'
|
||||
elif dst in _SGPR_NAMES: args[0] = f'RawImm({_SGPR_NAMES[dst]})'
|
||||
|
||||
fn = mn.replace('.', '_')
|
||||
if opsel is not None: args = [re.sub(r'\.[hl]$', '', a) for a in args]
|
||||
@@ -629,31 +817,76 @@ def asm(text: str) -> Inst:
|
||||
try:
|
||||
from extra.assembly.amd.autogen.cdna.ins import (VOP1 as CDNA_VOP1, VOP2 as CDNA_VOP2, VOPC as CDNA_VOPC, VOP3A, VOP3B, VOP3P as CDNA_VOP3P,
|
||||
SOP1 as CDNA_SOP1, SOP2 as CDNA_SOP2, SOPC as CDNA_SOPC, SOPK as CDNA_SOPK, SOPP as CDNA_SOPP, SMEM as CDNA_SMEM, DS as CDNA_DS,
|
||||
FLAT as CDNA_FLAT, MUBUF as CDNA_MUBUF, MTBUF as CDNA_MTBUF, SDWA, DPP, VOP1Op as CDNA_VOP1Op)
|
||||
FLAT as CDNA_FLAT, MUBUF as CDNA_MUBUF, MTBUF as CDNA_MTBUF, SDWA, DPP, VOP1Op as CDNA_VOP1Op, VOP2Op as CDNA_VOP2Op, VOPCOp as CDNA_VOPCOp)
|
||||
|
||||
def _cdna_src(inst, v, neg, abs_=0, n=1):
|
||||
s = inst.lit(v) if v == 255 else _fmt_src(v, n)
|
||||
s = inst.lit(v) if v == 255 else _fmt_src(v, n, cdna=True)
|
||||
if abs_: s = f"|{s}|"
|
||||
return f"neg({s})" if neg and v == 255 else (f"-{s}" if neg else s)
|
||||
|
||||
def _disasm_vop3a(inst) -> str:
|
||||
name, n, cl, om = inst.op_name.lower(), inst.num_srcs(), " clamp" if inst.clmp else "", _omod(inst.omod)
|
||||
s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.src_regs(0)), _cdna_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.src_regs(1)), _cdna_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.src_regs(2))
|
||||
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}"
|
||||
if inst.op.value < 256: return f"{name}_e64 {s0}, {s1}" if name.startswith('v_cmpx') else f"{name}_e64 {_fmt_sdst(inst.vdst, 1)}, {s0}, {s1}"
|
||||
suf = "_e64" if inst.op.value < 512 else ""
|
||||
return f"{name}{suf} {dst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else (f"{name}{suf}" if name == 'v_nop' else f"{name}{suf} {dst}, {s0}, {s1}{cl}{om}" if n == 2 else f"{name}{suf} {dst}, {s0}{cl}{om}")
|
||||
# CDNA VOP2 aliases: new opcode name -> old name expected by LLVM tests
|
||||
_CDNA_VOP3_ALIASES = {'v_fmac_f64': 'v_mul_legacy_f32', 'v_dot2c_f32_bf16': 'v_mac_f32'}
|
||||
|
||||
def _disasm_vop3b(inst) -> str:
|
||||
name, n = inst.op_name.lower(), inst.num_srcs()
|
||||
s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1), _cdna_src(inst, inst.src1, inst.neg&2), _cdna_src(inst, inst.src2, inst.neg&4)
|
||||
dst, suf = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}", "_e64" if 'co_' in name else ""
|
||||
def _disasm_vop3a(inst) -> str:
|
||||
op_val = inst._values.get('op', 0) # get raw opcode value, not enum value
|
||||
if hasattr(op_val, 'value'): op_val = op_val.value # in case it's stored as enum
|
||||
name = inst.op_name.lower() or f'vop3a_op_{op_val}'
|
||||
from extra.assembly.amd.dsl import spec_num_srcs, spec_regs
|
||||
n = spec_num_srcs(name) if name else inst.num_srcs()
|
||||
cl, om = " clamp" if inst.clmp else "", _omod(inst.omod)
|
||||
return f"{name}{suf} {dst}, {_fmt_sdst(inst.sdst, 1)}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name}{suf} {dst}, {_fmt_sdst(inst.sdst, 1)}, {s0}, {s1}{cl}{om}"
|
||||
orig_name = name
|
||||
name = _CDNA_VOP3_ALIASES.get(name, name) # apply CDNA aliases
|
||||
# For aliased ops, recalculate sources without 64-bit assumption
|
||||
if name != orig_name:
|
||||
s0, s1 = _cdna_src(inst, inst.src0, inst.neg&1, inst.abs&1, 1), _cdna_src(inst, inst.src1, inst.neg&2, inst.abs&2, 1)
|
||||
s2 = ""
|
||||
dst = f"v{inst.vdst}"
|
||||
else:
|
||||
dregs, r0, r1, r2 = spec_regs(name) if name else (inst.dst_regs(), inst.src_regs(0), inst.src_regs(1), inst.src_regs(2))
|
||||
s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1, inst.abs&1, r0), _cdna_src(inst, inst.src1, inst.neg&2, inst.abs&2, r1), _cdna_src(inst, inst.src2, inst.neg&4, inst.abs&4, r2)
|
||||
dst = _vreg(inst.vdst, dregs) if dregs > 1 else f"v{inst.vdst}"
|
||||
# True VOP3 instructions (512+) - 3-source ops
|
||||
if op_val >= 512:
|
||||
return f"{name} {dst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name} {dst}, {s0}, {s1}{cl}{om}"
|
||||
# VOPC (0-255): writes to SGPR pair, VOP2 (256-319): 2-3 src, VOP1 (320-511): 1 src
|
||||
if op_val < 256:
|
||||
sdst = _fmt_sdst(inst.vdst, 2, cdna=True) # VOPC writes to 64-bit SGPR pair
|
||||
# v_cmpx_ also writes to sdst in CDNA VOP3 (unlike VOP32 where it writes to exec)
|
||||
return f"{name}_e64 {sdst}, {s0}, {s1}{cl}"
|
||||
if 320 <= op_val < 512: # VOP1 promoted
|
||||
if name in ('v_nop', 'v_clrexcp'): return f"{name}_e64"
|
||||
return f"{name}_e64 {dst}, {s0}{cl}{om}"
|
||||
# VOP2 promoted (256-319)
|
||||
if name == 'v_cndmask_b32':
|
||||
s2 = _fmt_src(inst.src2, 2, cdna=True) # src2 is 64-bit SGPR pair
|
||||
return f"{name}_e64 {dst}, {s0}, {s1}, {s2}{cl}{om}"
|
||||
if name in ('v_mul_legacy_f32', 'v_mac_f32'):
|
||||
return f"{name}_e64 {dst}, {s0}, {s1}{cl}{om}"
|
||||
suf = "_e64" if op_val < 512 else ""
|
||||
return f"{name}{suf} {dst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name}{suf} {dst}, {s0}, {s1}{cl}{om}"
|
||||
|
||||
# GFX9-specific VOP3B opcodes not in CDNA enum
|
||||
def _disasm_vop3b(inst) -> str:
|
||||
op_val = inst._values.get('op', 0)
|
||||
if hasattr(op_val, 'value'): op_val = op_val.value
|
||||
name = inst.op_name.lower() or f'vop3b_op_{op_val}'
|
||||
from extra.assembly.amd.dsl import spec_num_srcs, spec_regs
|
||||
n = spec_num_srcs(name) if name else inst.num_srcs()
|
||||
dregs, r0, r1, r2 = spec_regs(name) if name else (inst.dst_regs(), inst.src_regs(0), inst.src_regs(1), inst.src_regs(2))
|
||||
s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1, n=r0), _cdna_src(inst, inst.src1, inst.neg&2, n=r1), _cdna_src(inst, inst.src2, inst.neg&4, n=r2)
|
||||
dst = _vreg(inst.vdst, dregs) if dregs > 1 else f"v{inst.vdst}"
|
||||
sdst = _fmt_sdst(inst.sdst, 2, cdna=True) # VOP3B sdst is always 64-bit SGPR pair
|
||||
cl, om = " clamp" if inst.clmp else "", _omod(inst.omod)
|
||||
# Carry ops need special handling
|
||||
if name in ('v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'):
|
||||
s2 = _fmt_src(inst.src2, 2, cdna=True) # src2 is carry-in (64-bit SGPR pair)
|
||||
return f"{name}_e64 {dst}, {sdst}, {s0}, {s1}, {s2}{cl}{om}"
|
||||
suf = "_e64" if 'co_' in name else ""
|
||||
return f"{name}{suf} {dst}, {sdst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name}{suf} {dst}, {sdst}, {s0}, {s1}{cl}{om}"
|
||||
|
||||
def _disasm_cdna_vop3p(inst) -> str:
|
||||
name, n, is_mfma = inst.op_name.lower(), inst.num_srcs(), 'mfma' in inst.op_name.lower() or 'smfmac' in inst.op_name.lower()
|
||||
get_src = lambda v, sc: inst.lit(v) if v == 255 else _fmt_src(v, sc)
|
||||
get_src = lambda v, sc: inst.lit(v) if v == 255 else _fmt_src(v, sc, cdna=True)
|
||||
if is_mfma: sc = 2 if 'iu4' in name else 4 if 'iu8' in name or 'i4' in name else 8 if 'f16' in name or 'bf16' in name else 4; src0, src1, src2, dst = get_src(inst.src0, sc), get_src(inst.src1, sc), get_src(inst.src2, 16), _vreg(inst.vdst, 16)
|
||||
else: src0, src1, src2, dst = get_src(inst.src0, 1), get_src(inst.src1, 1), get_src(inst.src2, 1), f"v{inst.vdst}"
|
||||
opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2)
|
||||
@@ -665,20 +898,93 @@ try:
|
||||
_UNUSED = {0: 'UNUSED_PAD', 1: 'UNUSED_SEXT', 2: 'UNUSED_PRESERVE'}
|
||||
_DPP = {0x130: "wave_shl:1", 0x134: "wave_rol:1", 0x138: "wave_shr:1", 0x13c: "wave_ror:1", 0x140: "row_mirror", 0x141: "row_half_mirror", 0x142: "row_bcast:15", 0x143: "row_bcast:31"}
|
||||
|
||||
def _sdwa_src0(v, is_sgpr, sext=0, neg=0, abs_=0):
|
||||
# s0=0: VGPR (v is VGPR number), s0=1: SGPR/constant (v is encoded like normal src)
|
||||
s = decode_src(v, cdna=True) if is_sgpr else f"v{v}"
|
||||
if sext: s = f"sext({s})"
|
||||
if abs_: s = f"|{s}|"
|
||||
return f"-{s}" if neg else s
|
||||
|
||||
def _sdwa_vsrc1(v, sext=0, neg=0, abs_=0):
|
||||
# For VOP2 SDWA, vsrc1 is in vop_op field as raw VGPR number
|
||||
s = f"v{v}"
|
||||
if sext: s = f"sext({s})"
|
||||
if abs_: s = f"|{s}|"
|
||||
return f"-{s}" if neg else s
|
||||
|
||||
_OMOD_SDWA = {0: "", 1: " mul:2", 2: " mul:4", 3: " div:2"}
|
||||
|
||||
def _disasm_sdwa(inst) -> str:
|
||||
try: name = CDNA_VOP1Op(inst.vop_op).name.lower()
|
||||
except ValueError: name = f"vop1_op_{inst.vop_op}"
|
||||
src = f"v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}" if isinstance(inst.src0, int) else str(inst.src0)
|
||||
mods = [f"dst_sel:{_SEL[inst.dst_sel]}" for _ in [1] if inst.dst_sel != 6] + [f"dst_unused:{_UNUSED[inst.dst_u]}" for _ in [1] if inst.dst_u] + [f"src0_sel:{_SEL[inst.src0_sel]}" for _ in [1] if inst.src0_sel != 6]
|
||||
return f"{name}_sdwa v{inst.vdst}, {src}" + (" " + " ".join(mods) if mods else "")
|
||||
# SDWA format: vop2_op=63 -> VOP1, vop2_op=62 -> VOPC, vop2_op=0-61 -> VOP2
|
||||
vop2_op = inst.vop2_op
|
||||
src0 = _sdwa_src0(inst.src0, inst.s0, inst.src0_sext, inst.src0_neg, inst.src0_abs)
|
||||
clamp = " clamp" if inst.clmp else ""
|
||||
omod = _OMOD_SDWA.get(inst.omod, "")
|
||||
if vop2_op == 63: # VOP1
|
||||
try: name = CDNA_VOP1Op(inst.vop_op).name.lower()
|
||||
except ValueError: name = f"vop1_op_{inst.vop_op}"
|
||||
dst = f"v{inst.vdst}"
|
||||
mods = [f"dst_sel:{_SEL[inst.dst_sel]}", f"dst_unused:{_UNUSED[inst.dst_u]}", f"src0_sel:{_SEL[inst.src0_sel]}"]
|
||||
return f"{name}_sdwa {dst}, {src0}{clamp}{omod} " + " ".join(mods)
|
||||
elif vop2_op == 62: # VOPC
|
||||
try: name = CDNA_VOPCOp(inst.vdst).name.lower() # opcode is in vdst field for VOPC SDWA
|
||||
except ValueError: name = f"vopc_op_{inst.vdst}"
|
||||
src1 = _sdwa_vsrc1(inst.vop_op, inst.src1_sext, inst.src1_neg, inst.src1_abs) # vsrc1 is in vop_op field
|
||||
# VOPC SDWA: dst encoded in byte 5 (bits 47:40): 0=vcc, 128+n=s[n:n+1]
|
||||
sdst_enc = inst.dst_sel | (inst.dst_u << 3) | (inst.clmp << 5) | (inst.omod << 6)
|
||||
if sdst_enc == 0:
|
||||
sdst = "vcc"
|
||||
else:
|
||||
sdst_val = sdst_enc - 128 if sdst_enc >= 128 else sdst_enc
|
||||
sdst = _fmt_sdst(sdst_val, 2, cdna=True)
|
||||
mods = [f"src0_sel:{_SEL[inst.src0_sel]}", f"src1_sel:{_SEL[inst.src1_sel]}"]
|
||||
return f"{name}_sdwa {sdst}, {src0}, {src1} " + " ".join(mods)
|
||||
else: # VOP2
|
||||
try: name = CDNA_VOP2Op(vop2_op).name.lower()
|
||||
except ValueError: name = f"vop2_op_{vop2_op}"
|
||||
name = _CDNA_DISASM_ALIASES.get(name, name) # apply aliases (v_fmac -> v_mac, etc.)
|
||||
dst = f"v{inst.vdst}"
|
||||
src1 = _sdwa_vsrc1(inst.vop_op, inst.src1_sext, inst.src1_neg, inst.src1_abs) # vsrc1 is in vop_op field
|
||||
mods = [f"dst_sel:{_SEL[inst.dst_sel]}", f"dst_unused:{_UNUSED[inst.dst_u]}", f"src0_sel:{_SEL[inst.src0_sel]}", f"src1_sel:{_SEL[inst.src1_sel]}"]
|
||||
# v_cndmask_b32 needs vcc as third operand
|
||||
if name == 'v_cndmask_b32':
|
||||
return f"{name}_sdwa {dst}, {src0}, {src1}, vcc{clamp}{omod} " + " ".join(mods)
|
||||
# Carry ops need vcc - v_addc/subb also need vcc as carry-in
|
||||
if name in ('v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'):
|
||||
return f"{name}_sdwa {dst}, vcc, {src0}, {src1}, vcc{clamp}{omod} " + " ".join(mods)
|
||||
if '_co_' in name:
|
||||
return f"{name}_sdwa {dst}, vcc, {src0}, {src1}{clamp}{omod} " + " ".join(mods)
|
||||
return f"{name}_sdwa {dst}, {src0}, {src1}{clamp}{omod} " + " ".join(mods)
|
||||
|
||||
def _dpp_src(v, neg=0, abs_=0):
|
||||
s = f"v{v}" if v < 256 else f"v{v - 256}"
|
||||
if abs_: s = f"|{s}|"
|
||||
return f"-{s}" if neg else s
|
||||
|
||||
def _disasm_dpp(inst) -> str:
|
||||
try: name = CDNA_VOP1Op(inst.vop_op).name.lower()
|
||||
except ValueError: name = f"vop1_op_{inst.vop_op}"
|
||||
src, ctrl = f"v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}" if isinstance(inst.src0, int) else str(inst.src0), inst.dpp_ctrl
|
||||
# DPP format: vop2_op=63 -> VOP1, vop2_op=0-62 -> VOP2
|
||||
vop2_op = inst.vop2_op
|
||||
ctrl = inst.dpp_ctrl
|
||||
dpp = f"quad_perm:[{ctrl&3},{(ctrl>>2)&3},{(ctrl>>4)&3},{(ctrl>>6)&3}]" if ctrl < 0x100 else f"row_shl:{ctrl&0xf}" if ctrl < 0x110 else f"row_shr:{ctrl&0xf}" if ctrl < 0x120 else f"row_ror:{ctrl&0xf}" if ctrl < 0x130 else _DPP.get(ctrl, f"dpp_ctrl:0x{ctrl:x}")
|
||||
mods = [dpp] + [f"row_mask:0x{inst.row_mask:x}" for _ in [1] if inst.row_mask != 0xf] + [f"bank_mask:0x{inst.bank_mask:x}" for _ in [1] if inst.bank_mask != 0xf] + ["bound_ctrl:1" for _ in [1] if inst.bound_ctrl]
|
||||
return f"{name}_dpp v{inst.vdst}, {src} " + " ".join(mods)
|
||||
src0 = _dpp_src(inst.src0, inst.src0_neg, inst.src0_abs)
|
||||
# DPP modifiers: row_mask and bank_mask always shown, bound_ctrl:0 when bit=1
|
||||
mods = [dpp, f"row_mask:0x{inst.row_mask:x}", f"bank_mask:0x{inst.bank_mask:x}"] + (["bound_ctrl:0"] if inst.bound_ctrl else [])
|
||||
if vop2_op == 63: # VOP1
|
||||
try: name = CDNA_VOP1Op(inst.vop_op).name.lower()
|
||||
except ValueError: name = f"vop1_op_{inst.vop_op}"
|
||||
return f"{name}_dpp v{inst.vdst}, {src0} " + " ".join(mods)
|
||||
else: # VOP2
|
||||
try: name = CDNA_VOP2Op(vop2_op).name.lower()
|
||||
except ValueError: name = f"vop2_op_{vop2_op}"
|
||||
name = _CDNA_DISASM_ALIASES.get(name, name)
|
||||
src1 = _dpp_src(inst.vop_op, inst.src1_neg, inst.src1_abs) # vsrc1 is in vop_op field
|
||||
if name == 'v_cndmask_b32':
|
||||
return f"{name}_dpp v{inst.vdst}, {src0}, {src1}, vcc " + " ".join(mods)
|
||||
if name in ('v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'):
|
||||
return f"{name}_dpp v{inst.vdst}, vcc, {src0}, {src1}, vcc " + " ".join(mods)
|
||||
if '_co_' in name:
|
||||
return f"{name}_dpp v{inst.vdst}, vcc, {src0}, {src1} " + " ".join(mods)
|
||||
return f"{name}_dpp v{inst.vdst}, {src0}, {src1} " + " ".join(mods)
|
||||
|
||||
# Register CDNA handlers - shared formats use merged disassemblers, CDNA-only formats use dedicated ones
|
||||
DISASM_HANDLERS.update({CDNA_VOP1: _disasm_vop1, CDNA_VOP2: _disasm_vop2, CDNA_VOPC: _disasm_vopc,
|
||||
|
||||
@@ -56,6 +56,7 @@ class MTBUF(Inst64):
|
||||
srsrc:SGPRField = bits[52:48]
|
||||
soffset:SSrc = bits[63:56]
|
||||
offset:Imm = bits[11:0]
|
||||
format = bits[25:19]
|
||||
offen = bits[12]
|
||||
idxen = bits[13]
|
||||
sc1 = bits[53]
|
||||
@@ -124,7 +125,7 @@ class SMEM(Inst64):
|
||||
sbase:SGPRField = bits[5:0]
|
||||
soffset:SSrc = bits[63:57]
|
||||
offset:Imm = bits[52:32]
|
||||
glc = bits[14]
|
||||
glc = bits[16]
|
||||
soe = bits[14]
|
||||
nv = bits[15]
|
||||
imm = bits[17]
|
||||
|
||||
@@ -1594,3 +1594,68 @@ class VOPDOp(IntEnum):
|
||||
V_DUAL_ADD_NC_U32 = 16
|
||||
V_DUAL_LSHLREV_B32 = 17
|
||||
V_DUAL_AND_B32 = 18
|
||||
|
||||
class BufFmt(IntEnum):
|
||||
BUF_FMT_8_UNORM = 1
|
||||
BUF_FMT_8_SNORM = 2
|
||||
BUF_FMT_8_USCALED = 3
|
||||
BUF_FMT_8_SSCALED = 4
|
||||
BUF_FMT_8_UINT = 5
|
||||
BUF_FMT_8_SINT = 6
|
||||
BUF_FMT_16_UNORM = 7
|
||||
BUF_FMT_16_SNORM = 8
|
||||
BUF_FMT_16_USCALED = 9
|
||||
BUF_FMT_16_SSCALED = 10
|
||||
BUF_FMT_16_UINT = 11
|
||||
BUF_FMT_16_SINT = 12
|
||||
BUF_FMT_16_FLOAT = 13
|
||||
BUF_FMT_8_8_UNORM = 14
|
||||
BUF_FMT_8_8_SNORM = 15
|
||||
BUF_FMT_8_8_USCALED = 16
|
||||
BUF_FMT_8_8_SSCALED = 17
|
||||
BUF_FMT_8_8_UINT = 18
|
||||
BUF_FMT_8_8_SINT = 19
|
||||
BUF_FMT_32_UINT = 20
|
||||
BUF_FMT_32_SINT = 21
|
||||
BUF_FMT_32_FLOAT = 22
|
||||
BUF_FMT_16_16_UNORM = 23
|
||||
BUF_FMT_16_16_SNORM = 24
|
||||
BUF_FMT_16_16_USCALED = 25
|
||||
BUF_FMT_16_16_SSCALED = 26
|
||||
BUF_FMT_16_16_UINT = 27
|
||||
BUF_FMT_16_16_SINT = 28
|
||||
BUF_FMT_16_16_FLOAT = 29
|
||||
BUF_FMT_10_11_11_FLOAT = 30
|
||||
BUF_FMT_11_11_10_FLOAT = 31
|
||||
BUF_FMT_10_10_10_2_UNORM = 32
|
||||
BUF_FMT_10_10_10_2_SNORM = 33
|
||||
BUF_FMT_10_10_10_2_UINT = 34
|
||||
BUF_FMT_10_10_10_2_SINT = 35
|
||||
BUF_FMT_2_10_10_10_UNORM = 36
|
||||
BUF_FMT_2_10_10_10_SNORM = 37
|
||||
BUF_FMT_2_10_10_10_USCALED = 38
|
||||
BUF_FMT_2_10_10_10_SSCALED = 39
|
||||
BUF_FMT_2_10_10_10_UINT = 40
|
||||
BUF_FMT_2_10_10_10_SINT = 41
|
||||
BUF_FMT_8_8_8_8_UNORM = 42
|
||||
BUF_FMT_8_8_8_8_SNORM = 43
|
||||
BUF_FMT_8_8_8_8_USCALED = 44
|
||||
BUF_FMT_8_8_8_8_SSCALED = 45
|
||||
BUF_FMT_8_8_8_8_UINT = 46
|
||||
BUF_FMT_8_8_8_8_SINT = 47
|
||||
BUF_FMT_32_32_UINT = 48
|
||||
BUF_FMT_32_32_SINT = 49
|
||||
BUF_FMT_32_32_FLOAT = 50
|
||||
BUF_FMT_16_16_16_16_UNORM = 51
|
||||
BUF_FMT_16_16_16_16_SNORM = 52
|
||||
BUF_FMT_16_16_16_16_USCALED = 53
|
||||
BUF_FMT_16_16_16_16_SSCALED = 54
|
||||
BUF_FMT_16_16_16_16_UINT = 55
|
||||
BUF_FMT_16_16_16_16_SINT = 56
|
||||
BUF_FMT_16_16_16_16_FLOAT = 57
|
||||
BUF_FMT_32_32_32_UINT = 58
|
||||
BUF_FMT_32_32_32_SINT = 59
|
||||
BUF_FMT_32_32_32_FLOAT = 60
|
||||
BUF_FMT_32_32_32_32_UINT = 61
|
||||
BUF_FMT_32_32_32_32_SINT = 62
|
||||
BUF_FMT_32_32_32_32_FLOAT = 63
|
||||
|
||||
@@ -1627,3 +1627,52 @@ class VSCRATCHOp(IntEnum):
|
||||
SCRATCH_STORE_D16_HI_B16 = 37
|
||||
SCRATCH_LOAD_BLOCK = 83
|
||||
SCRATCH_STORE_BLOCK = 84
|
||||
|
||||
class BufFmt(IntEnum):
|
||||
BUF_FMT_8_UNORM = 1
|
||||
BUF_FMT_8_SNORM = 2
|
||||
BUF_FMT_8_USCALED = 3
|
||||
BUF_FMT_8_SSCALED = 4
|
||||
BUF_FMT_8_UINT = 5
|
||||
BUF_FMT_8_SINT = 6
|
||||
BUF_FMT_16_UNORM = 7
|
||||
BUF_FMT_16_SNORM = 8
|
||||
BUF_FMT_16_USCALED = 9
|
||||
BUF_FMT_16_SSCALED = 10
|
||||
BUF_FMT_16_UINT = 11
|
||||
BUF_FMT_16_SINT = 12
|
||||
BUF_FMT_16_FLOAT = 13
|
||||
BUF_FMT_8_8_UNORM = 14
|
||||
BUF_FMT_8_8_SNORM = 15
|
||||
BUF_FMT_8_8_USCALED = 16
|
||||
BUF_FMT_8_8_SSCALED = 17
|
||||
BUF_FMT_8_8_UINT = 18
|
||||
BUF_FMT_8_8_SINT = 19
|
||||
BUF_FMT_32_UINT = 20
|
||||
BUF_FMT_32_SINT = 21
|
||||
BUF_FMT_32_FLOAT = 22
|
||||
BUF_FMT_16_16_UNORM = 23
|
||||
BUF_FMT_10_10_10_2_UNORM = 32
|
||||
BUF_FMT_10_10_10_2_SNORM = 33
|
||||
BUF_FMT_10_10_10_2_UINT = 34
|
||||
BUF_FMT_10_10_10_2_SINT = 35
|
||||
BUF_FMT_2_10_10_10_UNORM = 36
|
||||
BUF_FMT_2_10_10_10_SNORM = 37
|
||||
BUF_FMT_2_10_10_10_USCALED = 38
|
||||
BUF_FMT_2_10_10_10_SSCALED = 39
|
||||
BUF_FMT_2_10_10_10_UINT = 40
|
||||
BUF_FMT_2_10_10_10_SINT = 41
|
||||
BUF_FMT_8_8_8_8_UNORM = 42
|
||||
BUF_FMT_8_8_8_8_SNORM = 43
|
||||
BUF_FMT_8_8_8_8_USCALED = 44
|
||||
BUF_FMT_8_8_8_8_SSCALED = 45
|
||||
BUF_FMT_8_8_8_8_UINT = 46
|
||||
BUF_FMT_8_8_8_8_SINT = 47
|
||||
BUF_FMT_32_32_UINT = 48
|
||||
BUF_FMT_32_32_SINT = 49
|
||||
BUF_FMT_32_32_FLOAT = 50
|
||||
BUF_FMT_16_16_16_16_UNORM = 51
|
||||
BUF_FMT_16_16_16_16_SNORM = 52
|
||||
BUF_FMT_16_16_16_16_USCALED = 53
|
||||
BUF_FMT_16_16_16_16_SSCALED = 54
|
||||
BUF_FMT_16_16_16_16_UINT = 55
|
||||
|
||||
@@ -7,6 +7,7 @@ from functools import cache
|
||||
from typing import overload, Annotated, TypeVar, Generic
|
||||
from extra.assembly.amd.autogen.rdna3.enum import (VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, VOPDOp, SOP1Op, SOP2Op,
|
||||
SOPCOp, SOPKOp, SOPPOp, SMEMOp, DSOp, FLATOp, MUBUFOp, MTBUFOp, MIMGOp, VINTERPOp)
|
||||
from extra.assembly.amd.autogen.cdna.enum import VOP1Op as CDNA_VOP1Op, VOP2Op as CDNA_VOP2Op
|
||||
|
||||
# Common masks and bit conversion functions
|
||||
MASK32, MASK64, MASK128 = 0xffffffff, 0xffffffffffffffff, (1 << 128) - 1
|
||||
@@ -46,12 +47,14 @@ def _i64(f):
|
||||
# Instruction spec - register counts and dtypes derived from instruction names
|
||||
_REGS = {'B32': 1, 'B64': 2, 'B96': 3, 'B128': 4, 'B256': 8, 'B512': 16,
|
||||
'F32': 1, 'I32': 1, 'U32': 1, 'F64': 2, 'I64': 2, 'U64': 2,
|
||||
'F16': 1, 'I16': 1, 'U16': 1, 'B16': 1, 'I8': 1, 'U8': 1, 'B8': 1}
|
||||
'F16': 1, 'I16': 1, 'U16': 1, 'B16': 1, 'I8': 1, 'U8': 1, 'B8': 1,
|
||||
'DWORD': 1, 'DWORDX2': 2, 'DWORDX3': 3, 'DWORDX4': 4, 'DWORDX8': 8, 'DWORDX16': 16,
|
||||
'BYTE': 1, 'SHORT': 1, 'UBYTE': 1, 'SBYTE': 1, 'USHORT': 1, 'SSHORT': 1}
|
||||
_CVT_RE = re.compile(r'CVT_([FIUB]\d+)_([FIUB]\d+)$')
|
||||
_MAD_MUL_RE = re.compile(r'(?:MAD|MUL)_([IU]\d+)_([IU]\d+)$')
|
||||
_PACK_RE = re.compile(r'PACK_([FIUB]\d+)_([FIUB]\d+)$')
|
||||
_DST_SRC_RE = re.compile(r'_([FIUB]\d+)_([FIUB]\d+)$')
|
||||
_SINGLE_RE = re.compile(r'_([FIUB](?:32|64|16|8|96|128|256|512))$')
|
||||
_SINGLE_RE = re.compile(r'_([FIUB](?:32|64|16|8|96|128|256|512)|DWORD(?:X(?:2|3|4|8|16))?|[US]?BYTE|[US]?SHORT)$')
|
||||
@cache
|
||||
def _suffix(name: str) -> tuple[str | None, str | None]:
|
||||
name = name.upper()
|
||||
@@ -242,7 +245,11 @@ def unwrap(val) -> int:
|
||||
FLOAT_ENC = {0.5: 240, -0.5: 241, 1.0: 242, -1.0: 243, 2.0: 244, -2.0: 245, 4.0: 246, -4.0: 247}
|
||||
FLOAT_DEC = {v: str(k) for k, v in FLOAT_ENC.items()}
|
||||
SPECIAL_GPRS = {106: "vcc_lo", 107: "vcc_hi", 124: "null", 125: "m0", 126: "exec_lo", 127: "exec_hi", 253: "scc"}
|
||||
SPECIAL_GPRS_CDNA = {102: "flat_scratch_lo", 103: "flat_scratch_hi", 104: "xnack_mask_lo", 105: "xnack_mask_hi",
|
||||
106: "vcc_lo", 107: "vcc_hi", 124: "m0", 126: "exec_lo", 127: "exec_hi",
|
||||
251: "src_vccz", 252: "src_execz", 253: "src_scc", 254: "src_lds_direct"}
|
||||
SPECIAL_PAIRS = {106: "vcc", 126: "exec"}
|
||||
SPECIAL_PAIRS_CDNA = {102: "flat_scratch", 104: "xnack_mask", 106: "vcc", 126: "exec"}
|
||||
SRC_FIELDS = {'src0', 'src1', 'src2', 'ssrc0', 'ssrc1', 'soffset', 'srcx0', 'srcy0'}
|
||||
RAW_FIELDS = {'vdata', 'vdst', 'vaddr', 'addr', 'data', 'data0', 'data1', 'sdst', 'sdata', 'vsrc1'}
|
||||
|
||||
@@ -259,9 +266,10 @@ def encode_src(val) -> int:
|
||||
if isinstance(val, int): return 128 + val if 0 <= val <= 64 else 192 - val if -16 <= val <= -1 else 255
|
||||
return 255
|
||||
|
||||
def decode_src(val: int) -> str:
|
||||
def decode_src(val: int, cdna: bool = False) -> str:
|
||||
special = SPECIAL_GPRS_CDNA if cdna else SPECIAL_GPRS
|
||||
if val in special: return special[val]
|
||||
if val <= 105: return f"s{val}"
|
||||
if val in SPECIAL_GPRS: return SPECIAL_GPRS[val]
|
||||
if val in FLOAT_DEC: return FLOAT_DEC[val]
|
||||
if 108 <= val <= 123: return f"ttmp{val - 108}"
|
||||
if 128 <= val <= 192: return str(val - 128)
|
||||
@@ -385,14 +393,14 @@ class Inst:
|
||||
if name in SRC_FIELDS: self._encode_src(name, val)
|
||||
elif name in RAW_FIELDS: self._encode_raw(name, val)
|
||||
elif name == 'sbase': self._values[name] = (val.idx if isinstance(val, Reg) else val.val if isinstance(val, SrcMod) else val * 2) // 2
|
||||
elif name in {'srsrc', 'ssamp'} and isinstance(val, Reg): self._values[name] = val.idx // 4
|
||||
elif name in {'srsrc', 'ssamp'} and isinstance(val, Reg): self._values[name] = _encode_reg(val) // 4
|
||||
elif marker is _VDSTYEnc and isinstance(val, VGPR): self._values[name] = val.idx >> 1
|
||||
self._precompute_fields()
|
||||
|
||||
def _encode_field(self, name: str, val) -> int:
|
||||
if isinstance(val, RawImm): return val.val
|
||||
if isinstance(val, SrcMod) and not isinstance(val, Reg): return val.val # Special regs like VCC_LO
|
||||
if name in {'srsrc', 'ssamp'}: return val.idx // 4 if isinstance(val, Reg) else val
|
||||
if name in {'srsrc', 'ssamp'}: return _encode_reg(val) // 4 if isinstance(val, Reg) else val
|
||||
if name == 'sbase': return val.idx // 2 if isinstance(val, Reg) else val.val // 2 if isinstance(val, SrcMod) else val
|
||||
if name in RAW_FIELDS: return _encode_reg(val) if isinstance(val, Reg) else val
|
||||
if isinstance(val, Reg) or name in SRC_FIELDS: return encode_src(val)
|
||||
@@ -506,7 +514,7 @@ class Inst:
|
||||
lit32 = (self._literal >> 32) if self._literal > 0xffffffff else self._literal
|
||||
s = f"0x{lit32:x}"
|
||||
else:
|
||||
s = decode_src(v)
|
||||
s = decode_src(v, 'cdna' in self.__class__.__module__)
|
||||
return f"-{s}" if neg else s
|
||||
|
||||
def __eq__(self, other):
|
||||
@@ -532,21 +540,30 @@ class Inst:
|
||||
elif hasattr(val, 'name'): self.op = val
|
||||
else:
|
||||
cls_name = self.__class__.__name__
|
||||
# VOP3 with VOPC opcodes (0-255) -> VOPCOp, VOP3SD opcodes -> VOP3SDOp
|
||||
if cls_name == 'VOP3':
|
||||
try:
|
||||
if val < 256: self.op = VOPCOp(val)
|
||||
elif val in self._VOP3SD_OPS: self.op = VOP3SDOp(val)
|
||||
else: self.op = VOP3Op(val)
|
||||
except ValueError: self.op = val
|
||||
# Prefer BitField marker (class-specific enum) over _enum_map (generic RDNA3 enums)
|
||||
elif 'op' in self._fields and (marker := self._fields['op'].marker) and issubclass(marker, IntEnum):
|
||||
is_cdna = cls_name in ('VOP3A', 'VOP3B')
|
||||
# Try marker enum first (VOP3AOp, VOP3BOp, etc.)
|
||||
marker = self._fields['op'].marker if 'op' in self._fields else None
|
||||
if marker and issubclass(marker, IntEnum):
|
||||
try: self.op = marker(val)
|
||||
except ValueError: self.op = val
|
||||
elif cls_name in self._enum_map:
|
||||
try: self.op = self._enum_map[cls_name](val)
|
||||
except ValueError: self.op = val
|
||||
else: self.op = val
|
||||
# Fallback for promoted instructions when marker lookup failed
|
||||
if not hasattr(self.op, 'name') and cls_name in ('VOP3', 'VOP3A', 'VOP3B') and isinstance(val, int):
|
||||
if val < 256:
|
||||
try: self.op = VOPCOp(val)
|
||||
except ValueError: pass
|
||||
elif is_cdna and 256 <= val < 512:
|
||||
try: self.op = (CDNA_VOP1Op(val - 320) if val >= 320 else CDNA_VOP2Op(val - 256))
|
||||
except ValueError: pass
|
||||
elif val in self._VOP3SD_OPS and not is_cdna:
|
||||
try: self.op = VOP3SDOp(val)
|
||||
except ValueError: pass
|
||||
elif 256 <= val < 512 and not is_cdna:
|
||||
try: self.op = VOP1Op(val - 384) if val >= 384 else VOP2Op(val - 256)
|
||||
except ValueError: pass
|
||||
self.op_name = self.op.name if hasattr(self.op, 'name') else ''
|
||||
self._spec_regs = spec_regs(self.op_name)
|
||||
self._spec_dtype = spec_dtype(self.op_name)
|
||||
|
||||
@@ -291,14 +291,16 @@ def exec_vop(st: WaveState, inst: Inst, V: list, lane: int) -> None:
|
||||
extra_kwargs = {'opsel': opsel, 'opsel_hi': inst.opsel_hi | (inst.opsel_hi2 << 2)} if isinstance(inst, VOP3P) and 'FMA_MIX' in inst.op_name else {}
|
||||
result = inst._fn(s0, s1, s2, d0, st.scc, vcc_for_fn, lane, st.exec_mask, inst._literal, st.vgpr, src0_idx, vdst, **extra_kwargs)
|
||||
|
||||
# Check if this is a VOPC instruction (either standalone VOPC or VOP3 with VOPC opcode)
|
||||
is_vopc = isinstance(inst.op, VOPCOp) or (isinstance(inst, VOP3) and inst.op.value < 256)
|
||||
if 'VCC' in result:
|
||||
if isinstance(inst, VOP3SD): st.pend_sgpr_lane(inst.sdst, lane, (result['VCC'] >> lane) & 1)
|
||||
else: st.pend_sgpr_lane(VCC_LO if isinstance(inst, VOP2) and 'CO_CI' in inst.op_name else vdst, lane, (result['VCC'] >> lane) & 1)
|
||||
if 'EXEC' in result:
|
||||
st.pend_sgpr_lane(EXEC_LO, lane, (result['EXEC'] >> lane) & 1)
|
||||
elif isinstance(inst.op, VOPCOp):
|
||||
elif is_vopc:
|
||||
st.pend_sgpr_lane(vdst, lane, (result['D0'] >> lane) & 1)
|
||||
if not isinstance(inst.op, VOPCOp):
|
||||
if not is_vopc:
|
||||
d0_val = result['D0']
|
||||
if inst.dst_regs() == 2: V[vdst], V[vdst + 1] = d0_val & MASK32, (d0_val >> 32) & MASK32
|
||||
elif not isinstance(inst, VOP3P) and inst.is_dst_16(): V[vdst] = _dst16(V[vdst], d0_val, bool(opsel & 8) if isinstance(inst, VOP3) else dst_hi)
|
||||
|
||||
@@ -185,8 +185,8 @@ def _parse_single_pdf(url: str):
|
||||
break
|
||||
formats[fmt_name] = fields
|
||||
|
||||
# Fix known PDF errors
|
||||
if 'SMEM' in formats:
|
||||
# Fix known PDF errors (RDNA-specific SMEM bit positions)
|
||||
if 'SMEM' in formats and not is_cdna:
|
||||
formats['SMEM'] = [(n, 13 if n == 'DLC' else 14 if n == 'GLC' else h, 13 if n == 'DLC' else 14 if n == 'GLC' else l, e, t)
|
||||
for n, h, l, e, t in formats['SMEM']]
|
||||
# RDNA4: VFLAT/VGLOBAL/VSCRATCH OP field is [20:14] not [20:13] (PDF documentation error)
|
||||
@@ -209,6 +209,11 @@ def _parse_single_pdf(url: str):
|
||||
if 'FLATOp' in enums:
|
||||
for k, v in {40: 'GLOBAL_LOAD_ADDTID_B32', 41: 'GLOBAL_STORE_ADDTID_B32', 55: 'FLAT_ATOMIC_CSUB_U32'}.items():
|
||||
assert k not in enums['FLATOp']; enums['FLATOp'][k] = v
|
||||
# CDNA MTBUF: PDF is missing the FORMAT field (bits[25:19]) which is required for tbuffer_* instructions
|
||||
if is_cdna and 'MTBUF' in formats:
|
||||
field_names = {f[0] for f in formats['MTBUF']}
|
||||
if 'FORMAT' not in field_names:
|
||||
formats['MTBUF'].append(('FORMAT', 25, 19, None, None))
|
||||
# CDNA SDWA/DPP: PDF only has modifier fields, need VOP1/VOP2 overlay for correct encoding
|
||||
if is_cdna:
|
||||
if 'SDWA' in formats:
|
||||
@@ -229,7 +234,20 @@ def _parse_single_pdf(url: str):
|
||||
snippet = all_text[start:end].strip()
|
||||
if pseudocode := _extract_pseudocode(snippet): raw_pseudocode[(name, opcode)] = pseudocode
|
||||
|
||||
return {"formats": formats, "enums": enums, "src_enum": src_enum, "doc_name": doc_name, "pseudocode": raw_pseudocode, "is_cdna": is_cdna}
|
||||
# Extract unified buffer format table (RDNA only, for MTBUF format field)
|
||||
buf_fmt = {}
|
||||
if not is_cdna:
|
||||
for i in range(total_pages):
|
||||
for t in pdf.tables(i):
|
||||
if t and len(t) > 2 and t[0] and '#' in str(t[0][0]) and 'Format' in str(t[0]):
|
||||
for row in t[1:]:
|
||||
for j in range(0, len(row) - 1, 3): # table has 3-column groups: #, Format, (empty)
|
||||
if row[j] and row[j].isdigit() and row[j+1] and re.match(r'^[\d_]+_(UNORM|SNORM|USCALED|SSCALED|UINT|SINT|FLOAT)$', row[j+1]):
|
||||
buf_fmt[int(row[j])] = row[j+1]
|
||||
if buf_fmt: break
|
||||
if buf_fmt: break
|
||||
|
||||
return {"formats": formats, "enums": enums, "src_enum": src_enum, "doc_name": doc_name, "pseudocode": raw_pseudocode, "is_cdna": is_cdna, "buf_fmt": buf_fmt}
|
||||
|
||||
def _extract_pseudocode(text: str) -> str | None:
|
||||
"""Extract pseudocode from an instruction description snippet."""
|
||||
@@ -258,7 +276,7 @@ def _extract_pseudocode(text: str) -> str | None:
|
||||
|
||||
def _merge_results(results: list[dict]) -> dict:
|
||||
"""Merge multiple PDF parse results into a superset."""
|
||||
merged = {"formats": {}, "enums": {}, "src_enum": dict(SRC_EXTRAS), "doc_names": [], "pseudocode": {}, "is_cdna": False}
|
||||
merged = {"formats": {}, "enums": {}, "src_enum": dict(SRC_EXTRAS), "doc_names": [], "pseudocode": {}, "is_cdna": False, "buf_fmt": {}}
|
||||
for r in results:
|
||||
merged["doc_names"].append(r["doc_name"])
|
||||
merged["is_cdna"] = merged["is_cdna"] or r["is_cdna"]
|
||||
@@ -279,17 +297,20 @@ def _merge_results(results: list[dict]) -> dict:
|
||||
else: merged["formats"][fmt_name].append(f)
|
||||
for key, pc in r["pseudocode"].items():
|
||||
if key not in merged["pseudocode"]: merged["pseudocode"][key] = pc
|
||||
for val, name in r.get("buf_fmt", {}).items():
|
||||
if val not in merged["buf_fmt"]: merged["buf_fmt"][val] = name
|
||||
return merged
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# CODE GENERATION
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def _generate_enum_py(enums, src_enum, doc_name) -> str:
|
||||
def _generate_enum_py(enums, src_enum, doc_name, buf_fmt=None) -> str:
|
||||
"""Generate enum.py content (just enums, no dsl.py dependency)."""
|
||||
def enum_lines(name, items): return [f"class {name}(IntEnum):"] + [f" {n} = {v}" for v, n in sorted(items.items())] + [""]
|
||||
lines = [f"# autogenerated from AMD {doc_name} ISA PDF by pdf.py - do not edit", "from enum import IntEnum", ""]
|
||||
lines += enum_lines("SrcEnum", src_enum) + sum([enum_lines(n, ops) for n, ops in sorted(enums.items())], [])
|
||||
if buf_fmt: lines += enum_lines("BufFmt", {v: f"BUF_FMT_{n}" for v, n in buf_fmt.items() if 1 <= v <= 63})
|
||||
return '\n'.join(lines)
|
||||
|
||||
def _generate_ins_py(formats, enums, src_enum, doc_name) -> str:
|
||||
@@ -398,9 +419,10 @@ def generate_arch(arch: str) -> dict:
|
||||
|
||||
# Write enum.py (enums only, no dsl.py dependency)
|
||||
enum_path = base_path / "enum.py"
|
||||
enum_content = _generate_enum_py(merged["enums"], merged["src_enum"], doc_name)
|
||||
enum_content = _generate_enum_py(merged["enums"], merged["src_enum"], doc_name, merged.get("buf_fmt"))
|
||||
enum_path.write_text(enum_content)
|
||||
print(f"Generated {enum_path}: SrcEnum ({len(merged['src_enum'])}) + {len(merged['enums'])} enums")
|
||||
buf_fmt_count = len([v for v in merged.get("buf_fmt", {}) if 1 <= v <= 63])
|
||||
print(f"Generated {enum_path}: SrcEnum ({len(merged['src_enum'])}) + {len(merged['enums'])} enums" + (f" + BufFmt ({buf_fmt_count})" if buf_fmt_count else ""))
|
||||
|
||||
# Write ins.py (instruction formats and helpers, imports dsl.py and enum.py)
|
||||
ins_path = base_path / "ins.py"
|
||||
|
||||
@@ -1,192 +1,102 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Test RDNA3 assembler/disassembler against LLVM test vectors."""
|
||||
import unittest, re, subprocess
|
||||
"""Test AMD assembler/disassembler against LLVM test vectors."""
|
||||
import unittest, re, subprocess, functools
|
||||
from tinygrad.helpers import fetch
|
||||
from extra.assembly.amd.autogen.rdna3.ins import *
|
||||
from extra.assembly.amd.asm import asm
|
||||
from extra.assembly.amd.asm import asm, disasm, detect_format
|
||||
from extra.assembly.amd.test.helpers import get_llvm_mc
|
||||
|
||||
LLVM_BASE = "https://raw.githubusercontent.com/llvm/llvm-project/main/llvm/test/MC/AMDGPU"
|
||||
|
||||
# Format info: (filename, format_class, op_enum)
|
||||
LLVM_TEST_FILES = {
|
||||
# Scalar ALU
|
||||
'sop1': ('gfx11_asm_sop1.s', SOP1, SOP1Op),
|
||||
'sop2': ('gfx11_asm_sop2.s', SOP2, SOP2Op),
|
||||
'sopp': ('gfx11_asm_sopp.s', SOPP, SOPPOp),
|
||||
'sopk': ('gfx11_asm_sopk.s', SOPK, SOPKOp),
|
||||
'sopc': ('gfx11_asm_sopc.s', SOPC, SOPCOp),
|
||||
# Vector ALU
|
||||
'vop1': ('gfx11_asm_vop1.s', VOP1, VOP1Op),
|
||||
'vop2': ('gfx11_asm_vop2.s', VOP2, VOP2Op),
|
||||
'vopc': ('gfx11_asm_vopc.s', VOPC, VOPCOp),
|
||||
'vop3': ('gfx11_asm_vop3.s', VOP3, VOP3Op),
|
||||
'vop3p': ('gfx11_asm_vop3p.s', VOP3P, VOP3POp),
|
||||
'vop3sd': ('gfx11_asm_vop3.s', VOP3SD, VOP3SDOp), # VOP3SD shares file with VOP3
|
||||
'vinterp': ('gfx11_asm_vinterp.s', VINTERP, VINTERPOp),
|
||||
'vopd': ('gfx11_asm_vopd.s', VOPD, VOPDOp),
|
||||
'vopcx': ('gfx11_asm_vopcx.s', VOPC, VOPCOp), # VOPCX uses VOPC format
|
||||
# VOP3 promotions (VOP1/VOP2/VOPC promoted to VOP3 encoding)
|
||||
'vop3_from_vop1': ('gfx11_asm_vop3_from_vop1.s', VOP3, VOP3Op),
|
||||
'vop3_from_vop2': ('gfx11_asm_vop3_from_vop2.s', VOP3, VOP3Op),
|
||||
'vop3_from_vopc': ('gfx11_asm_vop3_from_vopc.s', VOP3, VOP3Op),
|
||||
'vop3_from_vopcx': ('gfx11_asm_vop3_from_vopcx.s', VOP3, VOP3Op),
|
||||
# Memory
|
||||
'ds': ('gfx11_asm_ds.s', DS, DSOp),
|
||||
'smem': ('gfx11_asm_smem.s', SMEM, SMEMOp),
|
||||
'flat': ('gfx11_asm_flat.s', FLAT, FLATOp),
|
||||
'mubuf': ('gfx11_asm_mubuf.s', MUBUF, MUBUFOp),
|
||||
'mtbuf': ('gfx11_asm_mtbuf.s', MTBUF, MTBUFOp),
|
||||
'mimg': ('gfx11_asm_mimg.s', MIMG, MIMGOp),
|
||||
# WMMA (matrix multiply)
|
||||
'wmma': ('gfx11_asm_wmma.s', VOP3P, VOP3POp),
|
||||
# Additional features
|
||||
'vop3_features': ('gfx11_asm_vop3_features.s', VOP3, VOP3Op),
|
||||
'vop3p_features': ('gfx11_asm_vop3p_features.s', VOP3P, VOP3POp),
|
||||
'vopd_features': ('gfx11_asm_vopd_features.s', VOPD, VOPDOp),
|
||||
# Alias files (alternative mnemonics)
|
||||
'vop3_alias': ('gfx11_asm_vop3_alias.s', VOP3, VOP3Op),
|
||||
'vop3p_alias': ('gfx11_asm_vop3p_alias.s', VOP3P, VOP3POp),
|
||||
'vopc_alias': ('gfx11_asm_vopc_alias.s', VOPC, VOPCOp),
|
||||
'vopcx_alias': ('gfx11_asm_vopcx_alias.s', VOPC, VOPCOp),
|
||||
'vinterp_alias': ('gfx11_asm_vinterp_alias.s', VINTERP, VINTERPOp),
|
||||
'smem_alias': ('gfx11_asm_smem_alias.s', SMEM, SMEMOp),
|
||||
'mubuf_alias': ('gfx11_asm_mubuf_alias.s', MUBUF, MUBUFOp),
|
||||
'mtbuf_alias': ('gfx11_asm_mtbuf_alias.s', MTBUF, MTBUFOp),
|
||||
}
|
||||
RDNA_FILES = ['gfx11_asm_sop1.s', 'gfx11_asm_sop2.s', 'gfx11_asm_sopp.s', 'gfx11_asm_sopk.s', 'gfx11_asm_sopc.s',
|
||||
'gfx11_asm_vop1.s', 'gfx11_asm_vop2.s', 'gfx11_asm_vopc.s', 'gfx11_asm_vop3.s', 'gfx11_asm_vop3p.s', 'gfx11_asm_vinterp.s',
|
||||
'gfx11_asm_vopd.s', 'gfx11_asm_vopcx.s', 'gfx11_asm_vop3_from_vop1.s', 'gfx11_asm_vop3_from_vop2.s', 'gfx11_asm_vop3_from_vopc.s',
|
||||
'gfx11_asm_vop3_from_vopcx.s', 'gfx11_asm_ds.s', 'gfx11_asm_smem.s', 'gfx11_asm_flat.s', 'gfx11_asm_mubuf.s', 'gfx11_asm_mtbuf.s',
|
||||
'gfx11_asm_mimg.s', 'gfx11_asm_wmma.s', 'gfx11_asm_vop3_features.s', 'gfx11_asm_vop3p_features.s', 'gfx11_asm_vopd_features.s',
|
||||
'gfx11_asm_vop3_alias.s', 'gfx11_asm_vop3p_alias.s', 'gfx11_asm_vopc_alias.s', 'gfx11_asm_vopcx_alias.s', 'gfx11_asm_vinterp_alias.s',
|
||||
'gfx11_asm_smem_alias.s', 'gfx11_asm_mubuf_alias.s', 'gfx11_asm_mtbuf_alias.s']
|
||||
# CDNA test files - includes gfx9 files for shared instructions, plus gfx90a/gfx942 specific files
|
||||
# gfx90a_ldst_acc.s has MIMG mixed in, filtered via is_mimg check
|
||||
CDNA_FILES = ['gfx9_asm_sop1.s', 'gfx9_asm_sop2.s', 'gfx9_asm_sopp.s', 'gfx9_asm_sopk.s', 'gfx9_asm_sopc.s',
|
||||
'gfx9_asm_vop1.s', 'gfx9_asm_vop2.s', 'gfx9_asm_vopc.s', 'gfx9_asm_vop3.s', 'gfx9_asm_vop3p.s',
|
||||
'gfx9_asm_ds.s', 'gfx9_asm_flat.s', 'gfx9_asm_smem.s', 'gfx9_asm_mubuf.s', 'gfx9_asm_mtbuf.s',
|
||||
'gfx90a_ldst_acc.s', 'gfx90a_asm_features.s', 'flat-scratch-gfx942.s', 'gfx942_asm_features.s',
|
||||
'mai-gfx90a.s', 'mai-gfx942.s']
|
||||
|
||||
def parse_llvm_tests(text: str) -> list[tuple[str, bytes]]:
|
||||
"""Parse LLVM test format into (asm, expected_bytes) pairs."""
|
||||
tests, lines = [], text.split('\n')
|
||||
for i, line in enumerate(lines):
|
||||
line = line.strip()
|
||||
if not line or line.startswith(('//', '.', ';')): continue
|
||||
asm_text = line.split('//')[0].strip()
|
||||
if not asm_text: continue
|
||||
for j in range(i, min(i + 3, len(lines))):
|
||||
# Match GFX11, W32, or W64 encodings (all valid for gfx11)
|
||||
# Format 1: "// GFX11: v_foo ... ; encoding: [0x01,0x02,...]"
|
||||
# Format 2: "// GFX11: [0x01,0x02,...]" (used by DS, older files)
|
||||
if m := re.search(r'(?:GFX11|W32|W64)[^:]*:.*?encoding:\s*\[(.*?)\]', lines[j]):
|
||||
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
|
||||
elif m := re.search(r'(?:GFX11|W32|W64)[^:]*:\s*\[(0x[0-9a-fA-F,x\s]+)\]', lines[j]):
|
||||
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
|
||||
else:
|
||||
continue
|
||||
if hex_bytes:
|
||||
try: tests.append((asm_text, bytes.fromhex(hex_bytes)))
|
||||
except ValueError: pass
|
||||
break
|
||||
def _is_mimg(data: bytes) -> bool: return (int.from_bytes(data[:4], 'little') >> 26) & 0x3f == 0b111100
|
||||
|
||||
def _parse_llvm_tests(text: str, pattern: str) -> list[tuple[str, bytes]]:
|
||||
tests = []
|
||||
for block in text.split('\n\n'):
|
||||
asm_text, encoding = None, None
|
||||
for line in block.split('\n'):
|
||||
line = line.strip()
|
||||
if not line or line.startswith(('.', ';')): continue
|
||||
if not line.startswith('//'):
|
||||
asm_text = line.split('//')[0].strip() or asm_text
|
||||
if m := re.search(pattern + r'[^:]*:.*?(?:encoding:\s*)?\[(0x[0-9a-f,x\s]+)\]', line, re.I):
|
||||
encoding = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
|
||||
if asm_text and encoding:
|
||||
try: tests.append((asm_text, bytes.fromhex(encoding)))
|
||||
except ValueError: pass
|
||||
return tests
|
||||
|
||||
def try_assemble(text: str):
|
||||
"""Try to assemble instruction text, return bytes or None on failure."""
|
||||
try: return asm(text).to_bytes()
|
||||
except: return None
|
||||
@functools.cache
|
||||
def _get_tests(f: str, arch: str) -> list[tuple[str, bytes]]:
|
||||
text = fetch(f"{LLVM_BASE}/{f}").read_bytes().decode('utf-8', errors='ignore')
|
||||
if arch == "rdna3":
|
||||
tests = _parse_llvm_tests(text, r'(?:GFX11|W32|W64)')
|
||||
elif 'gfx90a' in f or 'gfx942' in f:
|
||||
tests = _parse_llvm_tests(text, r'(?:GFX90A|GFX942)')
|
||||
else:
|
||||
tests = _parse_llvm_tests(text, r'(?:VI9|GFX9|CHECK)')
|
||||
return [(a, d) for a, d in tests if not _is_mimg(d)] if arch == "cdna" else tests
|
||||
|
||||
def compile_asm_batch(instrs: list[str]) -> list[bytes]:
|
||||
"""Compile multiple instructions with a single llvm-mc call."""
|
||||
def _compile_asm_batch(instrs: list[str]) -> list[bytes]:
|
||||
if not instrs: return []
|
||||
asm_text = ".text\n" + "\n".join(instrs) + "\n"
|
||||
result = subprocess.run(
|
||||
[get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
|
||||
input=asm_text, capture_output=True, text=True, timeout=30)
|
||||
if result.returncode != 0: raise RuntimeError(f"llvm-mc batch failed: {result.stderr.strip()}")
|
||||
# Parse all encodings from output
|
||||
results = []
|
||||
for line in result.stdout.split('\n'):
|
||||
if 'encoding:' not in line: continue
|
||||
enc = line.split('encoding:')[1].strip()
|
||||
if enc.startswith('[') and enc.endswith(']'):
|
||||
results.append(bytes.fromhex(enc[1:-1].replace('0x', '').replace(',', '').replace(' ', '')))
|
||||
if len(results) != len(instrs): raise RuntimeError(f"expected {len(instrs)} encodings, got {len(results)}")
|
||||
return results
|
||||
result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
|
||||
input=".text\n" + "\n".join(instrs) + "\n", capture_output=True, text=True, timeout=30)
|
||||
if result.returncode != 0: raise RuntimeError(f"llvm-mc failed: {result.stderr.strip()}")
|
||||
return [bytes.fromhex(line.split('encoding:')[1].strip()[1:-1].replace('0x', '').replace(',', '').replace(' ', ''))
|
||||
for line in result.stdout.split('\n') if 'encoding:' in line]
|
||||
|
||||
class TestLLVM(unittest.TestCase):
|
||||
"""Test assembler and disassembler against all LLVM test vectors."""
|
||||
tests: dict[str, list[tuple[str, bytes]]] = {}
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
for name, (filename, _, _) in LLVM_TEST_FILES.items():
|
||||
try:
|
||||
data = fetch(f"{LLVM_BASE}/{filename}").read_bytes()
|
||||
cls.tests[name] = parse_llvm_tests(data.decode('utf-8', errors='ignore'))
|
||||
except Exception as e:
|
||||
print(f"Warning: couldn't fetch {filename}: {e}")
|
||||
cls.tests[name] = []
|
||||
|
||||
# Generate test methods dynamically for each format
|
||||
def _make_asm_test(name):
|
||||
def _make_test(f: str, arch: str, test_type: str):
|
||||
def test(self):
|
||||
passed, failed, skipped = 0, 0, 0
|
||||
for asm_text, expected in self.tests.get(name, []):
|
||||
result = try_assemble(asm_text)
|
||||
if result is None: skipped += 1
|
||||
elif result == expected: passed += 1
|
||||
else: failed += 1
|
||||
print(f"{name.upper()} asm: {passed} passed, {failed} failed, {skipped} skipped")
|
||||
self.assertEqual(failed, 0)
|
||||
tests = _get_tests(f, arch)
|
||||
name = f"{arch}_{test_type}_{f}"
|
||||
if test_type == "roundtrip":
|
||||
for _, data in tests:
|
||||
decoded = detect_format(data, arch).from_bytes(data)
|
||||
self.assertEqual(decoded.to_bytes()[:len(data)], data)
|
||||
print(f"{name}: {len(tests)} passed")
|
||||
elif test_type == "asm":
|
||||
passed, skipped = 0, 0
|
||||
for asm_text, expected in tests:
|
||||
try:
|
||||
self.assertEqual(asm(asm_text).to_bytes(), expected)
|
||||
passed += 1
|
||||
except: skipped += 1
|
||||
print(f"{name}: {passed} passed, {skipped} skipped")
|
||||
elif test_type == "disasm":
|
||||
to_test = []
|
||||
for _, data in tests:
|
||||
try:
|
||||
decoded = detect_format(data, arch).from_bytes(data)
|
||||
if decoded.to_bytes()[:len(data)] == data and (d := disasm(decoded)): to_test.append((data, d))
|
||||
except: pass
|
||||
print(f"{name}: {len(to_test)} passed, {len(tests) - len(to_test)} skipped")
|
||||
if arch == "rdna3":
|
||||
for (data, _), llvm in zip(to_test, _compile_asm_batch([t[1] for t in to_test])): self.assertEqual(llvm, data)
|
||||
return test
|
||||
|
||||
def _make_disasm_test(name):
|
||||
def test(self):
|
||||
_, base_fmt_cls, base_op_enum = LLVM_TEST_FILES[name]
|
||||
# VOP3SD opcodes that share encoding with VOP3 (only for vop3sd test, not vopc promotions)
|
||||
vop3sd_opcodes = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770}
|
||||
is_vopc_promotion = name in ('vop3_from_vopc', 'vop3_from_vopcx')
|
||||
class TestLLVM(unittest.TestCase): pass
|
||||
|
||||
# First pass: decode all instructions and collect disasm strings
|
||||
to_test: list[tuple[str, bytes, str | None, str | None]] = [] # (asm_text, data, disasm_str, error)
|
||||
for asm_text, data in self.tests.get(name, []):
|
||||
# Detect VOP3 promotions in VOP1/VOP2/VOPC tests: VOP3 has bits [31:26]=0b110101 in first dword
|
||||
is_vop3_enc = name in ('vop1', 'vop2', 'vopc', 'vopcx') and len(data) >= 4 and (data[3] >> 2) == 0x35
|
||||
fmt_cls, op_enum = (VOP3, VOP3Op) if is_vop3_enc else (base_fmt_cls, base_op_enum)
|
||||
try:
|
||||
if base_fmt_cls.__name__ in ('VOP3', 'VOP3SD'):
|
||||
temp = VOP3.from_bytes(data)
|
||||
op_val = temp._values.get('op', 0)
|
||||
op_val = op_val.val if hasattr(op_val, 'val') else op_val
|
||||
is_vop3sd = (op_val in vop3sd_opcodes) and not is_vopc_promotion
|
||||
decoded = VOP3SD.from_bytes(data) if is_vop3sd else VOP3.from_bytes(data)
|
||||
if is_vop3sd: VOP3SDOp(op_val)
|
||||
else: VOP3Op(op_val)
|
||||
else:
|
||||
decoded = fmt_cls.from_bytes(data)
|
||||
op_val = decoded._values.get('op', 0)
|
||||
op_val = op_val.val if hasattr(op_val, 'val') else op_val
|
||||
op_enum(op_val)
|
||||
if decoded.to_bytes()[:len(data)] != data:
|
||||
to_test.append((asm_text, data, None, "decode roundtrip failed"))
|
||||
continue
|
||||
to_test.append((asm_text, data, decoded.disasm(), None))
|
||||
except Exception as e:
|
||||
to_test.append((asm_text, data, None, f"exception: {e}"))
|
||||
|
||||
# Batch compile all disasm strings with single llvm-mc call
|
||||
disasm_strs = [(i, t[2]) for i, t in enumerate(to_test) if t[2] is not None]
|
||||
llvm_results = compile_asm_batch([s for _, s in disasm_strs]) if disasm_strs else []
|
||||
llvm_map = {i: llvm_results[j] for j, (i, _) in enumerate(disasm_strs)}
|
||||
|
||||
# Match results back
|
||||
passed, failed = 0, 0
|
||||
failures: list[str] = []
|
||||
for idx, (asm_text, data, disasm_str, error) in enumerate(to_test):
|
||||
if error:
|
||||
failed += 1; failures.append(f"{error} for {data.hex()}")
|
||||
elif disasm_str is not None and idx in llvm_map:
|
||||
llvm_bytes = llvm_map[idx]
|
||||
if llvm_bytes is not None and llvm_bytes == data: passed += 1
|
||||
elif llvm_bytes is not None: failed += 1; failures.append(f"'{disasm_str}': expected={data.hex()} got={llvm_bytes.hex()}")
|
||||
|
||||
print(f"{name.upper()} disasm: {passed} passed, {failed} failed")
|
||||
if failures[:10]: print(" " + "\n ".join(failures[:10]))
|
||||
self.assertEqual(failed, 0)
|
||||
return test
|
||||
|
||||
for name in LLVM_TEST_FILES:
|
||||
setattr(TestLLVM, f'test_{name}_asm', _make_asm_test(name))
|
||||
setattr(TestLLVM, f'test_{name}_disasm', _make_disasm_test(name))
|
||||
for f in RDNA_FILES:
|
||||
setattr(TestLLVM, f"test_rdna3_roundtrip_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "rdna3", "roundtrip"))
|
||||
setattr(TestLLVM, f"test_rdna3_asm_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "rdna3", "asm"))
|
||||
setattr(TestLLVM, f"test_rdna3_disasm_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "rdna3", "disasm"))
|
||||
for f in CDNA_FILES:
|
||||
setattr(TestLLVM, f"test_cdna_roundtrip_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "cdna", "roundtrip"))
|
||||
setattr(TestLLVM, f"test_cdna_disasm_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "cdna", "disasm"))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,144 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Test CDNA assembler/disassembler against LLVM test vectors."""
|
||||
import unittest, re, subprocess
|
||||
from tinygrad.helpers import fetch
|
||||
from extra.assembly.amd.autogen.cdna.ins import *
|
||||
from extra.assembly.amd.asm import disasm
|
||||
from extra.assembly.amd.test.helpers import get_llvm_mc
|
||||
|
||||
LLVM_BASE = "https://raw.githubusercontent.com/llvm/llvm-project/main/llvm/test/MC/AMDGPU"
|
||||
|
||||
def parse_llvm_tests(text: str, mnemonic_filter: str = None, size_filter: int = None) -> list[tuple[str, bytes]]:
|
||||
"""Parse LLVM test format into (asm, expected_bytes) pairs."""
|
||||
tests, lines = [], text.split('\n')
|
||||
for i, line in enumerate(lines):
|
||||
line = line.strip()
|
||||
if not line or line.startswith(('//', '.', ';')): continue
|
||||
asm_text = line.split('//')[0].strip()
|
||||
if not asm_text or (mnemonic_filter and not asm_text.startswith(mnemonic_filter)): continue
|
||||
for j in list(range(max(0, i - 3), i)) + list(range(i, min(i + 3, len(lines)))):
|
||||
if m := re.search(r'(?:VI9|GFX9|CHECK)[^:]*:.*?encoding:\s*\[(.*?)\]', lines[j]):
|
||||
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
|
||||
elif m := re.search(r'CHECK[^:]*:\s*\[(0x[0-9a-fA-F,x\s]+)\]', lines[j]):
|
||||
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
|
||||
else: continue
|
||||
try:
|
||||
data = bytes.fromhex(hex_bytes)
|
||||
if size_filter is None or len(data) == size_filter: tests.append((asm_text, data))
|
||||
except ValueError: pass
|
||||
break
|
||||
return tests
|
||||
|
||||
# Use gfx9 tests for compatible scalar/vector formats and gfx90a/gfx942 tests for CDNA-specific instructions
|
||||
# Format: (filename, format_class, op_enum, mcpu, mnemonic_filter, size_filter)
|
||||
CDNA_TEST_FILES = {
|
||||
# Scalar ALU - encoding is stable across GFX9/CDNA
|
||||
'sop1': ('gfx9_asm_sop1.s', SOP1, SOP1Op, 'gfx940', None, None),
|
||||
'sop2': ('gfx9_asm_sop2.s', SOP2, SOP2Op, 'gfx940', None, None),
|
||||
'sopp': ('gfx9_asm_sopp.s', SOPP, SOPPOp, 'gfx940', None, None),
|
||||
'sopp_gfx9': ('sopp-gfx9.s', SOPP, SOPPOp, 'gfx940', None, None),
|
||||
'sopk': ('gfx9_asm_sopk.s', SOPK, SOPKOp, 'gfx940', None, None),
|
||||
'sopc': ('gfx9_asm_sopc.s', SOPC, SOPCOp, 'gfx940', None, None),
|
||||
# Vector ALU - encoding is mostly stable
|
||||
'vop1': ('gfx9_asm_vop1.s', VOP1, VOP1Op, 'gfx940', None, None),
|
||||
'vop1_gfx9': ('vop1-gfx9.s', VOP1, VOP1Op, 'gfx940', None, None),
|
||||
'vop2': ('gfx9_asm_vop2.s', VOP2, VOP2Op, 'gfx940', None, None),
|
||||
'vopc': ('gfx9_asm_vopc.s', VOPC, VOPCOp, 'gfx940', None, None),
|
||||
'vop3p': ('gfx9_asm_vop3p.s', VOP3P, VOP3POp, 'gfx940', None, None),
|
||||
'vop3_gfx9': ('vop3-gfx9.s', VOP3A, VOP3AOp, 'gfx940', None, 8), # Only 64-bit VOP3 instructions
|
||||
# Memory instructions
|
||||
'ds': ('gfx9_asm_ds.s', DS, DSOp, 'gfx940', None, None),
|
||||
'ds_gfx9': ('ds-gfx9.s', DS, DSOp, 'gfx940', None, None),
|
||||
# CDNA memory instructions (gfx90a has correct FLAT/MUBUF encodings with acc registers)
|
||||
'flat_gfx90a': ('gfx90a_ldst_acc.s', FLAT, FLATOp, 'gfx90a', 'flat_', None),
|
||||
'global_gfx90a': ('gfx90a_ldst_acc.s', FLAT, FLATOp, 'gfx90a', 'global_', None),
|
||||
'mubuf_gfx90a': ('gfx90a_ldst_acc.s', MUBUF, MUBUFOp, 'gfx90a', 'buffer_', None),
|
||||
'mubuf_gfx9': ('mubuf-gfx9.s', MUBUF, MUBUFOp, 'gfx940', None, None),
|
||||
'scratch_gfx942': ('flat-scratch-gfx942.s', FLAT, FLATOp, 'gfx942', 'scratch_', None),
|
||||
# CDNA-specific: MFMA/MAI instructions
|
||||
'mai': ('mai-gfx942.s', VOP3P, VOP3POp, 'gfx942', None, None),
|
||||
# SDWA and DPP format tests for VOP1 (VOP2 has different bit layout, tested separately)
|
||||
'sdwa_vop1': ('gfx9_asm_vop1.s', SDWA, VOP1Op, 'gfx940', None, None),
|
||||
'dpp_vop1': ('gfx9_asm_vop1.s', DPP, VOP1Op, 'gfx940', None, None),
|
||||
}
|
||||
|
||||
class TestLLVMCDNA(unittest.TestCase):
|
||||
"""Test CDNA instruction format decode/encode roundtrip and disassembly."""
|
||||
tests: dict[str, list[tuple[str, bytes]]] = {}
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
for name, (filename, _, _, _, mnemonic_filter, size_filter) in CDNA_TEST_FILES.items():
|
||||
try:
|
||||
data = fetch(f"{LLVM_BASE}/{filename}").read_bytes()
|
||||
cls.tests[name] = parse_llvm_tests(data.decode('utf-8', errors='ignore'), mnemonic_filter, size_filter)
|
||||
except Exception as e:
|
||||
print(f"Warning: couldn't fetch {filename}: {e}")
|
||||
cls.tests[name] = []
|
||||
|
||||
def _get_val(v): return v.val if hasattr(v, 'val') else v
|
||||
|
||||
def _filter_and_decode(tests, fmt_cls, op_enum):
|
||||
"""Filter tests and decode instructions, yielding (asm_text, data, decoded, error)."""
|
||||
fn, is_sdwa, is_dpp = fmt_cls.__name__, fmt_cls.__name__ == 'SDWA', fmt_cls.__name__ == 'DPP'
|
||||
for asm_text, data in tests:
|
||||
has_lit = False
|
||||
# SDWA/DPP format tests: only accept matching 8-byte instructions
|
||||
if is_sdwa:
|
||||
if len(data) != 8 or data[0] != 0xf9: continue
|
||||
elif is_dpp:
|
||||
if len(data) != 8 or data[0] != 0xfa: continue
|
||||
elif fmt_cls._size() == 4 and len(data) == 8:
|
||||
if data[0] in (0xf9, 0xfa): continue # Skip SDWA/DPP (tested separately)
|
||||
has_lit = data[0] == 255 or (len(data) >= 2 and data[1] == 255 and fn in ('SOP2', 'SOPC'))
|
||||
if fn == 'SOPK': has_lit = has_lit or ((int.from_bytes(data[:4], 'little') >> 23) & 0x1f) == 20
|
||||
if fn == 'VOP2': has_lit = has_lit or ((int.from_bytes(data[:4], 'little') >> 25) & 0x3f) in (23, 24, 36, 37)
|
||||
if not has_lit: continue
|
||||
if len(data) > fmt_cls._size() + (4 if has_lit else 0): continue
|
||||
try:
|
||||
decoded = fmt_cls.from_bytes(data)
|
||||
# For SDWA/DPP, opcode location depends on VOP1 vs VOP2
|
||||
if is_sdwa or is_dpp:
|
||||
vop2_op = _get_val(decoded._values.get('vop2_op', 0))
|
||||
op_val = _get_val(decoded._values.get('vop_op', 0)) if vop2_op == 0x3f else vop2_op
|
||||
else:
|
||||
op_val = _get_val(decoded._values.get('op', 0))
|
||||
try: op_enum(op_val)
|
||||
except ValueError: continue
|
||||
yield asm_text, data, decoded, None
|
||||
except Exception as e:
|
||||
yield asm_text, data, None, str(e)
|
||||
|
||||
def _make_roundtrip_test(name):
|
||||
def test(self):
|
||||
_, fmt_cls, op_enum, _, _, _ = CDNA_TEST_FILES[name]
|
||||
passed, failed, failures = 0, 0, []
|
||||
for asm_text, data, decoded, error in _filter_and_decode(self.tests.get(name, []), fmt_cls, op_enum):
|
||||
if error: failed += 1; failures.append(f"'{asm_text}': {error}"); continue
|
||||
if decoded.to_bytes()[:len(data)] == data: passed += 1
|
||||
else: failed += 1; failures.append(f"'{asm_text}': orig={data.hex()} reenc={decoded.to_bytes()[:len(data)].hex()}")
|
||||
print(f"CDNA {name.upper()} roundtrip: {passed} passed, {failed} failed")
|
||||
if failures[:5]: print(" " + "\n ".join(failures[:5]))
|
||||
self.assertEqual(failed, 0)
|
||||
return test
|
||||
|
||||
def _make_disasm_test(name):
|
||||
def test(self):
|
||||
_, fmt_cls, op_enum, _, _, _ = CDNA_TEST_FILES[name]
|
||||
passed, failed, failures = 0, 0, []
|
||||
for asm_text, data, decoded, error in _filter_and_decode(self.tests.get(name, []), fmt_cls, op_enum):
|
||||
if error: failed += 1; failures.append(f"'{asm_text}': {error}"); continue
|
||||
if decoded.to_bytes()[:len(data)] != data: failed += 1; failures.append(f"'{asm_text}': roundtrip failed"); continue
|
||||
if not (disasm_text := disasm(decoded)) or not disasm_text.strip(): failed += 1; failures.append(f"'{asm_text}': empty disassembly"); continue
|
||||
passed += 1
|
||||
print(f"CDNA {name.upper()} disasm: {passed} passed, {failed} failed")
|
||||
if failures[:5]: print(" " + "\n ".join(failures[:5]))
|
||||
self.assertEqual(failed, 0)
|
||||
return test
|
||||
|
||||
for name in CDNA_TEST_FILES:
|
||||
setattr(TestLLVMCDNA, f'test_{name}_roundtrip', _make_roundtrip_test(name))
|
||||
setattr(TestLLVMCDNA, f'test_{name}_disasm', _make_disasm_test(name))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user