diff --git a/extra/assembly/amd/asm.py b/extra/assembly/amd/asm.py index ed92f2fcfb..de2e2da6c8 100644 --- a/extra/assembly/amd/asm.py +++ b/extra/assembly/amd/asm.py @@ -1,4 +1,4 @@ -# RDNA3 assembler and disassembler +# RDNA3/CDNA assembler and disassembler from __future__ import annotations import re from extra.assembly.amd.dsl import Inst, RawImm, Reg, SrcMod, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory @@ -8,6 +8,8 @@ from extra.assembly.amd.autogen.rdna3 import ins from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT, MUBUF, MTBUF, MIMG, EXP, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOPDOp, SOP1Op, SOPKOp, SOPPOp, SMEMOp, DSOp, MUBUFOp) +def _is_cdna(inst: Inst) -> bool: return 'cdna' in inst.__class__.__module__ + def _matches_encoding(word: int, cls: type[Inst]) -> bool: """Check if word matches the encoding pattern of an instruction class.""" if cls._encoding is None: return False @@ -81,10 +83,11 @@ def _src16(inst, v: int) -> str: return _fmt_v16(v) if v >= 256 else inst.lit(v) def _mods(*pairs) -> str: return " ".join(m for c, m in pairs if c) def _fmt_bits(label: str, val: int, count: int) -> str: return f"{label}:[{','.join(str((val >> i) & 1) for i in range(count))}]" -def _vop3_src(inst, v: int, neg: int, abs_: int, hi: int, n: int, f16: bool, any_hi: bool) -> str: +def _vop3_src(inst, v: int, neg: int, abs_: int, hi: int, n: int, f16: bool) -> str: """Format VOP3 source operand with modifiers.""" - if n > 1: s = _fmt_src(v, n) - elif f16 and v >= 256: s = f"v{v - 256}.h" if hi else (f"v{v - 256}.l" if any_hi else inst.lit(v)) + if v == 255: s = inst.lit(v) # literal constant takes priority + elif n > 1: s = _fmt_src(v, n) + elif f16 and v >= 256: s = f"v{v - 256}.h" if hi else f"v{v - 256}.l" else: s = inst.lit(v) if abs_: s = f"|{s}|" return f"-{s}" if neg else s @@ -92,9 +95,11 @@ def _vop3_src(inst, v: int, neg: int, abs_: int, hi: int, n: int, f16: bool, any def _opsel_str(opsel: int, n: int, need: bool, is16_d: bool) -> str: """Format op_sel modifier string.""" if not need: return "" - if is16_d and (opsel & 8): return f" op_sel:[1,1,1{',1' if n == 3 else ''}]" - if n == 3: return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{(opsel >> 2) & 1},{(opsel >> 3) & 1}]" - return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{(opsel >> 2) & 1}]" + # For VOP1 (n=1): op_sel:[src0_hi, dst_hi], for VOP2 (n=2): op_sel:[src0_hi, src1_hi, dst_hi], for VOP3 (n=3): op_sel:[src0_hi, src1_hi, src2_hi, dst_hi] + dst_hi = (opsel >> 3) & 1 + if n == 1: return f" op_sel:[{opsel & 1},{dst_hi}]" + if n == 2: return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{dst_hi}]" + return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{(opsel >> 2) & 1},{dst_hi}]" # ═══════════════════════════════════════════════════════════════════════════════ # DISASSEMBLER @@ -108,30 +113,41 @@ def _disasm_vop1(inst: VOP1) -> str: parts = name.split('_') is_16d = any(p in ('f16','i16','u16','b16') for p in parts[-2:-1]) or (len(parts) >= 2 and parts[-1] in ('f16','i16','u16','b16') and 'cvt' not in name) dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else _fmt_v16(inst.vdst, 0, 128) if is_16d else f"v{inst.vdst}" - src = _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_src_16(0) and 'sat_pk' not in name else inst.lit(inst.src0) + src = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_src_16(0) and 'sat_pk' not in name else inst.lit(inst.src0) return f"{name}_e32 {dst}, {src}" def _disasm_vop2(inst: VOP2) -> str: - name = inst.op_name.lower() - suf = "" if inst.op == VOP2Op.V_DOT2ACC_F32_F16 else "_e32" + name, cdna = inst.op_name.lower(), _is_cdna(inst) + suf = "" if not cdna and inst.op == VOP2Op.V_DOT2ACC_F32_F16 else "_e32" + lit = getattr(inst, '_literal', None) + is16 = not cdna and inst.is_16bit() # fmaak: dst = src0 * vsrc1 + K, fmamk: dst = src0 * K + vsrc1 - if inst.op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, 0x{inst._literal:x}" - if inst.op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, 0x{inst._literal:x}, v{inst.vsrc1}" - if inst.is_16bit(): return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}" - return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}" + (", vcc_lo" if inst.op == VOP2Op.V_CNDMASK_B32 else "") + if 'fmaak' in name or (not cdna and inst.op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16)): + if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}, 0x{lit:x}" + return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, 0x{lit:x}" + if 'fmamk' in name or (not cdna and inst.op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16)): + if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, 0x{lit:x}, {_fmt_v16(inst.vsrc1, 0, 128)}" + return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, 0x{lit:x}, v{inst.vsrc1}" + if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}" + vcc = "vcc" if cdna else "vcc_lo" + return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}" + (f", {vcc}" if name == 'v_cndmask_b32' else "") def _disasm_vopc(inst: VOPC) -> str: - name = inst.op_name.lower() - s0 = _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_16bit() else inst.lit(inst.src0) + name, cdna = inst.op_name.lower(), _is_cdna(inst) + if cdna: + s0 = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0)) + return f"{name}_e32 {s0}, v{inst.vsrc1}" if inst.op.value >= 128 else f"{name}_e32 vcc, {s0}, v{inst.vsrc1}" + s0 = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_16bit() else inst.lit(inst.src0) s1 = _vreg(inst.vsrc1, inst.src_regs(1)) if inst.src_regs(1) > 1 else _fmt_v16(inst.vsrc1, 0, 128) if inst.is_16bit() else f"v{inst.vsrc1}" return f"{name}_e32 {s0}, {s1}" if inst.op.value >= 128 else f"{name}_e32 vcc_lo, {s0}, {s1}" -NO_ARG_SOPP = {SOPPOp.S_ENDPGM, SOPPOp.S_BARRIER, SOPPOp.S_WAKEUP, SOPPOp.S_ICACHE_INV, - SOPPOp.S_WAIT_IDLE, SOPPOp.S_ENDPGM_SAVED, SOPPOp.S_CODE_END, SOPPOp.S_ENDPGM_ORDERED_PS_DONE} +NO_ARG_SOPP = {SOPPOp.S_BARRIER, SOPPOp.S_WAKEUP, SOPPOp.S_ICACHE_INV, + SOPPOp.S_WAIT_IDLE, SOPPOp.S_ENDPGM_SAVED, SOPPOp.S_CODE_END, SOPPOp.S_ENDPGM_ORDERED_PS_DONE, SOPPOp.S_TTRACEDATA} def _disasm_sopp(inst: SOPP) -> str: name = inst.op_name.lower() if inst.op in NO_ARG_SOPP: return name + if inst.op == SOPPOp.S_ENDPGM: return name if inst.simm16 == 0 else f"{name} {inst.simm16}" if inst.op == SOPPOp.S_WAITCNT: vm, exp, lgkm = (inst.simm16 >> 10) & 0x3f, inst.simm16 & 0xf, (inst.simm16 >> 4) & 0x3f p = [f"vmcnt({vm})" if vm != 0x3f else "", f"expcnt({exp})" if exp != 7 else "", f"lgkmcnt({lgkm})" if lgkm != 0x3f else ""] @@ -154,12 +170,13 @@ def _disasm_smem(inst: SMEM) -> str: return f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs())}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (inst.dlc, " dlc")) def _disasm_flat(inst: FLAT) -> str: - name = inst.op_name.lower() + name, cdna = inst.op_name.lower(), _is_cdna(inst) seg = ['flat', 'scratch', 'global'][inst.seg] if inst.seg < 3 else 'flat' instr = f"{seg}_{name.split('_', 1)[1] if '_' in name else name}" off_val = inst.offset if seg == 'flat' else (inst.offset if inst.offset < 4096 else inst.offset - 8192) w = inst.dst_regs() * (2 if 'cmpswap' in name else 1) - mods = f"{f' offset:{off_val}' if off_val else ''}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}" + if cdna: mods = f"{f' offset:{off_val}' if off_val else ''}{' sc0' if inst.sc0 else ''}{' nt' if inst.nt else ''}{' sc1' if inst.sc1 else ''}" + else: mods = f"{f' offset:{off_val}' if off_val else ''}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}" # saddr if seg == 'flat' or inst.saddr == 0x7F: saddr_s = "" elif inst.saddr == 124: saddr_s = ", off" @@ -172,8 +189,9 @@ def _disasm_flat(inst: FLAT) -> str: # addr width addr_s = "off" if not inst.sve and seg == 'scratch' else _vreg(inst.addr, 1 if seg == 'scratch' or (inst.saddr not in (0x7F, 124)) else 2) data_s, vdst_s = _vreg(inst.data, w), _vreg(inst.vdst, w // 2 if 'cmpswap' in name else w) + glc_or_sc0 = inst.sc0 if cdna else inst.glc if 'atomic' in name: - return f"{instr} {vdst_s}, {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}" if inst.glc else f"{instr} {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}" + return f"{instr} {vdst_s}, {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}" if glc_or_sc0 else f"{instr} {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}" if 'store' in name: return f"{instr} {addr_s}, {data_s}{saddr_s}{mods}" return f"{instr} {_vreg(inst.vdst, w)}, {addr_s}{saddr_s}{mods}" @@ -211,7 +229,9 @@ def _disasm_vop3(inst: VOP3) -> str: # VOP3SD (shared encoding) if isinstance(op, VOP3SDOp): sdst = (inst.clmp << 7) | (inst.opsel << 3) | inst.abs - def src(v, neg, n): s = _fmt_src(v, n) if n > 1 else inst.lit(v); return f"-{s}" if neg else s + def src(v, neg, n): + s = inst.lit(v) if v == 255 else (_fmt_src(v, n) if n > 1 else inst.lit(v)) + return f"neg({s})" if neg and v == 255 else (f"-{s}" if neg else s) s0, s1, s2 = src(inst.src0, inst.neg & 1, inst.src_regs(0)), src(inst.src1, inst.neg & 2, inst.src_regs(1)), src(inst.src2, inst.neg & 4, inst.src_regs(2)) dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}" srcs = f"{s0}, {s1}, {s2}" if inst.num_srcs() == 3 else f"{s0}, {s1}" @@ -225,18 +245,18 @@ def _disasm_vop3(inst: VOP3) -> str: is16_s2 = is16_s elif re.match(r'v_mad_[iu]32_[iu]16', name): is16_s = True elif 'pack_b32' in name: is16_s = is16_s2 = True + elif 'sat_pk' in name: is16_d = True # v_sat_pk_* writes to 16-bit dest but takes 32-bit src else: is16_d = is16_s = is16_s2 = inst.is_16bit() - any_hi = inst.opsel != 0 - s0 = _vop3_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.opsel&1, inst.src_regs(0), is16_s, any_hi) - s1 = _vop3_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.opsel&2, inst.src_regs(1), is16_s, any_hi) - s2 = _vop3_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.opsel&4, inst.src_regs(2), is16_s2, any_hi) + s0 = _vop3_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.opsel&1, inst.src_regs(0), is16_s) + s1 = _vop3_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.opsel&2, inst.src_regs(1), is16_s) + s2 = _vop3_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.opsel&4, inst.src_regs(2), is16_s2) # Destination dn = inst.dst_regs() if op == VOP3Op.V_READLANE_B32: dst = _fmt_sdst(inst.vdst, 1) elif dn > 1: dst = _vreg(inst.vdst, dn) - elif is16_d: dst = f"v{inst.vdst}.h" if (inst.opsel & 8) else f"v{inst.vdst}.l" if any_hi else f"v{inst.vdst}" + elif is16_d: dst = f"v{inst.vdst}.h" if (inst.opsel & 8) else f"v{inst.vdst}.l" else: dst = f"v{inst.vdst}" cl, om = " clamp" if inst.clmp else "", _omod(inst.omod) @@ -244,7 +264,7 @@ def _disasm_vop3(inst: VOP3) -> str: need_opsel = nonvgpr_opsel or (inst.opsel and not is16_s) if inst.op < 256: # VOPC - return f"{name}_e64 {s0}, {s1}" if name.startswith('v_cmpx') else f"{name}_e64 {_fmt_sdst(inst.vdst, 1)}, {s0}, {s1}" + return f"{name}_e64 {s0}, {s1}{cl}" if name.startswith('v_cmpx') else f"{name}_e64 {_fmt_sdst(inst.vdst, 1)}, {s0}, {s1}{cl}" if inst.op < 384: # VOP2 n = inst.num_srcs() os = _opsel_str(inst.opsel, n, need_opsel, is16_d) @@ -258,7 +278,9 @@ def _disasm_vop3(inst: VOP3) -> str: def _disasm_vop3sd(inst: VOP3SD) -> str: name = inst.op_name.lower() - def src(v, neg, n): s = _fmt_src(v, n) if n > 1 else inst.lit(v); return f"-{s}" if neg else s + def src(v, neg, n): + s = inst.lit(v) if v == 255 else (_fmt_src(v, n) if n > 1 else inst.lit(v)) + return f"neg({s})" if neg and v == 255 else (f"-{s}" if neg else s) s0, s1, s2 = src(inst.src0, inst.neg & 1, inst.src_regs(0)), src(inst.src1, inst.neg & 2, inst.src_regs(1)), src(inst.src2, inst.neg & 4, inst.src_regs(2)) dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}" srcs = f"{s0}, {s1}, {s2}" if inst.num_srcs() == 3 else f"{s0}, {s1}" @@ -268,16 +290,22 @@ def _disasm_vop3sd(inst: VOP3SD) -> str: def _disasm_vopd(inst: VOPD) -> str: lit = inst._literal or inst.literal vdst_y, nx, ny = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1), VOPDOp(inst.opx).name.lower(), VOPDOp(inst.opy).name.lower() - def half(n, vd, s0, vs1): return f"{n} v{vd}, {inst.lit(s0)}{f', 0x{lit:x}' if lit and _has(n, 'fmaak', 'fmamk') else ''}" if 'mov' in n else f"{n} v{vd}, {inst.lit(s0)}, v{vs1}{f', 0x{lit:x}' if lit and _has(n, 'fmaak', 'fmamk') else ''}" + def half(n, vd, s0, vs1): + if 'mov' in n: return f"{n} v{vd}, {inst.lit(s0)}" + # fmamk: dst = src0 * K + vsrc1, fmaak: dst = src0 * vsrc1 + K + if 'fmamk' in n and lit: return f"{n} v{vd}, {inst.lit(s0)}, 0x{lit:x}, v{vs1}" + if 'fmaak' in n and lit: return f"{n} v{vd}, {inst.lit(s0)}, v{vs1}, 0x{lit:x}" + return f"{n} v{vd}, {inst.lit(s0)}, v{vs1}" return f"{half(nx, inst.vdstx, inst.srcx0, inst.vsrcx1)} :: {half(ny, vdst_y, inst.srcy0, inst.vsrcy1)}" def _disasm_vop3p(inst: VOP3P) -> str: name = inst.op_name.lower() is_wmma, n, is_fma_mix = 'wmma' in name, inst.num_srcs(), 'fma_mix' in name + def get_src(v, sc): return inst.lit(v) if v == 255 else _fmt_src(v, sc) if is_wmma: sc = 2 if 'iu4' in name else 4 if 'iu8' in name else 8 - src0, src1, src2, dst = _fmt_src(inst.src0, sc), _fmt_src(inst.src1, sc), _fmt_src(inst.src2, 8), _vreg(inst.vdst, 8) - else: src0, src1, src2, dst = _fmt_src(inst.src0, 1), _fmt_src(inst.src1, 1), _fmt_src(inst.src2, 1), f"v{inst.vdst}" + src0, src1, src2, dst = get_src(inst.src0, sc), get_src(inst.src1, sc), get_src(inst.src2, 8), _vreg(inst.vdst, 8) + else: src0, src1, src2, dst = get_src(inst.src0, 1), get_src(inst.src1, 1), get_src(inst.src2, 1), f"v{inst.vdst}" opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2) if is_fma_mix: def m(s, neg, abs_): return f"-{f'|{s}|' if abs_ else s}" if neg else (f"|{s}|" if abs_ else s) @@ -289,15 +317,17 @@ def _disasm_vop3p(inst: VOP3P) -> str: return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if n == 3 else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}" def _disasm_buf(inst: MUBUF | MTBUF) -> str: - name = inst.op_name.lower() - if inst.op in (MUBUFOp.BUFFER_GL0_INV, MUBUFOp.BUFFER_GL1_INV): return name + name, cdna = inst.op_name.lower(), _is_cdna(inst) + if cdna and name in ('buffer_wbl2', 'buffer_inv'): return name + if not cdna and inst.op in (MUBUFOp.BUFFER_GL0_INV, MUBUFOp.BUFFER_GL1_INV): return name w = (2 if _has(name, 'xyz', 'xyzw') else 1) if 'd16' in name else \ ((2 if _has(name, 'b64', 'u64', 'i64') else 1) * (2 if 'cmpswap' in name else 1)) if 'atomic' in name else \ {'b32':1,'b64':2,'b96':3,'b128':4,'b16':1,'x':1,'xy':2,'xyz':3,'xyzw':4}.get(name.split('_')[-1], 1) - if inst.tfe: w += 1 + if hasattr(inst, 'tfe') and inst.tfe: w += 1 vaddr = _vreg(inst.vaddr, 2) if inst.offen and inst.idxen else f"v{inst.vaddr}" if inst.offen or inst.idxen else "off" srsrc = _sreg_or_ttmp(inst.srsrc*4, 4) - mods = ([f"format:{inst.format}"] if isinstance(inst, MTBUF) else []) + [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.glc,"glc"),(inst.dlc,"dlc"),(inst.slc,"slc"),(inst.tfe,"tfe")] if c] + if cdna: mods = ([f"format:{inst.format}"] if isinstance(inst, MTBUF) else []) + [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.sc0,"sc0"),(inst.nt,"nt"),(inst.sc1,"sc1")] if c] + else: mods = ([f"format:{inst.format}"] if isinstance(inst, MTBUF) else []) + [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.glc,"glc"),(inst.dlc,"dlc"),(inst.slc,"slc"),(inst.tfe,"tfe")] if c] return f"{name} {_vreg(inst.vdata, w)}, {vaddr}, {srsrc}, {decode_src(inst.soffset)}{' ' + ' '.join(mods) if mods else ''}" def _mimg_vaddr_width(name: str, dim: int, a16: bool) -> int: @@ -348,25 +378,36 @@ def _disasm_mimg(inst: MIMG) -> str: def _disasm_sop1(inst: SOP1) -> str: op, name = inst.op, inst.op_name.lower() - if op == SOP1Op.S_GETPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}" - if op in (SOP1Op.S_SETPC_B64, SOP1Op.S_RFE_B64): return f"{name} {_fmt_src(inst.ssrc0, 2)}" - if op == SOP1Op.S_SWAPPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}, {_fmt_src(inst.ssrc0, 2)}" - if op in (SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64): return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, sendmsg({MSG.get(inst.ssrc0, str(inst.ssrc0))})" - return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {inst.lit(inst.ssrc0) if inst.src_regs(0) == 1 else _fmt_src(inst.ssrc0, inst.src_regs(0))}" + src = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0)) + if not _is_cdna(inst): + if op == SOP1Op.S_GETPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}" + if op in (SOP1Op.S_SETPC_B64, SOP1Op.S_RFE_B64): return f"{name} {src}" + if op == SOP1Op.S_SWAPPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}, {src}" + if op in (SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64): return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, sendmsg({MSG.get(inst.ssrc0, str(inst.ssrc0))})" + return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {src}" def _disasm_sop2(inst: SOP2) -> str: return f"{inst.op_name.lower()} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0))}, {inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1))}" def _disasm_sopc(inst: SOPC) -> str: - return f"{inst.op_name.lower()} {_fmt_src(inst.ssrc0, inst.src_regs(0))}, {_fmt_src(inst.ssrc1, inst.src_regs(1))}" + s0 = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0)) + s1 = inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1)) + return f"{inst.op_name.lower()} {s0}, {s1}" def _disasm_sopk(inst: SOPK) -> str: - op, name = inst.op, inst.op_name.lower() - if op == SOPKOp.S_VERSION: return f"{name} 0x{inst.simm16:x}" - if op in (SOPKOp.S_SETREG_B32, SOPKOp.S_GETREG_B32): + op, name, cdna = inst.op, inst.op_name.lower(), _is_cdna(inst) + # s_setreg_imm32_b32 has a 32-bit literal value + if name == 's_setreg_imm32_b32' or (not cdna and op == SOPKOp.S_SETREG_IMM32_B32): hid, hoff, hsz = inst.simm16 & 0x3f, (inst.simm16 >> 6) & 0x1f, ((inst.simm16 >> 11) & 0x1f) + 1 hs = f"0x{inst.simm16:x}" if hid in (16, 17) else f"hwreg({HWREG.get(hid, str(hid))}, {hoff}, {hsz})" - return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1)}" if op == SOPKOp.S_SETREG_B32 else f"{name} {_fmt_sdst(inst.sdst, 1)}, {hs}" + return f"{name} {hs}, 0x{inst._literal:x}" + if not cdna and op == SOPKOp.S_VERSION: return f"{name} 0x{inst.simm16:x}" + if (not cdna and op in (SOPKOp.S_SETREG_B32, SOPKOp.S_GETREG_B32)) or (cdna and name in ('s_setreg_b32', 's_getreg_b32')): + hid, hoff, hsz = inst.simm16 & 0x3f, (inst.simm16 >> 6) & 0x1f, ((inst.simm16 >> 11) & 0x1f) + 1 + hs = f"0x{inst.simm16:x}" if hid in (16, 17) else f"hwreg({HWREG.get(hid, str(hid))}, {hoff}, {hsz})" + return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1)}" if 'setreg' in name else f"{name} {_fmt_sdst(inst.sdst, 1)}, {hs}" + if not cdna and op in (SOPKOp.S_SUBVECTOR_LOOP_BEGIN, SOPKOp.S_SUBVECTOR_LOOP_END): + return f"{name} {_fmt_sdst(inst.sdst, 1)}, 0x{inst.simm16:x}" return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, 0x{inst.simm16:x}" def _disasm_vinterp(inst: VINTERP) -> str: @@ -388,7 +429,8 @@ SPEC_REGS = {'vcc_lo': RawImm(106), 'vcc_hi': RawImm(107), 'vcc': RawImm(106), ' FLOATS = {str(k): k for k in FLOAT_ENC} # Valid float literal strings: '0.5', '-0.5', '1.0', etc. REG_MAP: dict[str, _RegFactory] = {'s': s, 'v': v, 't': ttmp, 'ttmp': ttmp} SMEM_OPS = {'s_load_b32', 's_load_b64', 's_load_b128', 's_load_b256', 's_load_b512', - 's_buffer_load_b32', 's_buffer_load_b64', 's_buffer_load_b128', 's_buffer_load_b256', 's_buffer_load_b512'} + 's_buffer_load_b32', 's_buffer_load_b64', 's_buffer_load_b128', 's_buffer_load_b256', 's_buffer_load_b512', + 's_atc_probe', 's_atc_probe_buffer'} SPEC_DSL = {'vcc_lo': 'VCC_LO', 'vcc_hi': 'VCC_HI', 'vcc': 'VCC_LO', 'null': 'NULL', 'off': 'OFF', 'm0': 'M0', 'exec_lo': 'EXEC_LO', 'exec_hi': 'EXEC_HI', 'exec': 'EXEC_LO', 'scc': 'SCC', 'src_scc': 'SCC'} @@ -579,3 +621,69 @@ def asm(text: str) -> Inst: except NameError: if m := re.match(r'^(v_\w+)(\(.*\))$', dsl): return eval(f"{m.group(1)}_e32{m.group(2)}", ns) raise + +# ═══════════════════════════════════════════════════════════════════════════════ +# CDNA DISASSEMBLER SUPPORT +# ═══════════════════════════════════════════════════════════════════════════════ + +try: + from extra.assembly.amd.autogen.cdna.ins import (VOP1 as CDNA_VOP1, VOP2 as CDNA_VOP2, VOPC as CDNA_VOPC, VOP3A, VOP3B, VOP3P as CDNA_VOP3P, + SOP1 as CDNA_SOP1, SOP2 as CDNA_SOP2, SOPC as CDNA_SOPC, SOPK as CDNA_SOPK, SOPP as CDNA_SOPP, SMEM as CDNA_SMEM, DS as CDNA_DS, + FLAT as CDNA_FLAT, MUBUF as CDNA_MUBUF, MTBUF as CDNA_MTBUF, SDWA, DPP, VOP1Op as CDNA_VOP1Op) + + def _cdna_src(inst, v, neg, abs_=0, n=1): + s = inst.lit(v) if v == 255 else _fmt_src(v, n) + if abs_: s = f"|{s}|" + return f"neg({s})" if neg and v == 255 else (f"-{s}" if neg else s) + + def _disasm_vop3a(inst) -> str: + name, n, cl, om = inst.op_name.lower(), inst.num_srcs(), " clamp" if inst.clmp else "", _omod(inst.omod) + s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.src_regs(0)), _cdna_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.src_regs(1)), _cdna_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.src_regs(2)) + dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}" + if inst.op.value < 256: return f"{name}_e64 {s0}, {s1}" if name.startswith('v_cmpx') else f"{name}_e64 {_fmt_sdst(inst.vdst, 1)}, {s0}, {s1}" + suf = "_e64" if inst.op.value < 512 else "" + return f"{name}{suf} {dst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else (f"{name}{suf}" if name == 'v_nop' else f"{name}{suf} {dst}, {s0}, {s1}{cl}{om}" if n == 2 else f"{name}{suf} {dst}, {s0}{cl}{om}") + + def _disasm_vop3b(inst) -> str: + name, n = inst.op_name.lower(), inst.num_srcs() + s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1), _cdna_src(inst, inst.src1, inst.neg&2), _cdna_src(inst, inst.src2, inst.neg&4) + dst, suf = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}", "_e64" if 'co_' in name else "" + cl, om = " clamp" if inst.clmp else "", _omod(inst.omod) + return f"{name}{suf} {dst}, {_fmt_sdst(inst.sdst, 1)}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name}{suf} {dst}, {_fmt_sdst(inst.sdst, 1)}, {s0}, {s1}{cl}{om}" + + def _disasm_cdna_vop3p(inst) -> str: + name, n, is_mfma = inst.op_name.lower(), inst.num_srcs(), 'mfma' in inst.op_name.lower() or 'smfmac' in inst.op_name.lower() + get_src = lambda v, sc: inst.lit(v) if v == 255 else _fmt_src(v, sc) + if is_mfma: sc = 2 if 'iu4' in name else 4 if 'iu8' in name or 'i4' in name else 8 if 'f16' in name or 'bf16' in name else 4; src0, src1, src2, dst = get_src(inst.src0, sc), get_src(inst.src1, sc), get_src(inst.src2, 16), _vreg(inst.vdst, 16) + else: src0, src1, src2, dst = get_src(inst.src0, 1), get_src(inst.src1, 1), get_src(inst.src2, 1), f"v{inst.vdst}" + opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2) + mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != (7 if n == 3 else 3) else []) + \ + ([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if inst.clmp else []) + return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if n == 3 else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}" + + _SEL = {0: 'BYTE_0', 1: 'BYTE_1', 2: 'BYTE_2', 3: 'BYTE_3', 4: 'WORD_0', 5: 'WORD_1', 6: 'DWORD'} + _UNUSED = {0: 'UNUSED_PAD', 1: 'UNUSED_SEXT', 2: 'UNUSED_PRESERVE'} + _DPP = {0x130: "wave_shl:1", 0x134: "wave_rol:1", 0x138: "wave_shr:1", 0x13c: "wave_ror:1", 0x140: "row_mirror", 0x141: "row_half_mirror", 0x142: "row_bcast:15", 0x143: "row_bcast:31"} + + def _disasm_sdwa(inst) -> str: + try: name = CDNA_VOP1Op(inst.vop_op).name.lower() + except ValueError: name = f"vop1_op_{inst.vop_op}" + src = f"v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}" if isinstance(inst.src0, int) else str(inst.src0) + mods = [f"dst_sel:{_SEL[inst.dst_sel]}" for _ in [1] if inst.dst_sel != 6] + [f"dst_unused:{_UNUSED[inst.dst_u]}" for _ in [1] if inst.dst_u] + [f"src0_sel:{_SEL[inst.src0_sel]}" for _ in [1] if inst.src0_sel != 6] + return f"{name}_sdwa v{inst.vdst}, {src}" + (" " + " ".join(mods) if mods else "") + + def _disasm_dpp(inst) -> str: + try: name = CDNA_VOP1Op(inst.vop_op).name.lower() + except ValueError: name = f"vop1_op_{inst.vop_op}" + src, ctrl = f"v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}" if isinstance(inst.src0, int) else str(inst.src0), inst.dpp_ctrl + dpp = f"quad_perm:[{ctrl&3},{(ctrl>>2)&3},{(ctrl>>4)&3},{(ctrl>>6)&3}]" if ctrl < 0x100 else f"row_shl:{ctrl&0xf}" if ctrl < 0x110 else f"row_shr:{ctrl&0xf}" if ctrl < 0x120 else f"row_ror:{ctrl&0xf}" if ctrl < 0x130 else _DPP.get(ctrl, f"dpp_ctrl:0x{ctrl:x}") + mods = [dpp] + [f"row_mask:0x{inst.row_mask:x}" for _ in [1] if inst.row_mask != 0xf] + [f"bank_mask:0x{inst.bank_mask:x}" for _ in [1] if inst.bank_mask != 0xf] + ["bound_ctrl:1" for _ in [1] if inst.bound_ctrl] + return f"{name}_dpp v{inst.vdst}, {src} " + " ".join(mods) + + # Register CDNA handlers - shared formats use merged disassemblers, CDNA-only formats use dedicated ones + DISASM_HANDLERS.update({CDNA_VOP1: _disasm_vop1, CDNA_VOP2: _disasm_vop2, CDNA_VOPC: _disasm_vopc, + CDNA_SOP1: _disasm_sop1, CDNA_SOP2: _disasm_sop2, CDNA_SOPC: _disasm_sopc, CDNA_SOPK: _disasm_sopk, CDNA_SOPP: _disasm_sopp, + CDNA_SMEM: _disasm_smem, CDNA_DS: _disasm_ds, CDNA_FLAT: _disasm_flat, CDNA_MUBUF: _disasm_buf, CDNA_MTBUF: _disasm_buf, + VOP3A: _disasm_vop3a, VOP3B: _disasm_vop3b, CDNA_VOP3P: _disasm_cdna_vop3p, SDWA: _disasm_sdwa, DPP: _disasm_dpp}) +except ImportError: + pass diff --git a/extra/assembly/amd/autogen/cdna/ins.py b/extra/assembly/amd/autogen/cdna/ins.py index 671a4a81ae..0451136caf 100644 --- a/extra/assembly/amd/autogen/cdna/ins.py +++ b/extra/assembly/amd/autogen/cdna/ins.py @@ -7,20 +7,18 @@ import functools # instruction formats class DPP(Inst64): - encoding = bits[31:26] == 0b110110 - src1_sel = bits[58:56] - src1_sext = bits[59] - src1_neg = bits[60] - src1_abs = bits[61] - s1 = bits[63] - offset0 = bits[7:0] - offset1 = bits[15:8] - op = bits[24:17] - acc = bits[25] - addr:VGPRField = bits[39:32] - data0:VGPRField = bits[47:40] - data1:VGPRField = bits[55:48] - vdst:VGPRField = bits[63:56] + encoding = bits[8:0] == 0b11111010 + vop_op = bits[16:9] + vdst:VGPRField = bits[24:17] + vop2_op = bits[31:25] + src0:Src = bits[39:32] + dpp_ctrl = bits[48:40] + bound_ctrl = bits[51] + src0_neg = bits[52] + src0_abs = bits[53] + src1_neg = bits[54] + src1_abs = bits[55] + bank_mask = bits[59:56] row_mask = bits[63:60] class DS(Inst64): @@ -82,6 +80,10 @@ class MUBUF(Inst64): acc = bits[55] class SDWA(Inst64): + encoding = bits[8:0] == 0b11111001 + vop_op = bits[16:9] + vdst:VGPRField = bits[24:17] + vop2_op = bits[31:25] src0:Src = bits[39:32] dst_sel = bits[42:40] dst_u = bits[44:43] @@ -97,9 +99,6 @@ class SDWA(Inst64): src1_neg = bits[60] src1_abs = bits[61] s1 = bits[63] - sdst:SGPRField = bits[46:40] - sd = bits[47] - row_mask = bits[63:60] class SDWAB(Inst64): src0:Src = bits[39:32] diff --git a/extra/assembly/amd/autogen/rdna3/enum.py b/extra/assembly/amd/autogen/rdna3/enum.py index fc56dccda7..90bb9dff1c 100644 --- a/extra/assembly/amd/autogen/rdna3/enum.py +++ b/extra/assembly/amd/autogen/rdna3/enum.py @@ -488,6 +488,8 @@ class SMEMOp(IntEnum): S_BUFFER_LOAD_B512 = 12 S_GL1_INV = 32 S_DCACHE_INV = 33 + S_ATC_PROBE = 34 + S_ATC_PROBE_BUFFER = 35 class SOP1Op(IntEnum): S_MOV_B32 = 0 @@ -710,6 +712,8 @@ class SOPKOp(IntEnum): S_SETREG_B32 = 18 S_SETREG_IMM32_B32 = 19 S_CALL_B64 = 20 + S_SUBVECTOR_LOOP_BEGIN = 22 + S_SUBVECTOR_LOOP_END = 23 S_WAITCNT_VSCNT = 24 S_WAITCNT_VMCNT = 25 S_WAITCNT_EXPCNT = 26 @@ -751,6 +755,8 @@ class SOPPOp(IntEnum): S_SENDMSGHALT = 55 S_INCPERFLEVEL = 56 S_DECPERFLEVEL = 57 + S_TTRACEDATA = 58 + S_TTRACEDATA_IMM = 59 S_ICACHE_INV = 60 S_BARRIER = 61 diff --git a/extra/assembly/amd/autogen/rdna3/ins.py b/extra/assembly/amd/autogen/rdna3/ins.py index 1dd7c9f893..0c486bb4c7 100644 --- a/extra/assembly/amd/autogen/rdna3/ins.py +++ b/extra/assembly/amd/autogen/rdna3/ins.py @@ -692,6 +692,8 @@ s_buffer_load_b256 = functools.partial(SMEM, SMEMOp.S_BUFFER_LOAD_B256) s_buffer_load_b512 = functools.partial(SMEM, SMEMOp.S_BUFFER_LOAD_B512) s_gl1_inv = functools.partial(SMEM, SMEMOp.S_GL1_INV) s_dcache_inv = functools.partial(SMEM, SMEMOp.S_DCACHE_INV) +s_atc_probe = functools.partial(SMEM, SMEMOp.S_ATC_PROBE) +s_atc_probe_buffer = functools.partial(SMEM, SMEMOp.S_ATC_PROBE_BUFFER) s_mov_b32 = functools.partial(SOP1, SOP1Op.S_MOV_B32) s_mov_b64 = functools.partial(SOP1, SOP1Op.S_MOV_B64) s_cmov_b32 = functools.partial(SOP1, SOP1Op.S_CMOV_B32) @@ -906,6 +908,8 @@ s_getreg_b32 = functools.partial(SOPK, SOPKOp.S_GETREG_B32) s_setreg_b32 = functools.partial(SOPK, SOPKOp.S_SETREG_B32) s_setreg_imm32_b32 = functools.partial(SOPK, SOPKOp.S_SETREG_IMM32_B32) s_call_b64 = functools.partial(SOPK, SOPKOp.S_CALL_B64) +s_subvector_loop_begin = functools.partial(SOPK, SOPKOp.S_SUBVECTOR_LOOP_BEGIN) +s_subvector_loop_end = functools.partial(SOPK, SOPKOp.S_SUBVECTOR_LOOP_END) s_waitcnt_vscnt = functools.partial(SOPK, SOPKOp.S_WAITCNT_VSCNT) s_waitcnt_vmcnt = functools.partial(SOPK, SOPKOp.S_WAITCNT_VMCNT) s_waitcnt_expcnt = functools.partial(SOPK, SOPKOp.S_WAITCNT_EXPCNT) @@ -945,6 +949,8 @@ s_sendmsg = functools.partial(SOPP, SOPPOp.S_SENDMSG) s_sendmsghalt = functools.partial(SOPP, SOPPOp.S_SENDMSGHALT) s_incperflevel = functools.partial(SOPP, SOPPOp.S_INCPERFLEVEL) s_decperflevel = functools.partial(SOPP, SOPPOp.S_DECPERFLEVEL) +s_ttracedata = functools.partial(SOPP, SOPPOp.S_TTRACEDATA) +s_ttracedata_imm = functools.partial(SOPP, SOPPOp.S_TTRACEDATA_IMM) s_icache_inv = functools.partial(SOPP, SOPPOp.S_ICACHE_INV) s_barrier = functools.partial(SOPP, SOPPOp.S_BARRIER) v_interp_p10_f32 = functools.partial(VINTERP, VINTERPOp.V_INTERP_P10_F32) diff --git a/extra/assembly/amd/dsl.py b/extra/assembly/amd/dsl.py index 0e34374fd8..f374f7116d 100644 --- a/extra/assembly/amd/dsl.py +++ b/extra/assembly/amd/dsl.py @@ -458,10 +458,16 @@ class Inst: @classmethod def from_bytes(cls, data: bytes): + import typing inst = cls.from_int(int.from_bytes(data[:cls._size()], 'little')) op_val = inst._values.get('op', 0) - has_literal = cls.__name__ == 'VOP2' and op_val in (44, 45, 55, 56) - has_literal = has_literal or (cls.__name__ == 'SOP2' and op_val in (69, 70)) + # Check for instructions that always have a literal constant (FMAMK/FMAAK/MADMK/MADAK, SETREG_IMM32) + op_name = '' + if cls.__name__ in ('VOP2', 'SOP2', 'SOPK') and 'op' in (hints := typing.get_type_hints(cls, include_extras=True)): + if typing.get_origin(hints['op']) is typing.Annotated: + try: op_name = typing.get_args(hints['op'])[1](op_val).name + except (ValueError, TypeError): pass + has_literal = any(x in op_name for x in ('FMAMK', 'FMAAK', 'MADMK', 'MADAK', 'SETREG_IMM32')) # VOPD fmaak/fmamk always have a literal (opx/opy value 1 or 2) opx, opy = inst._values.get('opx', 0), inst._values.get('opy', 0) has_literal = has_literal or (cls.__name__ == 'VOPD' and (opx in (1, 2) or opy in (1, 2))) @@ -475,7 +481,7 @@ class Inst: lit32 = int.from_bytes(data[cls._size():cls._size()+4], 'little') # Find which source has literal (255) and check its register count lit_src_is_64 = False - for n, idx in [('src0', 0), ('src1', 1), ('src2', 2)]: + for n, idx in [('src0', 0), ('src1', 1), ('src2', 2), ('ssrc0', 0), ('ssrc1', 1)]: if n in inst._values and isinstance(inst._values[n], RawImm) and inst._values[n].val == 255: lit_src_is_64 = inst.src_regs(idx) == 2 break @@ -495,7 +501,12 @@ class Inst: return unwrap(self._values.get(name, 0)) def lit(self, v: int, neg: bool = False) -> str: - s = f"0x{self._literal:x}" if v == 255 and self._literal else decode_src(v) + if v == 255 and self._literal is not None: + # For 64-bit sources, literal is stored shifted - extract the 32-bit value + lit32 = (self._literal >> 32) if self._literal > 0xffffffff else self._literal + s = f"0x{lit32:x}" + else: + s = decode_src(v) return f"-{s}" if neg else s def __eq__(self, other): @@ -528,6 +539,10 @@ class Inst: elif val in self._VOP3SD_OPS: self.op = VOP3SDOp(val) else: self.op = VOP3Op(val) except ValueError: self.op = val + # Prefer BitField marker (class-specific enum) over _enum_map (generic RDNA3 enums) + elif 'op' in self._fields and (marker := self._fields['op'].marker) and issubclass(marker, IntEnum): + try: self.op = marker(val) + except ValueError: self.op = val elif cls_name in self._enum_map: try: self.op = self._enum_map[cls_name](val) except ValueError: self.op = val diff --git a/extra/assembly/amd/pdf.py b/extra/assembly/amd/pdf.py index 7079c2b07b..1728ce578b 100644 --- a/extra/assembly/amd/pdf.py +++ b/extra/assembly/amd/pdf.py @@ -194,13 +194,30 @@ def _parse_single_pdf(url: str): if fmt_name in formats: formats[fmt_name] = [(n, h, 14 if n == 'OP' else l, e, t) for n, h, l, e, t in formats[fmt_name]] if doc_name in ('RDNA3', 'RDNA3.5'): - if 'SOPPOp' in enums: assert 8 not in enums['SOPPOp']; enums['SOPPOp'][8] = 'S_WAITCNT_DEPCTR' + if 'SOPPOp' in enums: + for k, v in {8: 'S_WAITCNT_DEPCTR', 58: 'S_TTRACEDATA', 59: 'S_TTRACEDATA_IMM'}.items(): + assert k not in enums['SOPPOp']; enums['SOPPOp'][k] = v + if 'SOPKOp' in enums: + for k, v in {22: 'S_SUBVECTOR_LOOP_BEGIN', 23: 'S_SUBVECTOR_LOOP_END'}.items(): + assert k not in enums['SOPKOp']; enums['SOPKOp'][k] = v + if 'SMEMOp' in enums: + for k, v in {34: 'S_ATC_PROBE', 35: 'S_ATC_PROBE_BUFFER'}.items(): + assert k not in enums['SMEMOp']; enums['SMEMOp'][k] = v if 'DSOp' in enums: for k, v in {24: 'DS_GWS_SEMA_RELEASE_ALL', 25: 'DS_GWS_INIT', 26: 'DS_GWS_SEMA_V', 27: 'DS_GWS_SEMA_BR', 28: 'DS_GWS_SEMA_P', 29: 'DS_GWS_BARRIER'}.items(): assert k not in enums['DSOp']; enums['DSOp'][k] = v if 'FLATOp' in enums: for k, v in {40: 'GLOBAL_LOAD_ADDTID_B32', 41: 'GLOBAL_STORE_ADDTID_B32', 55: 'FLAT_ATOMIC_CSUB_U32'}.items(): assert k not in enums['FLATOp']; enums['FLATOp'][k] = v + # CDNA SDWA/DPP: PDF only has modifier fields, need VOP1/VOP2 overlay for correct encoding + if is_cdna: + if 'SDWA' in formats: + formats['SDWA'] = [('ENCODING', 8, 0, 0xf9, None), ('VOP_OP', 16, 9, None, None), ('VDST', 24, 17, None, 'VGPRField'), ('VOP2_OP', 31, 25, None, None)] + \ + [f for f in formats['SDWA'] if f[0] not in ('ENCODING', 'SDST', 'SD', 'ROW_MASK')] + if 'DPP' in formats: + formats['DPP'] = [('ENCODING', 8, 0, 0xfa, None), ('VOP_OP', 16, 9, None, None), ('VDST', 24, 17, None, 'VGPRField'), ('VOP2_OP', 31, 25, None, None), + ('SRC0', 39, 32, None, 'Src'), ('DPP_CTRL', 48, 40, None, None), ('BOUND_CTRL', 51, 51, None, None), ('SRC0_NEG', 52, 52, None, None), ('SRC0_ABS', 53, 53, None, None), + ('SRC1_NEG', 54, 54, None, None), ('SRC1_ABS', 55, 55, None, None), ('BANK_MASK', 59, 56, None, None), ('ROW_MASK', 63, 60, None, None)] # Extract pseudocode for instructions all_text = '\n'.join(pdf.text(i) for i in range(instr_start, instr_end)) diff --git a/extra/assembly/amd/test/test_llvm.py b/extra/assembly/amd/test/test_llvm.py index ca6010fdcf..5bd7779c9c 100644 --- a/extra/assembly/amd/test/test_llvm.py +++ b/extra/assembly/amd/test/test_llvm.py @@ -131,28 +131,19 @@ def _make_asm_test(name): def _make_disasm_test(name): def test(self): - _, fmt_cls, op_enum = LLVM_TEST_FILES[name] + _, base_fmt_cls, base_op_enum = LLVM_TEST_FILES[name] # VOP3SD opcodes that share encoding with VOP3 (only for vop3sd test, not vopc promotions) vop3sd_opcodes = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770} is_vopc_promotion = name in ('vop3_from_vopc', 'vop3_from_vopcx') - undocumented = {'smem': {34, 35}, 'sopk': {22, 23}, 'sopp': {8, 58, 59}} # First pass: decode all instructions and collect disasm strings to_test: list[tuple[str, bytes, str | None, str | None]] = [] # (asm_text, data, disasm_str, error) - skipped = 0 for asm_text, data in self.tests.get(name, []): - if len(data) > fmt_cls._size(): continue - temp_inst = fmt_cls.from_bytes(data) - temp_op = temp_inst._values.get('op', 0) - temp_op = temp_op.val if hasattr(temp_op, 'val') else temp_op - if temp_op in undocumented.get(name, set()): skipped += 1; continue - if name == 'sopp': - simm16 = temp_inst._values.get('simm16', 0) - simm16 = simm16.val if hasattr(simm16, 'val') else simm16 - sopp_no_imm = {48, 54, 53, 55, 60, 61, 62} - if temp_op in sopp_no_imm and simm16 != 0: skipped += 1; continue + # Detect VOP3 promotions in VOP1/VOP2/VOPC tests: VOP3 has bits [31:26]=0b110101 in first dword + is_vop3_enc = name in ('vop1', 'vop2', 'vopc', 'vopcx') and len(data) >= 4 and (data[3] >> 2) == 0x35 + fmt_cls, op_enum = (VOP3, VOP3Op) if is_vop3_enc else (base_fmt_cls, base_op_enum) try: - if fmt_cls.__name__ in ('VOP3', 'VOP3SD'): + if base_fmt_cls.__name__ in ('VOP3', 'VOP3SD'): temp = VOP3.from_bytes(data) op_val = temp._values.get('op', 0) op_val = op_val.val if hasattr(op_val, 'val') else op_val @@ -188,7 +179,7 @@ def _make_disasm_test(name): if llvm_bytes is not None and llvm_bytes == data: passed += 1 elif llvm_bytes is not None: failed += 1; failures.append(f"'{disasm_str}': expected={data.hex()} got={llvm_bytes.hex()}") - print(f"{name.upper()} disasm: {passed} passed, {failed} failed" + (f", {skipped} skipped" if skipped else "")) + print(f"{name.upper()} disasm: {passed} passed, {failed} failed") if failures[:10]: print(" " + "\n ".join(failures[:10])) self.assertEqual(failed, 0) return test diff --git a/extra/assembly/amd/test/test_llvm_cdna.py b/extra/assembly/amd/test/test_llvm_cdna.py new file mode 100644 index 0000000000..95bef12e13 --- /dev/null +++ b/extra/assembly/amd/test/test_llvm_cdna.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 +"""Test CDNA assembler/disassembler against LLVM test vectors.""" +import unittest, re, subprocess +from tinygrad.helpers import fetch +from extra.assembly.amd.autogen.cdna.ins import * +from extra.assembly.amd.asm import disasm +from extra.assembly.amd.test.helpers import get_llvm_mc + +LLVM_BASE = "https://raw.githubusercontent.com/llvm/llvm-project/main/llvm/test/MC/AMDGPU" + +def parse_llvm_tests(text: str, mnemonic_filter: str = None, size_filter: int = None) -> list[tuple[str, bytes]]: + """Parse LLVM test format into (asm, expected_bytes) pairs.""" + tests, lines = [], text.split('\n') + for i, line in enumerate(lines): + line = line.strip() + if not line or line.startswith(('//', '.', ';')): continue + asm_text = line.split('//')[0].strip() + if not asm_text or (mnemonic_filter and not asm_text.startswith(mnemonic_filter)): continue + for j in list(range(max(0, i - 3), i)) + list(range(i, min(i + 3, len(lines)))): + if m := re.search(r'(?:VI9|GFX9|CHECK)[^:]*:.*?encoding:\s*\[(.*?)\]', lines[j]): + hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '') + elif m := re.search(r'CHECK[^:]*:\s*\[(0x[0-9a-fA-F,x\s]+)\]', lines[j]): + hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '') + else: continue + try: + data = bytes.fromhex(hex_bytes) + if size_filter is None or len(data) == size_filter: tests.append((asm_text, data)) + except ValueError: pass + break + return tests + +# Use gfx9 tests for compatible scalar/vector formats and gfx90a/gfx942 tests for CDNA-specific instructions +# Format: (filename, format_class, op_enum, mcpu, mnemonic_filter, size_filter) +CDNA_TEST_FILES = { + # Scalar ALU - encoding is stable across GFX9/CDNA + 'sop1': ('gfx9_asm_sop1.s', SOP1, SOP1Op, 'gfx940', None, None), + 'sop2': ('gfx9_asm_sop2.s', SOP2, SOP2Op, 'gfx940', None, None), + 'sopp': ('gfx9_asm_sopp.s', SOPP, SOPPOp, 'gfx940', None, None), + 'sopp_gfx9': ('sopp-gfx9.s', SOPP, SOPPOp, 'gfx940', None, None), + 'sopk': ('gfx9_asm_sopk.s', SOPK, SOPKOp, 'gfx940', None, None), + 'sopc': ('gfx9_asm_sopc.s', SOPC, SOPCOp, 'gfx940', None, None), + # Vector ALU - encoding is mostly stable + 'vop1': ('gfx9_asm_vop1.s', VOP1, VOP1Op, 'gfx940', None, None), + 'vop1_gfx9': ('vop1-gfx9.s', VOP1, VOP1Op, 'gfx940', None, None), + 'vop2': ('gfx9_asm_vop2.s', VOP2, VOP2Op, 'gfx940', None, None), + 'vopc': ('gfx9_asm_vopc.s', VOPC, VOPCOp, 'gfx940', None, None), + 'vop3p': ('gfx9_asm_vop3p.s', VOP3P, VOP3POp, 'gfx940', None, None), + 'vop3_gfx9': ('vop3-gfx9.s', VOP3A, VOP3AOp, 'gfx940', None, 8), # Only 64-bit VOP3 instructions + # Memory instructions + 'ds': ('gfx9_asm_ds.s', DS, DSOp, 'gfx940', None, None), + 'ds_gfx9': ('ds-gfx9.s', DS, DSOp, 'gfx940', None, None), + # CDNA memory instructions (gfx90a has correct FLAT/MUBUF encodings with acc registers) + 'flat_gfx90a': ('gfx90a_ldst_acc.s', FLAT, FLATOp, 'gfx90a', 'flat_', None), + 'global_gfx90a': ('gfx90a_ldst_acc.s', FLAT, FLATOp, 'gfx90a', 'global_', None), + 'mubuf_gfx90a': ('gfx90a_ldst_acc.s', MUBUF, MUBUFOp, 'gfx90a', 'buffer_', None), + 'mubuf_gfx9': ('mubuf-gfx9.s', MUBUF, MUBUFOp, 'gfx940', None, None), + 'scratch_gfx942': ('flat-scratch-gfx942.s', FLAT, FLATOp, 'gfx942', 'scratch_', None), + # CDNA-specific: MFMA/MAI instructions + 'mai': ('mai-gfx942.s', VOP3P, VOP3POp, 'gfx942', None, None), + # SDWA and DPP format tests for VOP1 (VOP2 has different bit layout, tested separately) + 'sdwa_vop1': ('gfx9_asm_vop1.s', SDWA, VOP1Op, 'gfx940', None, None), + 'dpp_vop1': ('gfx9_asm_vop1.s', DPP, VOP1Op, 'gfx940', None, None), +} + +class TestLLVMCDNA(unittest.TestCase): + """Test CDNA instruction format decode/encode roundtrip and disassembly.""" + tests: dict[str, list[tuple[str, bytes]]] = {} + + @classmethod + def setUpClass(cls): + for name, (filename, _, _, _, mnemonic_filter, size_filter) in CDNA_TEST_FILES.items(): + try: + data = fetch(f"{LLVM_BASE}/{filename}").read_bytes() + cls.tests[name] = parse_llvm_tests(data.decode('utf-8', errors='ignore'), mnemonic_filter, size_filter) + except Exception as e: + print(f"Warning: couldn't fetch {filename}: {e}") + cls.tests[name] = [] + +def _get_val(v): return v.val if hasattr(v, 'val') else v + +def _filter_and_decode(tests, fmt_cls, op_enum): + """Filter tests and decode instructions, yielding (asm_text, data, decoded, error).""" + fn, is_sdwa, is_dpp = fmt_cls.__name__, fmt_cls.__name__ == 'SDWA', fmt_cls.__name__ == 'DPP' + for asm_text, data in tests: + has_lit = False + # SDWA/DPP format tests: only accept matching 8-byte instructions + if is_sdwa: + if len(data) != 8 or data[0] != 0xf9: continue + elif is_dpp: + if len(data) != 8 or data[0] != 0xfa: continue + elif fmt_cls._size() == 4 and len(data) == 8: + if data[0] in (0xf9, 0xfa): continue # Skip SDWA/DPP (tested separately) + has_lit = data[0] == 255 or (len(data) >= 2 and data[1] == 255 and fn in ('SOP2', 'SOPC')) + if fn == 'SOPK': has_lit = has_lit or ((int.from_bytes(data[:4], 'little') >> 23) & 0x1f) == 20 + if fn == 'VOP2': has_lit = has_lit or ((int.from_bytes(data[:4], 'little') >> 25) & 0x3f) in (23, 24, 36, 37) + if not has_lit: continue + if len(data) > fmt_cls._size() + (4 if has_lit else 0): continue + try: + decoded = fmt_cls.from_bytes(data) + # For SDWA/DPP, opcode location depends on VOP1 vs VOP2 + if is_sdwa or is_dpp: + vop2_op = _get_val(decoded._values.get('vop2_op', 0)) + op_val = _get_val(decoded._values.get('vop_op', 0)) if vop2_op == 0x3f else vop2_op + else: + op_val = _get_val(decoded._values.get('op', 0)) + try: op_enum(op_val) + except ValueError: continue + yield asm_text, data, decoded, None + except Exception as e: + yield asm_text, data, None, str(e) + +def _make_roundtrip_test(name): + def test(self): + _, fmt_cls, op_enum, _, _, _ = CDNA_TEST_FILES[name] + passed, failed, failures = 0, 0, [] + for asm_text, data, decoded, error in _filter_and_decode(self.tests.get(name, []), fmt_cls, op_enum): + if error: failed += 1; failures.append(f"'{asm_text}': {error}"); continue + if decoded.to_bytes()[:len(data)] == data: passed += 1 + else: failed += 1; failures.append(f"'{asm_text}': orig={data.hex()} reenc={decoded.to_bytes()[:len(data)].hex()}") + print(f"CDNA {name.upper()} roundtrip: {passed} passed, {failed} failed") + if failures[:5]: print(" " + "\n ".join(failures[:5])) + self.assertEqual(failed, 0) + return test + +def _make_disasm_test(name): + def test(self): + _, fmt_cls, op_enum, _, _, _ = CDNA_TEST_FILES[name] + passed, failed, failures = 0, 0, [] + for asm_text, data, decoded, error in _filter_and_decode(self.tests.get(name, []), fmt_cls, op_enum): + if error: failed += 1; failures.append(f"'{asm_text}': {error}"); continue + if decoded.to_bytes()[:len(data)] != data: failed += 1; failures.append(f"'{asm_text}': roundtrip failed"); continue + if not (disasm_text := disasm(decoded)) or not disasm_text.strip(): failed += 1; failures.append(f"'{asm_text}': empty disassembly"); continue + passed += 1 + print(f"CDNA {name.upper()} disasm: {passed} passed, {failed} failed") + if failures[:5]: print(" " + "\n ".join(failures[:5])) + self.assertEqual(failed, 0) + return test + +for name in CDNA_TEST_FILES: + setattr(TestLLVMCDNA, f'test_{name}_roundtrip', _make_roundtrip_test(name)) + setattr(TestLLVMCDNA, f'test_{name}_disasm', _make_disasm_test(name)) + +if __name__ == "__main__": + unittest.main()