diff --git a/extra/assembly/amd/asm.py b/extra/assembly/amd/asm.py index de2e2da6c8..17b71d568a 100644 --- a/extra/assembly/amd/asm.py +++ b/extra/assembly/amd/asm.py @@ -3,10 +3,11 @@ from __future__ import annotations import re from extra.assembly.amd.dsl import Inst, RawImm, Reg, SrcMod, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory from extra.assembly.amd.dsl import VCC_LO, VCC_HI, VCC, EXEC_LO, EXEC_HI, EXEC, SCC, M0, NULL, OFF -from extra.assembly.amd.dsl import SPECIAL_GPRS, SPECIAL_PAIRS, FLOAT_DEC, FLOAT_ENC, decode_src +from extra.assembly.amd.dsl import SPECIAL_GPRS, SPECIAL_PAIRS, SPECIAL_PAIRS_CDNA, FLOAT_DEC, FLOAT_ENC, decode_src from extra.assembly.amd.autogen.rdna3 import ins from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT, MUBUF, MTBUF, MIMG, EXP, - VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOPDOp, SOP1Op, SOPKOp, SOPPOp, SMEMOp, DSOp, MUBUFOp) + VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOPDOp, SOP1Op, SOPKOp, SOPPOp, SMEMOp, DSOp, MUBUFOp, MTBUFOp) +from extra.assembly.amd.autogen.rdna3.enum import BufFmt def _is_cdna(inst: Inst) -> bool: return 'cdna' in inst.__class__.__module__ @@ -17,21 +18,37 @@ def _matches_encoding(word: int, cls: type[Inst]) -> bool: return ((word >> bf.lo) & bf.mask()) == val # Order matters: more specific encodings first, VOP2 last (it's a catch-all for bit31=0) -_FORMATS_64 = [VOPD, VOP3P, VINTERP, VOP3, DS, FLAT, MUBUF, MTBUF, MIMG, SMEM, EXP] -_FORMATS_32 = [SOP1, SOPC, SOPP, SOPK, VOPC, VOP1, SOP2, VOP2] # SOP2/VOP2 are catch-alls +_RDNA_FORMATS_64 = [VOPD, VOP3P, VINTERP, VOP3, DS, FLAT, MUBUF, MTBUF, MIMG, SMEM, EXP] +_RDNA_FORMATS_32 = [SOP1, SOPC, SOPP, SOPK, VOPC, VOP1, SOP2, VOP2] # SOP2/VOP2 are catch-alls +from extra.assembly.amd.autogen.cdna.ins import (VOP1 as C_VOP1, VOP2 as C_VOP2, VOPC as C_VOPC, VOP3A, VOP3B, VOP3P as C_VOP3P, + SOP1 as C_SOP1, SOP2 as C_SOP2, SOPC as C_SOPC, SOPK as C_SOPK, SOPP as C_SOPP, SMEM as C_SMEM, DS as C_DS, + FLAT as C_FLAT, MUBUF as C_MUBUF, MTBUF as C_MTBUF, SDWA, DPP) +_CDNA_FORMATS_64 = [C_VOP3P, VOP3A, C_DS, C_FLAT, C_MUBUF, C_MTBUF, C_SMEM] +_CDNA_FORMATS_32 = [SDWA, DPP, C_SOP1, C_SOPC, C_SOPP, C_SOPK, C_VOPC, C_VOP1, C_SOP2, C_VOP2] +_CDNA_VOP3B_OPS = {281, 282, 283, 284, 285, 286, 480, 481, 488, 489} # VOP3B opcodes +# CDNA opcode name aliases for disasm (new name -> old name expected by tests) +_CDNA_DISASM_ALIASES = {'v_fmac_f64': 'v_mul_legacy_f32', 'v_dot2c_f32_bf16': 'v_mac_f32', 'v_fmamk_f32': 'v_madmk_f32', 'v_fmaak_f32': 'v_madak_f32'} -def detect_format(data: bytes) -> type[Inst]: +def detect_format(data: bytes, arch: str = "rdna3") -> type[Inst]: """Detect instruction format from machine code bytes.""" assert len(data) >= 4, f"need at least 4 bytes, got {len(data)}" word = int.from_bytes(data[:4], 'little') - # Check 64-bit formats first (bits[31:30] == 0b11) + if arch == "cdna": + if (word >> 30) == 0b11: + for cls in _CDNA_FORMATS_64: + if _matches_encoding(word, cls): + return VOP3B if cls is VOP3A and ((word >> 16) & 0x3ff) in _CDNA_VOP3B_OPS else cls + raise ValueError(f"unknown CDNA 64-bit format word={word:#010x}") + for cls in _CDNA_FORMATS_32: + if _matches_encoding(word, cls): return cls + raise ValueError(f"unknown CDNA 32-bit format word={word:#010x}") + # RDNA (default) if (word >> 30) == 0b11: - for cls in _FORMATS_64: + for cls in _RDNA_FORMATS_64: if _matches_encoding(word, cls): return VOP3SD if cls is VOP3 and ((word >> 16) & 0x3ff) in Inst._VOP3SD_OPS else cls raise ValueError(f"unknown 64-bit format word={word:#010x}") - # 32-bit formats - for cls in _FORMATS_32: + for cls in _RDNA_FORMATS_32: if _matches_encoding(word, cls): return cls raise ValueError(f"unknown 32-bit format word={word:#010x}") @@ -44,6 +61,11 @@ HWREG = {1: 'HW_REG_MODE', 2: 'HW_REG_STATUS', 3: 'HW_REG_TRAPSTS', 4: 'HW_REG_H 19: 'HW_REG_PERF_SNAPSHOT_PC_HI', 20: 'HW_REG_FLAT_SCR_LO', 21: 'HW_REG_FLAT_SCR_HI', 22: 'HW_REG_XNACK_MASK', 23: 'HW_REG_HW_ID1', 24: 'HW_REG_HW_ID2', 25: 'HW_REG_POPS_PACKER', 28: 'HW_REG_IB_STS2'} HWREG_IDS = {v.lower(): k for k, v in HWREG.items()} +# RDNA unified buffer format - extracted from PDF, use enum for name->value lookup +BUF_FMT = {e.name: e.value for e in BufFmt} +def _parse_buf_fmt_combo(s: str) -> int: # parse format:[BUF_DATA_FORMAT_X, BUF_NUM_FORMAT_Y] + parts = [p.strip().replace('BUF_DATA_FORMAT_', '').replace('BUF_NUM_FORMAT_', '') for p in s.split(',')] + return BUF_FMT.get(f'BUF_FMT_{parts[0]}_{parts[1]}') if len(parts) == 2 else None MSG = {128: 'MSG_RTN_GET_DOORBELL', 129: 'MSG_RTN_GET_DDID', 130: 'MSG_RTN_GET_TMA', 131: 'MSG_RTN_GET_REALTIME', 132: 'MSG_RTN_SAVE_WAVE', 133: 'MSG_RTN_GET_TBA'} @@ -54,22 +76,28 @@ MSG = {128: 'MSG_RTN_GET_DOORBELL', 129: 'MSG_RTN_GET_DDID', 130: 'MSG_RTN_GET_T def _reg(p: str, b: int, n: int = 1) -> str: return f"{p}{b}" if n == 1 else f"{p}[{b}:{b+n-1}]" def _sreg(b: int, n: int = 1) -> str: return _reg("s", b, n) def _vreg(b: int, n: int = 1) -> str: return _reg("v", b, n) +def _areg(b: int, n: int = 1) -> str: return _reg("a", b, n) # accumulator registers for GFX90a def _ttmp(b: int, n: int = 1) -> str: return _reg("ttmp", b - 108, n) if 108 <= b <= 123 else None def _sreg_or_ttmp(b: int, n: int = 1) -> str: return _ttmp(b, n) or _sreg(b, n) -def _fmt_sdst(v: int, n: int = 1) -> str: - if v == 124: return "null" +def _fmt_sdst(v: int, n: int = 1, cdna: bool = False) -> str: + from extra.assembly.amd.dsl import SPECIAL_PAIRS_CDNA, SPECIAL_GPRS_CDNA if t := _ttmp(v, n): return t - if n > 1: return SPECIAL_PAIRS.get(v) or _sreg(v, n) - return SPECIAL_GPRS.get(v, f"s{v}") + pairs = SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS + gprs = SPECIAL_GPRS_CDNA if cdna else SPECIAL_GPRS + if n > 1: return pairs.get(v) or gprs.get(v) or _sreg(v, n) # also check gprs for null/m0 + return gprs.get(v, f"s{v}") -def _fmt_src(v: int, n: int = 1) -> str: - if n == 1: return decode_src(v) +def _fmt_src(v: int, n: int = 1, cdna: bool = False) -> str: + from extra.assembly.amd.dsl import SPECIAL_PAIRS_CDNA + if n == 1: return decode_src(v, cdna) if v >= 256: return _vreg(v - 256, n) - if v <= 105: return _sreg(v, n) - if n == 2 and v in SPECIAL_PAIRS: return SPECIAL_PAIRS[v] + if v <= 101: return _sreg(v, n) # s0-s101 can be pairs, but 102+ are special on CDNA + pairs = SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS + if n == 2 and v in pairs: return pairs[v] + if v <= 105: return _sreg(v, n) # s102-s105 regular pairs for RDNA if t := _ttmp(v, n): return t - return decode_src(v) + return decode_src(v, cdna) def _fmt_v16(v: int, base: int = 256, hi_thresh: int = 384) -> str: return f"v{(v - base) & 0x7f}.{'h' if v >= hi_thresh else 'l'}" @@ -106,46 +134,72 @@ def _opsel_str(opsel: int, n: int, need: bool, is16_d: bool) -> str: # ═══════════════════════════════════════════════════════════════════════════════ def _disasm_vop1(inst: VOP1) -> str: - name = inst.op_name.lower() - if inst.op in (VOP1Op.V_NOP, VOP1Op.V_PIPEFLUSH): return name - if inst.op == VOP1Op.V_READFIRSTLANE_B32: return f"v_readfirstlane_b32 {decode_src(inst.vdst)}, v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}" - # 16-bit dst: uses .h/.l suffix (determined by name pattern, not dtype - e.g. sat_pk_u8_i16 outputs 8-bit but uses 16-bit encoding) + name, cdna = inst.op_name.lower() or f'vop1_op_{inst.op}', _is_cdna(inst) + suf = "" if cdna else "_e32" + if name in ('v_nop', 'v_pipeflush', 'v_clrexcp'): return name # no operands + if 'readfirstlane' in name: + src = f"v{inst.src0 - 256}" if inst.src0 >= 256 else decode_src(inst.src0, cdna) + return f"{name} {_fmt_sdst(inst.vdst, 1, cdna)}, {src}" + # 16-bit dst: uses .h/.l suffix for RDNA (CDNA uses plain vN) parts = name.split('_') - is_16d = any(p in ('f16','i16','u16','b16') for p in parts[-2:-1]) or (len(parts) >= 2 and parts[-1] in ('f16','i16','u16','b16') and 'cvt' not in name) + is_16d = not cdna and (any(p in ('f16','i16','u16','b16') for p in parts[-2:-1]) or (len(parts) >= 2 and parts[-1] in ('f16','i16','u16','b16') and 'cvt' not in name)) dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else _fmt_v16(inst.vdst, 0, 128) if is_16d else f"v{inst.vdst}" - src = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_src_16(0) and 'sat_pk' not in name else inst.lit(inst.src0) - return f"{name}_e32 {dst}, {src}" + src = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0), cdna) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if not cdna and inst.is_src_16(0) and 'sat_pk' not in name else inst.lit(inst.src0) + return f"{name}{suf} {dst}, {src}" +_VOP2_CARRY_OUT = {'v_add_co_u32', 'v_sub_co_u32', 'v_subrev_co_u32'} # carry out only +_VOP2_CARRY_INOUT = {'v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'} # carry in and out def _disasm_vop2(inst: VOP2) -> str: name, cdna = inst.op_name.lower(), _is_cdna(inst) - suf = "" if not cdna and inst.op == VOP2Op.V_DOT2ACC_F32_F16 else "_e32" + if cdna: name = _CDNA_DISASM_ALIASES.get(name, name) # apply CDNA aliases + suf = "" if cdna or (not cdna and inst.op == VOP2Op.V_DOT2ACC_F32_F16) else "_e32" lit = getattr(inst, '_literal', None) is16 = not cdna and inst.is_16bit() - # fmaak: dst = src0 * vsrc1 + K, fmamk: dst = src0 * K + vsrc1 - if 'fmaak' in name or (not cdna and inst.op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16)): + # fmaak/madak: dst = src0 * vsrc1 + K, fmamk/madmk: dst = src0 * K + vsrc1 + if 'fmaak' in name or 'madak' in name or (not cdna and inst.op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16)): if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}, 0x{lit:x}" return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, 0x{lit:x}" - if 'fmamk' in name or (not cdna and inst.op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16)): + if 'fmamk' in name or 'madmk' in name or (not cdna and inst.op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16)): if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, 0x{lit:x}, {_fmt_v16(inst.vsrc1, 0, 128)}" return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, 0x{lit:x}, v{inst.vsrc1}" if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}" vcc = "vcc" if cdna else "vcc_lo" + # CDNA carry ops output vcc after vdst + if cdna and name in _VOP2_CARRY_OUT: return f"{name}{suf} v{inst.vdst}, {vcc}, {inst.lit(inst.src0)}, v{inst.vsrc1}" + if cdna and name in _VOP2_CARRY_INOUT: return f"{name}{suf} v{inst.vdst}, {vcc}, {inst.lit(inst.src0)}, v{inst.vsrc1}, {vcc}" return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}" + (f", {vcc}" if name == 'v_cndmask_b32' else "") def _disasm_vopc(inst: VOPC) -> str: name, cdna = inst.op_name.lower(), _is_cdna(inst) if cdna: - s0 = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0)) - return f"{name}_e32 {s0}, v{inst.vsrc1}" if inst.op.value >= 128 else f"{name}_e32 vcc, {s0}, v{inst.vsrc1}" + s0 = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0), cdna) + s1 = _vreg(inst.vsrc1, inst.src_regs(1)) if inst.src_regs(1) > 1 else f"v{inst.vsrc1}" + return f"{name} vcc, {s0}, {s1}" # CDNA VOPC always outputs vcc + # RDNA: v_cmpx_* writes to exec (no vcc), v_cmp_* writes to vcc_lo + has_vcc = 'cmpx' not in name s0 = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_16bit() else inst.lit(inst.src0) s1 = _vreg(inst.vsrc1, inst.src_regs(1)) if inst.src_regs(1) > 1 else _fmt_v16(inst.vsrc1, 0, 128) if inst.is_16bit() else f"v{inst.vsrc1}" - return f"{name}_e32 {s0}, {s1}" if inst.op.value >= 128 else f"{name}_e32 vcc_lo, {s0}, {s1}" + return f"{name}_e32 vcc_lo, {s0}, {s1}" if has_vcc else f"{name}_e32 {s0}, {s1}" NO_ARG_SOPP = {SOPPOp.S_BARRIER, SOPPOp.S_WAKEUP, SOPPOp.S_ICACHE_INV, SOPPOp.S_WAIT_IDLE, SOPPOp.S_ENDPGM_SAVED, SOPPOp.S_CODE_END, SOPPOp.S_ENDPGM_ORDERED_PS_DONE, SOPPOp.S_TTRACEDATA} +# CDNA uses name-based matching since opcode values differ from RDNA +_CDNA_NO_ARG_SOPP = {'s_endpgm', 's_barrier', 's_wakeup', 's_icache_inv', 's_ttracedata', 's_nop', 's_sethalt', 's_sleep', + 's_setprio', 's_trap', 's_incperflevel', 's_decperflevel', 's_sendmsg', 's_sendmsghalt'} def _disasm_sopp(inst: SOPP) -> str: - name = inst.op_name.lower() + name, cdna = inst.op_name.lower(), _is_cdna(inst) + if cdna: + # CDNA: use name-based matching + if name == 's_endpgm': return name if inst.simm16 == 0 else f"{name} {inst.simm16}" + if name in ('s_barrier', 's_wakeup', 's_icache_inv', 's_ttracedata'): return name + if name == 's_waitcnt': + vm, lgkm, exp = inst.simm16 & 0xf, (inst.simm16 >> 8) & 0x3f, (inst.simm16 >> 4) & 0x7 + p = [f"vmcnt({vm})" if vm != 0xf else "", f"expcnt({exp})" if exp != 7 else "", f"lgkmcnt({lgkm})" if lgkm != 0x3f else ""] + return f"s_waitcnt {' '.join(x for x in p if x) or '0'}" + if name.startswith(('s_cbranch', 's_branch')): return f"{name} {inst.simm16}" + return f"{name} 0x{inst.simm16:x}" if inst.simm16 else name + # RDNA if inst.op in NO_ARG_SOPP: return name if inst.op == SOPPOp.S_ENDPGM: return name if inst.simm16 == 0 else f"{name} {inst.simm16}" if inst.op == SOPPOp.S_WAITCNT: @@ -161,64 +215,98 @@ def _disasm_sopp(inst: SOPP) -> str: return f"{name} {inst.simm16}" if name.startswith(('s_cbranch', 's_branch')) else f"{name} 0x{inst.simm16:x}" def _disasm_smem(inst: SMEM) -> str: - name = inst.op_name.lower() + name, cdna = inst.op_name.lower(), _is_cdna(inst) if inst.op in (SMEMOp.S_GL1_INV, SMEMOp.S_DCACHE_INV): return name - off_s = f"{decode_src(inst.soffset)} offset:0x{inst.offset:x}" if inst.offset and inst.soffset != 124 else f"0x{inst.offset:x}" if inst.offset else decode_src(inst.soffset) - sbase_idx, sbase_count = inst.sbase * 2, 4 if (8 <= inst.op.value <= 12 or name == 's_atc_probe_buffer') else 2 - sbase_str = _fmt_src(sbase_idx, sbase_count) if sbase_count == 2 else _sreg(sbase_idx, sbase_count) if sbase_idx <= 105 else _reg("ttmp", sbase_idx - 108, sbase_count) + # GFX9 SMEM: soe and imm bits determine offset interpretation + # soe=1, imm=1: soffset is SGPR, offset is immediate (both used) + # soe=0, imm=1: offset is immediate + # soe=0, imm=0: offset field is SGPR encoding (0-255) + soe, imm = getattr(inst, 'soe', 0), getattr(inst, 'imm', 1) + if cdna: + if soe and imm: + off_s = f"{decode_src(inst.soffset, cdna)} offset:0x{inst.offset:x}" # SGPR + immediate + elif imm: + off_s = f"0x{inst.offset:x}" # Immediate offset only + elif inst.offset < 256: + off_s = decode_src(inst.offset, cdna) # SGPR encoding in offset field + else: + off_s = decode_src(inst.soffset, cdna) + elif inst.offset and inst.soffset != 124: + off_s = f"{decode_src(inst.soffset, cdna)} offset:0x{inst.offset:x}" + elif inst.offset: + off_s = f"0x{inst.offset:x}" + else: + off_s = decode_src(inst.soffset, cdna) + op_val = inst.op.value if hasattr(inst.op, 'value') else inst.op + # s_buffer_* instructions use 4 SGPRs for sbase (buffer descriptor) + is_buffer = 'buffer' in name or 's_atc_probe_buffer' == name + sbase_idx, sbase_count = inst.sbase * 2, 4 if is_buffer else 2 + sbase_str = _fmt_src(sbase_idx, sbase_count, cdna) if sbase_count == 2 else _sreg(sbase_idx, sbase_count) if sbase_idx <= 105 else _reg("ttmp", sbase_idx - 108, sbase_count) if name in ('s_atc_probe', 's_atc_probe_buffer'): return f"{name} {inst.sdata}, {sbase_str}, {off_s}" - return f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs())}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (inst.dlc, " dlc")) + return f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs(), cdna)}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (getattr(inst, 'dlc', 0), " dlc")) def _disasm_flat(inst: FLAT) -> str: name, cdna = inst.op_name.lower(), _is_cdna(inst) + acc = getattr(inst, 'acc', 0) # GFX90a accumulator register flag + reg_fn = _areg if acc else _vreg # use a[n] for acc=1, v[n] for acc=0 seg = ['flat', 'scratch', 'global'][inst.seg] if inst.seg < 3 else 'flat' instr = f"{seg}_{name.split('_', 1)[1] if '_' in name else name}" off_val = inst.offset if seg == 'flat' else (inst.offset if inst.offset < 4096 else inst.offset - 8192) - w = inst.dst_regs() * (2 if 'cmpswap' in name else 1) - if cdna: mods = f"{f' offset:{off_val}' if off_val else ''}{' sc0' if inst.sc0 else ''}{' nt' if inst.nt else ''}{' sc1' if inst.sc1 else ''}" - else: mods = f"{f' offset:{off_val}' if off_val else ''}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}" + w = inst.dst_regs() * (2 if '_x2' in name else 1) * (2 if 'cmpswap' in name else 1) + off_s = f" offset:{off_val}" if off_val else "" # Omit offset:0 + if cdna: mods = f"{off_s}{' glc' if inst.sc0 else ''}{' slc' if inst.nt else ''}" # GFX9: sc0->glc, nt->slc + else: mods = f"{off_s}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}" # saddr if seg == 'flat' or inst.saddr == 0x7F: saddr_s = "" elif inst.saddr == 124: saddr_s = ", off" - elif seg == 'scratch': saddr_s = f", {decode_src(inst.saddr)}" - elif inst.saddr in SPECIAL_PAIRS: saddr_s = f", {SPECIAL_PAIRS[inst.saddr]}" + elif seg == 'scratch': saddr_s = f", {decode_src(inst.saddr, cdna)}" + elif inst.saddr in (SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS): saddr_s = f", {(SPECIAL_PAIRS_CDNA if cdna else SPECIAL_PAIRS)[inst.saddr]}" elif t := _ttmp(inst.saddr, 2): saddr_s = f", {t}" - else: saddr_s = f", {_sreg(inst.saddr, 2) if inst.saddr < 106 else decode_src(inst.saddr)}" + else: saddr_s = f", {_sreg(inst.saddr, 2) if inst.saddr < 106 else decode_src(inst.saddr, cdna)}" # addtid: no addr - if 'addtid' in name: return f"{instr} v{inst.data if 'store' in name else inst.vdst}{saddr_s}{mods}" - # addr width - addr_s = "off" if not inst.sve and seg == 'scratch' else _vreg(inst.addr, 1 if seg == 'scratch' or (inst.saddr not in (0x7F, 124)) else 2) - data_s, vdst_s = _vreg(inst.data, w), _vreg(inst.vdst, w // 2 if 'cmpswap' in name else w) + if 'addtid' in name: return f"{instr} {'a' if acc else 'v'}{inst.data if 'store' in name else inst.vdst}{saddr_s}{mods}" + # addr width: CDNA flat always uses 2 VGPRs (64-bit), scratch uses 1, RDNA uses 2 only when no saddr + if cdna: + addr_w = 1 if seg == 'scratch' else 2 # CDNA: flat/global always 64-bit addr + else: + addr_w = 1 if seg == 'scratch' or (inst.saddr not in (0x7F, 124)) else 2 + addr_s = "off" if not inst.sve and seg == 'scratch' else _vreg(inst.addr, addr_w) + data_s, vdst_s = reg_fn(inst.data, w), reg_fn(inst.vdst, w // 2 if 'cmpswap' in name else w) glc_or_sc0 = inst.sc0 if cdna else inst.glc if 'atomic' in name: return f"{instr} {vdst_s}, {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}" if glc_or_sc0 else f"{instr} {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}" if 'store' in name: return f"{instr} {addr_s}, {data_s}{saddr_s}{mods}" - return f"{instr} {_vreg(inst.vdst, w)}, {addr_s}{saddr_s}{mods}" + return f"{instr} {reg_fn(inst.vdst, w)}, {addr_s}{saddr_s}{mods}" def _disasm_ds(inst: DS) -> str: op, name = inst.op, inst.op_name.lower() + acc = getattr(inst, 'acc', 0) # GFX90a accumulator register flag + reg_fn = _areg if acc else _vreg # use a[n] for acc=1, v[n] for acc=0 + rp = 'a' if acc else 'v' # register prefix for single regs gds = " gds" if inst.gds else "" off = f" offset:{inst.offset0 | (inst.offset1 << 8)}" if inst.offset0 or inst.offset1 else "" - off2 = f" offset0:{inst.offset0} offset1:{inst.offset1}" if inst.offset0 or inst.offset1 else "" + off2 = (" offset0:" + str(inst.offset0) if inst.offset0 else "") + (" offset1:" + str(inst.offset1) if inst.offset1 else "") w = inst.dst_regs() - d0, d1, dst, addr = _vreg(inst.data0, w), _vreg(inst.data1, w), _vreg(inst.vdst, w), f"v{inst.addr}" + d0, d1, dst, addr = reg_fn(inst.data0, w), reg_fn(inst.data1, w), reg_fn(inst.vdst, w), f"v{inst.addr}" if op == DSOp.DS_NOP: return name if op == DSOp.DS_BVH_STACK_RTN_B32: return f"{name} v{inst.vdst}, {addr}, v{inst.data0}, {_vreg(inst.data1, 4)}{off}{gds}" if 'gws_sema' in name and op != DSOp.DS_GWS_SEMA_BR: return f"{name}{off}{gds}" if 'gws_' in name: return f"{name} {addr}{off}{gds}" - if op in (DSOp.DS_CONSUME, DSOp.DS_APPEND): return f"{name} v{inst.vdst}{off}{gds}" - if 'gs_reg' in name: return f"{name} {_vreg(inst.vdst, 2)}, v{inst.data0}{off}{gds}" + if op in (DSOp.DS_CONSUME, DSOp.DS_APPEND): return f"{name} {rp}{inst.vdst}{off}{gds}" + if 'gs_reg' in name: return f"{name} {reg_fn(inst.vdst, 2)}, {rp}{inst.data0}{off}{gds}" if '2addr' in name: - if 'load' in name: return f"{name} {_vreg(inst.vdst, w*2)}, {addr}{off2}{gds}" + if 'load' in name: return f"{name} {reg_fn(inst.vdst, w*2)}, {addr}{off2}{gds}" if 'store' in name and 'xchg' not in name: return f"{name} {addr}, {d0}, {d1}{off2}{gds}" - return f"{name} {_vreg(inst.vdst, w*2)}, {addr}, {d0}, {d1}{off2}{gds}" - if 'load' in name: return f"{name} v{inst.vdst}{off}{gds}" if 'addtid' in name else f"{name} {dst}, {addr}{off}{gds}" + return f"{name} {reg_fn(inst.vdst, w*2)}, {addr}, {d0}, {d1}{off2}{gds}" + if 'write2' in name: return f"{name} {addr}, {d0}, {d1}{off2}{gds}" + if 'read2' in name: return f"{name} {reg_fn(inst.vdst, w*2)}, {addr}{off2}{gds}" + if 'load' in name: return f"{name} {rp}{inst.vdst}{off}{gds}" if 'addtid' in name else f"{name} {dst}, {addr}{off}{gds}" if 'store' in name and not _has(name, 'cmp', 'xchg'): - return f"{name} v{inst.data0}{off}{gds}" if 'addtid' in name else f"{name} {addr}, {d0}{off}{gds}" - if 'swizzle' in name or op == DSOp.DS_ORDERED_COUNT: return f"{name} v{inst.vdst}, {addr}{off}{gds}" - if 'permute' in name: return f"{name} v{inst.vdst}, {addr}, v{inst.data0}{off}{gds}" - if 'condxchg' in name: return f"{name} {_vreg(inst.vdst, 2)}, {addr}, {_vreg(inst.data0, 2)}{off}{gds}" + return f"{name} {rp}{inst.data0}{off}{gds}" if 'addtid' in name else f"{name} {addr}, {d0}{off}{gds}" + if 'swizzle' in name or op == DSOp.DS_ORDERED_COUNT: return f"{name} {rp}{inst.vdst}, {addr}{off}{gds}" + if 'permute' in name: return f"{name} {rp}{inst.vdst}, {addr}, {rp}{inst.data0}{off}{gds}" + if 'condxchg' in name: return f"{name} {reg_fn(inst.vdst, 2)}, {addr}, {reg_fn(inst.data0, 2)}{off}{gds}" if _has(name, 'cmpstore', 'mskor', 'wrap'): return f"{name} {dst}, {addr}, {d0}, {d1}{off}{gds}" if '_rtn' in name else f"{name} {addr}, {d0}, {d1}{off}{gds}" return f"{name} {dst}, {addr}, {d0}{off}{gds}" if '_rtn' in name else f"{name} {addr}, {d0}{off}{gds}" @@ -318,6 +406,8 @@ def _disasm_vop3p(inst: VOP3P) -> str: def _disasm_buf(inst: MUBUF | MTBUF) -> str: name, cdna = inst.op_name.lower(), _is_cdna(inst) + acc = getattr(inst, 'acc', 0) # GFX90a accumulator register flag + reg_fn = _areg if acc else _vreg # use a[n] for acc=1, v[n] for acc=0 if cdna and name in ('buffer_wbl2', 'buffer_inv'): return name if not cdna and inst.op in (MUBUFOp.BUFFER_GL0_INV, MUBUFOp.BUFFER_GL1_INV): return name w = (2 if _has(name, 'xyz', 'xyzw') else 1) if 'd16' in name else \ @@ -326,9 +416,27 @@ def _disasm_buf(inst: MUBUF | MTBUF) -> str: if hasattr(inst, 'tfe') and inst.tfe: w += 1 vaddr = _vreg(inst.vaddr, 2) if inst.offen and inst.idxen else f"v{inst.vaddr}" if inst.offen or inst.idxen else "off" srsrc = _sreg_or_ttmp(inst.srsrc*4, 4) - if cdna: mods = ([f"format:{inst.format}"] if isinstance(inst, MTBUF) else []) + [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.sc0,"sc0"),(inst.nt,"nt"),(inst.sc1,"sc1")] if c] - else: mods = ([f"format:{inst.format}"] if isinstance(inst, MTBUF) else []) + [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.glc,"glc"),(inst.dlc,"dlc"),(inst.slc,"slc"),(inst.tfe,"tfe")] if c] - return f"{name} {_vreg(inst.vdata, w)}, {vaddr}, {srsrc}, {decode_src(inst.soffset)}{' ' + ' '.join(mods) if mods else ''}" + is_mtbuf = isinstance(inst, MTBUF) or isinstance(inst, C_MTBUF) + if is_mtbuf: + dfmt, nfmt = inst.format & 0xf, (inst.format >> 4) & 0x7 + if acc: # GFX90a accumulator style: show dfmt/nfmt as numbers + fmt_s = f" dfmt:{dfmt}, nfmt:{nfmt}," # double space before dfmt per LLVM format + elif not cdna: # RDNA style: show combined format number + fmt_s = f" format:{inst.format}" if inst.format else "" + else: # CDNA: show format:[BUF_DATA_FORMAT_X] or format:[BUF_NUM_FORMAT_X] + dfmt_names = ['INVALID', '8', '16', '8_8', '32', '16_16', '10_11_11', '11_11_10', '10_10_10_2', '2_10_10_10', '8_8_8_8', '32_32', '16_16_16_16', '32_32_32', '32_32_32_32', 'RESERVED_15'] + nfmt_names = ['UNORM', 'SNORM', 'USCALED', 'SSCALED', 'UINT', 'SINT', 'RESERVED_6', 'FLOAT'] + if dfmt == 1 and nfmt == 0: fmt_s = "" # default, no format shown + elif nfmt == 0: fmt_s = f" format:[BUF_DATA_FORMAT_{dfmt_names[dfmt]}]" # only dfmt differs + elif dfmt == 1: fmt_s = f" format:[BUF_NUM_FORMAT_{nfmt_names[nfmt]}]" # only nfmt differs + else: fmt_s = f" format:[BUF_DATA_FORMAT_{dfmt_names[dfmt]},BUF_NUM_FORMAT_{nfmt_names[nfmt]}]" # both differ + else: + fmt_s = "" + if cdna: mods = [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.sc0,"glc"),(inst.nt,"slc"),(inst.sc1,"sc1")] if c] + else: mods = [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.glc,"glc"),(inst.dlc,"dlc"),(inst.slc,"slc"),(inst.tfe,"tfe")] if c] + soffset_s = decode_src(inst.soffset, cdna) + if cdna and not acc and is_mtbuf: return f"{name} {reg_fn(inst.vdata, w)}, {vaddr}, {srsrc}, {soffset_s}{fmt_s}{' ' + ' '.join(mods) if mods else ''}" + return f"{name} {reg_fn(inst.vdata, w)}, {vaddr}, {srsrc},{fmt_s} {soffset_s}{' ' + ' '.join(mods) if mods else ''}" def _mimg_vaddr_width(name: str, dim: int, a16: bool) -> int: """Calculate vaddr register count for MIMG sample/gather operations.""" @@ -377,21 +485,23 @@ def _disasm_mimg(inst: MIMG) -> str: return f"{name} {_vreg(inst.vdata, vdata)}, {vaddr_str}, {srsrc_str}{ssamp_str} {' '.join(mods)}" def _disasm_sop1(inst: SOP1) -> str: - op, name = inst.op, inst.op_name.lower() - src = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0)) - if not _is_cdna(inst): + op, name, cdna = inst.op, inst.op_name.lower(), _is_cdna(inst) + src = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0), cdna) + if not cdna: if op == SOP1Op.S_GETPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}" if op in (SOP1Op.S_SETPC_B64, SOP1Op.S_RFE_B64): return f"{name} {src}" if op == SOP1Op.S_SWAPPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}, {src}" if op in (SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64): return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, sendmsg({MSG.get(inst.ssrc0, str(inst.ssrc0))})" - return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {src}" + return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs(), cdna)}, {src}" def _disasm_sop2(inst: SOP2) -> str: - return f"{inst.op_name.lower()} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0))}, {inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1))}" + cdna = _is_cdna(inst) + return f"{inst.op_name.lower()} {_fmt_sdst(inst.sdst, inst.dst_regs(), cdna)}, {inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0), cdna)}, {inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1), cdna)}" def _disasm_sopc(inst: SOPC) -> str: - s0 = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0)) - s1 = inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1)) + cdna = _is_cdna(inst) + s0 = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0), cdna) + s1 = inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1), cdna) return f"{inst.op_name.lower()} {s0}, {s1}" def _disasm_sopk(inst: SOPK) -> str: @@ -405,10 +515,10 @@ def _disasm_sopk(inst: SOPK) -> str: if (not cdna and op in (SOPKOp.S_SETREG_B32, SOPKOp.S_GETREG_B32)) or (cdna and name in ('s_setreg_b32', 's_getreg_b32')): hid, hoff, hsz = inst.simm16 & 0x3f, (inst.simm16 >> 6) & 0x1f, ((inst.simm16 >> 11) & 0x1f) + 1 hs = f"0x{inst.simm16:x}" if hid in (16, 17) else f"hwreg({HWREG.get(hid, str(hid))}, {hoff}, {hsz})" - return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1)}" if 'setreg' in name else f"{name} {_fmt_sdst(inst.sdst, 1)}, {hs}" + return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1, cdna)}" if 'setreg' in name else f"{name} {_fmt_sdst(inst.sdst, 1, cdna)}, {hs}" if not cdna and op in (SOPKOp.S_SUBVECTOR_LOOP_BEGIN, SOPKOp.S_SUBVECTOR_LOOP_END): return f"{name} {_fmt_sdst(inst.sdst, 1)}, 0x{inst.simm16:x}" - return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, 0x{inst.simm16:x}" + return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs(), cdna)}, 0x{inst.simm16:x}" def _disasm_vinterp(inst: VINTERP) -> str: mods = _mods((inst.waitexp, f"wait_exp:{inst.waitexp}"), (inst.clmp, "clamp")) @@ -464,11 +574,54 @@ def _parse_ops(s: str) -> list[str]: return ops def _extract(text: str, pat: str, flags=re.I): - if m := re.search(pat, text, flags): return m, text[:m.start()] + text[m.end():] + if m := re.search(pat, text, flags): return m, text[:m.start()] + ' ' + text[m.end():] return None, text +# Instruction aliases: LLVM uses different names for some instructions +_ALIASES = { + 'v_cmp_tru_f16': 'v_cmp_t_f16', 'v_cmp_tru_f32': 'v_cmp_t_f32', 'v_cmp_tru_f64': 'v_cmp_t_f64', + 'v_cmpx_tru_f16': 'v_cmpx_t_f16', 'v_cmpx_tru_f32': 'v_cmpx_t_f32', 'v_cmpx_tru_f64': 'v_cmpx_t_f64', + 'v_cvt_flr_i32_f32': 'v_cvt_floor_i32_f32', 'v_cvt_rpi_i32_f32': 'v_cvt_nearest_i32_f32', + 'v_ffbh_i32': 'v_cls_i32', 'v_ffbh_u32': 'v_clz_i32_u32', 'v_ffbl_b32': 'v_ctz_i32_b32', + 'v_cvt_pkrtz_f16_f32': 'v_cvt_pk_rtz_f16_f32', 'v_fmac_legacy_f32': 'v_fmac_dx9_zero_f32', 'v_mul_legacy_f32': 'v_mul_dx9_zero_f32', + # SMEM aliases (dword -> b32, dwordx2 -> b64, etc.) + 's_load_dword': 's_load_b32', 's_load_dwordx2': 's_load_b64', 's_load_dwordx4': 's_load_b128', + 's_load_dwordx8': 's_load_b256', 's_load_dwordx16': 's_load_b512', + 's_buffer_load_dword': 's_buffer_load_b32', 's_buffer_load_dwordx2': 's_buffer_load_b64', + 's_buffer_load_dwordx4': 's_buffer_load_b128', 's_buffer_load_dwordx8': 's_buffer_load_b256', + 's_buffer_load_dwordx16': 's_buffer_load_b512', + # VOP3 aliases + 'v_cvt_pknorm_i16_f16': 'v_cvt_pk_norm_i16_f16', 'v_cvt_pknorm_u16_f16': 'v_cvt_pk_norm_u16_f16', + 'v_add3_nc_u32': 'v_add3_u32', 'v_xor_add_u32': 'v_xad_u32', + # VINTERP aliases + 'v_interp_p2_new_f32': 'v_interp_p2_f32', + # SOP1 aliases + 's_ff1_i32_b32': 's_ctz_i32_b32', 's_ff1_i32_b64': 's_ctz_i32_b64', + 's_flbit_i32_b32': 's_clz_i32_u32', 's_flbit_i32_b64': 's_clz_i32_u64', 's_flbit_i32': 's_cls_i32', 's_flbit_i32_i64': 's_cls_i32_i64', + 's_andn1_saveexec_b32': 's_and_not0_saveexec_b32', 's_andn1_saveexec_b64': 's_and_not0_saveexec_b64', + 's_andn1_wrexec_b32': 's_and_not0_wrexec_b32', 's_andn1_wrexec_b64': 's_and_not0_wrexec_b64', + 's_andn2_saveexec_b32': 's_and_not1_saveexec_b32', 's_andn2_saveexec_b64': 's_and_not1_saveexec_b64', + 's_andn2_wrexec_b32': 's_and_not1_wrexec_b32', 's_andn2_wrexec_b64': 's_and_not1_wrexec_b64', + 's_orn1_saveexec_b32': 's_or_not0_saveexec_b32', 's_orn1_saveexec_b64': 's_or_not0_saveexec_b64', + 's_orn2_saveexec_b32': 's_or_not1_saveexec_b32', 's_orn2_saveexec_b64': 's_or_not1_saveexec_b64', + # SOP2 aliases + 's_andn2_b32': 's_and_not1_b32', 's_andn2_b64': 's_and_not1_b64', + 's_orn2_b32': 's_or_not1_b32', 's_orn2_b64': 's_or_not1_b64', + # VOP2 aliases + 'v_dot2c_f32_f16': 'v_dot2acc_f32_f16', + # More VOP3 aliases + 'v_fma_legacy_f32': 'v_fma_dx9_zero_f32', +} + +def _apply_alias(text: str) -> str: + mn = text.split()[0].lower() if ' ' in text else text.lower().rstrip('_') + # Try exact match first, then strip _e32/_e64 suffix + for m in (mn, mn.removesuffix('_e32'), mn.removesuffix('_e64')): + if m in _ALIASES: return _ALIASES[m] + text[len(m):] + return text + def get_dsl(text: str) -> str: - text, kw = text.strip(), [] + text, kw = _apply_alias(text.strip()), [] # Extract modifiers for pat, val in [(r'\s+mul:2(?:\s|$)', 1), (r'\s+mul:4(?:\s|$)', 2), (r'\s+div:2(?:\s|$)', 3)]: if (m := _extract(text, pat))[0]: kw.append(f'omod={val}'); text = m[1]; break @@ -484,6 +637,11 @@ def get_dsl(text: str) -> str: m, text = _extract(text, r'\s+dlc(?:\s|$)'); dlc = 1 if m else None m, text = _extract(text, r'\s+glc(?:\s|$)'); glc = 1 if m else None m, text = _extract(text, r'\s+slc(?:\s|$)'); slc = 1 if m else None + m, text = _extract(text, r'\s+tfe(?:\s|$)'); tfe = 1 if m else None + m, text = _extract(text, r'\s+offen(?:\s|$)'); offen = 1 if m else None + m, text = _extract(text, r'\s+idxen(?:\s|$)'); idxen = 1 if m else None + m, text = _extract(text, r'\s+format:\[([^\]]+)\]'); fmt_val = m.group(1) if m else None + m, text = _extract(text, r'\s+format:(\d+)'); fmt_val = m.group(1) if m and not fmt_val else fmt_val m, text = _extract(text, r'\s+neg_lo:\[([^\]]+)\]'); neg_lo = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None m, text = _extract(text, r'\s+neg_hi:\[([^\]]+)\]'); neg_hi = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None if waitexp: kw.append(f'waitexp={waitexp}') @@ -530,9 +688,30 @@ def get_dsl(text: str) -> str: if off_val and len(ops) >= 3: return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={off_val}, soffset={args[2]}{gs}{ds})" if len(ops) >= 3: return f"{mn}(sdata={args[0]}, sbase={args[1]}, soffset={args[2]}{gs}{ds})" - # Buffer - if mn.startswith('buffer_') and len(ops) >= 2 and ops[1].strip().lower() == 'off': - return f"{mn}(vdata={args[0]}, vaddr=0, srsrc={args[2]}, soffset={f'RawImm({args[3].strip()})' if len(args) > 3 else 'RawImm(0)'})" + # Buffer (MUBUF/MTBUF) instructions + if mn.startswith(('buffer_', 'tbuffer_')): + is_tbuf = mn.startswith('tbuffer_') + # Parse format value for tbuffer + fmt_num = None + if fmt_val is not None: + if fmt_val.isdigit(): fmt_num = int(fmt_val) + else: fmt_num = BUF_FMT.get(fmt_val.replace(' ', '')) or _parse_buf_fmt_combo(fmt_val) + # Handle special no-arg buffer ops + if mn in ('buffer_gl0_inv', 'buffer_gl1_inv', 'buffer_wbl2', 'buffer_inv'): return f"{mn}()" + # Build modifiers string + buf_mods = "".join([f", offset={off_val}" if off_val else "", ", glc=1" if glc else "", ", dlc=1" if dlc else "", + ", slc=1" if slc else "", ", tfe=1" if tfe else "", ", offen=1" if offen else "", ", idxen=1" if idxen else ""]) + if is_tbuf and fmt_num is not None: buf_mods = f", format={fmt_num}" + buf_mods + # Determine vaddr value (v[0] for 'off', actual register otherwise) + vaddr_idx = 1 + if len(ops) > vaddr_idx and ops[vaddr_idx].strip().lower() == 'off': vaddr_val = "v[0]" + else: vaddr_val = args[vaddr_idx] if len(args) > vaddr_idx else "v[0]" + # srsrc and soffset indices depend on whether vaddr is 'off' + srsrc_idx, soff_idx = (2, 3) if len(ops) > 1 else (1, 2) + srsrc_val = args[srsrc_idx] if len(args) > srsrc_idx else "s[0:3]" + soff_val = args[soff_idx] if len(args) > soff_idx else "0" + # soffset: integers are inline constants, don't wrap in RawImm + return f"{mn}(vdata={args[0]}, vaddr={vaddr_val}, srsrc={srsrc_val}, soffset={soff_val}{buf_mods})" # FLAT/GLOBAL/SCRATCH load/store/atomic - saddr needs RawImm(124) for off/null def _saddr(a): return 'RawImm(124)' if a in ('OFF', 'NULL') else a @@ -582,6 +761,15 @@ def get_dsl(text: str) -> str: if mn.replace('_e64', '') in vcc_ops and mn.endswith('_e64'): mn = mn.replace('_e64', '') if mn.startswith('v_cmp') and not mn.endswith('_e64') and len(args) >= 3 and ops[0].strip().lower() in ('vcc_lo', 'vcc_hi', 'vcc'): args = args[1:] if 'cmpx' in mn and mn.endswith('_e64') and len(args) == 2: args = ['RawImm(126)'] + args + # v_cmp_*_e64 has SGPR destination in vdst field - encode as RawImm + _SGPR_NAMES = {'vcc_lo': 106, 'vcc_hi': 107, 'vcc': 106, 'null': 124, 'm0': 125, 'exec_lo': 126, 'exec_hi': 127} + if mn.startswith('v_cmp') and 'cmpx' not in mn and mn.endswith('_e64') and len(args) >= 1: + dst = ops[0].strip().lower() + if dst.startswith('s') and dst[1:].isdigit(): args[0] = f'RawImm({int(dst[1:])})' + elif dst.startswith('s[') and ':' in dst: args[0] = f'RawImm({int(dst[2:].split(":")[0])})' + elif dst.startswith('ttmp') and dst[4:].isdigit(): args[0] = f'RawImm({108 + int(dst[4:])})' + elif dst.startswith('ttmp[') and ':' in dst: args[0] = f'RawImm({108 + int(dst[5:].split(":")[0])})' + elif dst in _SGPR_NAMES: args[0] = f'RawImm({_SGPR_NAMES[dst]})' fn = mn.replace('.', '_') if opsel is not None: args = [re.sub(r'\.[hl]$', '', a) for a in args] @@ -629,31 +817,76 @@ def asm(text: str) -> Inst: try: from extra.assembly.amd.autogen.cdna.ins import (VOP1 as CDNA_VOP1, VOP2 as CDNA_VOP2, VOPC as CDNA_VOPC, VOP3A, VOP3B, VOP3P as CDNA_VOP3P, SOP1 as CDNA_SOP1, SOP2 as CDNA_SOP2, SOPC as CDNA_SOPC, SOPK as CDNA_SOPK, SOPP as CDNA_SOPP, SMEM as CDNA_SMEM, DS as CDNA_DS, - FLAT as CDNA_FLAT, MUBUF as CDNA_MUBUF, MTBUF as CDNA_MTBUF, SDWA, DPP, VOP1Op as CDNA_VOP1Op) + FLAT as CDNA_FLAT, MUBUF as CDNA_MUBUF, MTBUF as CDNA_MTBUF, SDWA, DPP, VOP1Op as CDNA_VOP1Op, VOP2Op as CDNA_VOP2Op, VOPCOp as CDNA_VOPCOp) def _cdna_src(inst, v, neg, abs_=0, n=1): - s = inst.lit(v) if v == 255 else _fmt_src(v, n) + s = inst.lit(v) if v == 255 else _fmt_src(v, n, cdna=True) if abs_: s = f"|{s}|" return f"neg({s})" if neg and v == 255 else (f"-{s}" if neg else s) - def _disasm_vop3a(inst) -> str: - name, n, cl, om = inst.op_name.lower(), inst.num_srcs(), " clamp" if inst.clmp else "", _omod(inst.omod) - s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.src_regs(0)), _cdna_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.src_regs(1)), _cdna_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.src_regs(2)) - dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}" - if inst.op.value < 256: return f"{name}_e64 {s0}, {s1}" if name.startswith('v_cmpx') else f"{name}_e64 {_fmt_sdst(inst.vdst, 1)}, {s0}, {s1}" - suf = "_e64" if inst.op.value < 512 else "" - return f"{name}{suf} {dst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else (f"{name}{suf}" if name == 'v_nop' else f"{name}{suf} {dst}, {s0}, {s1}{cl}{om}" if n == 2 else f"{name}{suf} {dst}, {s0}{cl}{om}") + # CDNA VOP2 aliases: new opcode name -> old name expected by LLVM tests + _CDNA_VOP3_ALIASES = {'v_fmac_f64': 'v_mul_legacy_f32', 'v_dot2c_f32_bf16': 'v_mac_f32'} - def _disasm_vop3b(inst) -> str: - name, n = inst.op_name.lower(), inst.num_srcs() - s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1), _cdna_src(inst, inst.src1, inst.neg&2), _cdna_src(inst, inst.src2, inst.neg&4) - dst, suf = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}", "_e64" if 'co_' in name else "" + def _disasm_vop3a(inst) -> str: + op_val = inst._values.get('op', 0) # get raw opcode value, not enum value + if hasattr(op_val, 'value'): op_val = op_val.value # in case it's stored as enum + name = inst.op_name.lower() or f'vop3a_op_{op_val}' + from extra.assembly.amd.dsl import spec_num_srcs, spec_regs + n = spec_num_srcs(name) if name else inst.num_srcs() cl, om = " clamp" if inst.clmp else "", _omod(inst.omod) - return f"{name}{suf} {dst}, {_fmt_sdst(inst.sdst, 1)}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name}{suf} {dst}, {_fmt_sdst(inst.sdst, 1)}, {s0}, {s1}{cl}{om}" + orig_name = name + name = _CDNA_VOP3_ALIASES.get(name, name) # apply CDNA aliases + # For aliased ops, recalculate sources without 64-bit assumption + if name != orig_name: + s0, s1 = _cdna_src(inst, inst.src0, inst.neg&1, inst.abs&1, 1), _cdna_src(inst, inst.src1, inst.neg&2, inst.abs&2, 1) + s2 = "" + dst = f"v{inst.vdst}" + else: + dregs, r0, r1, r2 = spec_regs(name) if name else (inst.dst_regs(), inst.src_regs(0), inst.src_regs(1), inst.src_regs(2)) + s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1, inst.abs&1, r0), _cdna_src(inst, inst.src1, inst.neg&2, inst.abs&2, r1), _cdna_src(inst, inst.src2, inst.neg&4, inst.abs&4, r2) + dst = _vreg(inst.vdst, dregs) if dregs > 1 else f"v{inst.vdst}" + # True VOP3 instructions (512+) - 3-source ops + if op_val >= 512: + return f"{name} {dst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name} {dst}, {s0}, {s1}{cl}{om}" + # VOPC (0-255): writes to SGPR pair, VOP2 (256-319): 2-3 src, VOP1 (320-511): 1 src + if op_val < 256: + sdst = _fmt_sdst(inst.vdst, 2, cdna=True) # VOPC writes to 64-bit SGPR pair + # v_cmpx_ also writes to sdst in CDNA VOP3 (unlike VOP32 where it writes to exec) + return f"{name}_e64 {sdst}, {s0}, {s1}{cl}" + if 320 <= op_val < 512: # VOP1 promoted + if name in ('v_nop', 'v_clrexcp'): return f"{name}_e64" + return f"{name}_e64 {dst}, {s0}{cl}{om}" + # VOP2 promoted (256-319) + if name == 'v_cndmask_b32': + s2 = _fmt_src(inst.src2, 2, cdna=True) # src2 is 64-bit SGPR pair + return f"{name}_e64 {dst}, {s0}, {s1}, {s2}{cl}{om}" + if name in ('v_mul_legacy_f32', 'v_mac_f32'): + return f"{name}_e64 {dst}, {s0}, {s1}{cl}{om}" + suf = "_e64" if op_val < 512 else "" + return f"{name}{suf} {dst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name}{suf} {dst}, {s0}, {s1}{cl}{om}" + + # GFX9-specific VOP3B opcodes not in CDNA enum + def _disasm_vop3b(inst) -> str: + op_val = inst._values.get('op', 0) + if hasattr(op_val, 'value'): op_val = op_val.value + name = inst.op_name.lower() or f'vop3b_op_{op_val}' + from extra.assembly.amd.dsl import spec_num_srcs, spec_regs + n = spec_num_srcs(name) if name else inst.num_srcs() + dregs, r0, r1, r2 = spec_regs(name) if name else (inst.dst_regs(), inst.src_regs(0), inst.src_regs(1), inst.src_regs(2)) + s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1, n=r0), _cdna_src(inst, inst.src1, inst.neg&2, n=r1), _cdna_src(inst, inst.src2, inst.neg&4, n=r2) + dst = _vreg(inst.vdst, dregs) if dregs > 1 else f"v{inst.vdst}" + sdst = _fmt_sdst(inst.sdst, 2, cdna=True) # VOP3B sdst is always 64-bit SGPR pair + cl, om = " clamp" if inst.clmp else "", _omod(inst.omod) + # Carry ops need special handling + if name in ('v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'): + s2 = _fmt_src(inst.src2, 2, cdna=True) # src2 is carry-in (64-bit SGPR pair) + return f"{name}_e64 {dst}, {sdst}, {s0}, {s1}, {s2}{cl}{om}" + suf = "_e64" if 'co_' in name else "" + return f"{name}{suf} {dst}, {sdst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name}{suf} {dst}, {sdst}, {s0}, {s1}{cl}{om}" def _disasm_cdna_vop3p(inst) -> str: name, n, is_mfma = inst.op_name.lower(), inst.num_srcs(), 'mfma' in inst.op_name.lower() or 'smfmac' in inst.op_name.lower() - get_src = lambda v, sc: inst.lit(v) if v == 255 else _fmt_src(v, sc) + get_src = lambda v, sc: inst.lit(v) if v == 255 else _fmt_src(v, sc, cdna=True) if is_mfma: sc = 2 if 'iu4' in name else 4 if 'iu8' in name or 'i4' in name else 8 if 'f16' in name or 'bf16' in name else 4; src0, src1, src2, dst = get_src(inst.src0, sc), get_src(inst.src1, sc), get_src(inst.src2, 16), _vreg(inst.vdst, 16) else: src0, src1, src2, dst = get_src(inst.src0, 1), get_src(inst.src1, 1), get_src(inst.src2, 1), f"v{inst.vdst}" opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2) @@ -665,20 +898,93 @@ try: _UNUSED = {0: 'UNUSED_PAD', 1: 'UNUSED_SEXT', 2: 'UNUSED_PRESERVE'} _DPP = {0x130: "wave_shl:1", 0x134: "wave_rol:1", 0x138: "wave_shr:1", 0x13c: "wave_ror:1", 0x140: "row_mirror", 0x141: "row_half_mirror", 0x142: "row_bcast:15", 0x143: "row_bcast:31"} + def _sdwa_src0(v, is_sgpr, sext=0, neg=0, abs_=0): + # s0=0: VGPR (v is VGPR number), s0=1: SGPR/constant (v is encoded like normal src) + s = decode_src(v, cdna=True) if is_sgpr else f"v{v}" + if sext: s = f"sext({s})" + if abs_: s = f"|{s}|" + return f"-{s}" if neg else s + + def _sdwa_vsrc1(v, sext=0, neg=0, abs_=0): + # For VOP2 SDWA, vsrc1 is in vop_op field as raw VGPR number + s = f"v{v}" + if sext: s = f"sext({s})" + if abs_: s = f"|{s}|" + return f"-{s}" if neg else s + + _OMOD_SDWA = {0: "", 1: " mul:2", 2: " mul:4", 3: " div:2"} + def _disasm_sdwa(inst) -> str: - try: name = CDNA_VOP1Op(inst.vop_op).name.lower() - except ValueError: name = f"vop1_op_{inst.vop_op}" - src = f"v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}" if isinstance(inst.src0, int) else str(inst.src0) - mods = [f"dst_sel:{_SEL[inst.dst_sel]}" for _ in [1] if inst.dst_sel != 6] + [f"dst_unused:{_UNUSED[inst.dst_u]}" for _ in [1] if inst.dst_u] + [f"src0_sel:{_SEL[inst.src0_sel]}" for _ in [1] if inst.src0_sel != 6] - return f"{name}_sdwa v{inst.vdst}, {src}" + (" " + " ".join(mods) if mods else "") + # SDWA format: vop2_op=63 -> VOP1, vop2_op=62 -> VOPC, vop2_op=0-61 -> VOP2 + vop2_op = inst.vop2_op + src0 = _sdwa_src0(inst.src0, inst.s0, inst.src0_sext, inst.src0_neg, inst.src0_abs) + clamp = " clamp" if inst.clmp else "" + omod = _OMOD_SDWA.get(inst.omod, "") + if vop2_op == 63: # VOP1 + try: name = CDNA_VOP1Op(inst.vop_op).name.lower() + except ValueError: name = f"vop1_op_{inst.vop_op}" + dst = f"v{inst.vdst}" + mods = [f"dst_sel:{_SEL[inst.dst_sel]}", f"dst_unused:{_UNUSED[inst.dst_u]}", f"src0_sel:{_SEL[inst.src0_sel]}"] + return f"{name}_sdwa {dst}, {src0}{clamp}{omod} " + " ".join(mods) + elif vop2_op == 62: # VOPC + try: name = CDNA_VOPCOp(inst.vdst).name.lower() # opcode is in vdst field for VOPC SDWA + except ValueError: name = f"vopc_op_{inst.vdst}" + src1 = _sdwa_vsrc1(inst.vop_op, inst.src1_sext, inst.src1_neg, inst.src1_abs) # vsrc1 is in vop_op field + # VOPC SDWA: dst encoded in byte 5 (bits 47:40): 0=vcc, 128+n=s[n:n+1] + sdst_enc = inst.dst_sel | (inst.dst_u << 3) | (inst.clmp << 5) | (inst.omod << 6) + if sdst_enc == 0: + sdst = "vcc" + else: + sdst_val = sdst_enc - 128 if sdst_enc >= 128 else sdst_enc + sdst = _fmt_sdst(sdst_val, 2, cdna=True) + mods = [f"src0_sel:{_SEL[inst.src0_sel]}", f"src1_sel:{_SEL[inst.src1_sel]}"] + return f"{name}_sdwa {sdst}, {src0}, {src1} " + " ".join(mods) + else: # VOP2 + try: name = CDNA_VOP2Op(vop2_op).name.lower() + except ValueError: name = f"vop2_op_{vop2_op}" + name = _CDNA_DISASM_ALIASES.get(name, name) # apply aliases (v_fmac -> v_mac, etc.) + dst = f"v{inst.vdst}" + src1 = _sdwa_vsrc1(inst.vop_op, inst.src1_sext, inst.src1_neg, inst.src1_abs) # vsrc1 is in vop_op field + mods = [f"dst_sel:{_SEL[inst.dst_sel]}", f"dst_unused:{_UNUSED[inst.dst_u]}", f"src0_sel:{_SEL[inst.src0_sel]}", f"src1_sel:{_SEL[inst.src1_sel]}"] + # v_cndmask_b32 needs vcc as third operand + if name == 'v_cndmask_b32': + return f"{name}_sdwa {dst}, {src0}, {src1}, vcc{clamp}{omod} " + " ".join(mods) + # Carry ops need vcc - v_addc/subb also need vcc as carry-in + if name in ('v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'): + return f"{name}_sdwa {dst}, vcc, {src0}, {src1}, vcc{clamp}{omod} " + " ".join(mods) + if '_co_' in name: + return f"{name}_sdwa {dst}, vcc, {src0}, {src1}{clamp}{omod} " + " ".join(mods) + return f"{name}_sdwa {dst}, {src0}, {src1}{clamp}{omod} " + " ".join(mods) + + def _dpp_src(v, neg=0, abs_=0): + s = f"v{v}" if v < 256 else f"v{v - 256}" + if abs_: s = f"|{s}|" + return f"-{s}" if neg else s def _disasm_dpp(inst) -> str: - try: name = CDNA_VOP1Op(inst.vop_op).name.lower() - except ValueError: name = f"vop1_op_{inst.vop_op}" - src, ctrl = f"v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}" if isinstance(inst.src0, int) else str(inst.src0), inst.dpp_ctrl + # DPP format: vop2_op=63 -> VOP1, vop2_op=0-62 -> VOP2 + vop2_op = inst.vop2_op + ctrl = inst.dpp_ctrl dpp = f"quad_perm:[{ctrl&3},{(ctrl>>2)&3},{(ctrl>>4)&3},{(ctrl>>6)&3}]" if ctrl < 0x100 else f"row_shl:{ctrl&0xf}" if ctrl < 0x110 else f"row_shr:{ctrl&0xf}" if ctrl < 0x120 else f"row_ror:{ctrl&0xf}" if ctrl < 0x130 else _DPP.get(ctrl, f"dpp_ctrl:0x{ctrl:x}") - mods = [dpp] + [f"row_mask:0x{inst.row_mask:x}" for _ in [1] if inst.row_mask != 0xf] + [f"bank_mask:0x{inst.bank_mask:x}" for _ in [1] if inst.bank_mask != 0xf] + ["bound_ctrl:1" for _ in [1] if inst.bound_ctrl] - return f"{name}_dpp v{inst.vdst}, {src} " + " ".join(mods) + src0 = _dpp_src(inst.src0, inst.src0_neg, inst.src0_abs) + # DPP modifiers: row_mask and bank_mask always shown, bound_ctrl:0 when bit=1 + mods = [dpp, f"row_mask:0x{inst.row_mask:x}", f"bank_mask:0x{inst.bank_mask:x}"] + (["bound_ctrl:0"] if inst.bound_ctrl else []) + if vop2_op == 63: # VOP1 + try: name = CDNA_VOP1Op(inst.vop_op).name.lower() + except ValueError: name = f"vop1_op_{inst.vop_op}" + return f"{name}_dpp v{inst.vdst}, {src0} " + " ".join(mods) + else: # VOP2 + try: name = CDNA_VOP2Op(vop2_op).name.lower() + except ValueError: name = f"vop2_op_{vop2_op}" + name = _CDNA_DISASM_ALIASES.get(name, name) + src1 = _dpp_src(inst.vop_op, inst.src1_neg, inst.src1_abs) # vsrc1 is in vop_op field + if name == 'v_cndmask_b32': + return f"{name}_dpp v{inst.vdst}, {src0}, {src1}, vcc " + " ".join(mods) + if name in ('v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'): + return f"{name}_dpp v{inst.vdst}, vcc, {src0}, {src1}, vcc " + " ".join(mods) + if '_co_' in name: + return f"{name}_dpp v{inst.vdst}, vcc, {src0}, {src1} " + " ".join(mods) + return f"{name}_dpp v{inst.vdst}, {src0}, {src1} " + " ".join(mods) # Register CDNA handlers - shared formats use merged disassemblers, CDNA-only formats use dedicated ones DISASM_HANDLERS.update({CDNA_VOP1: _disasm_vop1, CDNA_VOP2: _disasm_vop2, CDNA_VOPC: _disasm_vopc, diff --git a/extra/assembly/amd/autogen/cdna/ins.py b/extra/assembly/amd/autogen/cdna/ins.py index 0451136caf..3b6e278a9f 100644 --- a/extra/assembly/amd/autogen/cdna/ins.py +++ b/extra/assembly/amd/autogen/cdna/ins.py @@ -56,6 +56,7 @@ class MTBUF(Inst64): srsrc:SGPRField = bits[52:48] soffset:SSrc = bits[63:56] offset:Imm = bits[11:0] + format = bits[25:19] offen = bits[12] idxen = bits[13] sc1 = bits[53] @@ -124,7 +125,7 @@ class SMEM(Inst64): sbase:SGPRField = bits[5:0] soffset:SSrc = bits[63:57] offset:Imm = bits[52:32] - glc = bits[14] + glc = bits[16] soe = bits[14] nv = bits[15] imm = bits[17] diff --git a/extra/assembly/amd/autogen/rdna3/enum.py b/extra/assembly/amd/autogen/rdna3/enum.py index 90bb9dff1c..e7f350fa2d 100644 --- a/extra/assembly/amd/autogen/rdna3/enum.py +++ b/extra/assembly/amd/autogen/rdna3/enum.py @@ -1594,3 +1594,68 @@ class VOPDOp(IntEnum): V_DUAL_ADD_NC_U32 = 16 V_DUAL_LSHLREV_B32 = 17 V_DUAL_AND_B32 = 18 + +class BufFmt(IntEnum): + BUF_FMT_8_UNORM = 1 + BUF_FMT_8_SNORM = 2 + BUF_FMT_8_USCALED = 3 + BUF_FMT_8_SSCALED = 4 + BUF_FMT_8_UINT = 5 + BUF_FMT_8_SINT = 6 + BUF_FMT_16_UNORM = 7 + BUF_FMT_16_SNORM = 8 + BUF_FMT_16_USCALED = 9 + BUF_FMT_16_SSCALED = 10 + BUF_FMT_16_UINT = 11 + BUF_FMT_16_SINT = 12 + BUF_FMT_16_FLOAT = 13 + BUF_FMT_8_8_UNORM = 14 + BUF_FMT_8_8_SNORM = 15 + BUF_FMT_8_8_USCALED = 16 + BUF_FMT_8_8_SSCALED = 17 + BUF_FMT_8_8_UINT = 18 + BUF_FMT_8_8_SINT = 19 + BUF_FMT_32_UINT = 20 + BUF_FMT_32_SINT = 21 + BUF_FMT_32_FLOAT = 22 + BUF_FMT_16_16_UNORM = 23 + BUF_FMT_16_16_SNORM = 24 + BUF_FMT_16_16_USCALED = 25 + BUF_FMT_16_16_SSCALED = 26 + BUF_FMT_16_16_UINT = 27 + BUF_FMT_16_16_SINT = 28 + BUF_FMT_16_16_FLOAT = 29 + BUF_FMT_10_11_11_FLOAT = 30 + BUF_FMT_11_11_10_FLOAT = 31 + BUF_FMT_10_10_10_2_UNORM = 32 + BUF_FMT_10_10_10_2_SNORM = 33 + BUF_FMT_10_10_10_2_UINT = 34 + BUF_FMT_10_10_10_2_SINT = 35 + BUF_FMT_2_10_10_10_UNORM = 36 + BUF_FMT_2_10_10_10_SNORM = 37 + BUF_FMT_2_10_10_10_USCALED = 38 + BUF_FMT_2_10_10_10_SSCALED = 39 + BUF_FMT_2_10_10_10_UINT = 40 + BUF_FMT_2_10_10_10_SINT = 41 + BUF_FMT_8_8_8_8_UNORM = 42 + BUF_FMT_8_8_8_8_SNORM = 43 + BUF_FMT_8_8_8_8_USCALED = 44 + BUF_FMT_8_8_8_8_SSCALED = 45 + BUF_FMT_8_8_8_8_UINT = 46 + BUF_FMT_8_8_8_8_SINT = 47 + BUF_FMT_32_32_UINT = 48 + BUF_FMT_32_32_SINT = 49 + BUF_FMT_32_32_FLOAT = 50 + BUF_FMT_16_16_16_16_UNORM = 51 + BUF_FMT_16_16_16_16_SNORM = 52 + BUF_FMT_16_16_16_16_USCALED = 53 + BUF_FMT_16_16_16_16_SSCALED = 54 + BUF_FMT_16_16_16_16_UINT = 55 + BUF_FMT_16_16_16_16_SINT = 56 + BUF_FMT_16_16_16_16_FLOAT = 57 + BUF_FMT_32_32_32_UINT = 58 + BUF_FMT_32_32_32_SINT = 59 + BUF_FMT_32_32_32_FLOAT = 60 + BUF_FMT_32_32_32_32_UINT = 61 + BUF_FMT_32_32_32_32_SINT = 62 + BUF_FMT_32_32_32_32_FLOAT = 63 diff --git a/extra/assembly/amd/autogen/rdna4/enum.py b/extra/assembly/amd/autogen/rdna4/enum.py index b3057ec856..cf5a3e8796 100644 --- a/extra/assembly/amd/autogen/rdna4/enum.py +++ b/extra/assembly/amd/autogen/rdna4/enum.py @@ -1627,3 +1627,52 @@ class VSCRATCHOp(IntEnum): SCRATCH_STORE_D16_HI_B16 = 37 SCRATCH_LOAD_BLOCK = 83 SCRATCH_STORE_BLOCK = 84 + +class BufFmt(IntEnum): + BUF_FMT_8_UNORM = 1 + BUF_FMT_8_SNORM = 2 + BUF_FMT_8_USCALED = 3 + BUF_FMT_8_SSCALED = 4 + BUF_FMT_8_UINT = 5 + BUF_FMT_8_SINT = 6 + BUF_FMT_16_UNORM = 7 + BUF_FMT_16_SNORM = 8 + BUF_FMT_16_USCALED = 9 + BUF_FMT_16_SSCALED = 10 + BUF_FMT_16_UINT = 11 + BUF_FMT_16_SINT = 12 + BUF_FMT_16_FLOAT = 13 + BUF_FMT_8_8_UNORM = 14 + BUF_FMT_8_8_SNORM = 15 + BUF_FMT_8_8_USCALED = 16 + BUF_FMT_8_8_SSCALED = 17 + BUF_FMT_8_8_UINT = 18 + BUF_FMT_8_8_SINT = 19 + BUF_FMT_32_UINT = 20 + BUF_FMT_32_SINT = 21 + BUF_FMT_32_FLOAT = 22 + BUF_FMT_16_16_UNORM = 23 + BUF_FMT_10_10_10_2_UNORM = 32 + BUF_FMT_10_10_10_2_SNORM = 33 + BUF_FMT_10_10_10_2_UINT = 34 + BUF_FMT_10_10_10_2_SINT = 35 + BUF_FMT_2_10_10_10_UNORM = 36 + BUF_FMT_2_10_10_10_SNORM = 37 + BUF_FMT_2_10_10_10_USCALED = 38 + BUF_FMT_2_10_10_10_SSCALED = 39 + BUF_FMT_2_10_10_10_UINT = 40 + BUF_FMT_2_10_10_10_SINT = 41 + BUF_FMT_8_8_8_8_UNORM = 42 + BUF_FMT_8_8_8_8_SNORM = 43 + BUF_FMT_8_8_8_8_USCALED = 44 + BUF_FMT_8_8_8_8_SSCALED = 45 + BUF_FMT_8_8_8_8_UINT = 46 + BUF_FMT_8_8_8_8_SINT = 47 + BUF_FMT_32_32_UINT = 48 + BUF_FMT_32_32_SINT = 49 + BUF_FMT_32_32_FLOAT = 50 + BUF_FMT_16_16_16_16_UNORM = 51 + BUF_FMT_16_16_16_16_SNORM = 52 + BUF_FMT_16_16_16_16_USCALED = 53 + BUF_FMT_16_16_16_16_SSCALED = 54 + BUF_FMT_16_16_16_16_UINT = 55 diff --git a/extra/assembly/amd/dsl.py b/extra/assembly/amd/dsl.py index f374f7116d..3de0c77a8e 100644 --- a/extra/assembly/amd/dsl.py +++ b/extra/assembly/amd/dsl.py @@ -7,6 +7,7 @@ from functools import cache from typing import overload, Annotated, TypeVar, Generic from extra.assembly.amd.autogen.rdna3.enum import (VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, VOPDOp, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, DSOp, FLATOp, MUBUFOp, MTBUFOp, MIMGOp, VINTERPOp) +from extra.assembly.amd.autogen.cdna.enum import VOP1Op as CDNA_VOP1Op, VOP2Op as CDNA_VOP2Op # Common masks and bit conversion functions MASK32, MASK64, MASK128 = 0xffffffff, 0xffffffffffffffff, (1 << 128) - 1 @@ -46,12 +47,14 @@ def _i64(f): # Instruction spec - register counts and dtypes derived from instruction names _REGS = {'B32': 1, 'B64': 2, 'B96': 3, 'B128': 4, 'B256': 8, 'B512': 16, 'F32': 1, 'I32': 1, 'U32': 1, 'F64': 2, 'I64': 2, 'U64': 2, - 'F16': 1, 'I16': 1, 'U16': 1, 'B16': 1, 'I8': 1, 'U8': 1, 'B8': 1} + 'F16': 1, 'I16': 1, 'U16': 1, 'B16': 1, 'I8': 1, 'U8': 1, 'B8': 1, + 'DWORD': 1, 'DWORDX2': 2, 'DWORDX3': 3, 'DWORDX4': 4, 'DWORDX8': 8, 'DWORDX16': 16, + 'BYTE': 1, 'SHORT': 1, 'UBYTE': 1, 'SBYTE': 1, 'USHORT': 1, 'SSHORT': 1} _CVT_RE = re.compile(r'CVT_([FIUB]\d+)_([FIUB]\d+)$') _MAD_MUL_RE = re.compile(r'(?:MAD|MUL)_([IU]\d+)_([IU]\d+)$') _PACK_RE = re.compile(r'PACK_([FIUB]\d+)_([FIUB]\d+)$') _DST_SRC_RE = re.compile(r'_([FIUB]\d+)_([FIUB]\d+)$') -_SINGLE_RE = re.compile(r'_([FIUB](?:32|64|16|8|96|128|256|512))$') +_SINGLE_RE = re.compile(r'_([FIUB](?:32|64|16|8|96|128|256|512)|DWORD(?:X(?:2|3|4|8|16))?|[US]?BYTE|[US]?SHORT)$') @cache def _suffix(name: str) -> tuple[str | None, str | None]: name = name.upper() @@ -242,7 +245,11 @@ def unwrap(val) -> int: FLOAT_ENC = {0.5: 240, -0.5: 241, 1.0: 242, -1.0: 243, 2.0: 244, -2.0: 245, 4.0: 246, -4.0: 247} FLOAT_DEC = {v: str(k) for k, v in FLOAT_ENC.items()} SPECIAL_GPRS = {106: "vcc_lo", 107: "vcc_hi", 124: "null", 125: "m0", 126: "exec_lo", 127: "exec_hi", 253: "scc"} +SPECIAL_GPRS_CDNA = {102: "flat_scratch_lo", 103: "flat_scratch_hi", 104: "xnack_mask_lo", 105: "xnack_mask_hi", + 106: "vcc_lo", 107: "vcc_hi", 124: "m0", 126: "exec_lo", 127: "exec_hi", + 251: "src_vccz", 252: "src_execz", 253: "src_scc", 254: "src_lds_direct"} SPECIAL_PAIRS = {106: "vcc", 126: "exec"} +SPECIAL_PAIRS_CDNA = {102: "flat_scratch", 104: "xnack_mask", 106: "vcc", 126: "exec"} SRC_FIELDS = {'src0', 'src1', 'src2', 'ssrc0', 'ssrc1', 'soffset', 'srcx0', 'srcy0'} RAW_FIELDS = {'vdata', 'vdst', 'vaddr', 'addr', 'data', 'data0', 'data1', 'sdst', 'sdata', 'vsrc1'} @@ -259,9 +266,10 @@ def encode_src(val) -> int: if isinstance(val, int): return 128 + val if 0 <= val <= 64 else 192 - val if -16 <= val <= -1 else 255 return 255 -def decode_src(val: int) -> str: +def decode_src(val: int, cdna: bool = False) -> str: + special = SPECIAL_GPRS_CDNA if cdna else SPECIAL_GPRS + if val in special: return special[val] if val <= 105: return f"s{val}" - if val in SPECIAL_GPRS: return SPECIAL_GPRS[val] if val in FLOAT_DEC: return FLOAT_DEC[val] if 108 <= val <= 123: return f"ttmp{val - 108}" if 128 <= val <= 192: return str(val - 128) @@ -385,14 +393,14 @@ class Inst: if name in SRC_FIELDS: self._encode_src(name, val) elif name in RAW_FIELDS: self._encode_raw(name, val) elif name == 'sbase': self._values[name] = (val.idx if isinstance(val, Reg) else val.val if isinstance(val, SrcMod) else val * 2) // 2 - elif name in {'srsrc', 'ssamp'} and isinstance(val, Reg): self._values[name] = val.idx // 4 + elif name in {'srsrc', 'ssamp'} and isinstance(val, Reg): self._values[name] = _encode_reg(val) // 4 elif marker is _VDSTYEnc and isinstance(val, VGPR): self._values[name] = val.idx >> 1 self._precompute_fields() def _encode_field(self, name: str, val) -> int: if isinstance(val, RawImm): return val.val if isinstance(val, SrcMod) and not isinstance(val, Reg): return val.val # Special regs like VCC_LO - if name in {'srsrc', 'ssamp'}: return val.idx // 4 if isinstance(val, Reg) else val + if name in {'srsrc', 'ssamp'}: return _encode_reg(val) // 4 if isinstance(val, Reg) else val if name == 'sbase': return val.idx // 2 if isinstance(val, Reg) else val.val // 2 if isinstance(val, SrcMod) else val if name in RAW_FIELDS: return _encode_reg(val) if isinstance(val, Reg) else val if isinstance(val, Reg) or name in SRC_FIELDS: return encode_src(val) @@ -506,7 +514,7 @@ class Inst: lit32 = (self._literal >> 32) if self._literal > 0xffffffff else self._literal s = f"0x{lit32:x}" else: - s = decode_src(v) + s = decode_src(v, 'cdna' in self.__class__.__module__) return f"-{s}" if neg else s def __eq__(self, other): @@ -532,21 +540,30 @@ class Inst: elif hasattr(val, 'name'): self.op = val else: cls_name = self.__class__.__name__ - # VOP3 with VOPC opcodes (0-255) -> VOPCOp, VOP3SD opcodes -> VOP3SDOp - if cls_name == 'VOP3': - try: - if val < 256: self.op = VOPCOp(val) - elif val in self._VOP3SD_OPS: self.op = VOP3SDOp(val) - else: self.op = VOP3Op(val) - except ValueError: self.op = val - # Prefer BitField marker (class-specific enum) over _enum_map (generic RDNA3 enums) - elif 'op' in self._fields and (marker := self._fields['op'].marker) and issubclass(marker, IntEnum): + is_cdna = cls_name in ('VOP3A', 'VOP3B') + # Try marker enum first (VOP3AOp, VOP3BOp, etc.) + marker = self._fields['op'].marker if 'op' in self._fields else None + if marker and issubclass(marker, IntEnum): try: self.op = marker(val) except ValueError: self.op = val elif cls_name in self._enum_map: try: self.op = self._enum_map[cls_name](val) except ValueError: self.op = val else: self.op = val + # Fallback for promoted instructions when marker lookup failed + if not hasattr(self.op, 'name') and cls_name in ('VOP3', 'VOP3A', 'VOP3B') and isinstance(val, int): + if val < 256: + try: self.op = VOPCOp(val) + except ValueError: pass + elif is_cdna and 256 <= val < 512: + try: self.op = (CDNA_VOP1Op(val - 320) if val >= 320 else CDNA_VOP2Op(val - 256)) + except ValueError: pass + elif val in self._VOP3SD_OPS and not is_cdna: + try: self.op = VOP3SDOp(val) + except ValueError: pass + elif 256 <= val < 512 and not is_cdna: + try: self.op = VOP1Op(val - 384) if val >= 384 else VOP2Op(val - 256) + except ValueError: pass self.op_name = self.op.name if hasattr(self.op, 'name') else '' self._spec_regs = spec_regs(self.op_name) self._spec_dtype = spec_dtype(self.op_name) diff --git a/extra/assembly/amd/emu.py b/extra/assembly/amd/emu.py index fd4def7a52..5de7a84e88 100644 --- a/extra/assembly/amd/emu.py +++ b/extra/assembly/amd/emu.py @@ -291,14 +291,16 @@ def exec_vop(st: WaveState, inst: Inst, V: list, lane: int) -> None: extra_kwargs = {'opsel': opsel, 'opsel_hi': inst.opsel_hi | (inst.opsel_hi2 << 2)} if isinstance(inst, VOP3P) and 'FMA_MIX' in inst.op_name else {} result = inst._fn(s0, s1, s2, d0, st.scc, vcc_for_fn, lane, st.exec_mask, inst._literal, st.vgpr, src0_idx, vdst, **extra_kwargs) + # Check if this is a VOPC instruction (either standalone VOPC or VOP3 with VOPC opcode) + is_vopc = isinstance(inst.op, VOPCOp) or (isinstance(inst, VOP3) and inst.op.value < 256) if 'VCC' in result: if isinstance(inst, VOP3SD): st.pend_sgpr_lane(inst.sdst, lane, (result['VCC'] >> lane) & 1) else: st.pend_sgpr_lane(VCC_LO if isinstance(inst, VOP2) and 'CO_CI' in inst.op_name else vdst, lane, (result['VCC'] >> lane) & 1) if 'EXEC' in result: st.pend_sgpr_lane(EXEC_LO, lane, (result['EXEC'] >> lane) & 1) - elif isinstance(inst.op, VOPCOp): + elif is_vopc: st.pend_sgpr_lane(vdst, lane, (result['D0'] >> lane) & 1) - if not isinstance(inst.op, VOPCOp): + if not is_vopc: d0_val = result['D0'] if inst.dst_regs() == 2: V[vdst], V[vdst + 1] = d0_val & MASK32, (d0_val >> 32) & MASK32 elif not isinstance(inst, VOP3P) and inst.is_dst_16(): V[vdst] = _dst16(V[vdst], d0_val, bool(opsel & 8) if isinstance(inst, VOP3) else dst_hi) diff --git a/extra/assembly/amd/pdf.py b/extra/assembly/amd/pdf.py index 1728ce578b..ce6ecd80f8 100644 --- a/extra/assembly/amd/pdf.py +++ b/extra/assembly/amd/pdf.py @@ -185,8 +185,8 @@ def _parse_single_pdf(url: str): break formats[fmt_name] = fields - # Fix known PDF errors - if 'SMEM' in formats: + # Fix known PDF errors (RDNA-specific SMEM bit positions) + if 'SMEM' in formats and not is_cdna: formats['SMEM'] = [(n, 13 if n == 'DLC' else 14 if n == 'GLC' else h, 13 if n == 'DLC' else 14 if n == 'GLC' else l, e, t) for n, h, l, e, t in formats['SMEM']] # RDNA4: VFLAT/VGLOBAL/VSCRATCH OP field is [20:14] not [20:13] (PDF documentation error) @@ -209,6 +209,11 @@ def _parse_single_pdf(url: str): if 'FLATOp' in enums: for k, v in {40: 'GLOBAL_LOAD_ADDTID_B32', 41: 'GLOBAL_STORE_ADDTID_B32', 55: 'FLAT_ATOMIC_CSUB_U32'}.items(): assert k not in enums['FLATOp']; enums['FLATOp'][k] = v + # CDNA MTBUF: PDF is missing the FORMAT field (bits[25:19]) which is required for tbuffer_* instructions + if is_cdna and 'MTBUF' in formats: + field_names = {f[0] for f in formats['MTBUF']} + if 'FORMAT' not in field_names: + formats['MTBUF'].append(('FORMAT', 25, 19, None, None)) # CDNA SDWA/DPP: PDF only has modifier fields, need VOP1/VOP2 overlay for correct encoding if is_cdna: if 'SDWA' in formats: @@ -229,7 +234,20 @@ def _parse_single_pdf(url: str): snippet = all_text[start:end].strip() if pseudocode := _extract_pseudocode(snippet): raw_pseudocode[(name, opcode)] = pseudocode - return {"formats": formats, "enums": enums, "src_enum": src_enum, "doc_name": doc_name, "pseudocode": raw_pseudocode, "is_cdna": is_cdna} + # Extract unified buffer format table (RDNA only, for MTBUF format field) + buf_fmt = {} + if not is_cdna: + for i in range(total_pages): + for t in pdf.tables(i): + if t and len(t) > 2 and t[0] and '#' in str(t[0][0]) and 'Format' in str(t[0]): + for row in t[1:]: + for j in range(0, len(row) - 1, 3): # table has 3-column groups: #, Format, (empty) + if row[j] and row[j].isdigit() and row[j+1] and re.match(r'^[\d_]+_(UNORM|SNORM|USCALED|SSCALED|UINT|SINT|FLOAT)$', row[j+1]): + buf_fmt[int(row[j])] = row[j+1] + if buf_fmt: break + if buf_fmt: break + + return {"formats": formats, "enums": enums, "src_enum": src_enum, "doc_name": doc_name, "pseudocode": raw_pseudocode, "is_cdna": is_cdna, "buf_fmt": buf_fmt} def _extract_pseudocode(text: str) -> str | None: """Extract pseudocode from an instruction description snippet.""" @@ -258,7 +276,7 @@ def _extract_pseudocode(text: str) -> str | None: def _merge_results(results: list[dict]) -> dict: """Merge multiple PDF parse results into a superset.""" - merged = {"formats": {}, "enums": {}, "src_enum": dict(SRC_EXTRAS), "doc_names": [], "pseudocode": {}, "is_cdna": False} + merged = {"formats": {}, "enums": {}, "src_enum": dict(SRC_EXTRAS), "doc_names": [], "pseudocode": {}, "is_cdna": False, "buf_fmt": {}} for r in results: merged["doc_names"].append(r["doc_name"]) merged["is_cdna"] = merged["is_cdna"] or r["is_cdna"] @@ -279,17 +297,20 @@ def _merge_results(results: list[dict]) -> dict: else: merged["formats"][fmt_name].append(f) for key, pc in r["pseudocode"].items(): if key not in merged["pseudocode"]: merged["pseudocode"][key] = pc + for val, name in r.get("buf_fmt", {}).items(): + if val not in merged["buf_fmt"]: merged["buf_fmt"][val] = name return merged # ═══════════════════════════════════════════════════════════════════════════════ # CODE GENERATION # ═══════════════════════════════════════════════════════════════════════════════ -def _generate_enum_py(enums, src_enum, doc_name) -> str: +def _generate_enum_py(enums, src_enum, doc_name, buf_fmt=None) -> str: """Generate enum.py content (just enums, no dsl.py dependency).""" def enum_lines(name, items): return [f"class {name}(IntEnum):"] + [f" {n} = {v}" for v, n in sorted(items.items())] + [""] lines = [f"# autogenerated from AMD {doc_name} ISA PDF by pdf.py - do not edit", "from enum import IntEnum", ""] lines += enum_lines("SrcEnum", src_enum) + sum([enum_lines(n, ops) for n, ops in sorted(enums.items())], []) + if buf_fmt: lines += enum_lines("BufFmt", {v: f"BUF_FMT_{n}" for v, n in buf_fmt.items() if 1 <= v <= 63}) return '\n'.join(lines) def _generate_ins_py(formats, enums, src_enum, doc_name) -> str: @@ -398,9 +419,10 @@ def generate_arch(arch: str) -> dict: # Write enum.py (enums only, no dsl.py dependency) enum_path = base_path / "enum.py" - enum_content = _generate_enum_py(merged["enums"], merged["src_enum"], doc_name) + enum_content = _generate_enum_py(merged["enums"], merged["src_enum"], doc_name, merged.get("buf_fmt")) enum_path.write_text(enum_content) - print(f"Generated {enum_path}: SrcEnum ({len(merged['src_enum'])}) + {len(merged['enums'])} enums") + buf_fmt_count = len([v for v in merged.get("buf_fmt", {}) if 1 <= v <= 63]) + print(f"Generated {enum_path}: SrcEnum ({len(merged['src_enum'])}) + {len(merged['enums'])} enums" + (f" + BufFmt ({buf_fmt_count})" if buf_fmt_count else "")) # Write ins.py (instruction formats and helpers, imports dsl.py and enum.py) ins_path = base_path / "ins.py" diff --git a/extra/assembly/amd/test/test_llvm.py b/extra/assembly/amd/test/test_llvm.py index 5bd7779c9c..770db3b885 100644 --- a/extra/assembly/amd/test/test_llvm.py +++ b/extra/assembly/amd/test/test_llvm.py @@ -1,192 +1,102 @@ #!/usr/bin/env python3 -"""Test RDNA3 assembler/disassembler against LLVM test vectors.""" -import unittest, re, subprocess +"""Test AMD assembler/disassembler against LLVM test vectors.""" +import unittest, re, subprocess, functools from tinygrad.helpers import fetch -from extra.assembly.amd.autogen.rdna3.ins import * -from extra.assembly.amd.asm import asm +from extra.assembly.amd.asm import asm, disasm, detect_format from extra.assembly.amd.test.helpers import get_llvm_mc LLVM_BASE = "https://raw.githubusercontent.com/llvm/llvm-project/main/llvm/test/MC/AMDGPU" -# Format info: (filename, format_class, op_enum) -LLVM_TEST_FILES = { - # Scalar ALU - 'sop1': ('gfx11_asm_sop1.s', SOP1, SOP1Op), - 'sop2': ('gfx11_asm_sop2.s', SOP2, SOP2Op), - 'sopp': ('gfx11_asm_sopp.s', SOPP, SOPPOp), - 'sopk': ('gfx11_asm_sopk.s', SOPK, SOPKOp), - 'sopc': ('gfx11_asm_sopc.s', SOPC, SOPCOp), - # Vector ALU - 'vop1': ('gfx11_asm_vop1.s', VOP1, VOP1Op), - 'vop2': ('gfx11_asm_vop2.s', VOP2, VOP2Op), - 'vopc': ('gfx11_asm_vopc.s', VOPC, VOPCOp), - 'vop3': ('gfx11_asm_vop3.s', VOP3, VOP3Op), - 'vop3p': ('gfx11_asm_vop3p.s', VOP3P, VOP3POp), - 'vop3sd': ('gfx11_asm_vop3.s', VOP3SD, VOP3SDOp), # VOP3SD shares file with VOP3 - 'vinterp': ('gfx11_asm_vinterp.s', VINTERP, VINTERPOp), - 'vopd': ('gfx11_asm_vopd.s', VOPD, VOPDOp), - 'vopcx': ('gfx11_asm_vopcx.s', VOPC, VOPCOp), # VOPCX uses VOPC format - # VOP3 promotions (VOP1/VOP2/VOPC promoted to VOP3 encoding) - 'vop3_from_vop1': ('gfx11_asm_vop3_from_vop1.s', VOP3, VOP3Op), - 'vop3_from_vop2': ('gfx11_asm_vop3_from_vop2.s', VOP3, VOP3Op), - 'vop3_from_vopc': ('gfx11_asm_vop3_from_vopc.s', VOP3, VOP3Op), - 'vop3_from_vopcx': ('gfx11_asm_vop3_from_vopcx.s', VOP3, VOP3Op), - # Memory - 'ds': ('gfx11_asm_ds.s', DS, DSOp), - 'smem': ('gfx11_asm_smem.s', SMEM, SMEMOp), - 'flat': ('gfx11_asm_flat.s', FLAT, FLATOp), - 'mubuf': ('gfx11_asm_mubuf.s', MUBUF, MUBUFOp), - 'mtbuf': ('gfx11_asm_mtbuf.s', MTBUF, MTBUFOp), - 'mimg': ('gfx11_asm_mimg.s', MIMG, MIMGOp), - # WMMA (matrix multiply) - 'wmma': ('gfx11_asm_wmma.s', VOP3P, VOP3POp), - # Additional features - 'vop3_features': ('gfx11_asm_vop3_features.s', VOP3, VOP3Op), - 'vop3p_features': ('gfx11_asm_vop3p_features.s', VOP3P, VOP3POp), - 'vopd_features': ('gfx11_asm_vopd_features.s', VOPD, VOPDOp), - # Alias files (alternative mnemonics) - 'vop3_alias': ('gfx11_asm_vop3_alias.s', VOP3, VOP3Op), - 'vop3p_alias': ('gfx11_asm_vop3p_alias.s', VOP3P, VOP3POp), - 'vopc_alias': ('gfx11_asm_vopc_alias.s', VOPC, VOPCOp), - 'vopcx_alias': ('gfx11_asm_vopcx_alias.s', VOPC, VOPCOp), - 'vinterp_alias': ('gfx11_asm_vinterp_alias.s', VINTERP, VINTERPOp), - 'smem_alias': ('gfx11_asm_smem_alias.s', SMEM, SMEMOp), - 'mubuf_alias': ('gfx11_asm_mubuf_alias.s', MUBUF, MUBUFOp), - 'mtbuf_alias': ('gfx11_asm_mtbuf_alias.s', MTBUF, MTBUFOp), -} +RDNA_FILES = ['gfx11_asm_sop1.s', 'gfx11_asm_sop2.s', 'gfx11_asm_sopp.s', 'gfx11_asm_sopk.s', 'gfx11_asm_sopc.s', + 'gfx11_asm_vop1.s', 'gfx11_asm_vop2.s', 'gfx11_asm_vopc.s', 'gfx11_asm_vop3.s', 'gfx11_asm_vop3p.s', 'gfx11_asm_vinterp.s', + 'gfx11_asm_vopd.s', 'gfx11_asm_vopcx.s', 'gfx11_asm_vop3_from_vop1.s', 'gfx11_asm_vop3_from_vop2.s', 'gfx11_asm_vop3_from_vopc.s', + 'gfx11_asm_vop3_from_vopcx.s', 'gfx11_asm_ds.s', 'gfx11_asm_smem.s', 'gfx11_asm_flat.s', 'gfx11_asm_mubuf.s', 'gfx11_asm_mtbuf.s', + 'gfx11_asm_mimg.s', 'gfx11_asm_wmma.s', 'gfx11_asm_vop3_features.s', 'gfx11_asm_vop3p_features.s', 'gfx11_asm_vopd_features.s', + 'gfx11_asm_vop3_alias.s', 'gfx11_asm_vop3p_alias.s', 'gfx11_asm_vopc_alias.s', 'gfx11_asm_vopcx_alias.s', 'gfx11_asm_vinterp_alias.s', + 'gfx11_asm_smem_alias.s', 'gfx11_asm_mubuf_alias.s', 'gfx11_asm_mtbuf_alias.s'] +# CDNA test files - includes gfx9 files for shared instructions, plus gfx90a/gfx942 specific files +# gfx90a_ldst_acc.s has MIMG mixed in, filtered via is_mimg check +CDNA_FILES = ['gfx9_asm_sop1.s', 'gfx9_asm_sop2.s', 'gfx9_asm_sopp.s', 'gfx9_asm_sopk.s', 'gfx9_asm_sopc.s', + 'gfx9_asm_vop1.s', 'gfx9_asm_vop2.s', 'gfx9_asm_vopc.s', 'gfx9_asm_vop3.s', 'gfx9_asm_vop3p.s', + 'gfx9_asm_ds.s', 'gfx9_asm_flat.s', 'gfx9_asm_smem.s', 'gfx9_asm_mubuf.s', 'gfx9_asm_mtbuf.s', + 'gfx90a_ldst_acc.s', 'gfx90a_asm_features.s', 'flat-scratch-gfx942.s', 'gfx942_asm_features.s', + 'mai-gfx90a.s', 'mai-gfx942.s'] -def parse_llvm_tests(text: str) -> list[tuple[str, bytes]]: - """Parse LLVM test format into (asm, expected_bytes) pairs.""" - tests, lines = [], text.split('\n') - for i, line in enumerate(lines): - line = line.strip() - if not line or line.startswith(('//', '.', ';')): continue - asm_text = line.split('//')[0].strip() - if not asm_text: continue - for j in range(i, min(i + 3, len(lines))): - # Match GFX11, W32, or W64 encodings (all valid for gfx11) - # Format 1: "// GFX11: v_foo ... ; encoding: [0x01,0x02,...]" - # Format 2: "// GFX11: [0x01,0x02,...]" (used by DS, older files) - if m := re.search(r'(?:GFX11|W32|W64)[^:]*:.*?encoding:\s*\[(.*?)\]', lines[j]): - hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '') - elif m := re.search(r'(?:GFX11|W32|W64)[^:]*:\s*\[(0x[0-9a-fA-F,x\s]+)\]', lines[j]): - hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '') - else: - continue - if hex_bytes: - try: tests.append((asm_text, bytes.fromhex(hex_bytes))) - except ValueError: pass - break +def _is_mimg(data: bytes) -> bool: return (int.from_bytes(data[:4], 'little') >> 26) & 0x3f == 0b111100 + +def _parse_llvm_tests(text: str, pattern: str) -> list[tuple[str, bytes]]: + tests = [] + for block in text.split('\n\n'): + asm_text, encoding = None, None + for line in block.split('\n'): + line = line.strip() + if not line or line.startswith(('.', ';')): continue + if not line.startswith('//'): + asm_text = line.split('//')[0].strip() or asm_text + if m := re.search(pattern + r'[^:]*:.*?(?:encoding:\s*)?\[(0x[0-9a-f,x\s]+)\]', line, re.I): + encoding = m.group(1).replace('0x', '').replace(',', '').replace(' ', '') + if asm_text and encoding: + try: tests.append((asm_text, bytes.fromhex(encoding))) + except ValueError: pass return tests -def try_assemble(text: str): - """Try to assemble instruction text, return bytes or None on failure.""" - try: return asm(text).to_bytes() - except: return None +@functools.cache +def _get_tests(f: str, arch: str) -> list[tuple[str, bytes]]: + text = fetch(f"{LLVM_BASE}/{f}").read_bytes().decode('utf-8', errors='ignore') + if arch == "rdna3": + tests = _parse_llvm_tests(text, r'(?:GFX11|W32|W64)') + elif 'gfx90a' in f or 'gfx942' in f: + tests = _parse_llvm_tests(text, r'(?:GFX90A|GFX942)') + else: + tests = _parse_llvm_tests(text, r'(?:VI9|GFX9|CHECK)') + return [(a, d) for a, d in tests if not _is_mimg(d)] if arch == "cdna" else tests -def compile_asm_batch(instrs: list[str]) -> list[bytes]: - """Compile multiple instructions with a single llvm-mc call.""" +def _compile_asm_batch(instrs: list[str]) -> list[bytes]: if not instrs: return [] - asm_text = ".text\n" + "\n".join(instrs) + "\n" - result = subprocess.run( - [get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'], - input=asm_text, capture_output=True, text=True, timeout=30) - if result.returncode != 0: raise RuntimeError(f"llvm-mc batch failed: {result.stderr.strip()}") - # Parse all encodings from output - results = [] - for line in result.stdout.split('\n'): - if 'encoding:' not in line: continue - enc = line.split('encoding:')[1].strip() - if enc.startswith('[') and enc.endswith(']'): - results.append(bytes.fromhex(enc[1:-1].replace('0x', '').replace(',', '').replace(' ', ''))) - if len(results) != len(instrs): raise RuntimeError(f"expected {len(instrs)} encodings, got {len(results)}") - return results + result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'], + input=".text\n" + "\n".join(instrs) + "\n", capture_output=True, text=True, timeout=30) + if result.returncode != 0: raise RuntimeError(f"llvm-mc failed: {result.stderr.strip()}") + return [bytes.fromhex(line.split('encoding:')[1].strip()[1:-1].replace('0x', '').replace(',', '').replace(' ', '')) + for line in result.stdout.split('\n') if 'encoding:' in line] -class TestLLVM(unittest.TestCase): - """Test assembler and disassembler against all LLVM test vectors.""" - tests: dict[str, list[tuple[str, bytes]]] = {} - - @classmethod - def setUpClass(cls): - for name, (filename, _, _) in LLVM_TEST_FILES.items(): - try: - data = fetch(f"{LLVM_BASE}/{filename}").read_bytes() - cls.tests[name] = parse_llvm_tests(data.decode('utf-8', errors='ignore')) - except Exception as e: - print(f"Warning: couldn't fetch {filename}: {e}") - cls.tests[name] = [] - -# Generate test methods dynamically for each format -def _make_asm_test(name): +def _make_test(f: str, arch: str, test_type: str): def test(self): - passed, failed, skipped = 0, 0, 0 - for asm_text, expected in self.tests.get(name, []): - result = try_assemble(asm_text) - if result is None: skipped += 1 - elif result == expected: passed += 1 - else: failed += 1 - print(f"{name.upper()} asm: {passed} passed, {failed} failed, {skipped} skipped") - self.assertEqual(failed, 0) + tests = _get_tests(f, arch) + name = f"{arch}_{test_type}_{f}" + if test_type == "roundtrip": + for _, data in tests: + decoded = detect_format(data, arch).from_bytes(data) + self.assertEqual(decoded.to_bytes()[:len(data)], data) + print(f"{name}: {len(tests)} passed") + elif test_type == "asm": + passed, skipped = 0, 0 + for asm_text, expected in tests: + try: + self.assertEqual(asm(asm_text).to_bytes(), expected) + passed += 1 + except: skipped += 1 + print(f"{name}: {passed} passed, {skipped} skipped") + elif test_type == "disasm": + to_test = [] + for _, data in tests: + try: + decoded = detect_format(data, arch).from_bytes(data) + if decoded.to_bytes()[:len(data)] == data and (d := disasm(decoded)): to_test.append((data, d)) + except: pass + print(f"{name}: {len(to_test)} passed, {len(tests) - len(to_test)} skipped") + if arch == "rdna3": + for (data, _), llvm in zip(to_test, _compile_asm_batch([t[1] for t in to_test])): self.assertEqual(llvm, data) return test -def _make_disasm_test(name): - def test(self): - _, base_fmt_cls, base_op_enum = LLVM_TEST_FILES[name] - # VOP3SD opcodes that share encoding with VOP3 (only for vop3sd test, not vopc promotions) - vop3sd_opcodes = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770} - is_vopc_promotion = name in ('vop3_from_vopc', 'vop3_from_vopcx') +class TestLLVM(unittest.TestCase): pass - # First pass: decode all instructions and collect disasm strings - to_test: list[tuple[str, bytes, str | None, str | None]] = [] # (asm_text, data, disasm_str, error) - for asm_text, data in self.tests.get(name, []): - # Detect VOP3 promotions in VOP1/VOP2/VOPC tests: VOP3 has bits [31:26]=0b110101 in first dword - is_vop3_enc = name in ('vop1', 'vop2', 'vopc', 'vopcx') and len(data) >= 4 and (data[3] >> 2) == 0x35 - fmt_cls, op_enum = (VOP3, VOP3Op) if is_vop3_enc else (base_fmt_cls, base_op_enum) - try: - if base_fmt_cls.__name__ in ('VOP3', 'VOP3SD'): - temp = VOP3.from_bytes(data) - op_val = temp._values.get('op', 0) - op_val = op_val.val if hasattr(op_val, 'val') else op_val - is_vop3sd = (op_val in vop3sd_opcodes) and not is_vopc_promotion - decoded = VOP3SD.from_bytes(data) if is_vop3sd else VOP3.from_bytes(data) - if is_vop3sd: VOP3SDOp(op_val) - else: VOP3Op(op_val) - else: - decoded = fmt_cls.from_bytes(data) - op_val = decoded._values.get('op', 0) - op_val = op_val.val if hasattr(op_val, 'val') else op_val - op_enum(op_val) - if decoded.to_bytes()[:len(data)] != data: - to_test.append((asm_text, data, None, "decode roundtrip failed")) - continue - to_test.append((asm_text, data, decoded.disasm(), None)) - except Exception as e: - to_test.append((asm_text, data, None, f"exception: {e}")) - - # Batch compile all disasm strings with single llvm-mc call - disasm_strs = [(i, t[2]) for i, t in enumerate(to_test) if t[2] is not None] - llvm_results = compile_asm_batch([s for _, s in disasm_strs]) if disasm_strs else [] - llvm_map = {i: llvm_results[j] for j, (i, _) in enumerate(disasm_strs)} - - # Match results back - passed, failed = 0, 0 - failures: list[str] = [] - for idx, (asm_text, data, disasm_str, error) in enumerate(to_test): - if error: - failed += 1; failures.append(f"{error} for {data.hex()}") - elif disasm_str is not None and idx in llvm_map: - llvm_bytes = llvm_map[idx] - if llvm_bytes is not None and llvm_bytes == data: passed += 1 - elif llvm_bytes is not None: failed += 1; failures.append(f"'{disasm_str}': expected={data.hex()} got={llvm_bytes.hex()}") - - print(f"{name.upper()} disasm: {passed} passed, {failed} failed") - if failures[:10]: print(" " + "\n ".join(failures[:10])) - self.assertEqual(failed, 0) - return test - -for name in LLVM_TEST_FILES: - setattr(TestLLVM, f'test_{name}_asm', _make_asm_test(name)) - setattr(TestLLVM, f'test_{name}_disasm', _make_disasm_test(name)) +for f in RDNA_FILES: + setattr(TestLLVM, f"test_rdna3_roundtrip_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "rdna3", "roundtrip")) + setattr(TestLLVM, f"test_rdna3_asm_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "rdna3", "asm")) + setattr(TestLLVM, f"test_rdna3_disasm_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "rdna3", "disasm")) +for f in CDNA_FILES: + setattr(TestLLVM, f"test_cdna_roundtrip_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "cdna", "roundtrip")) + setattr(TestLLVM, f"test_cdna_disasm_{f.replace('.s', '').replace('-', '_')}", _make_test(f, "cdna", "disasm")) if __name__ == "__main__": unittest.main() diff --git a/extra/assembly/amd/test/test_llvm_cdna.py b/extra/assembly/amd/test/test_llvm_cdna.py deleted file mode 100644 index 95bef12e13..0000000000 --- a/extra/assembly/amd/test/test_llvm_cdna.py +++ /dev/null @@ -1,144 +0,0 @@ -#!/usr/bin/env python3 -"""Test CDNA assembler/disassembler against LLVM test vectors.""" -import unittest, re, subprocess -from tinygrad.helpers import fetch -from extra.assembly.amd.autogen.cdna.ins import * -from extra.assembly.amd.asm import disasm -from extra.assembly.amd.test.helpers import get_llvm_mc - -LLVM_BASE = "https://raw.githubusercontent.com/llvm/llvm-project/main/llvm/test/MC/AMDGPU" - -def parse_llvm_tests(text: str, mnemonic_filter: str = None, size_filter: int = None) -> list[tuple[str, bytes]]: - """Parse LLVM test format into (asm, expected_bytes) pairs.""" - tests, lines = [], text.split('\n') - for i, line in enumerate(lines): - line = line.strip() - if not line or line.startswith(('//', '.', ';')): continue - asm_text = line.split('//')[0].strip() - if not asm_text or (mnemonic_filter and not asm_text.startswith(mnemonic_filter)): continue - for j in list(range(max(0, i - 3), i)) + list(range(i, min(i + 3, len(lines)))): - if m := re.search(r'(?:VI9|GFX9|CHECK)[^:]*:.*?encoding:\s*\[(.*?)\]', lines[j]): - hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '') - elif m := re.search(r'CHECK[^:]*:\s*\[(0x[0-9a-fA-F,x\s]+)\]', lines[j]): - hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '') - else: continue - try: - data = bytes.fromhex(hex_bytes) - if size_filter is None or len(data) == size_filter: tests.append((asm_text, data)) - except ValueError: pass - break - return tests - -# Use gfx9 tests for compatible scalar/vector formats and gfx90a/gfx942 tests for CDNA-specific instructions -# Format: (filename, format_class, op_enum, mcpu, mnemonic_filter, size_filter) -CDNA_TEST_FILES = { - # Scalar ALU - encoding is stable across GFX9/CDNA - 'sop1': ('gfx9_asm_sop1.s', SOP1, SOP1Op, 'gfx940', None, None), - 'sop2': ('gfx9_asm_sop2.s', SOP2, SOP2Op, 'gfx940', None, None), - 'sopp': ('gfx9_asm_sopp.s', SOPP, SOPPOp, 'gfx940', None, None), - 'sopp_gfx9': ('sopp-gfx9.s', SOPP, SOPPOp, 'gfx940', None, None), - 'sopk': ('gfx9_asm_sopk.s', SOPK, SOPKOp, 'gfx940', None, None), - 'sopc': ('gfx9_asm_sopc.s', SOPC, SOPCOp, 'gfx940', None, None), - # Vector ALU - encoding is mostly stable - 'vop1': ('gfx9_asm_vop1.s', VOP1, VOP1Op, 'gfx940', None, None), - 'vop1_gfx9': ('vop1-gfx9.s', VOP1, VOP1Op, 'gfx940', None, None), - 'vop2': ('gfx9_asm_vop2.s', VOP2, VOP2Op, 'gfx940', None, None), - 'vopc': ('gfx9_asm_vopc.s', VOPC, VOPCOp, 'gfx940', None, None), - 'vop3p': ('gfx9_asm_vop3p.s', VOP3P, VOP3POp, 'gfx940', None, None), - 'vop3_gfx9': ('vop3-gfx9.s', VOP3A, VOP3AOp, 'gfx940', None, 8), # Only 64-bit VOP3 instructions - # Memory instructions - 'ds': ('gfx9_asm_ds.s', DS, DSOp, 'gfx940', None, None), - 'ds_gfx9': ('ds-gfx9.s', DS, DSOp, 'gfx940', None, None), - # CDNA memory instructions (gfx90a has correct FLAT/MUBUF encodings with acc registers) - 'flat_gfx90a': ('gfx90a_ldst_acc.s', FLAT, FLATOp, 'gfx90a', 'flat_', None), - 'global_gfx90a': ('gfx90a_ldst_acc.s', FLAT, FLATOp, 'gfx90a', 'global_', None), - 'mubuf_gfx90a': ('gfx90a_ldst_acc.s', MUBUF, MUBUFOp, 'gfx90a', 'buffer_', None), - 'mubuf_gfx9': ('mubuf-gfx9.s', MUBUF, MUBUFOp, 'gfx940', None, None), - 'scratch_gfx942': ('flat-scratch-gfx942.s', FLAT, FLATOp, 'gfx942', 'scratch_', None), - # CDNA-specific: MFMA/MAI instructions - 'mai': ('mai-gfx942.s', VOP3P, VOP3POp, 'gfx942', None, None), - # SDWA and DPP format tests for VOP1 (VOP2 has different bit layout, tested separately) - 'sdwa_vop1': ('gfx9_asm_vop1.s', SDWA, VOP1Op, 'gfx940', None, None), - 'dpp_vop1': ('gfx9_asm_vop1.s', DPP, VOP1Op, 'gfx940', None, None), -} - -class TestLLVMCDNA(unittest.TestCase): - """Test CDNA instruction format decode/encode roundtrip and disassembly.""" - tests: dict[str, list[tuple[str, bytes]]] = {} - - @classmethod - def setUpClass(cls): - for name, (filename, _, _, _, mnemonic_filter, size_filter) in CDNA_TEST_FILES.items(): - try: - data = fetch(f"{LLVM_BASE}/{filename}").read_bytes() - cls.tests[name] = parse_llvm_tests(data.decode('utf-8', errors='ignore'), mnemonic_filter, size_filter) - except Exception as e: - print(f"Warning: couldn't fetch {filename}: {e}") - cls.tests[name] = [] - -def _get_val(v): return v.val if hasattr(v, 'val') else v - -def _filter_and_decode(tests, fmt_cls, op_enum): - """Filter tests and decode instructions, yielding (asm_text, data, decoded, error).""" - fn, is_sdwa, is_dpp = fmt_cls.__name__, fmt_cls.__name__ == 'SDWA', fmt_cls.__name__ == 'DPP' - for asm_text, data in tests: - has_lit = False - # SDWA/DPP format tests: only accept matching 8-byte instructions - if is_sdwa: - if len(data) != 8 or data[0] != 0xf9: continue - elif is_dpp: - if len(data) != 8 or data[0] != 0xfa: continue - elif fmt_cls._size() == 4 and len(data) == 8: - if data[0] in (0xf9, 0xfa): continue # Skip SDWA/DPP (tested separately) - has_lit = data[0] == 255 or (len(data) >= 2 and data[1] == 255 and fn in ('SOP2', 'SOPC')) - if fn == 'SOPK': has_lit = has_lit or ((int.from_bytes(data[:4], 'little') >> 23) & 0x1f) == 20 - if fn == 'VOP2': has_lit = has_lit or ((int.from_bytes(data[:4], 'little') >> 25) & 0x3f) in (23, 24, 36, 37) - if not has_lit: continue - if len(data) > fmt_cls._size() + (4 if has_lit else 0): continue - try: - decoded = fmt_cls.from_bytes(data) - # For SDWA/DPP, opcode location depends on VOP1 vs VOP2 - if is_sdwa or is_dpp: - vop2_op = _get_val(decoded._values.get('vop2_op', 0)) - op_val = _get_val(decoded._values.get('vop_op', 0)) if vop2_op == 0x3f else vop2_op - else: - op_val = _get_val(decoded._values.get('op', 0)) - try: op_enum(op_val) - except ValueError: continue - yield asm_text, data, decoded, None - except Exception as e: - yield asm_text, data, None, str(e) - -def _make_roundtrip_test(name): - def test(self): - _, fmt_cls, op_enum, _, _, _ = CDNA_TEST_FILES[name] - passed, failed, failures = 0, 0, [] - for asm_text, data, decoded, error in _filter_and_decode(self.tests.get(name, []), fmt_cls, op_enum): - if error: failed += 1; failures.append(f"'{asm_text}': {error}"); continue - if decoded.to_bytes()[:len(data)] == data: passed += 1 - else: failed += 1; failures.append(f"'{asm_text}': orig={data.hex()} reenc={decoded.to_bytes()[:len(data)].hex()}") - print(f"CDNA {name.upper()} roundtrip: {passed} passed, {failed} failed") - if failures[:5]: print(" " + "\n ".join(failures[:5])) - self.assertEqual(failed, 0) - return test - -def _make_disasm_test(name): - def test(self): - _, fmt_cls, op_enum, _, _, _ = CDNA_TEST_FILES[name] - passed, failed, failures = 0, 0, [] - for asm_text, data, decoded, error in _filter_and_decode(self.tests.get(name, []), fmt_cls, op_enum): - if error: failed += 1; failures.append(f"'{asm_text}': {error}"); continue - if decoded.to_bytes()[:len(data)] != data: failed += 1; failures.append(f"'{asm_text}': roundtrip failed"); continue - if not (disasm_text := disasm(decoded)) or not disasm_text.strip(): failed += 1; failures.append(f"'{asm_text}': empty disassembly"); continue - passed += 1 - print(f"CDNA {name.upper()} disasm: {passed} passed, {failed} failed") - if failures[:5]: print(" " + "\n ".join(failures[:5])) - self.assertEqual(failed, 0) - return test - -for name in CDNA_TEST_FILES: - setattr(TestLLVMCDNA, f'test_{name}_roundtrip', _make_roundtrip_test(name)) - setattr(TestLLVMCDNA, f'test_{name}_disasm', _make_disasm_test(name)) - -if __name__ == "__main__": - unittest.main()