From 49d1bf93d636f032316f7d6f2203f0ee38374080 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 30 Dec 2025 13:51:40 -0500 Subject: [PATCH] assembly/amd: refactor asm.py to be simpler (#13900) * assembly/amd: refactor asm.py * assembly/amd: refactor asm.py to be simpler * multiple fxns * fast * more tests pass * regen * stop decode --- extra/assembly/amd/asm.py | 1310 ++++++++---------- extra/assembly/amd/autogen/cdna/__init__.py | 54 +- extra/assembly/amd/autogen/rdna3/__init__.py | 64 +- extra/assembly/amd/dsl.py | 66 +- extra/assembly/amd/emu.py | 23 +- extra/assembly/amd/test/test_llvm.py | 14 +- extra/assembly/amd/test/test_roundtrip.py | 44 +- 7 files changed, 739 insertions(+), 836 deletions(-) diff --git a/extra/assembly/amd/asm.py b/extra/assembly/amd/asm.py index 3496795dc9..6b88eb336f 100644 --- a/extra/assembly/amd/asm.py +++ b/extra/assembly/amd/asm.py @@ -3,21 +3,65 @@ from __future__ import annotations import re from extra.assembly.amd.dsl import Inst, RawImm, Reg, SrcMod, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory, FLOAT_ENC, SRC_FIELDS, unwrap from extra.assembly.amd.dsl import VCC_LO, VCC_HI, VCC, EXEC_LO, EXEC_HI, EXEC, SCC, M0, NULL, OFF +from extra.assembly.amd.autogen.rdna3 import VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT, MUBUF, MTBUF, MIMG, EXP +from extra.assembly.amd.autogen.rdna3 import VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, VOPDOp, VINTERPOp +from extra.assembly.amd.autogen.rdna3 import SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, DSOp, FLATOp, MUBUFOp, MTBUFOp, MIMGOp + +# VOP3SD opcodes that share VOP3 encoding +VOP3SD_OPS = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770} + +def detect_format(data: bytes) -> type[Inst]: + """Detect instruction format from machine code bytes.""" + assert len(data) >= 4, f"need at least 4 bytes, got {len(data)}" + word = int.from_bytes(data[:4], 'little') + hi2 = (word >> 30) & 0x3 + if hi2 == 0b11: + enc = (word >> 26) & 0xf + if enc == 0b0010: return VOPD + if enc == 0b0011: return VOP3P + if enc == 0b0100: return VINTERP + if enc == 0b0101: return VOP3SD if ((word >> 16) & 0x3ff) in VOP3SD_OPS else VOP3 + if enc == 0b0110: return DS + if enc == 0b0111: return FLAT + if enc == 0b1000: return MUBUF + if enc == 0b1010: return MTBUF + if enc == 0b1100 or enc == 0b1111: return MIMG + if enc == 0b1101: return SMEM + if enc == 0b1110: return EXP + raise ValueError(f"unknown 64-bit format enc={enc:#06b} word={word:#010x}") + if hi2 == 0b10: + enc = (word >> 23) & 0x7f + if enc == 0b1111101: return SOP1 + if enc == 0b1111110: return SOPC + if enc == 0b1111111: return SOPP + return SOPK if ((word >> 28) & 0xf) == 0b1011 else SOP2 + # hi2 == 0b00 or 0b01: VOP1/VOP2/VOPC (bit 31 = 0) + assert (word >> 31) == 0, f"expected bit 31 = 0 for VOP, got word={word:#010x}" + enc = (word >> 25) & 0x7f + if enc == 0b0111110: return VOPC + if enc == 0b0111111: return VOP1 + if enc <= 0b0111101: return VOP2 + raise ValueError(f"unknown VOP format enc={enc:#09b} word={word:#010x}") + +# ═══════════════════════════════════════════════════════════════════════════════ +# CONSTANTS +# ═══════════════════════════════════════════════════════════════════════════════ -# Decoding helpers SPECIAL_GPRS = {106: "vcc_lo", 107: "vcc_hi", 124: "null", 125: "m0", 126: "exec_lo", 127: "exec_hi", 253: "scc"} SPECIAL_DEC = {**SPECIAL_GPRS, **{v: str(k) for k, v in FLOAT_ENC.items()}} -SPECIAL_PAIRS = {106: "vcc", 126: "exec"} # Special register pairs (for 64-bit ops) -# GFX11 hwreg names (IDs 16-17 are TBA - not supported, IDs 18-19 are PERF_SNAPSHOT) -HWREG_NAMES = {1: 'HW_REG_MODE', 2: 'HW_REG_STATUS', 3: 'HW_REG_TRAPSTS', 4: 'HW_REG_HW_ID', 5: 'HW_REG_GPR_ALLOC', - 6: 'HW_REG_LDS_ALLOC', 7: 'HW_REG_IB_STS', 15: 'HW_REG_SH_MEM_BASES', 18: 'HW_REG_PERF_SNAPSHOT_PC_LO', - 19: 'HW_REG_PERF_SNAPSHOT_PC_HI', 20: 'HW_REG_FLAT_SCR_LO', 21: 'HW_REG_FLAT_SCR_HI', - 22: 'HW_REG_XNACK_MASK', 23: 'HW_REG_HW_ID1', 24: 'HW_REG_HW_ID2', 25: 'HW_REG_POPS_PACKER', 28: 'HW_REG_IB_STS2'} -HWREG_IDS = {v.lower(): k for k, v in HWREG_NAMES.items()} # Reverse map for assembler -MSG_NAMES = {128: 'MSG_RTN_GET_DOORBELL', 129: 'MSG_RTN_GET_DDID', 130: 'MSG_RTN_GET_TMA', - 131: 'MSG_RTN_GET_REALTIME', 132: 'MSG_RTN_SAVE_WAVE', 133: 'MSG_RTN_GET_TBA'} -_16BIT_TYPES = ('f16', 'i16', 'u16', 'b16') -def _is_16bit(s: str) -> bool: return any(s.endswith(x) for x in _16BIT_TYPES) +SPECIAL_PAIRS = {106: "vcc", 126: "exec"} +HWREG = {1: 'HW_REG_MODE', 2: 'HW_REG_STATUS', 3: 'HW_REG_TRAPSTS', 4: 'HW_REG_HW_ID', 5: 'HW_REG_GPR_ALLOC', + 6: 'HW_REG_LDS_ALLOC', 7: 'HW_REG_IB_STS', 15: 'HW_REG_SH_MEM_BASES', 18: 'HW_REG_PERF_SNAPSHOT_PC_LO', + 19: 'HW_REG_PERF_SNAPSHOT_PC_HI', 20: 'HW_REG_FLAT_SCR_LO', 21: 'HW_REG_FLAT_SCR_HI', 22: 'HW_REG_XNACK_MASK', + 23: 'HW_REG_HW_ID1', 24: 'HW_REG_HW_ID2', 25: 'HW_REG_POPS_PACKER', 28: 'HW_REG_IB_STS2'} +HWREG_IDS = {v.lower(): k for k, v in HWREG.items()} +MSG = {128: 'MSG_RTN_GET_DOORBELL', 129: 'MSG_RTN_GET_DDID', 130: 'MSG_RTN_GET_TMA', + 131: 'MSG_RTN_GET_REALTIME', 132: 'MSG_RTN_SAVE_WAVE', 133: 'MSG_RTN_GET_TBA'} +VOP3SD_OPS = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770} + +# ═══════════════════════════════════════════════════════════════════════════════ +# HELPERS +# ═══════════════════════════════════════════════════════════════════════════════ def decode_src(val: int) -> str: if val <= 105: return f"s{val}" @@ -28,744 +72,596 @@ def decode_src(val: int) -> str: if 256 <= val <= 511: return f"v{val - 256}" return "lit" if val == 255 else f"?{val}" -def _reg(prefix: str, base: int, cnt: int = 1) -> str: return f"{prefix}{base}" if cnt == 1 else f"{prefix}[{base}:{base+cnt-1}]" -def _sreg(base: int, cnt: int = 1) -> str: return _reg("s", base, cnt) -def _vreg(base: int, cnt: int = 1) -> str: return _reg("v", base, cnt) +def _reg(p: str, b: int, n: int = 1) -> str: return f"{p}{b}" if n == 1 else f"{p}[{b}:{b+n-1}]" +def _sreg(b: int, n: int = 1) -> str: return _reg("s", b, n) +def _vreg(b: int, n: int = 1) -> str: return _reg("v", b, n) +def _hl(v: int, hi_thresh: int = 128) -> str: return 'h' if v >= hi_thresh else 'l' -def _fmt_sdst(v: int, cnt: int = 1) -> str: - """Format SGPR destination with special register names.""" +def _fmt_sdst(v: int, n: int = 1) -> str: if v == 124: return "null" - if 108 <= v <= 123: return _reg("ttmp", v - 108, cnt) - if cnt > 1 and v in SPECIAL_PAIRS: return SPECIAL_PAIRS[v] - if cnt > 1: return _sreg(v, cnt) + if 108 <= v <= 123: return _reg("ttmp", v - 108, n) + if n > 1: return SPECIAL_PAIRS.get(v) or _sreg(v, n) return {126: "exec_lo", 127: "exec_hi", 106: "vcc_lo", 107: "vcc_hi", 125: "m0"}.get(v, f"s{v}") -def _fmt_ssrc(v: int, cnt: int = 1) -> str: - """Format SGPR source with special register names and pairs.""" - if cnt == 2: - if v in SPECIAL_PAIRS: return SPECIAL_PAIRS[v] - if v <= 105: return _sreg(v, 2) - if 108 <= v <= 123: return _reg("ttmp", v - 108, 2) +def _fmt_src(v: int, n: int = 1) -> str: + if n == 1: return decode_src(v) + if v >= 256: return _vreg(v - 256, n) + if v <= 105: return _sreg(v, n) + if n == 2 and v in SPECIAL_PAIRS: return SPECIAL_PAIRS[v] + if 108 <= v <= 123: return _reg("ttmp", v - 108, n) return decode_src(v) -def _fmt_src_n(v: int, cnt: int) -> str: - """Format source with given register count (1, 2, or 4).""" - if cnt == 1: return decode_src(v) - if v >= 256: return _vreg(v - 256, cnt) - if v <= 105: return _sreg(v, cnt) - if cnt == 2 and v in SPECIAL_PAIRS: return SPECIAL_PAIRS[v] - if 108 <= v <= 123: return _reg("ttmp", v - 108, cnt) - return decode_src(v) +def _fmt_v16(v: int, base: int = 256, hi_thresh: int = 384) -> str: + return f"v{(v - base) & 0x7f}.{_hl(v, hi_thresh)}" -def _fmt_src64(v: int) -> str: - """Format 64-bit source (VGPR pair, SGPR pair, or special pair).""" - return _fmt_src_n(v, 2) - -def _parse_sop_sizes(op_name: str) -> tuple[int, ...]: - """Parse dst and src sizes from SOP instruction name. Returns (dst_cnt, src0_cnt) or (dst_cnt, src0_cnt, src1_cnt).""" - if op_name in ('s_bitset0_b64', 's_bitset1_b64'): return (2, 1) - if op_name in ('s_lshl_b64', 's_lshr_b64', 's_ashr_i64', 's_bfe_u64', 's_bfe_i64'): return (2, 2, 1) - if op_name in ('s_bfm_b64',): return (2, 1, 1) - # SOPC: s_bitcmp0_b64, s_bitcmp1_b64 - 64-bit src0, 32-bit src1 (bit index) - if op_name in ('s_bitcmp0_b64', 's_bitcmp1_b64'): return (1, 2, 1) - if m := re.search(r'_(b|i|u)(32|64)_(b|i|u)(32|64)$', op_name): - return (2 if m.group(2) == '64' else 1, 2 if m.group(4) == '64' else 1) - if m := re.search(r'_(b|i|u)(32|64)$', op_name): - sz = 2 if m.group(2) == '64' else 1 - return (sz, sz) - return (1, 1) - -# Waitcnt helpers (RDNA3 format: bits 15:10=vmcnt, bits 9:4=lgkmcnt, bits 3:0=expcnt) def waitcnt(vmcnt: int = 0x3f, expcnt: int = 0x7, lgkmcnt: int = 0x3f) -> int: return (expcnt & 0x7) | ((lgkmcnt & 0x3f) << 4) | ((vmcnt & 0x3f) << 10) -def decode_waitcnt(val: int) -> tuple[int, int, int]: - return (val >> 10) & 0x3f, val & 0xf, (val >> 4) & 0x3f # vmcnt, expcnt, lgkmcnt -# VOP3SD opcodes (shared encoding with VOP3 but different field layout) -# Note: opcodes 0-255 are VOPC promoted to VOP3 - never treat as VOP3SD -VOP3SD_OPCODES = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770} +def _has(op: str, *subs) -> bool: return any(s in op for s in subs) +def _is16(op: str) -> bool: return _has(op, 'f16', 'i16', 'u16', 'b16') and not _has(op, '_f32', '_i32') +def _is64(op: str) -> bool: return _has(op, 'f64', 'i64', 'u64', 'b64') +def _omod(v: int) -> str: return {1: " mul:2", 2: " mul:4", 3: " div:2"}.get(v, "") +def _mods(*pairs) -> str: return " ".join(m for c, m in pairs if c) +def _fmt_bits(label: str, val: int, count: int) -> str: return f"{label}:[{','.join(str((val >> i) & 1) for i in range(count))}]" -# Disassembler -def disasm(inst: Inst) -> str: - op_val = unwrap(inst._values.get('op', 0)) - cls_name = inst.__class__.__name__ - # VOP3 and VOP3SD share encoding - check opcode to determine which - is_vop3sd = cls_name == 'VOP3' and op_val in VOP3SD_OPCODES - try: - from extra.assembly.amd.autogen import rdna3 as autogen - if is_vop3sd: - op_name = autogen.VOP3SDOp(op_val).name.lower() - else: - op_name = getattr(autogen, f"{cls_name}Op")(op_val).name.lower() if hasattr(autogen, f"{cls_name}Op") else f"op_{op_val}" - except (ValueError, KeyError): op_name = f"op_{op_val}" - def fmt_src(v): return f"0x{inst._literal:x}" if v == 255 and inst._literal is not None else decode_src(v) +def _vop3_src(inst, v: int, neg: int, abs_: int, hi: int, n: int, f16: bool, any_hi: bool) -> str: + """Format VOP3 source operand with modifiers.""" + if n > 1: s = _fmt_src(v, n) + elif f16 and v >= 256: s = f"v{v - 256}.h" if hi else (f"v{v - 256}.l" if any_hi else inst.lit(v)) + else: s = inst.lit(v) + if abs_: s = f"|{s}|" + return f"-{s}" if neg else s - # VOP1 - if cls_name == 'VOP1': - vdst, src0 = unwrap(inst._values['vdst']), unwrap(inst._values['src0']) - if op_name == 'v_nop': return 'v_nop' - if op_name == 'v_pipeflush': return 'v_pipeflush' - parts = op_name.split('_') - is_16bit_dst = any(p in _16BIT_TYPES for p in parts[-2:-1]) or (len(parts) >= 2 and parts[-1] in _16BIT_TYPES and 'cvt' not in op_name) - is_16bit_src = parts[-1] in _16BIT_TYPES and 'sat_pk' not in op_name - _F64_OPS = ('v_ceil_f64', 'v_floor_f64', 'v_fract_f64', 'v_frexp_mant_f64', 'v_rcp_f64', 'v_rndne_f64', 'v_rsq_f64', 'v_sqrt_f64', 'v_trunc_f64') - is_f64_dst = op_name in _F64_OPS or op_name in ('v_cvt_f64_f32', 'v_cvt_f64_i32', 'v_cvt_f64_u32') - is_f64_src = op_name in _F64_OPS or op_name in ('v_cvt_f32_f64', 'v_cvt_i32_f64', 'v_cvt_u32_f64', 'v_frexp_exp_i32_f64') - if op_name == 'v_readfirstlane_b32': - return f"v_readfirstlane_b32 {decode_src(vdst)}, v{src0 - 256 if src0 >= 256 else src0}" - dst_str = _vreg(vdst, 2) if is_f64_dst else f"v{vdst & 0x7f}.{'h' if vdst >= 128 else 'l'}" if is_16bit_dst else f"v{vdst}" - src_str = _fmt_src64(src0) if is_f64_src else f"v{(src0 - 256) & 0x7f}.{'h' if src0 >= 384 else 'l'}" if is_16bit_src and src0 >= 256 else fmt_src(src0) - return f"{op_name}_e32 {dst_str}, {src_str}" +def _opsel_str(opsel: int, n: int, need: bool, is16_d: bool) -> str: + """Format op_sel modifier string.""" + if not need: return "" + if is16_d and (opsel & 8): return f" op_sel:[1,1,1{',1' if n == 3 else ''}]" + if n == 3: return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{(opsel >> 2) & 1},{(opsel >> 3) & 1}]" + return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{(opsel >> 2) & 1}]" - # VOP2 - if cls_name == 'VOP2': - vdst, src0_raw, vsrc1 = unwrap(inst._values['vdst']), unwrap(inst._values['src0']), unwrap(inst._values['vsrc1']) - suffix = "" if op_name == "v_dot2acc_f32_f16" else "_e32" - is_16bit_op = ('_f16' in op_name or '_i16' in op_name or '_u16' in op_name) and '_f32' not in op_name and '_i32' not in op_name and 'pk_' not in op_name - if is_16bit_op: - dst_str = f"v{vdst & 0x7f}.{'h' if vdst >= 128 else 'l'}" - src0_str = f"v{(src0_raw - 256) & 0x7f}.{'h' if src0_raw >= 384 else 'l'}" if src0_raw >= 256 else fmt_src(src0_raw) - vsrc1_str = f"v{vsrc1 & 0x7f}.{'h' if vsrc1 >= 128 else 'l'}" - else: - dst_str, src0_str, vsrc1_str = f"v{vdst}", fmt_src(src0_raw), f"v{vsrc1}" - return f"{op_name}{suffix} {dst_str}, {src0_str}, {vsrc1_str}" + (", vcc_lo" if op_name == "v_cndmask_b32" else "") +# ═══════════════════════════════════════════════════════════════════════════════ +# DISASSEMBLER +# ═══════════════════════════════════════════════════════════════════════════════ - # VOPC - if cls_name == 'VOPC': - src0, vsrc1 = unwrap(inst._values['src0']), unwrap(inst._values['vsrc1']) - is_64bit = any(x in op_name for x in ('f64', 'i64', 'u64')) - is_64bit_vsrc1 = is_64bit and 'class' not in op_name - is_16bit = any(x in op_name for x in ('_f16', '_i16', '_u16')) and 'f32' not in op_name - is_cmpx = op_name.startswith('v_cmpx') # VOPCX writes to exec, no vcc destination - src0_str = _fmt_src64(src0) if is_64bit else f"v{(src0 - 256) & 0x7f}.{'h' if src0 >= 384 else 'l'}" if is_16bit and src0 >= 256 else fmt_src(src0) - vsrc1_str = _vreg(vsrc1, 2) if is_64bit_vsrc1 else f"v{vsrc1 & 0x7f}.{'h' if vsrc1 >= 128 else 'l'}" if is_16bit else f"v{vsrc1}" - return f"{op_name}_e32 {src0_str}, {vsrc1_str}" if is_cmpx else f"{op_name}_e32 vcc_lo, {src0_str}, {vsrc1_str}" +def _disasm_vop1(inst: VOP1) -> str: + op = VOP1Op(inst.op) + if op in (VOP1Op.V_NOP, VOP1Op.V_PIPEFLUSH): return op.name.lower() + F64_OPS = {VOP1Op.V_CEIL_F64, VOP1Op.V_FLOOR_F64, VOP1Op.V_FRACT_F64, VOP1Op.V_FREXP_MANT_F64, VOP1Op.V_RCP_F64, VOP1Op.V_RNDNE_F64, VOP1Op.V_RSQ_F64, VOP1Op.V_SQRT_F64, VOP1Op.V_TRUNC_F64} + is_f64_d = op in F64_OPS or op in (VOP1Op.V_CVT_F64_F32, VOP1Op.V_CVT_F64_I32, VOP1Op.V_CVT_F64_U32) + is_f64_s = op in F64_OPS or op in (VOP1Op.V_CVT_F32_F64, VOP1Op.V_CVT_I32_F64, VOP1Op.V_CVT_U32_F64, VOP1Op.V_FREXP_EXP_I32_F64) + name = op.name.lower() + parts = name.split('_') + is_16d = any(p in ('f16','i16','u16','b16') for p in parts[-2:-1]) or (len(parts) >= 2 and parts[-1] in ('f16','i16','u16','b16') and 'cvt' not in name) + is_16s = parts[-1] in ('f16','i16','u16','b16') and 'sat_pk' not in name + if op == VOP1Op.V_READFIRSTLANE_B32: return f"v_readfirstlane_b32 {decode_src(inst.vdst)}, v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}" + dst = _vreg(inst.vdst, 2) if is_f64_d else _fmt_v16(inst.vdst, 0, 128) if is_16d else f"v{inst.vdst}" + src = _fmt_src(inst.src0, 2) if is_f64_s else _fmt_v16(inst.src0) if is_16s and inst.src0 >= 256 else inst.lit(inst.src0) + return f"{name}_e32 {dst}, {src}" - # SOPP - if cls_name == 'SOPP': - simm16 = unwrap(inst._values.get('simm16', 0)) - # No-operand instructions (simm16 is ignored) - no_imm_ops = ('s_endpgm', 's_barrier', 's_wakeup', 's_icache_inv', 's_ttracedata', 's_ttracedata_imm', - 's_wait_idle', 's_endpgm_saved', 's_code_end', 's_endpgm_ordered_ps_done') - if op_name in no_imm_ops: return op_name - if op_name == 's_waitcnt': - vmcnt, expcnt, lgkmcnt = decode_waitcnt(simm16) - parts = [] - if vmcnt != 0x3f: parts.append(f"vmcnt({vmcnt})") - if expcnt != 0x7: parts.append(f"expcnt({expcnt})") - if lgkmcnt != 0x3f: parts.append(f"lgkmcnt({lgkmcnt})") - return f"s_waitcnt {' '.join(parts)}" if parts else "s_waitcnt 0" - if op_name == 's_delay_alu': - dep_names = ['VALU_DEP_1','VALU_DEP_2','VALU_DEP_3','VALU_DEP_4','TRANS32_DEP_1','TRANS32_DEP_2','TRANS32_DEP_3','FMA_ACCUM_CYCLE_1','SALU_CYCLE_1','SALU_CYCLE_2','SALU_CYCLE_3'] - skip_names = ['SAME','NEXT','SKIP_1','SKIP_2','SKIP_3','SKIP_4'] - id0, skip, id1 = simm16 & 0xf, (simm16 >> 4) & 0x7, (simm16 >> 7) & 0xf - def dep_name(v): return dep_names[v-1] if 0 < v <= len(dep_names) else str(v) - parts = [f"instid0({dep_name(id0)})"] if id0 else [] - if skip: parts.append(f"instskip({skip_names[skip]})") - if id1: parts.append(f"instid1({dep_name(id1)})") - return f"s_delay_alu {' | '.join(p for p in parts if p)}" if parts else "s_delay_alu 0" - if op_name.startswith('s_cbranch') or op_name.startswith('s_branch'): - return f"{op_name} {simm16}" - # Most SOPP ops require immediate (s_nop, s_setkill, s_sethalt, s_sleep, s_setprio, s_sendmsg*, etc.) - return f"{op_name} 0x{simm16:x}" +def _disasm_vop2(inst: VOP2) -> str: + op = VOP2Op(inst.op) + name = op.name.lower() + suf = "" if op == VOP2Op.V_DOT2ACC_F32_F16 else "_e32" + is16 = _is16(name) and 'pk_' not in name + # fmaak: dst = src0 * vsrc1 + K, fmamk: dst = src0 * K + vsrc1 + if op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, 0x{inst._literal:x}" + if op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, 0x{inst._literal:x}, v{inst.vsrc1}" + if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_fmt_v16(inst.src0) if inst.src0 >= 256 else inst.lit(inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}" + return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}" + (", vcc_lo" if op == VOP2Op.V_CNDMASK_B32 else "") - # SMEM - if cls_name == 'SMEM': - if op_name in ('s_gl1_inv', 's_dcache_inv'): return op_name - sdata, sbase, soffset, offset = unwrap(inst._values['sdata']), unwrap(inst._values['sbase']), unwrap(inst._values['soffset']), unwrap(inst._values.get('offset', 0)) - glc, dlc = unwrap(inst._values.get('glc', 0)), unwrap(inst._values.get('dlc', 0)) - # Format offset: "soffset offset:X" if both, "0x{offset:x}" if only imm, or decode_src(soffset) - off_str = f"{decode_src(soffset)} offset:0x{offset:x}" if offset and soffset != 124 else f"0x{offset:x}" if offset else decode_src(soffset) - sbase_idx, sbase_cnt = sbase * 2, 4 if (8 <= op_val <= 12 or op_name == 's_atc_probe_buffer') else 2 - sbase_str = _fmt_ssrc(sbase_idx, sbase_cnt) if sbase_cnt == 2 else _sreg(sbase_idx, sbase_cnt) if sbase_idx <= 105 else _reg("ttmp", sbase_idx - 108, sbase_cnt) - if op_name in ('s_atc_probe', 's_atc_probe_buffer'): return f"{op_name} {sdata}, {sbase_str}, {off_str}" - width = {0:1, 1:2, 2:4, 3:8, 4:16, 8:1, 9:2, 10:4, 11:8, 12:16}.get(op_val, 1) - mods = [m for m in ["glc" if glc else "", "dlc" if dlc else ""] if m] - return f"{op_name} {_fmt_sdst(sdata, width)}, {sbase_str}, {off_str}" + (" " + " ".join(mods) if mods else "") +VOPC_CLASS = {VOPCOp.V_CMP_CLASS_F16, VOPCOp.V_CMP_CLASS_F32, VOPCOp.V_CMP_CLASS_F64, + VOPCOp.V_CMPX_CLASS_F16, VOPCOp.V_CMPX_CLASS_F32, VOPCOp.V_CMPX_CLASS_F64} - # FLAT - if cls_name == 'FLAT': - vdst, addr, data, saddr, offset, seg = [unwrap(inst._values.get(f, 0)) for f in ['vdst', 'addr', 'data', 'saddr', 'offset', 'seg']] - instr = f"{['flat', 'scratch', 'global'][seg] if seg < 3 else 'flat'}_{op_name.split('_', 1)[1] if '_' in op_name else op_name}" - width = {'b32':1, 'b64':2, 'b96':3, 'b128':4, 'u8':1, 'i8':1, 'u16':1, 'i16':1}.get(op_name.split('_')[-1], 1) - addr_str = _vreg(addr, 2) if saddr == 0x7F else _vreg(addr) - saddr_str = "" if saddr == 0x7F else f", {_sreg(saddr, 2)}" if saddr < 106 else ", off" if saddr == 124 else f", {decode_src(saddr)}" - off_str = f" offset:{offset}" if offset else "" - vdata_str = _vreg(data if 'store' in op_name else vdst, width) - return f"{instr} {addr_str}, {vdata_str}{saddr_str}{off_str}" if 'store' in op_name else f"{instr} {vdata_str}, {addr_str}{saddr_str}{off_str}" +def _disasm_vopc(inst: VOPC) -> str: + op = VOPCOp(inst.op) + name = op.name.lower() + is64, is16 = _is64(name), _is16(name) + s0 = _fmt_src(inst.src0, 2) if is64 else _fmt_v16(inst.src0) if is16 and inst.src0 >= 256 else inst.lit(inst.src0) + s1 = _vreg(inst.vsrc1, 2) if is64 and op not in VOPC_CLASS else _fmt_v16(inst.vsrc1, 0, 128) if is16 else f"v{inst.vsrc1}" + return f"{name}_e32 {s0}, {s1}" if op.value >= 128 else f"{name}_e32 vcc_lo, {s0}, {s1}" - # VOP3: vector ops with modifiers (can be 1, 2, or 3 sources depending on opcode range) - if cls_name == 'VOP3': - # Handle VOP3SD opcodes (same encoding, different field layout) - if is_vop3sd: - vdst = unwrap(inst._values.get('vdst', 0)) - # VOP3SD: sdst is at bits [14:8], but VOP3 decodes opsel at [14:11], abs at [10:8], clmp at [15] - # We need to reconstruct sdst from these fields - opsel_raw = unwrap(inst._values.get('opsel', 0)) - abs_raw = unwrap(inst._values.get('abs', 0)) - clmp_raw = unwrap(inst._values.get('clmp', 0)) - sdst = (clmp_raw << 7) | (opsel_raw << 3) | abs_raw - src0, src1, src2 = [unwrap(inst._values.get(f, 0)) for f in ('src0', 'src1', 'src2')] - neg = unwrap(inst._values.get('neg', 0)) - omod = unwrap(inst._values.get('omod', 0)) - omod_str = {1: " mul:2", 2: " mul:4", 3: " div:2"}.get(omod, "") - is_f64 = 'f64' in op_name - # v_mad_i64_i32/v_mad_u64_u32: 64-bit dst and src2, 32-bit src0/src1 - is_mad64 = 'mad_i64_i32' in op_name or 'mad_u64_u32' in op_name - def fmt_sd_src(v, neg_bit, is_64bit=False): - s = _fmt_src64(v) if (is_64bit or is_f64) else fmt_src(v) - return f"-{s}" if neg_bit else s - src0_str, src1_str = fmt_sd_src(src0, neg & 1), fmt_sd_src(src1, neg & 2) - src2_str = fmt_sd_src(src2, neg & 4, is_mad64) - dst_str = _vreg(vdst, 2) if (is_f64 or is_mad64) else f"v{vdst}" - sdst_str = _fmt_sdst(sdst, 1) - # v_add_co_u32, v_sub_co_u32, v_subrev_co_u32 only use 2 sources - if op_name in ('v_add_co_u32', 'v_sub_co_u32', 'v_subrev_co_u32'): - return f"{op_name} {dst_str}, {sdst_str}, {src0_str}, {src1_str}" - # v_add_co_ci_u32, v_sub_co_ci_u32, v_subrev_co_ci_u32 use 3 sources (src2 is carry-in) - if op_name in ('v_add_co_ci_u32', 'v_sub_co_ci_u32', 'v_subrev_co_ci_u32'): - return f"{op_name} {dst_str}, {sdst_str}, {src0_str}, {src1_str}, {src2_str}" - # v_div_scale uses 3 sources - return f"{op_name} {dst_str}, {sdst_str}, {src0_str}, {src1_str}, {src2_str}" + omod_str +NO_ARG_SOPP = {SOPPOp.S_ENDPGM, SOPPOp.S_BARRIER, SOPPOp.S_WAKEUP, SOPPOp.S_ICACHE_INV, + SOPPOp.S_WAIT_IDLE, SOPPOp.S_ENDPGM_SAVED, SOPPOp.S_CODE_END, SOPPOp.S_ENDPGM_ORDERED_PS_DONE} - vdst = unwrap(inst._values.get('vdst', 0)) - src0, src1, src2 = [unwrap(inst._values.get(f, 0)) for f in ('src0', 'src1', 'src2')] - neg, abs_, clmp = unwrap(inst._values.get('neg', 0)), unwrap(inst._values.get('abs', 0)), unwrap(inst._values.get('clmp', 0)) - opsel = unwrap(inst._values.get('opsel', 0)) - # Check if 64-bit op (needs register pairs) - is_f64 = 'f64' in op_name or 'i64' in op_name or 'u64' in op_name or 'b64' in op_name - # v_cmp_class_* has 64-bit src0 but 32-bit src1 (class mask) - is_class = 'class' in op_name - # Shift ops: v_*rev_*64 have 32-bit shift amount (src0), 64-bit value (src1) - is_shift64 = 'rev' in op_name and '64' in op_name and op_name.startswith('v_') - # v_ldexp_f64: 64-bit src0 (mantissa), 32-bit src1 (exponent) - is_ldexp64 = op_name == 'v_ldexp_f64' - # v_trig_preop_f64: 64-bit dst/src0, 32-bit src1 (exponent/scale) - is_trig_preop = op_name == 'v_trig_preop_f64' - # v_readlane_b32: destination is SGPR (despite vdst field) - is_readlane = op_name == 'v_readlane_b32' - # SAD/QSAD/MQSAD instructions have mixed sizes - # v_qsad_pk_u16_u8, v_mqsad_pk_u16_u8: 64-bit dst/src0/src2, 32-bit src1 - # v_mqsad_u32_u8: 128-bit (4 reg) dst/src2, 64-bit src0, 32-bit src1 - is_sad64 = any(x in op_name for x in ('qsad_pk', 'mqsad_pk')) - is_mqsad_u32 = 'mqsad_u32' in op_name - # Detect 16-bit and 64-bit operand sizes for various instruction patterns - if 'cvt_pk' in op_name: - is_f16_dst, is_f16_src, is_f16_src2 = False, op_name.endswith('16'), False - elif m := re.match(r'v_(?:cvt|frexp_exp)_([a-z0-9_]+)_([a-z0-9]+)', op_name): - dst_type, src_type = m.group(1), m.group(2) - is_f16_dst, is_f16_src, is_f16_src2 = _is_16bit(dst_type), _is_16bit(src_type), _is_16bit(src_type) - is_f64_dst, is_f64_src, is_f64 = '64' in dst_type, '64' in src_type, False - elif re.match(r'v_mad_[iu]32_[iu]16', op_name): - is_f16_dst, is_f16_src, is_f16_src2 = False, True, False # 32-bit dst, 16-bit src0/src1, 32-bit src2 - elif 'pack_b32' in op_name: - is_f16_dst, is_f16_src, is_f16_src2 = False, True, True # 32-bit dst, 16-bit sources - else: - is_16bit_op = any(x in op_name for x in _16BIT_TYPES) and not any(x in op_name for x in ('dot2', 'pk_', 'sad', 'msad', 'qsad', 'mqsad')) - is_f16_dst = is_f16_src = is_f16_src2 = is_16bit_op - # Check if any opsel bit is set (any operand uses .h) - if so, we need explicit .l for low-half - any_hi = opsel != 0 - def fmt_vop3_src(v, neg_bit, abs_bit, hi_bit=False, reg_cnt=1, is_16=False): - s = _fmt_src_n(v, reg_cnt) if reg_cnt > 1 else f"v{v - 256}.h" if is_16 and v >= 256 and hi_bit else f"v{v - 256}.l" if is_16 and v >= 256 and any_hi else fmt_src(v) - if abs_bit: s = f"|{s}|" - return f"-{s}" if neg_bit else s - # Determine register count for each source (check for cvt-specific 64-bit flags first) - is_src0_64 = locals().get('is_f64_src', is_f64 and not is_shift64) or is_sad64 or is_mqsad_u32 - is_src1_64 = is_f64 and not is_class and not is_ldexp64 and not is_trig_preop - src0_cnt = 2 if is_src0_64 else 1 - src1_cnt = 2 if is_src1_64 else 1 - src2_cnt = 4 if is_mqsad_u32 else 2 if (is_f64 or is_sad64) else 1 - src0_str = fmt_vop3_src(src0, neg & 1, abs_ & 1, opsel & 1, src0_cnt, is_f16_src) - src1_str = fmt_vop3_src(src1, neg & 2, abs_ & 2, opsel & 2, src1_cnt, is_f16_src) - src2_str = fmt_vop3_src(src2, neg & 4, abs_ & 4, opsel & 4, src2_cnt, is_f16_src2) - # Format destination - for 16-bit ops, use .h/.l suffix; readlane uses SGPR dest - is_dst_64 = locals().get('is_f64_dst', is_f64) or is_sad64 - dst_cnt = 4 if is_mqsad_u32 else 2 if is_dst_64 else 1 - if is_readlane: - dst_str = _fmt_sdst(vdst, 1) - elif dst_cnt > 1: - dst_str = _vreg(vdst, dst_cnt) - elif is_f16_dst: - dst_str = f"v{vdst}.h" if (opsel & 8) else f"v{vdst}.l" if any_hi else f"v{vdst}" - else: - dst_str = f"v{vdst}" - clamp_str = " clamp" if clmp else "" - omod = unwrap(inst._values.get('omod', 0)) - omod_str = {1: " mul:2", 2: " mul:4", 3: " div:2"}.get(omod, "") - # op_sel for non-VGPR sources (when opsel bits are set but source is not a VGPR) - # For 16-bit ops with VGPR sources, opsel is encoded in .h/.l suffix - # For non-VGPR sources or non-16-bit ops, we need explicit op_sel - has_nonvgpr_opsel = (src0 < 256 and (opsel & 1)) or (src1 < 256 and (opsel & 2)) or (src2 < 256 and (opsel & 4)) - need_opsel = has_nonvgpr_opsel or (opsel and not is_f16_src) - # Helper to format opsel string based on source count - def fmt_opsel(num_src): - if not need_opsel: return "" - # When dst is .h (for 16-bit ops) and non-VGPR sources have opsel, use all 1s - if is_f16_dst and (opsel & 8): # dst is .h - return f" op_sel:[1,1,1{',1' if num_src == 3 else ''}]" - # Otherwise output actual opsel values - if num_src == 3: - return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{(opsel >> 2) & 1},{(opsel >> 3) & 1}]" - return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{(opsel >> 2) & 1}]" - # Determine number of sources based on opcode range: - # 0-255: VOPC promoted (comparison, 2 src, sdst) - # 256-383: VOP2 promoted (2 src) - # 384-511: VOP1 promoted (1 src) - # 512+: Native VOP3 (2 or 3 src depending on instruction) - if op_val < 256: # VOPC promoted - # VOPCX (v_cmpx_*) writes to exec, no explicit destination - if op_name.startswith('v_cmpx'): - return f"{op_name}_e64 {src0_str}, {src1_str}" - return f"{op_name}_e64 {_fmt_sdst(vdst, 1)}, {src0_str}, {src1_str}" - elif op_val < 384: # VOP2 promoted - # v_cndmask_b32 in VOP3 format has 3 sources (src2 is mask selector) - if 'cndmask' in op_name: - return f"{op_name}_e64 {dst_str}, {src0_str}, {src1_str}, {src2_str}" + fmt_opsel(3) + clamp_str + omod_str - return f"{op_name}_e64 {dst_str}, {src0_str}, {src1_str}" + fmt_opsel(2) + clamp_str + omod_str - elif op_val < 512: # VOP1 promoted - if op_name in ('v_nop', 'v_pipeflush'): return f"{op_name}_e64" - return f"{op_name}_e64 {dst_str}, {src0_str}" + fmt_opsel(1) + clamp_str + omod_str - else: # Native VOP3 - determine 2 vs 3 sources based on instruction name - # 3-source ops: fma, mad, min3, max3, med3, div_fixup, div_fmas, sad, msad, qsad, mqsad, lerp, alignbit/byte, cubeid/sc/tc/ma, bfe, bfi, perm_b32, permlane, cndmask - # Note: v_writelane_b32 is 2-src (src0, src1 with vdst as 3rd operand - read-modify-write) - is_3src = any(x in op_name for x in ('fma', 'mad', 'min3', 'max3', 'med3', 'div_fix', 'div_fmas', 'sad', 'lerp', 'align', 'cube', - 'bfe', 'bfi', 'perm_b32', 'permlane', 'cndmask', 'xor3', 'or3', 'add3', 'lshl_or', 'and_or', 'lshl_add', - 'add_lshl', 'xad', 'maxmin', 'minmax', 'dot2', 'cvt_pk_u8', 'mullit')) - if is_3src: - return f"{op_name} {dst_str}, {src0_str}, {src1_str}, {src2_str}" + fmt_opsel(3) + clamp_str + omod_str - return f"{op_name} {dst_str}, {src0_str}, {src1_str}" + fmt_opsel(2) + clamp_str + omod_str +def _disasm_sopp(inst: SOPP) -> str: + op, name = SOPPOp(inst.op), SOPPOp(inst.op).name.lower() + if op in NO_ARG_SOPP: return name + if op == SOPPOp.S_WAITCNT: + vm, exp, lgkm = (inst.simm16 >> 10) & 0x3f, inst.simm16 & 0xf, (inst.simm16 >> 4) & 0x3f + p = [f"vmcnt({vm})" if vm != 0x3f else "", f"expcnt({exp})" if exp != 7 else "", f"lgkmcnt({lgkm})" if lgkm != 0x3f else ""] + return f"s_waitcnt {' '.join(x for x in p if x) or '0'}" + if op == SOPPOp.S_DELAY_ALU: + deps, skips = ['VALU_DEP_1','VALU_DEP_2','VALU_DEP_3','VALU_DEP_4','TRANS32_DEP_1','TRANS32_DEP_2','TRANS32_DEP_3','FMA_ACCUM_CYCLE_1','SALU_CYCLE_1','SALU_CYCLE_2','SALU_CYCLE_3'], ['SAME','NEXT','SKIP_1','SKIP_2','SKIP_3','SKIP_4'] + id0, skip, id1 = inst.simm16 & 0xf, (inst.simm16 >> 4) & 0x7, (inst.simm16 >> 7) & 0xf + dep = lambda v: deps[v-1] if 0 < v <= len(deps) else str(v) + p = [f"instid0({dep(id0)})" if id0 else "", f"instskip({skips[skip]})" if skip else "", f"instid1({dep(id1)})" if id1 else ""] + return f"s_delay_alu {' | '.join(x for x in p if x) or '0'}" + return f"{name} {inst.simm16}" if name.startswith(('s_cbranch', 's_branch')) else f"{name} 0x{inst.simm16:x}" - # VOP3SD: 3-source with scalar destination (v_div_scale_*, v_add_co_u32, v_mad_*64_*32, etc.) - if cls_name == 'VOP3SD': - vdst, sdst = unwrap(inst._values.get('vdst', 0)), unwrap(inst._values.get('sdst', 0)) - src0, src1, src2 = [unwrap(inst._values.get(f, 0)) for f in ('src0', 'src1', 'src2')] - neg, omod, clmp = unwrap(inst._values.get('neg', 0)), unwrap(inst._values.get('omod', 0)), unwrap(inst._values.get('clmp', 0)) - is_f64, is_mad64 = 'f64' in op_name, 'mad_i64_i32' in op_name or 'mad_u64_u32' in op_name - def fmt_neg(v, neg_bit, is_64=False): return f"-{_fmt_src64(v) if (is_64 or is_f64) else fmt_src(v)}" if neg_bit else _fmt_src64(v) if (is_64 or is_f64) else fmt_src(v) - srcs = [fmt_neg(src0, neg & 1), fmt_neg(src1, neg & 2), fmt_neg(src2, neg & 4, is_mad64)] - dst_str, sdst_str = _vreg(vdst, 2) if (is_f64 or is_mad64) else f"v{vdst}", _fmt_sdst(sdst, 1) - clamp_str, omod_str = " clamp" if clmp else "", {1: " mul:2", 2: " mul:4", 3: " div:2"}.get(omod, "") - is_2src = op_name in ('v_add_co_u32', 'v_sub_co_u32', 'v_subrev_co_u32') - suffix = "_e64" if op_name.startswith('v_') and 'co_' in op_name else "" - return f"{op_name}{suffix} {dst_str}, {sdst_str}, {', '.join(srcs[:2] if is_2src else srcs)}" + clamp_str + omod_str +def _disasm_smem(inst: SMEM) -> str: + op = SMEMOp(inst.op) + name = op.name.lower() + if op in (SMEMOp.S_GL1_INV, SMEMOp.S_DCACHE_INV): return name + off_s = f"{decode_src(inst.soffset)} offset:0x{inst.offset:x}" if inst.offset and inst.soffset != 124 else f"0x{inst.offset:x}" if inst.offset else decode_src(inst.soffset) + sbase_idx, sbase_count = inst.sbase * 2, 4 if (8 <= inst.op <= 12 or name == 's_atc_probe_buffer') else 2 + sbase_str = _fmt_src(sbase_idx, sbase_count) if sbase_count == 2 else _sreg(sbase_idx, sbase_count) if sbase_idx <= 105 else _reg("ttmp", sbase_idx - 108, sbase_count) + if name in ('s_atc_probe', 's_atc_probe_buffer'): return f"{name} {inst.sdata}, {sbase_str}, {off_s}" + width = {0:1, 1:2, 2:4, 3:8, 4:16, 8:1, 9:2, 10:4, 11:8, 12:16}.get(inst.op, 1) + return f"{name} {_fmt_sdst(inst.sdata, width)}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (inst.dlc, " dlc")) - # VOPD: dual-issue instructions - if cls_name == 'VOPD': - from extra.assembly.amd.autogen import rdna3 as autogen - opx, opy, vdstx, vdsty_enc = [unwrap(inst._values.get(f, 0)) for f in ('opx', 'opy', 'vdstx', 'vdsty')] - srcx0, vsrcx1, srcy0, vsrcy1 = [unwrap(inst._values.get(f, 0)) for f in ('srcx0', 'vsrcx1', 'srcy0', 'vsrcy1')] - literal = inst._literal if hasattr(inst, '_literal') and inst._literal else unwrap(inst._values.get('literal', None)) - vdsty = (vdsty_enc << 1) | ((vdstx & 1) ^ 1) # Decode vdsty - def fmt_vopd(op, vdst, src0, vsrc1, include_lit): - try: name = autogen.VOPDOp(op).name.lower() - except (ValueError, KeyError): name = f"op_{op}" - lit_str = f", 0x{literal:x}" if include_lit and literal is not None and ('fmaak' in name or 'fmamk' in name) else "" - return f"{name} v{vdst}, {fmt_src(src0)}{lit_str}" if 'mov' in name else f"{name} v{vdst}, {fmt_src(src0)}, v{vsrc1}{lit_str}" - # fmaak/fmamk: both X and Y can use the shared literal - x_needs_lit = 'fmaak' in autogen.VOPDOp(opx).name.lower() or 'fmamk' in autogen.VOPDOp(opx).name.lower() - y_needs_lit = 'fmaak' in autogen.VOPDOp(opy).name.lower() or 'fmamk' in autogen.VOPDOp(opy).name.lower() - return f"{fmt_vopd(opx, vdstx, srcx0, vsrcx1, x_needs_lit)} :: {fmt_vopd(opy, vdsty, srcy0, vsrcy1, y_needs_lit)}" +def _disasm_flat(inst: FLAT) -> str: + name = FLATOp(inst.op).name.lower() + seg = ['flat', 'scratch', 'global'][inst.seg] if inst.seg < 3 else 'flat' + instr = f"{seg}_{name.split('_', 1)[1] if '_' in name else name}" + off_val = inst.offset if seg == 'flat' else (inst.offset if inst.offset < 4096 else inst.offset - 8192) + suffix = name.split('_')[-1] + w = {'b32':1,'b64':2,'b96':3,'b128':4,'u8':1,'i8':1,'u16':1,'i16':1,'u32':1,'i32':1,'u64':2,'i64':2,'f32':1,'f64':2}.get(suffix, 1) + if 'cmpswap' in name: w *= 2 + if name.endswith('_x2') or 'x2' in suffix: w = max(w, 2) + mods = f"{f' offset:{off_val}' if off_val else ''}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}" + # saddr + if seg == 'flat' or inst.saddr == 0x7F: saddr_s = "" + elif inst.saddr == 124: saddr_s = ", off" + elif seg == 'scratch': saddr_s = f", {decode_src(inst.saddr)}" + elif inst.saddr in SPECIAL_PAIRS: saddr_s = f", {SPECIAL_PAIRS[inst.saddr]}" + elif 108 <= inst.saddr <= 123: saddr_s = f", {_reg('ttmp', inst.saddr - 108, 2)}" + else: saddr_s = f", {_sreg(inst.saddr, 2) if inst.saddr < 106 else decode_src(inst.saddr)}" + # addtid: no addr + if 'addtid' in name: return f"{instr} v{inst.data if 'store' in name else inst.vdst}{saddr_s}{mods}" + # addr width + addr_s = "off" if not inst.sve and seg == 'scratch' else _vreg(inst.addr, 1 if seg == 'scratch' or (inst.saddr not in (0x7F, 124)) else 2) + data_s, vdst_s = _vreg(inst.data, w), _vreg(inst.vdst, w // 2 if 'cmpswap' in name else w) + if 'atomic' in name: + return f"{instr} {vdst_s}, {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}" if inst.glc else f"{instr} {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}" + if 'store' in name: return f"{instr} {addr_s}, {data_s}{saddr_s}{mods}" + return f"{instr} {_vreg(inst.vdst, w)}, {addr_s}{saddr_s}{mods}" - # VOP3P: packed vector ops - if cls_name == 'VOP3P': - vdst, clmp = unwrap(inst._values.get('vdst', 0)), unwrap(inst._values.get('clmp', 0)) - src0, src1, src2 = [unwrap(inst._values.get(f, 0)) for f in ('src0', 'src1', 'src2')] - neg, neg_hi = unwrap(inst._values.get('neg', 0)), unwrap(inst._values.get('neg_hi', 0)) - opsel, opsel_hi, opsel_hi2 = unwrap(inst._values.get('opsel', 0)), unwrap(inst._values.get('opsel_hi', 0)), unwrap(inst._values.get('opsel_hi2', 0)) - is_wmma, is_3src = 'wmma' in op_name, any(x in op_name for x in ('fma', 'mad', 'dot', 'wmma')) - def fmt_bits(name, val, n): return f"{name}:[{','.join(str((val >> i) & 1) for i in range(n))}]" - # WMMA: f16/bf16 use 8-reg sources, iu8 uses 4-reg, iu4 uses 2-reg; all have 8-reg dst - if is_wmma: - src_cnt = 2 if 'iu4' in op_name else 4 if 'iu8' in op_name else 8 - src0_str, src1_str, src2_str = _fmt_src_n(src0, src_cnt), _fmt_src_n(src1, src_cnt), _fmt_src_n(src2, 8) - dst_str = _vreg(vdst, 8) - else: - src0_str, src1_str, src2_str = _fmt_src_n(src0, 1), _fmt_src_n(src1, 1), _fmt_src_n(src2, 1) - dst_str = f"v{vdst}" - n = 3 if is_3src else 2 - full_opsel_hi = opsel_hi | (opsel_hi2 << 2) - mods = [fmt_bits("op_sel", opsel, n)] if opsel else [] - if full_opsel_hi != (0b111 if is_3src else 0b11): mods.append(fmt_bits("op_sel_hi", full_opsel_hi, n)) - if neg: mods.append(fmt_bits("neg_lo", neg, n)) - if neg_hi: mods.append(fmt_bits("neg_hi", neg_hi, n)) - if clmp: mods.append("clamp") - mod_str = " " + " ".join(mods) if mods else "" - return f"{op_name} {dst_str}, {src0_str}, {src1_str}, {src2_str}{mod_str}" if is_3src else f"{op_name} {dst_str}, {src0_str}, {src1_str}{mod_str}" +def _disasm_ds(inst: DS) -> str: + op, name = DSOp(inst.op), DSOp(inst.op).name.lower() + gds = " gds" if inst.gds else "" + off = f" offset:{inst.offset0 | (inst.offset1 << 8)}" if inst.offset0 or inst.offset1 else "" + off2 = f" offset0:{inst.offset0} offset1:{inst.offset1}" if inst.offset0 or inst.offset1 else "" + w = 4 if '128' in name else 3 if '96' in name else 2 if (name.endswith('64') or 'gs_reg' in name) else 1 + d0, d1, dst, addr = _vreg(inst.data0, w), _vreg(inst.data1, w), _vreg(inst.vdst, w), f"v{inst.addr}" - # VINTERP: interpolation instructions - if cls_name == 'VINTERP': - vdst = unwrap(inst._values.get('vdst', 0)) - src0, src1, src2 = [unwrap(inst._values.get(f, 0)) for f in ('src0', 'src1', 'src2')] - neg, waitexp, clmp = unwrap(inst._values.get('neg', 0)), unwrap(inst._values.get('waitexp', 0)), unwrap(inst._values.get('clmp', 0)) - def fmt_neg_vi(v, neg_bit): return f"-{v}" if neg_bit else v - srcs = [fmt_neg_vi(f"v{s - 256}" if s >= 256 else fmt_src(s), neg & (1 << i)) for i, s in enumerate([src0, src1, src2])] - mods = [m for m in [f"wait_exp:{waitexp}" if waitexp else "", "clamp" if clmp else ""] if m] - return f"{op_name} v{vdst}, {', '.join(srcs)}" + (" " + " ".join(mods) if mods else "") + if op == DSOp.DS_NOP: return name + if op == DSOp.DS_BVH_STACK_RTN_B32: return f"{name} v{inst.vdst}, {addr}, v{inst.data0}, {_vreg(inst.data1, 4)}{off}{gds}" + if 'gws_sema' in name and op != DSOp.DS_GWS_SEMA_BR: return f"{name}{off}{gds}" + if 'gws_' in name: return f"{name} {addr}{off}{gds}" + if op in (DSOp.DS_CONSUME, DSOp.DS_APPEND): return f"{name} v{inst.vdst}{off}{gds}" + if 'gs_reg' in name: return f"{name} {_vreg(inst.vdst, 2)}, v{inst.data0}{off}{gds}" + if '2addr' in name: + if 'load' in name: return f"{name} {_vreg(inst.vdst, w*2)}, {addr}{off2}{gds}" + if 'store' in name and 'xchg' not in name: return f"{name} {addr}, {d0}, {d1}{off2}{gds}" + return f"{name} {_vreg(inst.vdst, w*2)}, {addr}, {d0}, {d1}{off2}{gds}" + if 'load' in name: return f"{name} v{inst.vdst}{off}{gds}" if 'addtid' in name else f"{name} {dst}, {addr}{off}{gds}" + if 'store' in name and not _has(name, 'cmp', 'xchg'): + return f"{name} v{inst.data0}{off}{gds}" if 'addtid' in name else f"{name} {addr}, {d0}{off}{gds}" + if 'swizzle' in name or op == DSOp.DS_ORDERED_COUNT: return f"{name} v{inst.vdst}, {addr}{off}{gds}" + if 'permute' in name: return f"{name} v{inst.vdst}, {addr}, v{inst.data0}{off}{gds}" + if 'condxchg' in name: return f"{name} {_vreg(inst.vdst, 2)}, {addr}, {_vreg(inst.data0, 2)}{off}{gds}" + if _has(name, 'cmpstore', 'mskor', 'wrap'): + return f"{name} {dst}, {addr}, {d0}, {d1}{off}{gds}" if '_rtn' in name else f"{name} {addr}, {d0}, {d1}{off}{gds}" + return f"{name} {dst}, {addr}, {d0}{off}{gds}" if '_rtn' in name else f"{name} {addr}, {d0}{off}{gds}" - # MUBUF/MTBUF helpers - def _buf_vaddr(vaddr, offen, idxen): return _vreg(vaddr, 2) if offen and idxen else f"v{vaddr}" if offen or idxen else "off" - def _buf_srsrc(srsrc): srsrc_base = srsrc * 4; return _reg("ttmp", srsrc_base - 108, 4) if 108 <= srsrc_base <= 123 else _sreg(srsrc_base, 4) +def _disasm_vop3(inst: VOP3) -> str: + op = VOP3SDOp(inst.op) if inst.op in VOP3SD_OPS else VOP3Op(inst.op) + name = op.name.lower() - # MUBUF: buffer load/store - if cls_name == 'MUBUF': - vdata, vaddr, srsrc, soffset = [unwrap(inst._values.get(f, 0)) for f in ('vdata', 'vaddr', 'srsrc', 'soffset')] - offset, offen, idxen = unwrap(inst._values.get('offset', 0)), unwrap(inst._values.get('offen', 0)), unwrap(inst._values.get('idxen', 0)) - glc, dlc, slc, tfe = [unwrap(inst._values.get(f, 0)) for f in ('glc', 'dlc', 'slc', 'tfe')] - if op_name in ('buffer_gl0_inv', 'buffer_gl1_inv'): return op_name - # Determine data width from op name - if 'd16' in op_name: width = 2 if any(x in op_name for x in ('xyz', 'xyzw')) else 1 - elif 'atomic' in op_name: - base_width = 2 if any(x in op_name for x in ('b64', 'u64', 'i64')) else 1 - width = base_width * 2 if 'cmpswap' in op_name else base_width - else: width = {'b32':1, 'b64':2, 'b96':3, 'b128':4, 'b16':1, 'x':1, 'xy':2, 'xyz':3, 'xyzw':4}.get(op_name.split('_')[-1], 1) - if tfe: width += 1 - mods = [m for m in ["offen" if offen else "", "idxen" if idxen else "", f"offset:{offset}" if offset else "", - "glc" if glc else "", "dlc" if dlc else "", "slc" if slc else "", "tfe" if tfe else ""] if m] - return f"{op_name} {_vreg(vdata, width)}, {_buf_vaddr(vaddr, offen, idxen)}, {_buf_srsrc(srsrc)}, {decode_src(soffset)}" + (" " + " ".join(mods) if mods else "") + # VOP3SD (shared encoding) + if inst.op in VOP3SD_OPS: + sdst = (inst.clmp << 7) | (inst.opsel << 3) | inst.abs + is64, mad64 = 'f64' in name, _has(name, 'mad_i64_i32', 'mad_u64_u32') + def src(v, neg, ext=False): s = _fmt_src(v, 2) if ext or is64 else inst.lit(v); return f"-{s}" if neg else s + s0, s1, s2 = src(inst.src0, inst.neg & 1), src(inst.src1, inst.neg & 2), src(inst.src2, inst.neg & 4, mad64) + dst = _vreg(inst.vdst, 2) if is64 or mad64 else f"v{inst.vdst}" + if op in (VOP3SDOp.V_ADD_CO_U32, VOP3SDOp.V_SUB_CO_U32, VOP3SDOp.V_SUBREV_CO_U32): return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {s0}, {s1}" + if op in (VOP3SDOp.V_ADD_CO_CI_U32, VOP3SDOp.V_SUB_CO_CI_U32, VOP3SDOp.V_SUBREV_CO_CI_U32): return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {s0}, {s1}, {s2}" + return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {s0}, {s1}, {s2}" + _omod(inst.omod) - # MTBUF: typed buffer load/store - if cls_name == 'MTBUF': - vdata, vaddr, srsrc, soffset = [unwrap(inst._values.get(f, 0)) for f in ('vdata', 'vaddr', 'srsrc', 'soffset')] - offset, tbuf_fmt, offen, idxen = [unwrap(inst._values.get(f, 0)) for f in ('offset', 'format', 'offen', 'idxen')] - glc, dlc, slc = [unwrap(inst._values.get(f, 0)) for f in ('glc', 'dlc', 'slc')] - mods = [f"format:{tbuf_fmt}"] + [m for m in ["idxen" if idxen else "", "offen" if offen else "", f"offset:{offset}" if offset else "", - "glc" if glc else "", "dlc" if dlc else "", "slc" if slc else ""] if m] - width = 2 if 'd16' in op_name and any(x in op_name for x in ('xyz', 'xyzw')) else 1 if 'd16' in op_name else {'x':1, 'xy':2, 'xyz':3, 'xyzw':4}.get(op_name.split('_')[-1], 1) - return f"{op_name} {_vreg(vdata, width)}, {_buf_vaddr(vaddr, offen, idxen)}, {_buf_srsrc(srsrc)}, {decode_src(soffset)} {' '.join(mods)}" + # Detect operand sizes + is64 = _is64(name) + is64_src, is64_dst = False, False + is16_d = is16_s = is16_s2 = False + if 'cvt_pk' in name: is16_s = name.endswith('16') + elif m := re.match(r'v_(?:cvt|frexp_exp)_([a-z0-9_]+)_([a-z0-9]+)', name): + is16_d, is16_s = _has(m.group(1), 'f16','i16','u16','b16'), _has(m.group(2), 'f16','i16','u16','b16') + is64_src, is64_dst = '64' in m.group(2), '64' in m.group(1) + is16_s2, is64 = is16_s, False + elif re.match(r'v_mad_[iu]32_[iu]16', name): is16_s = True + elif 'pack_b32' in name: is16_s = is16_s2 = True + else: is16_d = is16_s = is16_s2 = _is16(name) and not _has(name, 'dot2', 'pk_', 'sad', 'msad', 'qsad', 'mqsad') - # SOP1/SOP2/SOPC/SOPK - if cls_name in ('SOP1', 'SOP2', 'SOPC', 'SOPK'): - sizes = _parse_sop_sizes(op_name) - dst_cnt, src0_cnt = sizes[0], sizes[1] - src1_cnt = sizes[2] if len(sizes) > 2 else src0_cnt - if cls_name == 'SOP1': - sdst, ssrc0 = unwrap(inst._values.get('sdst', 0)), unwrap(inst._values.get('ssrc0', 0)) - if op_name == 's_getpc_b64': return f"{op_name} {_fmt_sdst(sdst, 2)}" - if op_name in ('s_setpc_b64', 's_rfe_b64'): return f"{op_name} {_fmt_ssrc(ssrc0, 2)}" - if op_name == 's_swappc_b64': return f"{op_name} {_fmt_sdst(sdst, 2)}, {_fmt_ssrc(ssrc0, 2)}" - if op_name in ('s_sendmsg_rtn_b32', 's_sendmsg_rtn_b64'): - return f"{op_name} {_fmt_sdst(sdst, 2 if 'b64' in op_name else 1)}, sendmsg({MSG_NAMES.get(ssrc0, str(ssrc0))})" - ssrc0_str = fmt_src(ssrc0) if src0_cnt == 1 else _fmt_ssrc(ssrc0, src0_cnt) - return f"{op_name} {_fmt_sdst(sdst, dst_cnt)}, {ssrc0_str}" - if cls_name == 'SOP2': - sdst, ssrc0, ssrc1 = [unwrap(inst._values.get(f, 0)) for f in ('sdst', 'ssrc0', 'ssrc1')] - ssrc0_str = fmt_src(ssrc0) if ssrc0 == 255 else _fmt_ssrc(ssrc0, src0_cnt) - ssrc1_str = fmt_src(ssrc1) if ssrc1 == 255 else _fmt_ssrc(ssrc1, src1_cnt) - return f"{op_name} {_fmt_sdst(sdst, dst_cnt)}, {ssrc0_str}, {ssrc1_str}" - if cls_name == 'SOPC': - return f"{op_name} {_fmt_ssrc(unwrap(inst._values.get('ssrc0', 0)), src0_cnt)}, {_fmt_ssrc(unwrap(inst._values.get('ssrc1', 0)), src1_cnt)}" - if cls_name == 'SOPK': - sdst, simm16 = unwrap(inst._values.get('sdst', 0)), unwrap(inst._values.get('simm16', 0)) - if op_name == 's_version': return f"{op_name} 0x{simm16:x}" - if op_name in ('s_setreg_b32', 's_getreg_b32'): - hwreg_id, hwreg_offset, hwreg_size = simm16 & 0x3f, (simm16 >> 6) & 0x1f, ((simm16 >> 11) & 0x1f) + 1 - hwreg_str = f"0x{simm16:x}" if hwreg_id in (16, 17) else f"hwreg({HWREG_NAMES.get(hwreg_id, str(hwreg_id))}, {hwreg_offset}, {hwreg_size})" - return f"{op_name} {hwreg_str}, {_fmt_sdst(sdst, 1)}" if op_name == 's_setreg_b32' else f"{op_name} {_fmt_sdst(sdst, 1)}, {hwreg_str}" - return f"{op_name} {_fmt_sdst(sdst, dst_cnt)}, 0x{simm16:x}" + # Source counts + shift64 = 'rev' in name and '64' in name and name.startswith('v_') + ldexp64 = op == VOP3Op.V_LDEXP_F64 + trig = op == VOP3Op.V_TRIG_PREOP_F64 + sad64, mqsad = _has(name, 'qsad_pk', 'mqsad_pk'), 'mqsad_u32' in name + s0n = 2 if ((is64 and not shift64) or sad64 or mqsad or is64_src) else 1 + s1n = 2 if (is64 and not _has(name, 'class') and not ldexp64 and not trig) else 1 + s2n = 4 if mqsad else 2 if (is64 or sad64) else 1 - # Generic fallback - def fmt_field(n, v): - v = unwrap(v) - if n in SRC_FIELDS: return fmt_src(v) if v != 255 else "0xff" - if n in ('sdst', 'vdst'): return f"{'s' if n == 'sdst' else 'v'}{v}" - return f"v{v}" if n == 'vsrc1' else f"0x{v:x}" if n == 'simm16' else str(v) - ops = [fmt_field(n, inst._values.get(n, 0)) for n in inst._fields if n not in ('encoding', 'op')] - return f"{op_name} {', '.join(ops)}" if ops else op_name + any_hi = inst.opsel != 0 + s0 = _vop3_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.opsel&1, s0n, is16_s, any_hi) + s1 = _vop3_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.opsel&2, s1n, is16_s, any_hi) + s2 = _vop3_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.opsel&4, s2n, is16_s2, any_hi) -# Assembler -SPECIAL_REGS = {'vcc_lo': RawImm(106), 'vcc_hi': RawImm(107), 'vcc': RawImm(106), 'null': RawImm(124), 'off': RawImm(124), 'm0': RawImm(125), - 'exec_lo': RawImm(126), 'exec_hi': RawImm(127), 'exec': RawImm(126), 'scc': RawImm(253), 'src_scc': RawImm(253)} -FLOAT_CONSTS = {'0.5': 0.5, '-0.5': -0.5, '1.0': 1.0, '-1.0': -1.0, '2.0': 2.0, '-2.0': -2.0, '4.0': 4.0, '-4.0': -4.0} + # Destination + dn = 4 if mqsad else 2 if (is64 or sad64 or is64_dst) else 1 + if op == VOP3Op.V_READLANE_B32: dst = _fmt_sdst(inst.vdst, 1) + elif dn > 1: dst = _vreg(inst.vdst, dn) + elif is16_d: dst = f"v{inst.vdst}.h" if (inst.opsel & 8) else f"v{inst.vdst}.l" if any_hi else f"v{inst.vdst}" + else: dst = f"v{inst.vdst}" + + cl, om = " clamp" if inst.clmp else "", _omod(inst.omod) + nonvgpr_opsel = (inst.src0 < 256 and (inst.opsel & 1)) or (inst.src1 < 256 and (inst.opsel & 2)) or (inst.src2 < 256 and (inst.opsel & 4)) + need_opsel = nonvgpr_opsel or (inst.opsel and not is16_s) + + if inst.op < 256: # VOPC + return f"{name}_e64 {s0}, {s1}" if name.startswith('v_cmpx') else f"{name}_e64 {_fmt_sdst(inst.vdst, 1)}, {s0}, {s1}" + if inst.op < 384: # VOP2 + os = _opsel_str(inst.opsel, 3, need_opsel, is16_d) if 'cndmask' in name else _opsel_str(inst.opsel, 2, need_opsel, is16_d) + return f"{name}_e64 {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if 'cndmask' in name else f"{name}_e64 {dst}, {s0}, {s1}{os}{cl}{om}" + if inst.op < 512: # VOP1 + return f"{name}_e64" if op in (VOP3Op.V_NOP, VOP3Op.V_PIPEFLUSH) else f"{name}_e64 {dst}, {s0}{_opsel_str(inst.opsel, 1, need_opsel, is16_d)}{cl}{om}" + # Native VOP3 + is3 = _has(name, 'fma', 'mad', 'min3', 'max3', 'med3', 'div_fix', 'div_fmas', 'sad', 'lerp', 'align', 'cube', 'bfe', 'bfi', + 'perm_b32', 'permlane', 'cndmask', 'xor3', 'or3', 'add3', 'lshl_or', 'and_or', 'lshl_add', 'add_lshl', 'xad', 'maxmin', 'minmax', 'dot2', 'cvt_pk_u8', 'mullit') + os = _opsel_str(inst.opsel, 3 if is3 else 2, need_opsel, is16_d) + return f"{name} {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if is3 else f"{name} {dst}, {s0}, {s1}{os}{cl}{om}" + +def _disasm_vop3sd(inst: VOP3SD) -> str: + op, name = VOP3SDOp(inst.op), VOP3SDOp(inst.op).name.lower() + is64, mad64 = 'f64' in name, _has(name, 'mad_i64_i32', 'mad_u64_u32') + def src(v, neg, ext=False): s = _fmt_src(v, 2) if ext or is64 else inst.lit(v); return f"-{s}" if neg else s + s0, s1, s2 = src(inst.src0, inst.neg & 1), src(inst.src1, inst.neg & 2), src(inst.src2, inst.neg & 4, mad64) + dst, is2src = _vreg(inst.vdst, 2) if is64 or mad64 else f"v{inst.vdst}", op in (VOP3SDOp.V_ADD_CO_U32, VOP3SDOp.V_SUB_CO_U32, VOP3SDOp.V_SUBREV_CO_U32) + suffix = "_e64" if name.startswith('v_') and 'co_' in name else "" + return f"{name}{suffix} {dst}, {_fmt_sdst(inst.sdst, 1)}, {s0}, {s1}{'' if is2src else f', {s2}'}{' clamp' if inst.clmp else ''}{_omod(inst.omod)}" + +def _disasm_vopd(inst: VOPD) -> str: + lit = inst._literal or inst.literal + vdst_y, nx, ny = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1), VOPDOp(inst.opx).name.lower(), VOPDOp(inst.opy).name.lower() + def half(n, vd, s0, vs1): return f"{n} v{vd}, {inst.lit(s0)}{f', 0x{lit:x}' if lit and _has(n, 'fmaak', 'fmamk') else ''}" if 'mov' in n else f"{n} v{vd}, {inst.lit(s0)}, v{vs1}{f', 0x{lit:x}' if lit and _has(n, 'fmaak', 'fmamk') else ''}" + return f"{half(nx, inst.vdstx, inst.srcx0, inst.vsrcx1)} :: {half(ny, vdst_y, inst.srcy0, inst.vsrcy1)}" + +def _disasm_vop3p(inst: VOP3P) -> str: + name = VOP3POp(inst.op).name.lower() + is_wmma, is_3src, is_fma_mix = 'wmma' in name, _has(name, 'fma', 'mad', 'dot', 'wmma'), 'fma_mix' in name + if is_wmma: + sc = 2 if 'iu4' in name else 4 if 'iu8' in name else 8 + src0, src1, src2, dst = _fmt_src(inst.src0, sc), _fmt_src(inst.src1, sc), _fmt_src(inst.src2, 8), _vreg(inst.vdst, 8) + else: src0, src1, src2, dst = _fmt_src(inst.src0, 1), _fmt_src(inst.src1, 1), _fmt_src(inst.src2, 1), f"v{inst.vdst}" + n, opsel_hi = 3 if is_3src else 2, inst.opsel_hi | (inst.opsel_hi2 << 2) + if is_fma_mix: + def m(s, neg, abs_): return f"-{f'|{s}|' if abs_ else s}" if neg else (f"|{s}|" if abs_ else s) + src0, src1, src2 = m(src0, inst.neg & 1, inst.neg_hi & 1), m(src1, inst.neg & 2, inst.neg_hi & 2), m(src2, inst.neg & 4, inst.neg_hi & 4) + mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi else []) + (["clamp"] if inst.clmp else []) + else: + mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != (7 if is_3src else 3) else []) + \ + ([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if inst.clmp else []) + return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if is_3src else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}" + +def _disasm_buf(inst: MUBUF | MTBUF) -> str: + op = MTBUFOp(inst.op) if isinstance(inst, MTBUF) else MUBUFOp(inst.op) + name = op.name.lower() + if op in (MUBUFOp.BUFFER_GL0_INV, MUBUFOp.BUFFER_GL1_INV): return name + w = (2 if _has(name, 'xyz', 'xyzw') else 1) if 'd16' in name else \ + ((2 if _has(name, 'b64', 'u64', 'i64') else 1) * (2 if 'cmpswap' in name else 1)) if 'atomic' in name else \ + {'b32':1,'b64':2,'b96':3,'b128':4,'b16':1,'x':1,'xy':2,'xyz':3,'xyzw':4}.get(name.split('_')[-1], 1) + if inst.tfe: w += 1 + vaddr = _vreg(inst.vaddr, 2) if inst.offen and inst.idxen else f"v{inst.vaddr}" if inst.offen or inst.idxen else "off" + srsrc = _reg("ttmp", inst.srsrc*4 - 108, 4) if 108 <= inst.srsrc*4 <= 123 else _sreg(inst.srsrc*4, 4) + mods = ([f"format:{inst.format}"] if isinstance(inst, MTBUF) else []) + [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.glc,"glc"),(inst.dlc,"dlc"),(inst.slc,"slc"),(inst.tfe,"tfe")] if c] + return f"{name} {_vreg(inst.vdata, w)}, {vaddr}, {srsrc}, {decode_src(inst.soffset)}{' ' + ' '.join(mods) if mods else ''}" + +def _mimg_vaddr_width(name: str, dim: int, a16: bool) -> int: + """Calculate vaddr register count for MIMG sample/gather operations.""" + # 1d,2d,3d,cube,1d_arr,2d_arr,2d_msaa,2d_msaa_arr + base = [1, 2, 3, 3, 2, 3, 3, 4][dim] # address coords + grad = [1, 2, 3, 2, 1, 2, 2, 2][dim] # gradient coords (for derivatives) + if 'get_resinfo' in name: return 1 # only mip level + packed, unpacked = 0, 0 + if '_mip' in name: packed += 1 + elif 'sample' in name or 'gather' in name: + if '_o' in name: unpacked += 1 # offset + if re.search(r'_c(_|$)', name): unpacked += 1 # compare (not _cl) + if '_d' in name: unpacked += (grad + 1) & ~1 if '_g16' in name else grad*2 # derivatives + if '_b' in name: unpacked += 1 # bias + if '_l' in name and '_cl' not in name and '_lz' not in name: packed += 1 # LOD + if '_cl' in name: packed += 1 # clamp + return (base + packed + 1) // 2 + unpacked if a16 else base + packed + unpacked + +def _disasm_mimg(inst: MIMG) -> str: + name = MIMGOp(inst.op).name.lower() + srsrc_base = inst.srsrc * 4 + srsrc_str = _reg("ttmp", srsrc_base - 108, 8) if 108 <= srsrc_base <= 123 else _sreg(srsrc_base, 8) + # BVH intersect ray: special case with 4 SGPR srsrc + if 'bvh' in name: + vaddr = (9 if '64' in name else 8) if inst.a16 else (12 if '64' in name else 11) + srsrc = _reg("ttmp", srsrc_base - 108, 4) if 108 <= srsrc_base <= 123 else _sreg(srsrc_base, 4) + return f"{name} {_vreg(inst.vdata, 4)}, {_vreg(inst.vaddr, vaddr)}, {srsrc}{' a16' if inst.a16 else ''}" + # vdata width from dmask (gather4/msaa_load always 4), d16 packs, tfe adds 1 + vdata = 4 if 'gather4' in name or 'msaa_load' in name else (bin(inst.dmask).count('1') or 1) + if inst.d16: vdata = (vdata + 1) // 2 + if inst.tfe: vdata += 1 + # vaddr width + dim_names = ['1d', '2d', '3d', 'cube', '1d_array', '2d_array', '2d_msaa', '2d_msaa_array'] + dim = dim_names[inst.dim] if inst.dim < len(dim_names) else f"dim_{inst.dim}" + vaddr = _mimg_vaddr_width(name, inst.dim, inst.a16) + vaddr_str = f"v{inst.vaddr}" if vaddr == 1 else _vreg(inst.vaddr, vaddr) + # modifiers + mods = [f"dmask:0x{inst.dmask:x}"] if inst.dmask and (inst.dmask != 15 or 'atomic' in name) else [] + mods.append(f"dim:SQ_RSRC_IMG_{dim.upper()}") + for flag, mod in [(inst.unrm,"unorm"),(inst.glc,"glc"),(inst.slc,"slc"),(inst.dlc,"dlc"),(inst.r128,"r128"), + (inst.a16,"a16"),(inst.tfe,"tfe"),(inst.lwe,"lwe"),(inst.d16,"d16")]: + if flag: mods.append(mod) + # ssamp for sample/gather/get_lod + ssamp_str = "" + if 'sample' in name or 'gather' in name or 'get_lod' in name: + ssamp_base = inst.ssamp * 4 + ssamp_str = ", " + (_reg("ttmp", ssamp_base - 108, 4) if 108 <= ssamp_base <= 123 else _sreg(ssamp_base, 4)) + return f"{name} {_vreg(inst.vdata, vdata)}, {vaddr_str}, {srsrc_str}{ssamp_str} {' '.join(mods)}" + +def _sop_widths(name: str) -> tuple[int, int, int]: + """Return (dst_width, src0_width, src1_width) in register count for SOP instructions.""" + if name in ('s_bitset0_b64', 's_bitset1_b64', 's_bfm_b64'): return 2, 1, 1 + if name in ('s_lshl_b64', 's_lshr_b64', 's_ashr_i64', 's_bfe_u64', 's_bfe_i64'): return 2, 2, 1 + if name in ('s_bitcmp0_b64', 's_bitcmp1_b64'): return 1, 2, 1 + if m := re.search(r'_(b|i|u)(32|64)_(b|i|u)(32|64)$', name): return 2 if m.group(2) == '64' else 1, 2 if m.group(4) == '64' else 1, 1 + if m := re.search(r'_(b|i|u)(32|64)$', name): sz = 2 if m.group(2) == '64' else 1; return sz, sz, sz + return 1, 1, 1 + +def _disasm_sop1(inst: SOP1) -> str: + op, name = SOP1Op(inst.op), SOP1Op(inst.op).name.lower() + if op == SOP1Op.S_GETPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}" + if op in (SOP1Op.S_SETPC_B64, SOP1Op.S_RFE_B64): return f"{name} {_fmt_src(inst.ssrc0, 2)}" + if op == SOP1Op.S_SWAPPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}, {_fmt_src(inst.ssrc0, 2)}" + if op in (SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64): return f"{name} {_fmt_sdst(inst.sdst, 2 if 'b64' in name else 1)}, sendmsg({MSG.get(inst.ssrc0, str(inst.ssrc0))})" + dn, s0n, _ = _sop_widths(name) + return f"{name} {_fmt_sdst(inst.sdst, dn)}, {inst.lit(inst.ssrc0) if s0n == 1 else _fmt_src(inst.ssrc0, s0n)}" + +def _disasm_sop2(inst: SOP2) -> str: + name = SOP2Op(inst.op).name.lower() + dn, s0n, s1n = _sop_widths(name) + return f"{name} {_fmt_sdst(inst.sdst, dn)}, {inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, s0n)}, {inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, s1n)}" + +def _disasm_sopc(inst: SOPC) -> str: + name = SOPCOp(inst.op).name.lower() + _, s0n, s1n = _sop_widths(name) + return f"{name} {_fmt_src(inst.ssrc0, s0n)}, {_fmt_src(inst.ssrc1, s1n)}" + +def _disasm_sopk(inst: SOPK) -> str: + op, name = SOPKOp(inst.op), SOPKOp(inst.op).name.lower() + if op == SOPKOp.S_VERSION: return f"{name} 0x{inst.simm16:x}" + if op in (SOPKOp.S_SETREG_B32, SOPKOp.S_GETREG_B32): + hid, hoff, hsz = inst.simm16 & 0x3f, (inst.simm16 >> 6) & 0x1f, ((inst.simm16 >> 11) & 0x1f) + 1 + hs = f"0x{inst.simm16:x}" if hid in (16, 17) else f"hwreg({HWREG.get(hid, str(hid))}, {hoff}, {hsz})" + return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1)}" if op == SOPKOp.S_SETREG_B32 else f"{name} {_fmt_sdst(inst.sdst, 1)}, {hs}" + dn, _, _ = _sop_widths(name) + return f"{name} {_fmt_sdst(inst.sdst, dn)}, 0x{inst.simm16:x}" + +def _disasm_vinterp(inst: VINTERP) -> str: + name = VINTERPOp(inst.op).name.lower() + src0 = f"-{inst.lit(inst.src0)}" if inst.neg & 1 else inst.lit(inst.src0) + src1 = f"-{inst.lit(inst.src1)}" if inst.neg & 2 else inst.lit(inst.src1) + src2 = f"-{inst.lit(inst.src2)}" if inst.neg & 4 else inst.lit(inst.src2) + mods = _mods((inst.waitexp, f"wait_exp:{inst.waitexp}"), (inst.clmp, "clamp")) + return f"{name} v{inst.vdst}, {src0}, {src1}, {src2}" + (" " + mods if mods else "") + +def _disasm_generic(inst: Inst) -> str: + name = f"op_{inst.op}" + def format_field(field_name, val): + val = unwrap(val) + if field_name in SRC_FIELDS: return inst.lit(val) if val != 255 else "0xff" + return f"{'s' if field_name == 'sdst' else 'v'}{val}" if field_name in ('sdst', 'vdst') else f"v{val}" if field_name == 'vsrc1' else f"0x{val:x}" if field_name == 'simm16' else str(val) + operands = [format_field(field_name, inst._values.get(field_name, 0)) for field_name in inst._fields if field_name not in ('encoding', 'op')] + return f"{name} {', '.join(operands)}" if operands else name + +DISASM_HANDLERS = {VOP1: _disasm_vop1, VOP2: _disasm_vop2, VOPC: _disasm_vopc, VOP3: _disasm_vop3, VOP3SD: _disasm_vop3sd, VOPD: _disasm_vopd, VOP3P: _disasm_vop3p, + VINTERP: _disasm_vinterp, SOPP: _disasm_sopp, SMEM: _disasm_smem, DS: _disasm_ds, FLAT: _disasm_flat, MUBUF: _disasm_buf, MTBUF: _disasm_buf, + MIMG: _disasm_mimg, SOP1: _disasm_sop1, SOP2: _disasm_sop2, SOPC: _disasm_sopc, SOPK: _disasm_sopk} + +def disasm(inst: Inst) -> str: return DISASM_HANDLERS.get(type(inst), _disasm_generic)(inst) + +# ═══════════════════════════════════════════════════════════════════════════════ +# ASSEMBLER +# ═══════════════════════════════════════════════════════════════════════════════ + +SPEC_REGS = {'vcc_lo': RawImm(106), 'vcc_hi': RawImm(107), 'vcc': RawImm(106), 'null': RawImm(124), 'off': RawImm(124), 'm0': RawImm(125), + 'exec_lo': RawImm(126), 'exec_hi': RawImm(127), 'exec': RawImm(126), 'scc': RawImm(253), 'src_scc': RawImm(253)} +FLOATS = {'0.5': 0.5, '-0.5': -0.5, '1.0': 1.0, '-1.0': -1.0, '2.0': 2.0, '-2.0': -2.0, '4.0': 4.0, '-4.0': -4.0} REG_MAP: dict[str, _RegFactory] = {'s': s, 'v': v, 't': ttmp, 'ttmp': ttmp} - -def parse_operand(op: str) -> tuple: - op = op.strip().lower() - neg = op.startswith('-') and not op[1:2].isdigit(); op = op[1:] if neg else op - abs_ = op.startswith('|') and op.endswith('|') or op.startswith('abs(') and op.endswith(')') - op = op[1:-1] if op.startswith('|') else op[4:-1] if op.startswith('abs(') else op - hi_half = op.endswith('.h') - op = re.sub(r'\.[lh]$', '', op) - if op in FLOAT_CONSTS: return (FLOAT_CONSTS[op], neg, abs_, hi_half) - if re.match(r'^-?\d+$', op): return (int(op), neg, abs_, hi_half) - if m := re.match(r'^-?0x([0-9a-f]+)$', op): - v = -int(m.group(1), 16) if op.startswith('-') else int(m.group(1), 16) - return (v, neg, abs_, hi_half) - if op in SPECIAL_REGS: return (SPECIAL_REGS[op], neg, abs_, hi_half) - if op == 'lit': return (RawImm(255), neg, abs_, hi_half) # literal marker (actual value comes from literal word) - if m := re.match(r'^([svt](?:tmp)?)\[(\d+):(\d+)\]$', op): return (REG_MAP[m.group(1)][int(m.group(2)):int(m.group(3))], neg, abs_, hi_half) - if m := re.match(r'^([svt](?:tmp)?)(\d+)$', op): - reg = REG_MAP[m.group(1)][int(m.group(2))] - reg.hi = hi_half - return (reg, neg, abs_, hi_half) - # hwreg(name, offset, size) or hwreg(name) -> simm16 encoding - if m := re.match(r'^hwreg\((\w+)(?:,\s*(\d+),\s*(\d+))?\)$', op): - name_str = m.group(1).lower() - hwreg_id = HWREG_IDS.get(name_str, int(name_str) if name_str.isdigit() else None) - if hwreg_id is None: raise ValueError(f"unknown hwreg name: {name_str}") - offset, size = int(m.group(2)) if m.group(2) else 0, int(m.group(3)) if m.group(3) else 32 - return (((size - 1) << 11) | (offset << 6) | hwreg_id, neg, abs_, hi_half) - raise ValueError(f"cannot parse operand: {op}") - SMEM_OPS = {'s_load_b32', 's_load_b64', 's_load_b128', 's_load_b256', 's_load_b512', 's_buffer_load_b32', 's_buffer_load_b64', 's_buffer_load_b128', 's_buffer_load_b256', 's_buffer_load_b512'} -SOP1_SRC_ONLY = {'s_setpc_b64', 's_rfe_b64'} -SOP1_MSG_IMM = {'s_sendmsg_rtn_b32', 's_sendmsg_rtn_b64'} -SOPK_IMM_ONLY = {'s_version'} -SOPK_IMM_FIRST = {'s_setreg_b32'} -SOPK_UNSUPPORTED = {'s_setreg_imm32_b32'} +SPEC_DSL = {'vcc_lo': 'VCC_LO', 'vcc_hi': 'VCC_HI', 'vcc': 'VCC_LO', 'null': 'NULL', 'off': 'OFF', 'm0': 'M0', + 'exec_lo': 'EXEC_LO', 'exec_hi': 'EXEC_HI', 'exec': 'EXEC_LO', 'scc': 'SCC', 'src_scc': 'SCC'} -def _operand_to_dsl(op: str) -> str: - """Transform a single operand from LLVM assembly syntax to DSL expression string.""" +def _op2dsl(op: str) -> str: op = op.strip() - # Handle negation prefix - neg = False - if op.startswith('-') and not (op[1:2].isdigit() or (len(op) > 2 and op[1] == '0' and op[2] in 'xX')): - neg, op = True, op[1:] - # Handle abs modifier: |x| or abs(x) - abs_ = False - if op.startswith('|') and op.endswith('|'): - abs_, op = True, op[1:-1] - elif op.startswith('abs(') and op.endswith(')'): - abs_, op = True, op[4:-1] - # Handle .h/.l suffix for 16-bit ops - hi_suffix = "" - if op.endswith('.h'): hi_suffix, op = ".h", op[:-2] - elif op.endswith('.l'): hi_suffix, op = ".l", op[:-2] - op_lower = op.lower() + neg = op.startswith('-') and not (op[1:2].isdigit() or (len(op) > 2 and op[1] == '0' and op[2] in 'xX')) + if neg: op = op[1:] + abs_ = (op.startswith('|') and op.endswith('|')) or (op.startswith('abs(') and op.endswith(')')) + if abs_: op = op[1:-1] if op.startswith('|') else op[4:-1] + hi = ".h" if op.endswith('.h') else ".l" if op.endswith('.l') else "" + if hi: op = op[:-2] + lo = op.lower() + def wrap(b): return f"{'-' if neg else ''}abs({b}){hi}" if abs_ else f"-{b}{hi}" if neg else f"{b}{hi}" + if lo in SPEC_DSL: return wrap(SPEC_DSL[lo]) + if op in FLOATS: return wrap(op) + rp = {'s': 's', 'v': 'v', 't': 'ttmp', 'ttmp': 'ttmp'} + if m := re.match(r'^([svt](?:tmp)?)\[(\d+):(\d+)\]$', lo): return wrap(f"{rp[m.group(1)]}[{m.group(2)}:{m.group(3)}]") + if m := re.match(r'^([svt](?:tmp)?)(\d+)$', lo): return wrap(f"{rp[m.group(1)]}[{m.group(2)}]") + if re.match(r'^-?\d+$|^-?0x[0-9a-fA-F]+$', op): return f"SrcMod({op}, neg={neg}, abs_={abs_})" if neg or abs_ else op + return wrap(op) - # Helper to apply modifiers - def apply_mods(base: str) -> str: - if not neg and not abs_: return f"{base}{hi_suffix}" - if abs_: return f"{'-' if neg else ''}abs({base}){hi_suffix}" - return f"-{base}{hi_suffix}" +def _parse_ops(s: str) -> list[str]: + ops, cur, depth, pipe = [], "", 0, False + for c in s: + if c in '[(': depth += 1 + elif c in '])': depth -= 1 + elif c == '|': pipe = not pipe + if c == ',' and depth == 0 and not pipe: ops.append(cur.strip()); cur = "" + else: cur += c + if cur.strip(): ops.append(cur.strip()) + return ops - # Special registers - vcc maps to VCC_LO (64-bit alias) - special_map = {'vcc_lo': 'VCC_LO', 'vcc_hi': 'VCC_HI', 'vcc': 'VCC_LO', 'null': 'NULL', 'off': 'OFF', - 'm0': 'M0', 'exec_lo': 'EXEC_LO', 'exec_hi': 'EXEC_HI', 'exec': 'EXEC_LO', 'scc': 'SCC', - 'src_scc': 'SCC'} - if op_lower in special_map: return apply_mods(special_map[op_lower]) - # Float constants - float_map = {'0.5': '0.5', '-0.5': '-0.5', '1.0': '1.0', '-1.0': '-1.0', '2.0': '2.0', '-2.0': '-2.0', '4.0': '4.0', '-4.0': '-4.0'} - if op in float_map: return apply_mods(float_map[op]) - # Register range: v[0:3], s[4:7] - if m := re.match(r'^([svt](?:tmp)?)\[(\d+):(\d+)\]$', op_lower): - prefix = {'s': 's', 'v': 'v', 't': 'ttmp', 'ttmp': 'ttmp'}[m.group(1)] - return apply_mods(f"{prefix}[{m.group(2)}:{m.group(3)}]") - # Single register: v0, s1, ttmp5 - if m := re.match(r'^([svt](?:tmp)?)(\d+)$', op_lower): - prefix = {'s': 's', 'v': 'v', 't': 'ttmp', 'ttmp': 'ttmp'}[m.group(1)] - return apply_mods(f"{prefix}[{m.group(2)}]") - # Integer literals (decimal or hex) - use SrcMod wrapper when modifiers present - if re.match(r'^-?\d+$', op) or re.match(r'^-?0x([0-9a-fA-F]+)$', op): - if neg or abs_: - return f"SrcMod({op}, neg={neg}, abs_={abs_})" - return op - # hwreg(name, offset, size) -> pass through - if op_lower.startswith('hwreg('): return apply_mods(op) - # sendmsg(...) -> pass through - if op_lower.startswith('sendmsg('): return apply_mods(op) - # Fallback: return as-is - return apply_mods(op) - -def _parse_operands(op_str: str) -> list[str]: - """Parse comma-separated operands, respecting brackets and pipes.""" - operands, current, depth, in_pipe = [], "", 0, False - for ch in op_str: - if ch in '[(': depth += 1 - elif ch in '])': depth -= 1 - elif ch == '|': in_pipe = not in_pipe - if ch == ',' and depth == 0 and not in_pipe: - operands.append(current.strip()) - current = "" - else: - current += ch - if current.strip(): operands.append(current.strip()) - return operands - -def _unwrap_dsl(s: str) -> str: - """Unwrap a DSL expression to get the raw value for literals.""" - if re.match(r'^-?\d+$', s): return s - if re.match(r'^-?0x[0-9a-fA-F]+$', s): return s - return s +def _extract(text: str, pat: str, flags=re.I): + if m := re.search(pat, text, flags): return m, text[:m.start()] + text[m.end():] + return None, text def get_dsl(text: str) -> str: - """Transform LLVM-style assembly instruction to Python DSL expression string.""" - text = text.strip() - # Extract and remove trailing modifiers (must happen before operand parsing) - kwargs = [] - # Extract mul:N and div:N modifiers (omod) - omod_val = 0 - if m := re.search(r'\s+mul:2(?:\s|$)', text, re.I): - omod_val = 1; text = text[:m.start()] + text[m.end():] - elif m := re.search(r'\s+mul:4(?:\s|$)', text, re.I): - omod_val = 2; text = text[:m.start()] + text[m.end():] - elif m := re.search(r'\s+div:2(?:\s|$)', text, re.I): - omod_val = 3; text = text[:m.start()] + text[m.end():] - if omod_val: kwargs.append(f'omod={omod_val}') - # Extract clamp modifier - if m := re.search(r'\s+clamp(?:\s|$)', text, re.I): - kwargs.append('clmp=1') - text = text[:m.start()] + text[m.end():] - # Extract op_sel:[...] modifier - interpretation depends on format: - # VOP3: [src0, src1, dst] or [src0, src1, src2, dst] -> bits 0, 1, (2), 3 - # VOP3P/WMMA: [src0, src1, src2] -> bits 0, 1, 2 (no dst bit, 3-source ops) - opsel_explicit = None - if m := re.search(r'\s+op_sel:\[([^\]]+)\]', text, re.I): - bits = [int(x.strip()) for x in m.group(1).split(',')] - # Check if this is a VOP3P instruction (v_pk_*, v_wmma_*, v_dot*) - mnemonic = text.split()[0].lower() - is_vop3p = mnemonic.startswith(('v_pk_', 'v_wmma_', 'v_dot')) - if len(bits) == 3: - if is_vop3p: - # VOP3P: [src0, src1, src2] -> bits 0, 1, 2 - opsel_explicit = bits[0] | (bits[1] << 1) | (bits[2] << 2) - else: - # VOP3: [src0, src1, dst] -> bits 0, 1, 3 - opsel_explicit = bits[0] | (bits[1] << 1) | (bits[2] << 3) - else: - opsel_explicit = sum(b << i for i, b in enumerate(bits)) - text = text[:m.start()] + text[m.end():] - if m := re.search(r'\s+wait_exp:(\d+)', text, re.I): - kwargs.append(f'waitexp={m.group(1)}') - text = text[:m.start()] + text[m.end():] - # Extract offset:N for FLAT/GLOBAL/SCRATCH/SMEM (can be hex or decimal) - offset_val = None - if m := re.search(r'\s+offset:(0x[0-9a-fA-F]+|-?\d+)', text, re.I): - offset_val = m.group(1) - text = text[:m.start()] + text[m.end():] - # Extract dlc modifier (before glc to avoid partial match issues) - dlc_val = None - if m := re.search(r'\s+dlc(?:\s|$)', text, re.I): - dlc_val = 1 - text = text[:m.start()] + text[m.end():] - # Extract glc modifier - glc_val = None - if m := re.search(r'\s+glc(?:\s|$)', text, re.I): - glc_val = 1 - text = text[:m.start()] + text[m.end():] - # Extract neg_lo:[...] and neg_hi:[...] for VOP3P - neg_lo_val = None - if m := re.search(r'\s+neg_lo:\[([^\]]+)\]', text, re.I): - bits = [int(x.strip()) for x in m.group(1).split(',')] - neg_lo_val = sum(b << i for i, b in enumerate(bits)) - text = text[:m.start()] + text[m.end():] - neg_hi_val = None - if m := re.search(r'\s+neg_hi:\[([^\]]+)\]', text, re.I): - bits = [int(x.strip()) for x in m.group(1).split(',')] - neg_hi_val = sum(b << i for i, b in enumerate(bits)) - text = text[:m.start()] + text[m.end():] + text, kw = text.strip(), [] + # Extract modifiers + for pat, val in [(r'\s+mul:2(?:\s|$)', 1), (r'\s+mul:4(?:\s|$)', 2), (r'\s+div:2(?:\s|$)', 3)]: + if (m := _extract(text, pat))[0]: kw.append(f'omod={val}'); text = m[1]; break + if (m := _extract(text, r'\s+clamp(?:\s|$)'))[0]: kw.append('clmp=1'); text = m[1] + opsel, m, text = None, *_extract(text, r'\s+op_sel:\[([^\]]+)\]') + if m: + bits, mn = [int(x.strip()) for x in m.group(1).split(',')], text.split()[0].lower() + is3p = mn.startswith(('v_pk_', 'v_wmma_', 'v_dot')) + opsel = (bits[0] | (bits[1] << 1) | (bits[2] << 2)) if len(bits) == 3 and is3p else \ + (bits[0] | (bits[1] << 1) | (bits[2] << 3)) if len(bits) == 3 else sum(b << i for i, b in enumerate(bits)) + m, text = _extract(text, r'\s+wait_exp:(\d+)'); waitexp = m.group(1) if m else None + m, text = _extract(text, r'\s+offset:(0x[0-9a-fA-F]+|-?\d+)'); off_val = m.group(1) if m else None + m, text = _extract(text, r'\s+dlc(?:\s|$)'); dlc = 1 if m else None + m, text = _extract(text, r'\s+glc(?:\s|$)'); glc = 1 if m else None + m, text = _extract(text, r'\s+slc(?:\s|$)'); slc = 1 if m else None + m, text = _extract(text, r'\s+neg_lo:\[([^\]]+)\]'); neg_lo = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None + m, text = _extract(text, r'\s+neg_hi:\[([^\]]+)\]'); neg_hi = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None + if waitexp: kw.append(f'waitexp={waitexp}') + parts = text.replace(',', ' ').split() if not parts: raise ValueError("empty instruction") - mnemonic, op_str = parts[0].lower(), text[len(parts[0]):].strip() - # Handle s_waitcnt specially - if mnemonic == 's_waitcnt': - vmcnt, expcnt, lgkmcnt = 0x3f, 0x7, 0x3f - for part in op_str.replace(',', ' ').split(): - if m := re.match(r'vmcnt\((\d+)\)', part): vmcnt = int(m.group(1)) - elif m := re.match(r'expcnt\((\d+)\)', part): expcnt = int(m.group(1)) - elif m := re.match(r'lgkmcnt\((\d+)\)', part): lgkmcnt = int(m.group(1)) - elif re.match(r'^0x[0-9a-f]+$|^\d+$', part): return f"s_waitcnt(simm16={int(part, 0)})" - wc = waitcnt(vmcnt, expcnt, lgkmcnt) - return f"s_waitcnt(simm16={wc})" - # Handle VOPD dual-issue: opx dst, src :: opy dst, src + mn, op_str = parts[0].lower(), text[len(parts[0]):].strip() + ops, args = _parse_ops(op_str), [_op2dsl(o) for o in _parse_ops(op_str)] + + # s_waitcnt + if mn == 's_waitcnt': + vm, exp, lgkm = 0x3f, 0x7, 0x3f + for p in op_str.replace(',', ' ').split(): + if m := re.match(r'vmcnt\((\d+)\)', p): vm = int(m.group(1)) + elif m := re.match(r'expcnt\((\d+)\)', p): exp = int(m.group(1)) + elif m := re.match(r'lgkmcnt\((\d+)\)', p): lgkm = int(m.group(1)) + elif re.match(r'^0x[0-9a-f]+$|^\d+$', p): return f"s_waitcnt(simm16={int(p, 0)})" + return f"s_waitcnt(simm16={waitcnt(vm, exp, lgkm)})" + + # VOPD if '::' in text: - x_part, y_part = text.split('::') - x_parts, y_parts = x_part.strip().replace(',', ' ').split(), y_part.strip().replace(',', ' ').split() - opx_name, opy_name = x_parts[0].upper(), y_parts[0].upper() - x_ops = [_operand_to_dsl(p) for p in x_parts[1:]] - y_ops = [_operand_to_dsl(p) for p in y_parts[1:]] - vdstx, srcx0 = x_ops[0], x_ops[1] if len(x_ops) > 1 else '0' - vsrcx1 = x_ops[2] if len(x_ops) > 2 else 'v[0]' - vdsty, srcy0 = y_ops[0], y_ops[1] if len(y_ops) > 1 else '0' - vsrcy1 = y_ops[2] if len(y_ops) > 2 else 'v[0]' - lit = None - if 'fmaak' in opx_name.lower() and len(x_ops) > 3: lit = x_ops[3] - elif 'fmamk' in opx_name.lower() and len(x_ops) > 3: lit, vsrcx1 = x_ops[2], x_ops[3] - elif 'fmaak' in opy_name.lower() and len(y_ops) > 3: lit = y_ops[3] - elif 'fmamk' in opy_name.lower() and len(y_ops) > 3: lit, vsrcy1 = y_ops[2], y_ops[3] - lit_str = f", literal={lit}" if lit else "" - return f"VOPD(VOPDOp.{opx_name}, VOPDOp.{opy_name}, vdstx={vdstx}, vdsty={vdsty}, srcx0={srcx0}, vsrcx1={vsrcx1}, srcy0={srcy0}, vsrcy1={vsrcy1}{lit_str})" - operands = _parse_operands(op_str) - dsl_args = [_operand_to_dsl(op) for op in operands] - # Handle special instructions - if mnemonic in SOPK_UNSUPPORTED: raise ValueError(f"unsupported instruction: {mnemonic}") - if mnemonic in SOP1_SRC_ONLY: return f"{mnemonic}(ssrc0={dsl_args[0]})" - if mnemonic in SOP1_MSG_IMM: return f"{mnemonic}(sdst={dsl_args[0]}, ssrc0=RawImm({_unwrap_dsl(dsl_args[1])}))" - if mnemonic in SOPK_IMM_ONLY: return f"{mnemonic}(simm16={dsl_args[0]})" - if mnemonic in SOPK_IMM_FIRST: return f"{mnemonic}(simm16={dsl_args[0]}, sdst={dsl_args[1]})" - # SMEM with immediate offset (offset in operand[2] or offset: modifier) - if mnemonic in SMEM_OPS: - glc_str = ", glc=1" if glc_val else "" - dlc_str = ", dlc=1" if dlc_val else "" - # Pure immediate offset in operand[2] - if len(operands) >= 3 and re.match(r'^-?[0-9]|^-?0x', operands[2].strip().lower()): - return f"{mnemonic}(sdata={dsl_args[0]}, sbase={dsl_args[1]}, offset={dsl_args[2]}, soffset=RawImm(124){glc_str}{dlc_str})" - # Register soffset with offset: modifier - if offset_val and len(operands) >= 3: - return f"{mnemonic}(sdata={dsl_args[0]}, sbase={dsl_args[1]}, offset={offset_val}, soffset={dsl_args[2]}{glc_str}{dlc_str})" - # Register soffset only (no offset modifier) - if len(operands) >= 3: - return f"{mnemonic}(sdata={dsl_args[0]}, sbase={dsl_args[1]}, soffset={dsl_args[2]}{glc_str}{dlc_str})" - # Buffer ops with 'off' - if mnemonic.startswith('buffer_') and len(operands) >= 2 and operands[1].strip().lower() == 'off': - soff = f"RawImm({_unwrap_dsl(dsl_args[3])})" if len(dsl_args) > 3 else "RawImm(0)" - return f"{mnemonic}(vdata={dsl_args[0]}, vaddr=0, srsrc={dsl_args[2]}, soffset={soff})" - # FLAT/GLOBAL/SCRATCH load - if (mnemonic.startswith('flat_load') or mnemonic.startswith('global_load') or mnemonic.startswith('scratch_load')) and len(dsl_args) >= 3: - off = f", offset={offset_val}" if offset_val else "" - return f"{mnemonic}(vdst={dsl_args[0]}, addr={dsl_args[1]}, saddr={dsl_args[2]}{off})" - # FLAT/GLOBAL/SCRATCH store - if (mnemonic.startswith('flat_store') or mnemonic.startswith('global_store') or mnemonic.startswith('scratch_store')) and len(dsl_args) >= 3: - off = f", offset={offset_val}" if offset_val else "" - return f"{mnemonic}(addr={dsl_args[0]}, data={dsl_args[1]}, saddr={dsl_args[2]}{off})" - # Handle v_fmaak/v_fmamk literals - lit_str = "" - if mnemonic in ('v_fmaak_f32', 'v_fmaak_f16') and len(dsl_args) == 4: - lit_str, dsl_args = f", literal={_unwrap_dsl(dsl_args[3])}", dsl_args[:3] - elif mnemonic in ('v_fmamk_f32', 'v_fmamk_f16') and len(dsl_args) == 4: - lit_str, dsl_args = f", literal={_unwrap_dsl(dsl_args[2])}", [dsl_args[0], dsl_args[1], dsl_args[3]] - # Handle v_add_co_ci_u32_e32 etc with vcc operands - strip implicit vcc sdst and carry_in, add _e32 suffix + xp, yp = text.split('::') + xps, yps = xp.strip().replace(',', ' ').split(), yp.strip().replace(',', ' ').split() + xo, yo = [_op2dsl(p) for p in xps[1:]], [_op2dsl(p) for p in yps[1:]] + vdx, sx0, vsx1 = xo[0], xo[1] if len(xo) > 1 else '0', xo[2] if len(xo) > 2 else 'v[0]' + vdy, sy0, vsy1 = yo[0], yo[1] if len(yo) > 1 else '0', yo[2] if len(yo) > 2 else 'v[0]' + lit = xo[3] if 'fmaak' in xps[0].lower() and len(xo) > 3 else yo[3] if 'fmaak' in yps[0].lower() and len(yo) > 3 else None + if 'fmamk' in xps[0].lower() and len(xo) > 3: lit, vsx1 = xo[2], xo[3] + elif 'fmamk' in yps[0].lower() and len(yo) > 3: lit, vsy1 = yo[2], yo[3] + return f"VOPD(VOPDOp.{xps[0].upper()}, VOPDOp.{yps[0].upper()}, vdstx={vdx}, vdsty={vdy}, srcx0={sx0}, vsrcx1={vsx1}, srcy0={sy0}, vsrcy1={vsy1}{f', literal={lit}' if lit else ''})" + + # Special instructions + if mn == 's_setreg_imm32_b32': raise ValueError(f"unsupported: {mn}") + if mn in ('s_setpc_b64', 's_rfe_b64'): return f"{mn}(ssrc0={args[0]})" + if mn in ('s_sendmsg_rtn_b32', 's_sendmsg_rtn_b64'): return f"{mn}(sdst={args[0]}, ssrc0=RawImm({args[1].strip()}))" + if mn == 's_version': return f"{mn}(simm16={args[0]})" + if mn == 's_setreg_b32': return f"{mn}(simm16={args[0]}, sdst={args[1]})" + + # SMEM + if mn in SMEM_OPS: + gs, ds = ", glc=1" if glc else "", ", dlc=1" if dlc else "" + if len(ops) >= 3 and re.match(r'^-?[0-9]|^-?0x', ops[2].strip().lower()): + return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={args[2]}, soffset=RawImm(124){gs}{ds})" + if off_val and len(ops) >= 3: return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={off_val}, soffset={args[2]}{gs}{ds})" + if len(ops) >= 3: return f"{mn}(sdata={args[0]}, sbase={args[1]}, soffset={args[2]}{gs}{ds})" + + # Buffer + if mn.startswith('buffer_') and len(ops) >= 2 and ops[1].strip().lower() == 'off': + return f"{mn}(vdata={args[0]}, vaddr=0, srsrc={args[2]}, soffset={f'RawImm({args[3].strip()})' if len(args) > 3 else 'RawImm(0)'})" + + # FLAT/GLOBAL/SCRATCH load/store/atomic - saddr needs RawImm(124) for off/null + def _saddr(a): return 'RawImm(124)' if a in ('OFF', 'NULL') else a + flat_mods = f"{f', offset={off_val}' if off_val else ''}{', glc=1' if glc else ''}{', slc=1' if slc else ''}{', dlc=1' if dlc else ''}" + for pre, flds in [('flat_load','vdst,addr,saddr'), ('global_load','vdst,addr,saddr'), ('scratch_load','vdst,addr,saddr'), + ('flat_store','addr,data,saddr'), ('global_store','addr,data,saddr'), ('scratch_store','addr,data,saddr')]: + if mn.startswith(pre) and len(args) >= 2: + f0, f1, f2 = flds.split(',') + return f"{mn}({f0}={args[0]}, {f1}={args[1]}{f', {f2}={_saddr(args[2])}' if len(args) >= 3 else ', saddr=RawImm(124)'}{flat_mods})" + for pre in ('flat_atomic', 'global_atomic', 'scratch_atomic'): + if mn.startswith(pre): + if glc and len(args) >= 3: return f"{mn}(vdst={args[0]}, addr={args[1]}, data={args[2]}{f', saddr={_saddr(args[3])}' if len(args) >= 4 else ', saddr=RawImm(124)'}{flat_mods})" + if len(args) >= 2: return f"{mn}(addr={args[0]}, data={args[1]}{f', saddr={_saddr(args[2])}' if len(args) >= 3 else ', saddr=RawImm(124)'}{flat_mods})" + + # DS instructions + if mn.startswith('ds_'): + off0, off1 = (str(int(off_val, 0) & 0xff), str((int(off_val, 0) >> 8) & 0xff)) if off_val else ("0", "0") + gds_s = ", gds=1" if 'gds' in text.lower().split()[-1:] else "" + off_kw = f", offset0={off0}, offset1={off1}{gds_s}" + if mn == 'ds_nop' or mn in ('ds_gws_sema_v', 'ds_gws_sema_p', 'ds_gws_sema_release_all'): return f"{mn}({off_kw.lstrip(', ')})" + if 'gws_' in mn: return f"{mn}(addr={args[0]}{off_kw})" + if 'consume' in mn or 'append' in mn: return f"{mn}(vdst={args[0]}{off_kw})" + if 'gs_reg' in mn: return f"{mn}(vdst={args[0]}, data0={args[1]}{off_kw})" + if '2addr' in mn: + if 'load' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}{off_kw})" + if 'store' in mn and 'xchg' not in mn: return f"{mn}(addr={args[0]}, data0={args[1]}, data1={args[2]}{off_kw})" + return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}, data1={args[3]}{off_kw})" + if 'load' in mn: return f"{mn}(vdst={args[0]}{off_kw})" if 'addtid' in mn else f"{mn}(vdst={args[0]}, addr={args[1]}{off_kw})" + if 'store' in mn and not _has(mn, 'cmp', 'xchg'): + return f"{mn}(data0={args[0]}{off_kw})" if 'addtid' in mn else f"{mn}(addr={args[0]}, data0={args[1]}{off_kw})" + if 'swizzle' in mn or 'ordered_count' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}{off_kw})" + if 'permute' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}{off_kw})" + if 'bvh' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}, data1={args[3]}{off_kw})" + if 'condxchg' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}{off_kw})" + if _has(mn, 'cmpstore', 'mskor', 'wrap'): + return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}, data1={args[3]}{off_kw})" if '_rtn' in mn else f"{mn}(addr={args[0]}, data0={args[1]}, data1={args[2]}{off_kw})" + return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}{off_kw})" if '_rtn' in mn else f"{mn}(addr={args[0]}, data0={args[1]}{off_kw})" + + # v_fmaak/v_fmamk literal extraction + lit_s = "" + if mn in ('v_fmaak_f32', 'v_fmaak_f16') and len(args) == 4: lit_s, args = f", literal={args[3].strip()}", args[:3] + elif mn in ('v_fmamk_f32', 'v_fmamk_f16') and len(args) == 4: lit_s, args = f", literal={args[2].strip()}", [args[0], args[1], args[3]] + + # VCC ops cleanup vcc_ops = {'v_add_co_ci_u32', 'v_sub_co_ci_u32', 'v_subrev_co_ci_u32'} - if mnemonic.replace('_e32', '') in vcc_ops and len(dsl_args) >= 5: - mnemonic = mnemonic.replace('_e32', '') + '_e32' # Ensure _e32 suffix for VOP2 encoding - dsl_args = [dsl_args[0], dsl_args[2], dsl_args[3]] - # Handle v_add_co_ci_u32_e64 etc - strip _e64 suffix (function name doesn't have it, returns VOP3SD) - if mnemonic.replace('_e64', '') in vcc_ops and mnemonic.endswith('_e64'): - mnemonic = mnemonic.replace('_e64', '') - # v_cmp_*_e32: strip implicit vcc_lo dest - if mnemonic.startswith('v_cmp') and not mnemonic.endswith('_e64') and len(dsl_args) >= 3 and operands[0].strip().lower() in ('vcc_lo', 'vcc_hi', 'vcc'): - dsl_args = dsl_args[1:] - # CMPX with _e64: prepend implicit EXEC_LO (vdst=126) - if 'cmpx' in mnemonic and mnemonic.endswith('_e64') and len(dsl_args) == 2: - dsl_args = ['RawImm(126)'] + dsl_args - # Build the function name - use mnemonic as-is, replacing . with _ - func_name = mnemonic.replace('.', '_') - # When explicit opsel is given, strip .h/.l from register args (opsel overrides) - if opsel_explicit is not None: - dsl_args = [re.sub(r'\.[hl]$', '', a) for a in dsl_args] - args_str = ', '.join(dsl_args) - all_kwargs = list(kwargs) - if lit_str: all_kwargs.append(lit_str.lstrip(', ')) - if opsel_explicit is not None: all_kwargs.append(f'opsel={opsel_explicit}') - if neg_lo_val is not None: all_kwargs.append(f'neg={neg_lo_val}') - if neg_hi_val is not None: all_kwargs.append(f'neg_hi={neg_hi_val}') - kwargs_str = ', '.join(all_kwargs) - if kwargs_str: - return f"{func_name}({args_str}, {kwargs_str})" if args_str else f"{func_name}({kwargs_str})" - return f"{func_name}({args_str})" + if mn.replace('_e32', '') in vcc_ops and len(args) >= 5: mn, args = mn.replace('_e32', '') + '_e32', [args[0], args[2], args[3]] + if mn.replace('_e64', '') in vcc_ops and mn.endswith('_e64'): mn = mn.replace('_e64', '') + if mn.startswith('v_cmp') and not mn.endswith('_e64') and len(args) >= 3 and ops[0].strip().lower() in ('vcc_lo', 'vcc_hi', 'vcc'): args = args[1:] + if 'cmpx' in mn and mn.endswith('_e64') and len(args) == 2: args = ['RawImm(126)'] + args + + fn = mn.replace('.', '_') + if opsel is not None: args = [re.sub(r'\.[hl]$', '', a) for a in args] + + # v_fma_mix*: extract inline neg/abs modifiers + if 'fma_mix' in mn and neg_lo is None and neg_hi is None: + inline_neg, inline_abs, clean_args = 0, 0, [args[0]] + for i, op in enumerate(ops[1:4]): + op = op.strip() + neg = op.startswith('-') and not (op[1:2].isdigit() or (len(op) > 2 and op[1] == '0' and op[2] in 'xX')) + if neg: op = op[1:] + abs_ = op.startswith('|') and op.endswith('|') + if abs_: op = op[1:-1] + if neg: inline_neg |= (1 << i) + if abs_: inline_abs |= (1 << i) + clean_args.append(_op2dsl(op)) + args = clean_args + args[4:] + if inline_neg: neg_lo = inline_neg + if inline_abs: neg_hi = inline_abs + + all_kw = list(kw) + if lit_s: all_kw.append(lit_s.lstrip(', ')) + if opsel is not None: all_kw.append(f'opsel={opsel}') + if neg_lo is not None: all_kw.append(f'neg={neg_lo}') + if neg_hi is not None: all_kw.append(f'neg_hi={neg_hi}') + if 'bvh' in mn and 'intersect_ray' in mn: all_kw.extend(['dmask=15', 'unrm=1', 'r128=1']) + + a_str, kw_str = ', '.join(args), ', '.join(all_kw) + return f"{fn}({a_str}, {kw_str})" if kw_str and a_str else f"{fn}({kw_str})" if kw_str else f"{fn}({a_str})" def asm(text: str) -> Inst: - """Assemble LLVM-style instruction text to Inst by transforming to DSL and eval.""" - from extra.assembly.amd.autogen import rdna3 as autogen - dsl_expr = get_dsl(text) - namespace = {name: getattr(autogen, name) for name in dir(autogen) if not name.startswith('_')} - namespace.update({'s': s, 'v': v, 'ttmp': ttmp, 'abs': abs, 'RawImm': RawImm, 'SrcMod': SrcMod, 'VGPR': VGPR, 'SGPR': SGPR, 'TTMP': TTMP, - 'VCC_LO': VCC_LO, 'VCC_HI': VCC_HI, 'VCC': VCC, 'EXEC_LO': EXEC_LO, 'EXEC_HI': EXEC_HI, 'EXEC': EXEC, - 'SCC': SCC, 'M0': M0, 'NULL': NULL, 'OFF': OFF}) - try: - return eval(dsl_expr, namespace) + from extra.assembly.amd.autogen import rdna3 as ag + dsl = get_dsl(text) + ns = {n: getattr(ag, n) for n in dir(ag) if not n.startswith('_')} + ns.update({'s': s, 'v': v, 'ttmp': ttmp, 'abs': abs, 'RawImm': RawImm, 'SrcMod': SrcMod, 'VGPR': VGPR, 'SGPR': SGPR, 'TTMP': TTMP, + 'VCC_LO': VCC_LO, 'VCC_HI': VCC_HI, 'VCC': VCC, 'EXEC_LO': EXEC_LO, 'EXEC_HI': EXEC_HI, 'EXEC': EXEC, 'SCC': SCC, 'M0': M0, 'NULL': NULL, 'OFF': OFF}) + try: return eval(dsl, ns) except NameError: - # Try with _e32 suffix for VOP1/VOP2/VOPC (only for v_* instructions) - if m := re.match(r'^(v_\w+)(\(.*\))$', dsl_expr): - return eval(f"{m.group(1)}_e32{m.group(2)}", namespace) + if m := re.match(r'^(v_\w+)(\(.*\))$', dsl): return eval(f"{m.group(1)}_e32{m.group(2)}", ns) raise diff --git a/extra/assembly/amd/autogen/cdna/__init__.py b/extra/assembly/amd/autogen/cdna/__init__.py index c1c1ecaaad..416e4b9ac8 100644 --- a/extra/assembly/amd/autogen/cdna/__init__.py +++ b/extra/assembly/amd/autogen/cdna/__init__.py @@ -2209,33 +2209,33 @@ buffer_atomic_xor_x2 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_XOR_X2) buffer_atomic_inc_x2 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_INC_X2) buffer_atomic_dec_x2 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_DEC_X2) cdna4 = functools.partial(MUBUF, MUBUFOp.CDNA4) -scratch_load_ubyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_UBYTE, seg=2) -scratch_load_sbyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SBYTE, seg=2) -scratch_load_ushort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_USHORT, seg=2) -scratch_load_sshort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SSHORT, seg=2) -scratch_load_dword = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORD, seg=2) -scratch_load_dwordx2 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORDX2, seg=2) -scratch_load_dwordx3 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORDX3, seg=2) -scratch_load_dwordx4 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORDX4, seg=2) -scratch_store_byte = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_BYTE, seg=2) -scratch_store_byte_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_BYTE_D16_HI, seg=2) -scratch_store_short = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_SHORT, seg=2) -scratch_store_short_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_SHORT_D16_HI, seg=2) -scratch_store_dword = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORD, seg=2) -scratch_store_dwordx2 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORDX2, seg=2) -scratch_store_dwordx3 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORDX3, seg=2) -scratch_store_dwordx4 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORDX4, seg=2) -scratch_load_ubyte_d16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_UBYTE_D16, seg=2) -scratch_load_ubyte_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_UBYTE_D16_HI, seg=2) -scratch_load_sbyte_d16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SBYTE_D16, seg=2) -scratch_load_sbyte_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SBYTE_D16_HI, seg=2) -scratch_load_short_d16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SHORT_D16, seg=2) -scratch_load_short_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SHORT_D16_HI, seg=2) -scratch_load_lds_ubyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_UBYTE, seg=2) -scratch_load_lds_sbyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_SBYTE, seg=2) -scratch_load_lds_ushort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_USHORT, seg=2) -scratch_load_lds_sshort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_SSHORT, seg=2) -scratch_load_lds_dword = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_DWORD, seg=2) +scratch_load_ubyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_UBYTE, seg=1) +scratch_load_sbyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SBYTE, seg=1) +scratch_load_ushort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_USHORT, seg=1) +scratch_load_sshort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SSHORT, seg=1) +scratch_load_dword = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORD, seg=1) +scratch_load_dwordx2 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORDX2, seg=1) +scratch_load_dwordx3 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORDX3, seg=1) +scratch_load_dwordx4 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_DWORDX4, seg=1) +scratch_store_byte = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_BYTE, seg=1) +scratch_store_byte_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_BYTE_D16_HI, seg=1) +scratch_store_short = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_SHORT, seg=1) +scratch_store_short_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_SHORT_D16_HI, seg=1) +scratch_store_dword = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORD, seg=1) +scratch_store_dwordx2 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORDX2, seg=1) +scratch_store_dwordx3 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORDX3, seg=1) +scratch_store_dwordx4 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_DWORDX4, seg=1) +scratch_load_ubyte_d16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_UBYTE_D16, seg=1) +scratch_load_ubyte_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_UBYTE_D16_HI, seg=1) +scratch_load_sbyte_d16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SBYTE_D16, seg=1) +scratch_load_sbyte_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SBYTE_D16_HI, seg=1) +scratch_load_short_d16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SHORT_D16, seg=1) +scratch_load_short_d16_hi = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_SHORT_D16_HI, seg=1) +scratch_load_lds_ubyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_UBYTE, seg=1) +scratch_load_lds_sbyte = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_SBYTE, seg=1) +scratch_load_lds_ushort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_USHORT, seg=1) +scratch_load_lds_sshort = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_SSHORT, seg=1) +scratch_load_lds_dword = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_LDS_DWORD, seg=1) s_load_dword = functools.partial(SMEM, SMEMOp.S_LOAD_DWORD) s_load_dwordx2 = functools.partial(SMEM, SMEMOp.S_LOAD_DWORDX2) s_load_dwordx4 = functools.partial(SMEM, SMEMOp.S_LOAD_DWORDX4) diff --git a/extra/assembly/amd/autogen/rdna3/__init__.py b/extra/assembly/amd/autogen/rdna3/__init__.py index 65c8c928cd..a1d5780754 100644 --- a/extra/assembly/amd/autogen/rdna3/__init__.py +++ b/extra/assembly/amd/autogen/rdna3/__init__.py @@ -56,6 +56,12 @@ class DSOp(IntEnum): DS_MAX_F32 = 19 DS_NOP = 20 DS_ADD_F32 = 21 + DS_GWS_SEMA_RELEASE_ALL = 24 + DS_GWS_INIT = 25 + DS_GWS_SEMA_V = 26 + DS_GWS_SEMA_BR = 27 + DS_GWS_SEMA_P = 28 + DS_GWS_BARRIER = 29 DS_STORE_B8 = 30 DS_STORE_B16 = 31 DS_ADD_RTN_U32 = 32 @@ -178,10 +184,13 @@ class FLATOp(IntEnum): FLAT_LOAD_D16_HI_B16 = 35 FLAT_STORE_D16_HI_B8 = 36 FLAT_STORE_D16_HI_B16 = 37 + GLOBAL_LOAD_ADDTID_B32 = 40 + GLOBAL_STORE_ADDTID_B32 = 41 FLAT_ATOMIC_SWAP_B32 = 51 FLAT_ATOMIC_CMPSWAP_B32 = 52 FLAT_ATOMIC_ADD_U32 = 53 FLAT_ATOMIC_SUB_U32 = 54 + FLAT_ATOMIC_CSUB_U32 = 55 FLAT_ATOMIC_MIN_I32 = 56 FLAT_ATOMIC_MIN_U32 = 57 FLAT_ATOMIC_MAX_I32 = 58 @@ -717,6 +726,7 @@ class SOPPOp(IntEnum): S_SET_INST_PREFETCH_DISTANCE = 4 S_CLAUSE = 5 S_DELAY_ALU = 7 + S_WAITCNT_DEPCTR = 8 S_WAITCNT = 9 S_WAIT_IDLE = 10 S_WAIT_EVENT = 11 @@ -1848,6 +1858,12 @@ ds_min_f32 = functools.partial(DS, DSOp.DS_MIN_F32) ds_max_f32 = functools.partial(DS, DSOp.DS_MAX_F32) ds_nop = functools.partial(DS, DSOp.DS_NOP) ds_add_f32 = functools.partial(DS, DSOp.DS_ADD_F32) +ds_gws_sema_release_all = functools.partial(DS, DSOp.DS_GWS_SEMA_RELEASE_ALL) +ds_gws_init = functools.partial(DS, DSOp.DS_GWS_INIT) +ds_gws_sema_v = functools.partial(DS, DSOp.DS_GWS_SEMA_V) +ds_gws_sema_br = functools.partial(DS, DSOp.DS_GWS_SEMA_BR) +ds_gws_sema_p = functools.partial(DS, DSOp.DS_GWS_SEMA_P) +ds_gws_barrier = functools.partial(DS, DSOp.DS_GWS_BARRIER) ds_store_b8 = functools.partial(DS, DSOp.DS_STORE_B8) ds_store_b16 = functools.partial(DS, DSOp.DS_STORE_B16) ds_add_rtn_u32 = functools.partial(DS, DSOp.DS_ADD_RTN_U32) @@ -1968,10 +1984,13 @@ flat_load_d16_hi_i8 = functools.partial(FLAT, FLATOp.FLAT_LOAD_D16_HI_I8) flat_load_d16_hi_b16 = functools.partial(FLAT, FLATOp.FLAT_LOAD_D16_HI_B16) flat_store_d16_hi_b8 = functools.partial(FLAT, FLATOp.FLAT_STORE_D16_HI_B8) flat_store_d16_hi_b16 = functools.partial(FLAT, FLATOp.FLAT_STORE_D16_HI_B16) +global_load_addtid_b32 = functools.partial(FLAT, FLATOp.GLOBAL_LOAD_ADDTID_B32) +global_store_addtid_b32 = functools.partial(FLAT, FLATOp.GLOBAL_STORE_ADDTID_B32) flat_atomic_swap_b32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_SWAP_B32) flat_atomic_cmpswap_b32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_CMPSWAP_B32) flat_atomic_add_u32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_ADD_U32) flat_atomic_sub_u32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_SUB_U32) +flat_atomic_csub_u32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_CSUB_U32) flat_atomic_min_i32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_MIN_I32) flat_atomic_min_u32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_MIN_U32) flat_atomic_max_i32 = functools.partial(FLAT, FLATOp.FLAT_ATOMIC_MAX_I32) @@ -2226,28 +2245,28 @@ buffer_atomic_cmpswap_f32 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_CMPSW buffer_atomic_min_f32 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_MIN_F32) buffer_atomic_max_f32 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_MAX_F32) buffer_atomic_add_f32 = functools.partial(MUBUF, MUBUFOp.BUFFER_ATOMIC_ADD_F32) -scratch_load_u8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_U8, seg=2) -scratch_load_i8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_I8, seg=2) -scratch_load_u16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_U16, seg=2) -scratch_load_i16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_I16, seg=2) -scratch_load_b32 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B32, seg=2) -scratch_load_b64 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B64, seg=2) -scratch_load_b96 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B96, seg=2) -scratch_load_b128 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B128, seg=2) -scratch_store_b8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B8, seg=2) -scratch_store_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B16, seg=2) -scratch_store_b32 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B32, seg=2) -scratch_store_b64 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B64, seg=2) -scratch_store_b96 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B96, seg=2) -scratch_store_b128 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B128, seg=2) -scratch_load_d16_u8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_U8, seg=2) -scratch_load_d16_i8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_I8, seg=2) -scratch_load_d16_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_B16, seg=2) -scratch_load_d16_hi_u8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_HI_U8, seg=2) -scratch_load_d16_hi_i8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_HI_I8, seg=2) -scratch_load_d16_hi_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_HI_B16, seg=2) -scratch_store_d16_hi_b8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_D16_HI_B8, seg=2) -scratch_store_d16_hi_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_D16_HI_B16, seg=2) +scratch_load_u8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_U8, seg=1) +scratch_load_i8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_I8, seg=1) +scratch_load_u16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_U16, seg=1) +scratch_load_i16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_I16, seg=1) +scratch_load_b32 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B32, seg=1) +scratch_load_b64 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B64, seg=1) +scratch_load_b96 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B96, seg=1) +scratch_load_b128 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_B128, seg=1) +scratch_store_b8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B8, seg=1) +scratch_store_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B16, seg=1) +scratch_store_b32 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B32, seg=1) +scratch_store_b64 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B64, seg=1) +scratch_store_b96 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B96, seg=1) +scratch_store_b128 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_B128, seg=1) +scratch_load_d16_u8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_U8, seg=1) +scratch_load_d16_i8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_I8, seg=1) +scratch_load_d16_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_B16, seg=1) +scratch_load_d16_hi_u8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_HI_U8, seg=1) +scratch_load_d16_hi_i8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_HI_I8, seg=1) +scratch_load_d16_hi_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_LOAD_D16_HI_B16, seg=1) +scratch_store_d16_hi_b8 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_D16_HI_B8, seg=1) +scratch_store_d16_hi_b16 = functools.partial(FLAT, SCRATCHOp.SCRATCH_STORE_D16_HI_B16, seg=1) s_load_b32 = functools.partial(SMEM, SMEMOp.S_LOAD_B32) s_load_b64 = functools.partial(SMEM, SMEMOp.S_LOAD_B64) s_load_b128 = functools.partial(SMEM, SMEMOp.S_LOAD_B128) @@ -2485,6 +2504,7 @@ s_sleep = functools.partial(SOPP, SOPPOp.S_SLEEP) s_set_inst_prefetch_distance = functools.partial(SOPP, SOPPOp.S_SET_INST_PREFETCH_DISTANCE) s_clause = functools.partial(SOPP, SOPPOp.S_CLAUSE) s_delay_alu = functools.partial(SOPP, SOPPOp.S_DELAY_ALU) +s_waitcnt_depctr = functools.partial(SOPP, SOPPOp.S_WAITCNT_DEPCTR) s_waitcnt = functools.partial(SOPP, SOPPOp.S_WAITCNT) s_wait_idle = functools.partial(SOPP, SOPPOp.S_WAIT_IDLE) s_wait_event = functools.partial(SOPP, SOPPOp.S_WAIT_EVENT) diff --git a/extra/assembly/amd/dsl.py b/extra/assembly/amd/dsl.py index ef2f3aefdc..d4ac98f431 100644 --- a/extra/assembly/amd/dsl.py +++ b/extra/assembly/amd/dsl.py @@ -6,22 +6,21 @@ from typing import overload, Annotated, TypeVar, Generic # Bit field DSL class BitField: - def __init__(self, hi: int, lo: int, name: str | None = None): self.hi, self.lo, self.name = hi, lo, name - def __set_name__(self, owner, name): self.name, self._owner = name, owner + def __init__(self, hi: int, lo: int, name: str | None = None): self.hi, self.lo, self.name, self._marker = hi, lo, name, None + def __set_name__(self, owner, name): + import typing + self.name, self._owner = name, owner + # Cache marker at class definition time + hints = typing.get_type_hints(owner, include_extras=True) + if name in hints: + hint = hints[name] + if typing.get_origin(hint) is Annotated: + args = typing.get_args(hint) + self._marker = args[1] if len(args) > 1 else None def __eq__(self, val: int) -> tuple[BitField, int]: return (self, val) # type: ignore def mask(self) -> int: return (1 << (self.hi - self.lo + 1)) - 1 @property - def marker(self) -> type | None: - # Get marker from Annotated type hint if present - import typing - if hasattr(self, '_owner') and self.name: - hints = typing.get_type_hints(self._owner, include_extras=True) - if self.name in hints: - hint = hints[self.name] - if typing.get_origin(hint) is Annotated: - args = typing.get_args(hint) - return args[1] if len(args) > 1 else None - return None + def marker(self) -> type | None: return self._marker @overload def __get__(self, obj: None, objtype: type) -> BitField: ... @overload @@ -179,6 +178,21 @@ class Inst: raise ValueError(f"SOP1 {op_val.name} expects {expected} destination register(s), got {sdst_val.count}") if isinstance(ssrc0_val, Reg) and ssrc0_val.count != expected: raise ValueError(f"SOP1 {op_val.name} expects {expected} source register(s), got {ssrc0_val.count}") + # FLAT: set sve=1 when addr is a VGPR for scratch only + # For scratch (seg=1), sve=1 means addr VGPR is used; sve=0 means addr is "off" + # For global (seg=2) and flat (seg=0), sve is always 0 + if self.__class__.__name__ == 'FLAT' and 'sve' in self._fields: + seg_val = self._values.get('seg', 0) + if isinstance(seg_val, RawImm): seg_val = seg_val.val + addr_val = orig_args.get('addr') + if seg_val == 1 and isinstance(addr_val, VGPR): self._values['sve'] = 1 + # VOP3P: v_fma_mix* instructions (opcodes 32-34) have opsel_hi default of 0, not 7 + if self.__class__.__name__ == 'VOP3P': + op_val = orig_args.get(field_names[0]) if args else orig_args.get('op') + if hasattr(op_val, 'value'): op_val = op_val.value + if op_val in (32, 33, 34) and 'opsel_hi' not in orig_args and 'opsel_hi2' not in orig_args: + self._values['opsel_hi'] = 0 + self._values['opsel_hi2'] = 0 # Type check and encode values for name, val in list(self._values.items()): if name == 'encoding': continue @@ -340,6 +354,14 @@ class Inst: lit = f", literal={hex(self._literal)}" if self._literal is not None else "" return f"{self.__class__.__name__}({', '.join(f'{k}={v}' for k, v in items)}{lit})" + def __getattr__(self, name: str): + if name.startswith('_'): raise AttributeError(name) + return unwrap(self._values.get(name, 0)) + + def lit(self, v: int) -> str: + from extra.assembly.amd.asm import decode_src + return f"0x{self._literal:x}" if v == 255 and self._literal else decode_src(v) + def __eq__(self, other): if not isinstance(other, Inst): return NotImplemented return self.__class__ == other.__class__ and self._values == other._values and self._literal == other._literal @@ -519,10 +541,24 @@ def _parse_single_pdf(url: str) -> dict: break formats[fmt_name] = fields - # fix known PDF errors + # fix known PDF errors - assert if already present (so we know when the bug is fixed) if 'SMEM' in formats: formats['SMEM'] = [(n, 13 if n == 'DLC' else 14 if n == 'GLC' else h, 13 if n == 'DLC' else 14 if n == 'GLC' else l, e, t) for n, h, l, e, t in formats['SMEM']] + # add missing opcodes not in PDF tables (RDNA3/RDNA3.5 specific) + if doc_name in ('RDNA3', 'RDNA3.5'): + if 'SOPPOp' in enums: + assert 8 not in enums['SOPPOp'], "S_WAITCNT_DEPCTR now in PDF, remove workaround" + enums['SOPPOp'][8] = 'S_WAITCNT_DEPCTR' + if 'DSOp' in enums: + gws_ops = {24: 'DS_GWS_SEMA_RELEASE_ALL', 25: 'DS_GWS_INIT', 26: 'DS_GWS_SEMA_V', + 27: 'DS_GWS_SEMA_BR', 28: 'DS_GWS_SEMA_P', 29: 'DS_GWS_BARRIER'} + for k in gws_ops: assert k not in enums['DSOp'], f"{gws_ops[k]} now in PDF, remove workaround" + enums['DSOp'].update(gws_ops) + if 'FLATOp' in enums: + flat_ops = {40: 'GLOBAL_LOAD_ADDTID_B32', 41: 'GLOBAL_STORE_ADDTID_B32', 55: 'FLAT_ATOMIC_CSUB_U32'} + for k in flat_ops: assert k not in enums['FLATOp'], f"{flat_ops[k]} now in PDF, remove workaround" + enums['FLATOp'].update(flat_ops) return {"formats": formats, "enums": enums, "src_enum": src_enum, "doc_name": doc_name, "is_cdna": is_cdna} @@ -608,7 +644,7 @@ def generate(output_path: str | None = None, arch: str = "rdna3") -> dict: for cls_name, ops in sorted(enums.items()): fmt = cls_name[:-2] for op_val, name in sorted(ops.items()): - seg = {"GLOBAL": ", seg=2", "SCRATCH": ", seg=2"}.get(fmt, "") + seg = {"GLOBAL": ", seg=2", "SCRATCH": ", seg=1"}.get(fmt, "") tgt = {"GLOBAL": "FLAT, GLOBALOp", "SCRATCH": "FLAT, SCRATCHOp"}.get(fmt, f"{fmt}, {cls_name}") if fmt in formats or fmt in ("GLOBAL", "SCRATCH"): if fmt in ("VOP1", "VOP2", "VOPC"): diff --git a/extra/assembly/amd/emu.py b/extra/assembly/amd/emu.py index c99720aceb..18d50f6a25 100644 --- a/extra/assembly/amd/emu.py +++ b/extra/assembly/amd/emu.py @@ -3,6 +3,7 @@ from __future__ import annotations import ctypes, os from extra.assembly.amd.dsl import Inst, RawImm +from extra.assembly.amd.asm import detect_format from extra.assembly.amd.pcode import _f32, _i32, _sext, _f16, _i16, _f64, _i64 from extra.assembly.amd.autogen.rdna3.gen_pcode import get_compiled_functions from extra.assembly.amd.autogen.rdna3 import ( @@ -146,21 +147,7 @@ class WaveState: for reg, val in self._pend_sgpr.items(): self.sgpr[reg] = val self._pend_sgpr.clear() -# Instruction decode -def decode_format(word: int) -> tuple[type[Inst] | None, bool]: - hi2 = (word >> 30) & 0x3 - if hi2 == 0b11: - enc = (word >> 26) & 0xf - if enc == 0b1101: return SMEM, True - if enc == 0b0101: - op = (word >> 16) & 0x3ff - return (VOP3SD, True) if op in (288, 289, 290, 764, 765, 766, 767, 768, 769, 770) else (VOP3, True) - return {0b0011: (VOP3P, True), 0b0110: (DS, True), 0b0111: (FLAT, True), 0b0010: (VOPD, True)}.get(enc, (None, True)) - if hi2 == 0b10: - enc = (word >> 23) & 0x7f - return {0b1111101: (SOP1, False), 0b1111110: (SOPC, False), 0b1111111: (SOPP, False)}.get(enc, (SOPK, False) if ((word >> 28) & 0xf) == 0b1011 else (SOP2, False)) - enc = (word >> 25) & 0x7f - return (VOPC, False) if enc == 0b0111110 else (VOP1, False) if enc == 0b0111111 else (VOP2, False) + def _unwrap(v) -> int: return v.val if isinstance(v, RawImm) else v.value if hasattr(v, 'value') else v @@ -168,10 +155,10 @@ def decode_program(data: bytes) -> Program: result: Program = {} i = 0 while i < len(data): - word = int.from_bytes(data[i:i+4], 'little') - inst_class, is_64 = decode_format(word) + try: inst_class = detect_format(data[i:]) + except ValueError: break # stop at invalid instruction (padding/metadata after code) if inst_class is None: i += 4; continue - base_size = 8 if is_64 else 4 + base_size = inst_class._size() # Pass enough data for potential 64-bit literal (base + 8 bytes max) inst = inst_class.from_bytes(data[i:i+base_size+8]) for name, val in inst._values.items(): setattr(inst, name, _unwrap(val)) diff --git a/extra/assembly/amd/test/test_llvm.py b/extra/assembly/amd/test/test_llvm.py index 97b1cab756..955b3239a2 100644 --- a/extra/assembly/amd/test/test_llvm.py +++ b/extra/assembly/amd/test/test_llvm.py @@ -65,12 +65,18 @@ def parse_llvm_tests(text: str) -> list[tuple[str, bytes]]: if not asm_text: continue for j in range(i, min(i + 3, len(lines))): # Match GFX11, W32, or W64 encodings (all valid for gfx11) + # Format 1: "// GFX11: v_foo ... ; encoding: [0x01,0x02,...]" + # Format 2: "// GFX11: [0x01,0x02,...]" (used by DS, older files) if m := re.search(r'(?:GFX11|W32|W64)[^:]*:.*?encoding:\s*\[(.*?)\]', lines[j]): hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '') - if hex_bytes: - try: tests.append((asm_text, bytes.fromhex(hex_bytes))) - except ValueError: pass - break + elif m := re.search(r'(?:GFX11|W32|W64)[^:]*:\s*\[(0x[0-9a-fA-F,x\s]+)\]', lines[j]): + hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '') + else: + continue + if hex_bytes: + try: tests.append((asm_text, bytes.fromhex(hex_bytes))) + except ValueError: pass + break return tests def try_assemble(text: str): diff --git a/extra/assembly/amd/test/test_roundtrip.py b/extra/assembly/amd/test/test_roundtrip.py index bf9b68d869..d2660ab140 100644 --- a/extra/assembly/amd/test/test_roundtrip.py +++ b/extra/assembly/amd/test/test_roundtrip.py @@ -4,51 +4,9 @@ import unittest, io, sys, re, subprocess, os from extra.assembly.amd.autogen.rdna3 import * from extra.assembly.amd.dsl import Inst from extra.assembly.amd.asm import asm +from extra.assembly.amd.asm import detect_format from extra.assembly.amd.test.helpers import get_llvm_mc, get_llvm_objdump -# Instruction format detection based on encoding bits -def detect_format(data: bytes) -> type[Inst] | None: - """Detect instruction format from machine code bytes.""" - if len(data) < 4: return None - word = int.from_bytes(data[:4], 'little') - enc_9bit = (word >> 23) & 0x1FF # 9-bit encoding for SOP1/SOPC/SOPP - enc_8bit = (word >> 24) & 0xFF - - # Check 9-bit encodings first (most specific) - if enc_9bit == 0x17D: return SOP1 # bits 31:23 = 101111101 - if enc_9bit == 0x17E: return SOPC # bits 31:23 = 101111110 - if enc_9bit == 0x17F: return SOPP # bits 31:23 = 101111111 - # SOPK: bits 31:28 = 1011, bits 27:23 = opcode (check after SOP1/SOPC/SOPP) - if enc_8bit in range(0xB0, 0xC0): return SOPK - # SOP2: bits 31:23 in range 0x100-0x17C (0x80-0xBE in bits 31:24, but not SOPK) - if 0x80 <= enc_8bit <= 0x9F: return SOP2 - # VOP1: bits 31:25 = 0111111 (0x3F) - if (word >> 25) == 0x3F: return VOP1 - # VOPC: bits 31:25 = 0111110 (0x3E) - if (word >> 25) == 0x3E: return VOPC - # VOP2: bits 31:30 = 00 - if (word >> 30) == 0: return VOP2 - - # Check 64-bit formats - if len(data) >= 8: - if enc_8bit in (0xD4, 0xD5, 0xD7): - # VOP3 and VOP3SD share encoding - check opcode to determine which - # VOP3SD opcodes: 288-290 (v_*_co_ci_*), 764-770 (v_div_scale_*, v_mad_*, v_*_co_u32) - op = (int.from_bytes(data[:8], 'little') >> 16) & 0x3FF - if op in {288, 289, 290, 764, 765, 766, 767, 768, 769, 770}: return VOP3SD - return VOP3 - if enc_8bit == 0xD6: return VOP3SD - if enc_8bit == 0xCC: return VOP3P - if enc_8bit == 0xCD: return VINTERP - if enc_8bit in (0xC8, 0xC9): return VOPD - if enc_8bit == 0xF4: return SMEM - if enc_8bit == 0xD8: return DS - if enc_8bit in (0xDC, 0xDD, 0xDE, 0xDF): return FLAT - if enc_8bit in (0xE0, 0xE1, 0xE2, 0xE3): return MUBUF - if enc_8bit in (0xE8, 0xE9, 0xEA, 0xEB): return MTBUF - - return None - def disassemble_lib(lib: bytes, compiler) -> list[tuple[str, bytes]]: """Disassemble ELF binary and return list of (instruction_text, machine_code_bytes).""" old_stdout = sys.stdout