From 4e03b3ebef4488b5635dcf004ce4705c3bc5e5de Mon Sep 17 00:00:00 2001 From: George Hotz Date: Thu, 1 Jan 2026 16:45:39 +0000 Subject: [PATCH] rdna4 work --- extra/assembly/amd/asm.py | 927 ++++++++++++++-------- extra/assembly/amd/dsl.py | 4 +- extra/assembly/amd/test/test_llvm.py | 15 +- extra/assembly/amd/test/test_roundtrip.py | 259 +++--- 4 files changed, 710 insertions(+), 495 deletions(-) diff --git a/extra/assembly/amd/asm.py b/extra/assembly/amd/asm.py index 2abb3cbad3..8f0d511c09 100644 --- a/extra/assembly/amd/asm.py +++ b/extra/assembly/amd/asm.py @@ -1,29 +1,37 @@ -# RDNA3/RDNA4 assembler and disassembler +# RDNA3 assembler and disassembler from __future__ import annotations import re -from extra.assembly.amd.dsl import Inst, RawImm, Reg, SrcMod, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory +from extra.assembly.amd.dsl import Inst, RawImm, Reg, SrcMod, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory, SRC_FIELDS, unwrap from extra.assembly.amd.dsl import VCC_LO, VCC_HI, VCC, EXEC_LO, EXEC_HI, EXEC, SCC, M0, NULL, OFF from extra.assembly.amd.dsl import SPECIAL_GPRS, SPECIAL_PAIRS, FLOAT_DEC, FLOAT_ENC, decode_src from extra.assembly.amd.autogen.rdna3 import ins -from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT, MUBUF, MTBUF, MIMG, EXP, - VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOPDOp, SOP1Op, SOPKOp, SOPPOp, SMEMOp, DSOp, MUBUFOp) +from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT, MUBUF, MTBUF, MIMG, EXP, LDSDIR, + VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, VOPDOp, VINTERPOp, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, DSOp, FLATOp, MUBUFOp, MTBUFOp, MIMGOp) + +# VOP3SD opcodes that share VOP3 encoding +VOP3SD_OPS = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770} def _matches_encoding(word: int, cls: type[Inst]) -> bool: + """Check if word matches the encoding pattern of an instruction class.""" if cls._encoding is None: return False bf, val = cls._encoding return ((word >> bf.lo) & bf.mask()) == val +# Order matters: more specific encodings first, VOP2 last (it's a catch-all for bit31=0) _FORMATS_64 = [VOPD, VOP3P, VINTERP, VOP3, DS, FLAT, MUBUF, MTBUF, MIMG, SMEM, EXP] -_FORMATS_32 = [SOP1, SOPC, SOPP, SOPK, VOPC, VOP1, SOP2, VOP2] +_FORMATS_32 = [SOP1, SOPC, SOPP, SOPK, VOPC, VOP1, SOP2, VOP2] # SOP2/VOP2 are catch-alls def detect_format(data: bytes) -> type[Inst]: + """Detect instruction format from machine code bytes.""" assert len(data) >= 4, f"need at least 4 bytes, got {len(data)}" word = int.from_bytes(data[:4], 'little') + # Check 64-bit formats first (bits[31:30] == 0b11) if (word >> 30) == 0b11: for cls in _FORMATS_64: if _matches_encoding(word, cls): - return VOP3SD if cls is VOP3 and ((word >> 16) & 0x3ff) in Inst._VOP3SD_OPS else cls + return VOP3SD if cls is VOP3 and ((word >> 16) & 0x3ff) in VOP3SD_OPS else cls raise ValueError(f"unknown 64-bit format word={word:#010x}") + # 32-bit formats for cls in _FORMATS_32: if _matches_encoding(word, cls): return cls raise ValueError(f"unknown 32-bit format word={word:#010x}") @@ -32,26 +40,19 @@ def detect_format(data: bytes) -> type[Inst]: # CONSTANTS # ═══════════════════════════════════════════════════════════════════════════════ +# GFX11 HWREG IDs HWREG = {1: 'HW_REG_MODE', 2: 'HW_REG_STATUS', 3: 'HW_REG_TRAPSTS', 4: 'HW_REG_HW_ID', 5: 'HW_REG_GPR_ALLOC', 6: 'HW_REG_LDS_ALLOC', 7: 'HW_REG_IB_STS', 15: 'HW_REG_SH_MEM_BASES', 18: 'HW_REG_PERF_SNAPSHOT_PC_LO', 19: 'HW_REG_PERF_SNAPSHOT_PC_HI', 20: 'HW_REG_FLAT_SCR_LO', 21: 'HW_REG_FLAT_SCR_HI', 22: 'HW_REG_XNACK_MASK', 23: 'HW_REG_HW_ID1', 24: 'HW_REG_HW_ID2', 25: 'HW_REG_POPS_PACKER', 28: 'HW_REG_IB_STS2'} -HWREG_RDNA4 = {1: 'HW_REG_MODE', 2: 'HW_REG_STATUS', 5: 'HW_REG_GPR_ALLOC', 6: 'HW_REG_LDS_ALLOC', - 7: 'HW_REG_IB_STS', 23: 'HW_REG_HW_ID1', 24: 'HW_REG_HW_ID2'} +# GFX12 HWREG IDs - use names that LLVM recognizes +HWREG_GFX12 = {1: 'HW_REG_MODE', 2: 'HW_REG_STATUS', 5: 'HW_REG_GPR_ALLOC', 6: 'HW_REG_LDS_ALLOC', 7: 'HW_REG_IB_STS', + 18: 'HW_REG_EXCP_FLAG_USER', 19: 'HW_REG_TRAP_CTRL', 20: 'HW_REG_SCRATCH_BASE_LO', 21: 'HW_REG_SCRATCH_BASE_HI', + 23: 'HW_REG_HW_ID1', 24: 'HW_REG_HW_ID2', 29: 'HW_REG_SHADER_CYCLES_LO', 30: 'HW_REG_SHADER_CYCLES_HI'} HWREG_IDS = {v.lower(): k for k, v in HWREG.items()} MSG = {128: 'MSG_RTN_GET_DOORBELL', 129: 'MSG_RTN_GET_DDID', 130: 'MSG_RTN_GET_TMA', 131: 'MSG_RTN_GET_REALTIME', 132: 'MSG_RTN_SAVE_WAVE', 133: 'MSG_RTN_GET_TBA'} -# RDNA4 cache policy tables -_TH_LOAD = {0: None, 1: 'TH_LOAD_NT', 2: 'TH_LOAD_HT', 3: 'TH_LOAD_LU', 4: 'TH_LOAD_NT_RT', 5: 'TH_LOAD_RT_NT', 6: 'TH_LOAD_NT_HT'} -_TH_STORE = {0: None, 1: 'TH_STORE_NT', 2: 'TH_STORE_HT', 3: 'TH_STORE_LU', 4: 'TH_STORE_NT_RT', 5: 'TH_STORE_RT_NT', 6: 'TH_STORE_NT_HT'} -_TH_ATOMIC = {0: None, 1: 'TH_ATOMIC_NT', 2: 'TH_ATOMIC_RETURN'} -_SCOPE = {0: None, 1: 'SCOPE_SE', 2: 'SCOPE_SA', 3: 'SCOPE_SYS'} - -# Export target mapping -_EXP_TARGETS = {**{i: f'mrt{i}' for i in range(8)}, 8: 'mrtz', **{i+12: f'pos{i}' for i in range(5)}, - 20: 'prim', 21: 'dual_src_blend0', 22: 'dual_src_blend1'} - # ═══════════════════════════════════════════════════════════════════════════════ # HELPERS # ═══════════════════════════════════════════════════════════════════════════════ @@ -83,12 +84,15 @@ def waitcnt(vmcnt: int = 0x3f, expcnt: int = 0x7, lgkmcnt: int = 0x3f) -> int: return (expcnt & 0x7) | ((lgkmcnt & 0x3f) << 4) | ((vmcnt & 0x3f) << 10) def _has(op: str, *subs) -> bool: return any(s in op for s in subs) +def _is16(op: str) -> bool: return _has(op, 'f16', 'i16', 'u16', 'b16') and not _has(op, '_f32', '_i32') +def _is64(op: str) -> bool: return _has(op, 'f64', 'i64', 'u64', 'b64') def _omod(v: int) -> str: return {1: " mul:2", 2: " mul:4", 3: " div:2"}.get(v, "") -def _src16(inst, v: int) -> str: return _fmt_v16(v) if v >= 256 else inst.lit(v) +def _src16(inst, v: int) -> str: return _fmt_v16(v) if v >= 256 else inst.lit(v) # format 16-bit src: vgpr.h/l or literal def _mods(*pairs) -> str: return " ".join(m for c, m in pairs if c) def _fmt_bits(label: str, val: int, count: int) -> str: return f"{label}:[{','.join(str((val >> i) & 1) for i in range(count))}]" def _vop3_src(inst, v: int, neg: int, abs_: int, hi: int, n: int, f16: bool, any_hi: bool) -> str: + """Format VOP3 source operand with modifiers.""" if n > 1: s = _fmt_src(v, n) elif f16 and v >= 256: s = f"v{v - 256}.h" if hi else (f"v{v - 256}.l" if any_hi else inst.lit(v)) else: s = inst.lit(v) @@ -96,90 +100,83 @@ def _vop3_src(inst, v: int, neg: int, abs_: int, hi: int, n: int, f16: bool, any return f"-{s}" if neg else s def _opsel_str(opsel: int, n: int, need: bool, is16_d: bool) -> str: + """Format op_sel modifier string.""" if not need: return "" if is16_d and (opsel & 8): return f" op_sel:[1,1,1{',1' if n == 3 else ''}]" if n == 3: return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{(opsel >> 2) & 1},{(opsel >> 3) & 1}]" return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{(opsel >> 2) & 1}]" -def _mimg_vaddr_width(name: str, dim: int, a16: bool) -> int: - base = [1, 2, 3, 3, 2, 3, 3, 4][dim] # 1d,2d,3d,cube,1d_arr,2d_arr,2d_msaa,2d_msaa_arr - grad = [1, 2, 3, 2, 1, 2, 2, 2][dim] - if 'get_resinfo' in name: return 1 - packed, unpacked = 0, 0 - if '_mip' in name: packed += 1 - elif 'sample' in name or 'gather' in name: - if '_o' in name: unpacked += 1 - if re.search(r'_c(_|$)', name): unpacked += 1 - if '_d' in name: unpacked += (grad + 1) & ~1 if '_g16' in name else grad*2 - if '_b' in name: unpacked += 1 - if '_l' in name and '_cl' not in name and '_lz' not in name: packed += 1 - if '_cl' in name: packed += 1 - return (base + packed + 1) // 2 + unpacked if a16 else base + packed + unpacked - -def _collect_vaddrs(inst, count: int) -> list[int]: - vaddrs = [inst.vaddr0] - if count > 1: vaddrs.append(inst.vaddr1) - if count > 2: vaddrs.append(inst.vaddr2) - if count > 3: vaddrs.append(inst.vaddr3) - if count > 4 and hasattr(inst, 'vaddr4'): vaddrs.append(inst.vaddr4) - return vaddrs[:count] - -def _fmt_vaddr_nsa(vaddrs: list[int]) -> str: - return f"v{vaddrs[0]}" if len(vaddrs) == 1 else "[" + ", ".join(f"v{v}" for v in vaddrs) + "]" - # ═══════════════════════════════════════════════════════════════════════════════ # DISASSEMBLER # ═══════════════════════════════════════════════════════════════════════════════ +_VOP1_F64 = {VOP1Op.V_CEIL_F64, VOP1Op.V_FLOOR_F64, VOP1Op.V_FRACT_F64, VOP1Op.V_FREXP_MANT_F64, VOP1Op.V_RCP_F64, VOP1Op.V_RNDNE_F64, VOP1Op.V_RSQ_F64, VOP1Op.V_SQRT_F64, VOP1Op.V_TRUNC_F64} + def _disasm_vop1(inst: VOP1) -> str: - name = inst.op_name.lower() - if inst.op in (VOP1Op.V_NOP, VOP1Op.V_PIPEFLUSH): return name - if inst.op == VOP1Op.V_READFIRSTLANE_B32: return f"v_readfirstlane_b32 {decode_src(inst.vdst)}, v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}" + # Use architecture-specific op enum + if 'rdna4' in inst.__class__.__module__: + from extra.assembly.amd.autogen.rdna4.enum import VOP1Op as OpEnum + else: + OpEnum = VOP1Op + op, name = OpEnum(inst.op), OpEnum(inst.op).name.lower() + if name in ('v_nop', 'v_pipeflush'): return name + if name == 'v_readfirstlane_b32': return f"v_readfirstlane_b32 {decode_src(inst.vdst)}, v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}" parts = name.split('_') + is_f64_d = 'f64' in name and any(x in name for x in ['ceil', 'floor', 'fract', 'frexp_mant', 'rcp', 'rndne', 'rsq', 'sqrt', 'trunc', 'cvt_f64_f32', 'cvt_f64_i32', 'cvt_f64_u32']) + is_f64_s = 'f64' in name and any(x in name for x in ['ceil', 'floor', 'fract', 'frexp_mant', 'rcp', 'rndne', 'rsq', 'sqrt', 'trunc', 'cvt_f32_f64', 'cvt_i32_f64', 'cvt_u32_f64', 'frexp_exp_i32_f64']) + # v_cvt_pk_f32_bf8/fp8 output 2 VGPRs (packed f32x2) and take packed 8-bit (16-bit VGPR with .l/.h) source + is_pk_f32 = 'cvt_pk_f32_bf8' in name or 'cvt_pk_f32_fp8' in name is_16d = any(p in ('f16','i16','u16','b16') for p in parts[-2:-1]) or (len(parts) >= 2 and parts[-1] in ('f16','i16','u16','b16') and 'cvt' not in name) - dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else _fmt_v16(inst.vdst, 0, 128) if is_16d else f"v{inst.vdst}" - src = _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_src_16(0) and 'sat_pk' not in name else inst.lit(inst.src0) + # Only packed bf8/fp8 (cvt_pk_*) use 16-bit VGPR encoding; non-packed versions use regular VGPRs + is_16s = (parts[-1] in ('f16','i16','u16','b16') and 'sat_pk' not in name) or (parts[-1] in ('bf8', 'fp8') and 'pk' in name) + dst = _vreg(inst.vdst, 2) if is_f64_d or is_pk_f32 else _fmt_v16(inst.vdst, 0, 128) if is_16d else f"v{inst.vdst}" + src = _fmt_src(inst.src0, 2) if is_f64_s else _src16(inst, inst.src0) if is_16s else inst.lit(inst.src0) return f"{name}_e32 {dst}, {src}" def _disasm_vop2(inst: VOP2) -> str: - name = inst.op_name.lower() - try: is_dot2acc = inst.op == VOP2Op.V_DOT2ACC_F32_F16 - except ValueError: is_dot2acc = False - suf = "" if is_dot2acc else "_e32" - try: - if inst.op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, 0x{inst._literal:x}" - if inst.op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, 0x{inst._literal:x}, v{inst.vsrc1}" - except ValueError: pass - try: - if inst.op == VOP2Op.V_CNDMASK_B32: return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, vcc_lo" - except ValueError: pass - if inst.is_16bit(): return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}" - dn, sn0, sn1 = inst.dst_regs(), inst.src_regs(0), inst.src_regs(1) - dst = _vreg(inst.vdst, dn) if dn > 1 else f"v{inst.vdst}" - src0 = _fmt_src(inst.src0, sn0) if sn0 > 1 else inst.lit(inst.src0) - src1 = _vreg(inst.vsrc1, sn1) if sn1 > 1 else f"v{inst.vsrc1}" - return f"{name}{suf} {dst}, {src0}, {src1}" + # Use architecture-specific op enum + if 'rdna4' in inst.__class__.__module__: + from extra.assembly.amd.autogen.rdna4.enum import VOP2Op as OpEnum + else: + OpEnum = VOP2Op + op, name = OpEnum(inst.op), OpEnum(inst.op).name.lower() + suf, is16 = "" if name == 'v_dot2acc_f32_f16' else "_e32", _is16(name) and 'pk_' not in name + is64 = _is64(name) + # For shift ops with b64, src0 is 32-bit (shift amount), dst/vsrc1 are 64-bit + is_shift64 = 'lshlrev_b64' in name + # fmaak: dst = src0 * vsrc1 + K, fmamk: dst = src0 * K + vsrc1 + if 'fmaak' in name: return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, 0x{inst._literal:x}" + if 'fmamk' in name: return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, 0x{inst._literal:x}, v{inst.vsrc1}" + if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}" + if is_shift64: return f"{name}{suf} {_vreg(inst.vdst, 2)}, {inst.lit(inst.src0)}, {_vreg(inst.vsrc1, 2)}" + if is64: return f"{name}{suf} {_vreg(inst.vdst, 2)}, {_fmt_src(inst.src0, 2)}, {_vreg(inst.vsrc1, 2)}" + return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}" + (", vcc_lo" if name == 'v_cndmask_b32' else "") def _disasm_vopc(inst: VOPC) -> str: - name = inst.op_name.lower() - s0 = _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_16bit() else inst.lit(inst.src0) - s1 = _vreg(inst.vsrc1, inst.src_regs(1)) if inst.src_regs(1) > 1 else _fmt_v16(inst.vsrc1, 0, 128) if inst.is_16bit() else f"v{inst.vsrc1}" - return f"{name}_e32 {s0}, {s1}" if inst.op.value >= 128 else f"{name}_e32 vcc_lo, {s0}, {s1}" + # Use architecture-specific op enum + if 'rdna4' in inst.__class__.__module__: + from extra.assembly.amd.autogen.rdna4.enum import VOPCOp as OpEnum + else: + OpEnum = VOPCOp + op, name = OpEnum(inst.op), OpEnum(inst.op).name.lower() + is64, is16 = _is64(name), _is16(name) + is_class = 'class' in name + s0 = _fmt_src(inst.src0, 2) if is64 else _src16(inst, inst.src0) if is16 else inst.lit(inst.src0) + s1 = _vreg(inst.vsrc1, 2) if is64 and not is_class else _fmt_v16(inst.vsrc1, 0, 128) if is16 else f"v{inst.vsrc1}" + return f"{name}_e32 {s0}, {s1}" if op.value >= 128 else f"{name}_e32 vcc_lo, {s0}, {s1}" NO_ARG_SOPP = {SOPPOp.S_ENDPGM, SOPPOp.S_BARRIER, SOPPOp.S_WAKEUP, SOPPOp.S_ICACHE_INV, SOPPOp.S_WAIT_IDLE, SOPPOp.S_ENDPGM_SAVED, SOPPOp.S_CODE_END, SOPPOp.S_ENDPGM_ORDERED_PS_DONE} def _disasm_sopp(inst: SOPP) -> str: - name = inst.op_name.lower() - if not name: raise ValueError(f"undefined SOPP op: {inst.op}") - if inst.op in NO_ARG_SOPP: return name - if inst.op == SOPPOp.S_WAITCNT: + op, name = SOPPOp(inst.op), SOPPOp(inst.op).name.lower() + if op in NO_ARG_SOPP: return name + if op == SOPPOp.S_WAITCNT: vm, exp, lgkm = (inst.simm16 >> 10) & 0x3f, inst.simm16 & 0xf, (inst.simm16 >> 4) & 0x3f p = [f"vmcnt({vm})" if vm != 0x3f else "", f"expcnt({exp})" if exp != 7 else "", f"lgkmcnt({lgkm})" if lgkm != 0x3f else ""] return f"s_waitcnt {' '.join(x for x in p if x) or '0'}" - if inst.op == SOPPOp.S_DELAY_ALU: - deps = ['VALU_DEP_1','VALU_DEP_2','VALU_DEP_3','VALU_DEP_4','TRANS32_DEP_1','TRANS32_DEP_2','TRANS32_DEP_3','FMA_ACCUM_CYCLE_1','SALU_CYCLE_1','SALU_CYCLE_2','SALU_CYCLE_3'] - skips = ['SAME','NEXT','SKIP_1','SKIP_2','SKIP_3','SKIP_4'] + if op == SOPPOp.S_DELAY_ALU: + deps, skips = ['VALU_DEP_1','VALU_DEP_2','VALU_DEP_3','VALU_DEP_4','TRANS32_DEP_1','TRANS32_DEP_2','TRANS32_DEP_3','FMA_ACCUM_CYCLE_1','SALU_CYCLE_1','SALU_CYCLE_2','SALU_CYCLE_3'], ['SAME','NEXT','SKIP_1','SKIP_2','SKIP_3','SKIP_4'] id0, skip, id1 = inst.simm16 & 0xf, (inst.simm16 >> 4) & 0x7, (inst.simm16 >> 7) & 0xf dep = lambda v: deps[v-1] if 0 < v <= len(deps) else str(v) p = [f"instid0({dep(id0)})" if id0 else "", f"instskip({skips[skip]})" if skip else "", f"instid1({dep(id1)})" if id1 else ""] @@ -187,56 +184,62 @@ def _disasm_sopp(inst: SOPP) -> str: return f"{name} {inst.simm16}" if name.startswith(('s_cbranch', 's_branch')) else f"{name} 0x{inst.simm16:x}" def _disasm_smem(inst: SMEM) -> str: - name = inst.op_name.lower() - if 'rdna4' in inst.__class__.__module__: return _disasm_smem_rdna4(inst) - if inst.op in (SMEMOp.S_GL1_INV, SMEMOp.S_DCACHE_INV): return name - off_s = f"{decode_src(inst.soffset)} offset:0x{inst.offset:x}" if inst.offset and inst.soffset != 124 else f"0x{inst.offset:x}" if inst.offset else decode_src(inst.soffset) - sbase_idx, sbase_count = inst.sbase * 2, 4 if (8 <= inst.op.value <= 12 or name == 's_atc_probe_buffer') else 2 - sbase_str = _fmt_src(sbase_idx, sbase_count) if sbase_count == 2 else _sreg(sbase_idx, sbase_count) if sbase_idx <= 105 else _reg("ttmp", sbase_idx - 108, sbase_count) - if name in ('s_atc_probe', 's_atc_probe_buffer'): return f"{name} {inst.sdata}, {sbase_str}, {off_s}" - return f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs())}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (inst.dlc, " dlc")) - -def _disasm_smem_rdna4(inst) -> str: - name = inst.op_name.lower() - op_val = inst._values.get('op') - if not name: - name = {34: 's_atc_probe', 35: 's_atc_probe_buffer', 32: 's_gl1_inv'}.get(op_val, f's_smem_op{op_val}') - if name in ('s_gl1_inv', 's_dcache_inv'): return name - sbase_idx, sbase_count = inst.sbase * 2, 4 if 'buffer' in name else 2 - if sbase_idx == 106: sbase_str = "vcc" - elif 108 <= sbase_idx <= 123: sbase_str = _reg("ttmp", sbase_idx - 108, sbase_count) - else: sbase_str = _sreg(sbase_idx, sbase_count) - ioffset = inst.ioffset if inst.ioffset < 0x800000 else inst.ioffset - 0x1000000 - off_str = f"0x{ioffset:x}" if ioffset >= 0 else f"-0x{-ioffset:x}" - soffset_str = decode_src(inst.soffset) - th_names = ['','TH_LOAD_NT','TH_LOAD_HT','TH_LOAD_LU','TH_LOAD_NT_RT','TH_LOAD_NT_HT','TH_LOAD_BYPASS'] - scope_names = ['','SCOPE_SE','SCOPE_DEV','SCOPE_SYS'] - if 'prefetch' in name: - return f"{name} {off_str}, {soffset_str}, {inst.sdata}" if 'pc_rel' in name else f"{name} {sbase_str}, {off_str}, {soffset_str}, {inst.sdata}" - if 'atc_probe' in name: - return f"{name} {inst.sdata}, {sbase_str}, {soffset_str}" + (f" offset:{off_str}" if ioffset else "") - if inst.soffset == 124: base_str = f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs())}, {sbase_str}, {off_str}" - elif ioffset: base_str = f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs())}, {sbase_str}, {soffset_str} offset:{off_str}" - else: base_str = f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs())}, {sbase_str}, {soffset_str}" - mods = [] - if inst.th and inst.th < len(th_names) and th_names[inst.th]: mods.append(f"th:{th_names[inst.th]}") - if inst.scope and inst.scope < len(scope_names) and scope_names[inst.scope]: mods.append(f"scope:{scope_names[inst.scope]}") - return base_str + (" " + " ".join(mods) if mods else "") + is_rdna4 = 'rdna4' in inst.__class__.__module__ + if is_rdna4: + from extra.assembly.amd.autogen.rdna4.enum import SMEMOp as SMEMOp4 + op = SMEMOp4(inst.op) + name = op.name.lower() + if op == SMEMOp4.S_DCACHE_INV: return name + # RDNA4: s_buffer_* uses 4-SGPR descriptor, s_load/s_prefetch uses 2-SGPR + is_buffer = 'buffer' in name + sbase_idx = inst.sbase * 2 + sbase_count = 4 if is_buffer else 2 + sbase_str = _fmt_src(sbase_idx, sbase_count) if sbase_count == 2 else _sreg(sbase_idx, sbase_count) if sbase_idx <= 105 else _reg("ttmp", sbase_idx - 108, sbase_count) + # Format offset - ioffset is signed 24-bit, show as hex + ioff = inst.ioffset if inst.ioffset < 0x800000 else inst.ioffset - 0x1000000 # sign extend + off_hex = f"0x{ioff & 0xffffff:x}" if ioff >= 0 else f"-0x{(-ioff) & 0xffffff:x}" + off_s = f"{decode_src(inst.soffset)} offset:{off_hex}" if inst.soffset != 124 else off_hex + # Data width from opcode + width_map = {0:1, 1:2, 2:4, 3:8, 4:16, 5:3, 8:1, 9:1, 10:1, 11:1, 16:1, 17:2, 18:4, 19:8, 20:16, 21:3, 24:1, 25:1, 26:1, 27:1} + width = width_map.get(inst.op, 1) + if 'prefetch' in name: + # Prefetch has different format: s_prefetch_* sbase, offset, soffset, length + # But we need to handle various prefetch types differently + if name == 's_prefetch_inst_pc_rel' or name == 's_prefetch_data_pc_rel': + return f"{name} {off_hex}, {decode_src(inst.soffset)}, {inst.sdata}" + return f"{name} {sbase_str}, {off_hex}, {decode_src(inst.soffset)}, {inst.sdata}" + return f"{name} {_fmt_sdst(inst.sdata, width)}, {sbase_str}, {off_s}" + else: + op = SMEMOp(inst.op) + name = op.name.lower() + if op in (SMEMOp.S_GL1_INV, SMEMOp.S_DCACHE_INV): return name + off_s = f"{decode_src(inst.soffset)} offset:0x{inst.offset:x}" if inst.offset and inst.soffset != 124 else f"0x{inst.offset:x}" if inst.offset else decode_src(inst.soffset) + sbase_idx, sbase_count = inst.sbase * 2, 4 if (8 <= inst.op <= 12 or name == 's_atc_probe_buffer') else 2 + sbase_str = _fmt_src(sbase_idx, sbase_count) if sbase_count == 2 else _sreg(sbase_idx, sbase_count) if sbase_idx <= 105 else _reg("ttmp", sbase_idx - 108, sbase_count) + if name in ('s_atc_probe', 's_atc_probe_buffer'): return f"{name} {inst.sdata}, {sbase_str}, {off_s}" + width = {0:1, 1:2, 2:4, 3:8, 4:16, 8:1, 9:2, 10:4, 11:8, 12:16}.get(inst.op, 1) + return f"{name} {_fmt_sdst(inst.sdata, width)}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (inst.dlc, " dlc")) def _disasm_flat(inst: FLAT) -> str: - name = inst.op_name.lower() + name = FLATOp(inst.op).name.lower() seg = ['flat', 'scratch', 'global'][inst.seg] if inst.seg < 3 else 'flat' instr = f"{seg}_{name.split('_', 1)[1] if '_' in name else name}" off_val = inst.offset if seg == 'flat' else (inst.offset if inst.offset < 4096 else inst.offset - 8192) - w = inst.dst_regs() * (2 if 'cmpswap' in name else 1) + suffix = name.split('_')[-1] + w = {'b32':1,'b64':2,'b96':3,'b128':4,'u8':1,'i8':1,'u16':1,'i16':1,'u32':1,'i32':1,'u64':2,'i64':2,'f32':1,'f64':2}.get(suffix, 1) + if 'cmpswap' in name: w *= 2 + if name.endswith('_x2') or 'x2' in suffix: w = max(w, 2) mods = f"{f' offset:{off_val}' if off_val else ''}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}" + # saddr if seg == 'flat' or inst.saddr == 0x7F: saddr_s = "" elif inst.saddr == 124: saddr_s = ", off" elif seg == 'scratch': saddr_s = f", {decode_src(inst.saddr)}" elif inst.saddr in SPECIAL_PAIRS: saddr_s = f", {SPECIAL_PAIRS[inst.saddr]}" elif t := _ttmp(inst.saddr, 2): saddr_s = f", {t}" else: saddr_s = f", {_sreg(inst.saddr, 2) if inst.saddr < 106 else decode_src(inst.saddr)}" + # addtid: no addr if 'addtid' in name: return f"{instr} v{inst.data if 'store' in name else inst.vdst}{saddr_s}{mods}" + # addr width addr_s = "off" if not inst.sve and seg == 'scratch' else _vreg(inst.addr, 1 if seg == 'scratch' or (inst.saddr not in (0x7F, 124)) else 2) data_s, vdst_s = _vreg(inst.data, w), _vreg(inst.vdst, w // 2 if 'cmpswap' in name else w) if 'atomic' in name: @@ -245,17 +248,15 @@ def _disasm_flat(inst: FLAT) -> str: return f"{instr} {_vreg(inst.vdst, w)}, {addr_s}{saddr_s}{mods}" def _disasm_ds(inst: DS) -> str: - op, name = inst.op, inst.op_name.lower() + op, name = DSOp(inst.op), DSOp(inst.op).name.lower() gds = " gds" if inst.gds else "" off = f" offset:{inst.offset0 | (inst.offset1 << 8)}" if inst.offset0 or inst.offset1 else "" off2 = f" offset0:{inst.offset0} offset1:{inst.offset1}" if inst.offset0 or inst.offset1 else "" - w = inst.dst_regs() + w = 4 if '128' in name else 3 if '96' in name else 2 if (name.endswith('64') or 'gs_reg' in name) else 1 d0, d1, dst, addr = _vreg(inst.data0, w), _vreg(inst.data1, w), _vreg(inst.vdst, w), f"v{inst.addr}" + if op == DSOp.DS_NOP: return name if op == DSOp.DS_BVH_STACK_RTN_B32: return f"{name} v{inst.vdst}, {addr}, v{inst.data0}, {_vreg(inst.data1, 4)}{off}{gds}" - if 'bvh_stack_push4_pop1' in name: return f"{name} v{inst.vdst}, {addr}, v{inst.data0}, {_vreg(inst.data1, 4)}{off}{gds}" - if 'bvh_stack_push8_pop1' in name: return f"{name} v{inst.vdst}, {addr}, v{inst.data0}, {_vreg(inst.data1, 8)}{off}{gds}" - if 'bvh_stack_push8_pop2' in name: return f"{name} {_vreg(inst.vdst, 2)}, {addr}, v{inst.data0}, {_vreg(inst.data1, 8)}{off}{gds}" if 'gws_sema' in name and op != DSOp.DS_GWS_SEMA_BR: return f"{name}{off}{gds}" if 'gws_' in name: return f"{name} {addr}{off}{gds}" if op in (DSOp.DS_CONSUME, DSOp.DS_APPEND): return f"{name} v{inst.vdst}{off}{gds}" @@ -275,113 +276,185 @@ def _disasm_ds(inst: DS) -> str: return f"{name} {dst}, {addr}, {d0}{off}{gds}" if '_rtn' in name else f"{name} {addr}, {d0}{off}{gds}" def _disasm_vop3(inst: VOP3) -> str: - op, name = inst.op, inst.op_name.lower() - if name.startswith('v_s_'): - return f"{name} {_fmt_sdst(inst.vdst, 1)}, {_fmt_src(inst.src0, 1)}" - if hasattr(op, '__class__') and op.__class__.__name__ == 'VOP3SDOp': + is_rdna4 = 'rdna4' in inst.__class__.__module__ + if is_rdna4: + from extra.assembly.amd.autogen.rdna4.enum import VOP3Op as VOP3Op4, VOP3SDOp as VOP3SDOp4 + op = VOP3SDOp4(inst.op) if inst.op in VOP3SD_OPS else VOP3Op4(inst.op) + else: + op = VOP3SDOp(inst.op) if inst.op in VOP3SD_OPS else VOP3Op(inst.op) + name = op.name.lower() + + # VOP3SD (shared encoding) + if inst.op in VOP3SD_OPS: sdst = (inst.clmp << 7) | (inst.opsel << 3) | inst.abs - def src(v, neg, n): s = _fmt_src(v, n) if n > 1 else inst.lit(v); return f"-{s}" if neg else s - s0, s1, s2 = src(inst.src0, inst.neg & 1, inst.src_regs(0)), src(inst.src1, inst.neg & 2, inst.src_regs(1)), src(inst.src2, inst.neg & 4, inst.src_regs(2)) - dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}" - srcs = f"{s0}, {s1}, {s2}" if inst.num_srcs() == 3 else f"{s0}, {s1}" - return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {srcs}" + _omod(inst.omod) + is64, mad64 = 'f64' in name, _has(name, 'mad_i64_i32', 'mad_u64_u32', 'mad_co_i64_i32', 'mad_co_u64_u32') + def src(v, neg, ext=False): s = _fmt_src(v, 2) if ext or is64 else inst.lit(v); return f"-{s}" if neg else s + s0, s1, s2 = src(inst.src0, inst.neg & 1), src(inst.src1, inst.neg & 2), src(inst.src2, inst.neg & 4, mad64) + dst = _vreg(inst.vdst, 2) if is64 or mad64 else f"v{inst.vdst}" + if op in (VOP3SDOp.V_ADD_CO_U32, VOP3SDOp.V_SUB_CO_U32, VOP3SDOp.V_SUBREV_CO_U32): return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {s0}, {s1}" + if op in (VOP3SDOp.V_ADD_CO_CI_U32, VOP3SDOp.V_SUB_CO_CI_U32, VOP3SDOp.V_SUBREV_CO_CI_U32): return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {s0}, {s1}, {s2}" + return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {s0}, {s1}, {s2}" + _omod(inst.omod) + + # Detect operand sizes + is64 = _is64(name) + is64_src, is64_dst = False, False is16_d = is16_s = is16_s2 = False - if 'cvt_pk' in name: is16_s = name.endswith('16') + # v_cvt_pk_f32_bf8/fp8 outputs a VGPR pair (f32x2) from 16-bit packed input + if 'cvt_pk_f32_bf8' in name or 'cvt_pk_f32_fp8' in name: is64_dst = True + elif 'cvt_pk' in name: is16_s = name.endswith('16') elif m := re.match(r'v_(?:cvt|frexp_exp)_([a-z0-9_]+)_([a-z0-9]+)', name): is16_d, is16_s = _has(m.group(1), 'f16','i16','u16','b16'), _has(m.group(2), 'f16','i16','u16','b16') - is16_s2 = is16_s + is64_src, is64_dst = '64' in m.group(2), '64' in m.group(1) + is16_s2, is64 = is16_s, False elif re.match(r'v_mad_[iu]32_[iu]16', name): is16_s = True elif 'pack_b32' in name: is16_s = is16_s2 = True - else: is16_d = is16_s = is16_s2 = inst.is_16bit() + else: is16_d = is16_s = is16_s2 = _is16(name) and not _has(name, 'dot2', 'pk_', 'sad', 'msad', 'qsad', 'mqsad') + + # Source counts + shift64 = 'rev' in name and '64' in name and name.startswith('v_') + ldexp64 = op == VOP3Op.V_LDEXP_F64 + trig = op == VOP3Op.V_TRIG_PREOP_F64 + sad64, mqsad = _has(name, 'qsad_pk', 'mqsad_pk'), 'mqsad_u32' in name + s0n = 2 if ((is64 and not shift64) or sad64 or mqsad or is64_src) else 1 + s1n = 2 if (is64 and not _has(name, 'class') and not ldexp64 and not trig) else 1 + s2n = 4 if mqsad else 2 if (is64 or sad64) else 1 + any_hi = inst.opsel != 0 - s0 = _vop3_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.opsel&1, inst.src_regs(0), is16_s, any_hi) - s1 = _vop3_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.opsel&2, inst.src_regs(1), is16_s, any_hi) - s2 = _vop3_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.opsel&4, inst.src_regs(2), is16_s2, any_hi) - dn = inst.dst_regs() + s0 = _vop3_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.opsel&1, s0n, is16_s, any_hi) + s1 = _vop3_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.opsel&2, s1n, is16_s, any_hi) + s2 = _vop3_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.opsel&4, s2n, is16_s2, any_hi) + + # Destination + dn = 4 if mqsad else 2 if (is64 or sad64 or is64_dst) else 1 if op == VOP3Op.V_READLANE_B32: dst = _fmt_sdst(inst.vdst, 1) elif dn > 1: dst = _vreg(inst.vdst, dn) elif is16_d: dst = f"v{inst.vdst}.h" if (inst.opsel & 8) else f"v{inst.vdst}.l" if any_hi else f"v{inst.vdst}" else: dst = f"v{inst.vdst}" + cl, om = " clamp" if inst.clmp else "", _omod(inst.omod) nonvgpr_opsel = (inst.src0 < 256 and (inst.opsel & 1)) or (inst.src1 < 256 and (inst.opsel & 2)) or (inst.src2 < 256 and (inst.opsel & 4)) need_opsel = nonvgpr_opsel or (inst.opsel and not is16_s) - if inst.op < 256: + + # RDNA4 v_s_* instructions (pseudo-scalar VOP1-like) have SGPR destination + if name.startswith('v_s_') and is_rdna4: + return f"{name} {_fmt_sdst(inst.vdst, 1)}, {s0}{cl}{om}" + + if inst.op < 256: # VOPC return f"{name}_e64 {s0}, {s1}" if name.startswith('v_cmpx') else f"{name}_e64 {_fmt_sdst(inst.vdst, 1)}, {s0}, {s1}" - if inst.op < 384: - n = inst.num_srcs() - os = _opsel_str(inst.opsel, n, need_opsel, is16_d) - return f"{name}_e64 {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if n == 3 else f"{name}_e64 {dst}, {s0}, {s1}{os}{cl}{om}" - if inst.op < 512: - if _has(name, 'cvt_f32_fp8', 'cvt_f32_bf8'): need_opsel = False - return f"{name}_e64" if op in (VOP3Op.V_NOP, VOP3Op.V_PIPEFLUSH) else f"{name}_e64 {dst}, {s0}{_opsel_str(inst.opsel, 1, need_opsel, is16_d)}{cl}{om}" - n = inst.num_srcs() - if 'permlane' in name and '_var' in name: n = 2 - if _has(name, 'cvt_sr_fp8', 'cvt_sr_bf8'): n, need_opsel = 2, False - os = _opsel_str(inst.opsel, n, need_opsel, is16_d) - return f"{name} {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if n == 3 else f"{name} {dst}, {s0}, {s1}{os}{cl}{om}" + if inst.op < 384: # VOP2 + os = _opsel_str(inst.opsel, 3, need_opsel, is16_d) if 'cndmask' in name else _opsel_str(inst.opsel, 2, need_opsel, is16_d) + return f"{name}_e64 {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if 'cndmask' in name else f"{name}_e64 {dst}, {s0}, {s1}{os}{cl}{om}" + if inst.op < 512: # VOP1 + if name in ('v_nop', 'v_pipeflush'): return f"{name}_e64" + # Handle byte_sel for non-pk fp8/bf8 conversions + if ('cvt_f32_fp8' in name or 'cvt_f32_bf8' in name) and 'pk' not in name: + byte_sel = inst.opsel & 3 + os = f" byte_sel:{byte_sel}" if byte_sel else "" + elif 'cvt_pk_f32_bf8' in name or 'cvt_pk_f32_fp8' in name: + os = _opsel_str(inst.opsel, 2, need_opsel, is16_d) # 2-element for pk variants + else: + os = _opsel_str(inst.opsel, 1, need_opsel, is16_d) # 1-element for other VOP1 + return f"{name}_e64 {dst}, {s0}{os}{cl}{om}" + # Native VOP3 + is3 = _has(name, 'fma', 'mad', 'min3', 'max3', 'med3', 'div_fix', 'div_fmas', 'sad', 'lerp', 'align', 'cube', 'bfe', 'bfi', + 'perm_b32', 'cndmask', 'xor3', 'or3', 'add3', 'lshl_or', 'and_or', 'lshl_add', 'add_lshl', 'xad', 'maxmin', 'minmax', 'dot2', 'cvt_pk_u8', 'mullit', + 'minimummaximum', 'maximumminimum', 'minimum3', 'maximum3') + # permlane16/permlanex16 have 3 sources, but _var variants have 2 + if 'permlane' in name and 'var' not in name: is3 = True + # Handle byte_sel for fp8/bf8 instructions (opsel encodes byte_sel, not op_sel) + # For VOP1-encoded VOP3 (op < 512): cvt_f32_fp8, cvt_f32_bf8 + # For native VOP3: cvt_sr_fp8, cvt_sr_bf8 + if ('cvt_f32_fp8' in name or 'cvt_f32_bf8' in name or 'cvt_sr_fp8' in name or 'cvt_sr_bf8' in name) and 'pk' not in name: + # For VOP1 encoding (op < 512), byte_sel is in bits[1:0] of opsel; for native VOP3, it's bits[3:2] + byte_sel = (inst.opsel & 3) if inst.op < 512 else ((inst.opsel >> 2) & 3) + os = f" byte_sel:{byte_sel}" if byte_sel else "" + else: + os = _opsel_str(inst.opsel, 3 if is3 else 2, need_opsel, is16_d) + return f"{name} {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if is3 else f"{name} {dst}, {s0}, {s1}{os}{cl}{om}" def _disasm_vop3sd(inst: VOP3SD) -> str: - name = inst.op_name.lower() - def src(v, neg, n): s = _fmt_src(v, n) if n > 1 else inst.lit(v); return f"-{s}" if neg else s - s0, s1, s2 = src(inst.src0, inst.neg & 1, inst.src_regs(0)), src(inst.src1, inst.neg & 2, inst.src_regs(1)), src(inst.src2, inst.neg & 4, inst.src_regs(2)) - dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}" - srcs = f"{s0}, {s1}, {s2}" if inst.num_srcs() == 3 else f"{s0}, {s1}" + op, name = VOP3SDOp(inst.op), VOP3SDOp(inst.op).name.lower() + is64, mad64 = 'f64' in name, _has(name, 'mad_i64_i32', 'mad_u64_u32') + def src(v, neg, ext=False): s = _fmt_src(v, 2) if ext or is64 else inst.lit(v); return f"-{s}" if neg else s + s0, s1, s2 = src(inst.src0, inst.neg & 1), src(inst.src1, inst.neg & 2), src(inst.src2, inst.neg & 4, mad64) + dst, is2src = _vreg(inst.vdst, 2) if is64 or mad64 else f"v{inst.vdst}", op in (VOP3SDOp.V_ADD_CO_U32, VOP3SDOp.V_SUB_CO_U32, VOP3SDOp.V_SUBREV_CO_U32) suffix = "_e64" if name.startswith('v_') and 'co_' in name else "" - return f"{name}{suffix} {dst}, {_fmt_sdst(inst.sdst, 1)}, {srcs}{' clamp' if inst.clmp else ''}{_omod(inst.omod)}" + return f"{name}{suffix} {dst}, {_fmt_sdst(inst.sdst, 1)}, {s0}, {s1}{'' if is2src else f', {s2}'}{' clamp' if inst.clmp else ''}{_omod(inst.omod)}" def _disasm_vopd(inst: VOPD) -> str: - lit = inst._literal or inst.literal - if 'rdna4' in inst.__class__.__module__: - import importlib - VOPDOpCls = importlib.import_module('extra.assembly.amd.autogen.rdna4.enum').VOPDOp + is_rdna4 = 'rdna4' in inst.__class__.__module__ + if is_rdna4: + from extra.assembly.amd.autogen.rdna4.enum import VOPDOp as VOPDOp4 + OpEnum = VOPDOp4 else: - VOPDOpCls = VOPDOp - vdst_y, nx, ny = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1), VOPDOpCls(inst.opx).name.lower(), VOPDOpCls(inst.opy).name.lower() + OpEnum = VOPDOp + lit = inst._literal or inst.literal + vdst_y, nx, ny = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1), OpEnum(inst.opx).name.lower(), OpEnum(inst.opy).name.lower() def half(n, vd, s0, vs1): return f"{n} v{vd}, {inst.lit(s0)}{f', 0x{lit:x}' if lit and _has(n, 'fmaak', 'fmamk') else ''}" if 'mov' in n else f"{n} v{vd}, {inst.lit(s0)}, v{vs1}{f', 0x{lit:x}' if lit and _has(n, 'fmaak', 'fmamk') else ''}" return f"{half(nx, inst.vdstx, inst.srcx0, inst.vsrcx1)} :: {half(ny, vdst_y, inst.srcy0, inst.vsrcy1)}" -def _disasm_vop3p(inst: VOP3P) -> str: - name = inst.op_name.lower() - is_wmma, n, is_fma_mix, is_swmmac = 'wmma' in name, inst.num_srcs(), 'fma_mix' in name, 'swmmac' in name +def _disasm_vop3p(inst: VOP3P, wave_size: int = 32) -> str: is_rdna4 = 'rdna4' in inst.__class__.__module__ - if is_wmma or is_swmmac: - if is_rdna4: - if is_swmmac: - if '16x16x32_iu4' in name: sc0, sc1 = 1, 2 - elif '16x16x64_iu4' in name or '16x16x32_iu8' in name or 'fp8' in name or 'bf8' in name: sc0, sc1 = 2, 4 - else: sc0, sc1 = 4, 8 - sc2 = 1 - dst_w = 4 if name.startswith('v_swmmac_f16') or name.startswith('v_swmmac_bf16') else 8 - else: - if '16x16x16_iu4' in name: sc0 = 1 - elif '16x16x32_iu4' in name or 'iu8' in name or 'fp8' in name or 'bf8' in name: sc0 = 2 - else: sc0 = 4 - sc1 = sc0 - sc2 = 4 if (name.startswith('v_wmma_f16') or name.startswith('v_wmma_bf16')) else 8 - dst_w = sc2 - src0, src1, src2, dst = _fmt_src(inst.src0, sc0), _fmt_src(inst.src1, sc1), _fmt_src(inst.src2, sc2), _vreg(inst.vdst, dst_w) - else: - sc = 2 if 'iu4' in name else 4 if 'iu8' in name else 8 - src0, src1, src2, dst = _fmt_src(inst.src0, sc), _fmt_src(inst.src1, sc), _fmt_src(inst.src2, 8), _vreg(inst.vdst, 8) + if is_rdna4: + from extra.assembly.amd.autogen.rdna4.enum import VOP3POp as OpEnum else: - src0, src1, src2, dst = _fmt_src(inst.src0, 1), _fmt_src(inst.src1, 1), _fmt_src(inst.src2, 1), f"v{inst.vdst}" - opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2) - if is_fma_mix: + OpEnum = VOP3POp + name = OpEnum(inst.op).name.lower() + is_wmma, is_swmmac = 'wmma' in name and 'swmmac' not in name, 'swmmac' in name + is_3src, is_fma_mix = _has(name, 'fma', 'mad', 'dot', 'wmma'), 'fma_mix' in name + # Wave64 uses half the register widths of wave32 for WMMA + wave_div = 2 if wave_size == 64 else 1 + if is_swmmac and is_rdna4: + # SWMMAC (sparse WMMA): src2 is a single VGPR index, not an accumulator + # Determine src0/src1/dst sizes based on instruction type + if 'f16' in name or 'bf16' in name: + if 'f16_16x16x32_f16' in name or 'bf16_16x16x32_bf16' in name: + s0c, s1c, dc = 4, 8, 4 # f16/bf16 output + else: + s0c, s1c, dc = 4, 8, 8 # f32 output + elif 'iu8' in name: s0c, s1c, dc = 2, 4, 8 + elif 'iu4' in name: + if '16x16x64' in name: s0c, s1c, dc = 2, 4, 8 + else: s0c, s1c, dc = 1, 2, 8 + elif 'fp8' in name or 'bf8' in name: s0c, s1c, dc = 2, 4, 8 + else: s0c, s1c, dc = 4, 8, 8 + s0c, s1c, dc = max(1, s0c // wave_div), max(1, s1c // wave_div), max(1, dc // wave_div) + src0, src1, src2, dst = _fmt_src(inst.src0, s0c), _fmt_src(inst.src1, s1c), _fmt_src(inst.src2, 1), _vreg(inst.vdst, dc) + elif is_wmma: + # RDNA4 WMMA uses smaller source register widths than RDNA3 + if is_rdna4: + # RDNA4 wave32 source widths: iu4->1/2, iu8->2, fp8/bf8->2, f16/bf16->4 + if 'iu4' in name: + sc = 2 if '16x16x32' in name else 1 + elif 'iu8' in name or 'fp8' in name or 'bf8' in name: sc = 2 + else: sc = 4 # f16/bf16 + # Destination width: f16/bf16 output->4, f32/i32 output->8 + dc = 4 if name.startswith('v_wmma_f16') or name.startswith('v_wmma_bf16') else 8 + else: + # RDNA3: iu4->2, iu8->4, f16/bf16->8 + sc = 2 if 'iu4' in name else 4 if 'iu8' in name else 8 + dc = 8 + sc, dc = max(1, sc // wave_div), max(1, dc // wave_div) + src0, src1, src2, dst = _fmt_src(inst.src0, sc), _fmt_src(inst.src1, sc), _fmt_src(inst.src2, dc), _vreg(inst.vdst, dc) + else: src0, src1, src2, dst = _fmt_src(inst.src0, 1), _fmt_src(inst.src1, 1), _fmt_src(inst.src2, 1), f"v{inst.vdst}" + n, opsel_hi = 3 if is_3src else 2, inst.opsel_hi | (inst.opsel_hi2 << 2) + if is_swmmac and is_rdna4: + # SWMMAC uses index_key instead of op_sel; opsel bits encode the key value + mods = ([f"index_key:{inst.opsel & 7}"] if inst.opsel else []) + \ + ([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if inst.clmp else []) + elif is_fma_mix: def m(s, neg, abs_): return f"-{f'|{s}|' if abs_ else s}" if neg else (f"|{s}|" if abs_ else s) src0, src1, src2 = m(src0, inst.neg & 1, inst.neg_hi & 1), m(src1, inst.neg & 2, inst.neg_hi & 2), m(src2, inst.neg & 4, inst.neg_hi & 4) mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi else []) + (["clamp"] if inst.clmp else []) - elif is_swmmac: - has_index_key = '16x16x64_iu4' not in name - mods = ([f"index_key:{inst.opsel & 1}"] if has_index_key and (inst.opsel & 1) else []) + \ - ([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) else: - mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != (7 if n == 3 else 3) else []) + \ + mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != (7 if is_3src else 3) else []) + \ ([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if inst.clmp else []) - return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if n == 3 else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}" + return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if is_3src else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}" def _disasm_buf(inst: MUBUF | MTBUF) -> str: - name = inst.op_name.lower() - if inst.op in (MUBUFOp.BUFFER_GL0_INV, MUBUFOp.BUFFER_GL1_INV): return name + op = MTBUFOp(inst.op) if isinstance(inst, MTBUF) else MUBUFOp(inst.op) + name = op.name.lower() + if op in (MUBUFOp.BUFFER_GL0_INV, MUBUFOp.BUFFER_GL1_INV): return name w = (2 if _has(name, 'xyz', 'xyzw') else 1) if 'd16' in name else \ ((2 if _has(name, 'b64', 'u64', 'i64') else 1) * (2 if 'cmpswap' in name else 1)) if 'atomic' in name else \ {'b32':1,'b64':2,'b96':3,'b128':4,'b16':1,'x':1,'xy':2,'xyz':3,'xyzw':4}.get(name.split('_')[-1], 1) @@ -391,132 +464,113 @@ def _disasm_buf(inst: MUBUF | MTBUF) -> str: mods = ([f"format:{inst.format}"] if isinstance(inst, MTBUF) else []) + [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.glc,"glc"),(inst.dlc,"dlc"),(inst.slc,"slc"),(inst.tfe,"tfe")] if c] return f"{name} {_vreg(inst.vdata, w)}, {vaddr}, {srsrc}, {decode_src(inst.soffset)}{' ' + ' '.join(mods) if mods else ''}" +def _mimg_vaddr_width(name: str, dim: int, a16: bool) -> int: + """Calculate vaddr register count for MIMG sample/gather operations.""" + # 1d,2d,3d,cube,1d_arr,2d_arr,2d_msaa,2d_msaa_arr + base = [1, 2, 3, 3, 2, 3, 3, 4][dim] # address coords + grad = [1, 2, 3, 2, 1, 2, 2, 2][dim] # gradient coords (for derivatives) + if 'get_resinfo' in name: return 1 # only mip level + packed, unpacked = 0, 0 + if '_mip' in name: packed += 1 + elif 'sample' in name or 'gather' in name: + if '_o' in name: unpacked += 1 # offset + if re.search(r'_c(_|$)', name): unpacked += 1 # compare (not _cl) + if '_d' in name: unpacked += (grad + 1) & ~1 if '_g16' in name else grad*2 # derivatives + if '_b' in name: unpacked += 1 # bias + if '_l' in name and '_cl' not in name and '_lz' not in name: packed += 1 # LOD + if '_cl' in name: packed += 1 # clamp + return (base + packed + 1) // 2 + unpacked if a16 else base + packed + unpacked + def _disasm_mimg(inst: MIMG) -> str: - name = inst.op_name.lower() + name = MIMGOp(inst.op).name.lower() srsrc_base = inst.srsrc * 4 srsrc_str = _sreg_or_ttmp(srsrc_base, 8) + # BVH intersect ray: special case with 4 SGPR srsrc if 'bvh' in name: vaddr = (9 if '64' in name else 8) if inst.a16 else (12 if '64' in name else 11) return f"{name} {_vreg(inst.vdata, 4)}, {_vreg(inst.vaddr, vaddr)}, {_sreg_or_ttmp(srsrc_base, 4)}{' a16' if inst.a16 else ''}" + # vdata width from dmask (gather4/msaa_load always 4), d16 packs, tfe adds 1 vdata = 4 if 'gather4' in name or 'msaa_load' in name else (bin(inst.dmask).count('1') or 1) if inst.d16: vdata = (vdata + 1) // 2 if inst.tfe: vdata += 1 + # vaddr width dim_names = ['1d', '2d', '3d', 'cube', '1d_array', '2d_array', '2d_msaa', '2d_msaa_array'] dim = dim_names[inst.dim] if inst.dim < len(dim_names) else f"dim_{inst.dim}" vaddr = _mimg_vaddr_width(name, inst.dim, inst.a16) vaddr_str = f"v{inst.vaddr}" if vaddr == 1 else _vreg(inst.vaddr, vaddr) + # modifiers mods = [f"dmask:0x{inst.dmask:x}"] if inst.dmask and (inst.dmask != 15 or 'atomic' in name) else [] mods.append(f"dim:SQ_RSRC_IMG_{dim.upper()}") - for flag, mod in [(inst.unrm,"unorm"),(inst.glc,"glc"),(inst.slc,"slc"),(inst.dlc,"dlc"),(inst.r128,"r128"),(inst.a16,"a16"),(inst.tfe,"tfe"),(inst.lwe,"lwe"),(inst.d16,"d16")]: + for flag, mod in [(inst.unrm,"unorm"),(inst.glc,"glc"),(inst.slc,"slc"),(inst.dlc,"dlc"),(inst.r128,"r128"), + (inst.a16,"a16"),(inst.tfe,"tfe"),(inst.lwe,"lwe"),(inst.d16,"d16")]: if flag: mods.append(mod) - ssamp_str = ", " + _sreg_or_ttmp(inst.ssamp * 4, 4) if 'sample' in name or 'gather' in name or 'get_lod' in name else "" + # ssamp for sample/gather/get_lod + ssamp_str = "" + if 'sample' in name or 'gather' in name or 'get_lod' in name: + ssamp_str = ", " + _sreg_or_ttmp(inst.ssamp * 4, 4) return f"{name} {_vreg(inst.vdata, vdata)}, {vaddr_str}, {srsrc_str}{ssamp_str} {' '.join(mods)}" -def _disasm_vsample(inst) -> str: - name = inst.op_name.lower() - if not name: raise ValueError(f"undefined VSAMPLE op: {inst.op}") - if 'msaa_load' in name: raise ValueError(f"image_msaa_load not supported in VSAMPLE for gfx1200") - dim, dim_names = inst.dim, ['1d', '2d', '3d', 'cube', '1d_array', '2d_array', '2d_msaa', '2d_msaa_array'] - dim_str = dim_names[dim] if dim < len(dim_names) else f"dim_{dim}" - vdata = 4 if 'gather4' in name else (bin(inst.dmask).count('1') or 1) - if inst.d16: vdata = (vdata + 1) // 2 - if inst.tfe: vdata += 1 - vaddr_count = _mimg_vaddr_width(name, dim, inst.a16) - if vaddr_count > 4: raise ValueError(f"{name} with dim={dim} needs {vaddr_count} vaddrs (>4, unsupported)") - vaddr_str = _fmt_vaddr_nsa(_collect_vaddrs(inst, vaddr_count)) - srsrc_str, ssamp_str = _sreg_or_ttmp(inst.rsrc, 8), _sreg_or_ttmp(inst.samp, 4) - mods = [f"dmask:0x{inst.dmask:x}"] if inst.dmask else [] - mods.append(f"dim:SQ_RSRC_IMG_{dim_str.upper()}") - for flag, mod in [(inst.unrm, "unorm"), (inst.r128, "r128"), (inst.a16, "a16"), (inst.tfe, "tfe"), (inst.lwe, "lwe"), (inst.d16, "d16")]: - if flag: mods.append(mod) - th_val, scope_val = inst.th, inst.scope - if th_val == 3 and scope_val == 3: raise ValueError("invalid th/scope: TH_LOAD_LU with SCOPE_SYS") - if scope_val == 2 and th_val == 0: raise ValueError("invalid scope SCOPE_SA without th") - if inst.tfe and inst.d16 and th_val != 0: raise ValueError("invalid th with tfe+d16") - if (th_name := _TH_LOAD.get(th_val)): mods.append(f"th:{th_name}") - if (scope_name := _SCOPE.get(scope_val)): mods.append(f"scope:{scope_name}") - return f"{name} {_vreg(inst.vdata, vdata)}, {vaddr_str}, {srsrc_str}, {ssamp_str} {' '.join(mods)}" - -def _disasm_vimage(inst) -> str: - name = inst.op_name.lower() - if 'bvh' in name: raise ValueError(f"BVH instruction {name} not supported") - dim, dim_names = inst.dim, ['1d', '2d', '3d', 'cube', '1d_array', '2d_array', '2d_msaa', '2d_msaa_array'] - dim_str = dim_names[dim] if dim < len(dim_names) else f"dim_{dim}" - is_resinfo, is_atomic, is_store = 'resinfo' in name, 'atomic' in name, 'store' in name - if is_atomic: vdata = (2 if _has(name, 'b64', 'u64', 'i64') else 1) * (2 if 'cmpswap' in name else 1) - else: vdata = 4 if 'msaa_load' in name else (bin(inst.dmask).count('1') or 1) - if inst.d16: vdata = (vdata + 1) // 2 - if inst.tfe: vdata += 1 - if is_resinfo: vaddr_count = 1 - else: - base_count = [1, 2, 3, 3, 2, 3, 3, 4][dim] if dim < 8 else 1 - total_coords = base_count + (1 if '_mip' in name else 0) - vaddr_count = (total_coords + 1) // 2 if inst.a16 else total_coords - vaddr_str = _fmt_vaddr_nsa(_collect_vaddrs(inst, vaddr_count)) - srsrc_str = _sreg_or_ttmp(inst.rsrc, 8) - mods = [f"dmask:0x3" if 'cmpswap' in name else f"dmask:0x1"] if is_atomic else [f"dmask:0x{inst.dmask:x}"] - mods.append(f"dim:SQ_RSRC_IMG_{dim_str.upper()}") - for flag, mod in [(inst.r128, "r128"), (inst.a16, "a16"), (inst.tfe, "tfe"), (inst.d16, "d16")]: - if flag: mods.append(mod) - th_val, scope_val = inst.th, inst.scope - if th_val == 3 and scope_val == 3 and not is_atomic: raise ValueError("invalid th/scope: TH_LOAD_LU with SCOPE_SYS") - if is_atomic and th_val > 2: raise ValueError(f"invalid th value {th_val} for atomic") - if is_store and th_val == 3: raise ValueError("invalid TH_STORE_LU for store") - if scope_val == 2 and th_val == 0: raise ValueError("invalid SCOPE_SA without th") - if inst.tfe and inst.d16 and th_val != 0: raise ValueError("invalid th with tfe+d16") - th_table = _TH_ATOMIC if is_atomic else (_TH_STORE if is_store else _TH_LOAD) - if (th_name := th_table.get(th_val)): mods.append(f"th:{th_name}") - if (scope_name := _SCOPE.get(scope_val)): mods.append(f"scope:{scope_name}") - return f"{name} {_vreg(inst.vdata, vdata)}, {vaddr_str}, {srsrc_str} {' '.join(mods)}" +def _sop_widths(name: str) -> tuple[int, int, int]: + """Return (dst_width, src0_width, src1_width) in register count for SOP instructions.""" + if name in ('s_bitset0_b64', 's_bitset1_b64', 's_bfm_b64'): return 2, 1, 1 + if name in ('s_lshl_b64', 's_lshr_b64', 's_ashr_i64', 's_bfe_u64', 's_bfe_i64'): return 2, 2, 1 + if name in ('s_bitcmp0_b64', 's_bitcmp1_b64'): return 1, 2, 1 + if m := re.search(r'_(b|i|u)(32|64)_(b|i|u)(32|64)$', name): return 2 if m.group(2) == '64' else 1, 2 if m.group(4) == '64' else 1, 1 + if m := re.search(r'_(b|i|u)(32|64)$', name): sz = 2 if m.group(2) == '64' else 1; return sz, sz, sz + return 1, 1, 1 def _disasm_sop1(inst: SOP1) -> str: - op, name = inst.op, inst.op_name.lower() - if not name: raise ValueError(f"undefined SOP1 op: {inst.op}") - if _has(name, 'alloc_vgpr', 'sleep_var', 'barrier_signal', 'barrier_wait', 'wakeup_barrier'): - return f"{name} {inst.lit(inst.ssrc0) if inst.src_regs(0) == 1 else _fmt_src(inst.ssrc0, inst.src_regs(0))}" + op, name = SOP1Op(inst.op), SOP1Op(inst.op).name.lower() if op == SOP1Op.S_GETPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}" if op in (SOP1Op.S_SETPC_B64, SOP1Op.S_RFE_B64): return f"{name} {_fmt_src(inst.ssrc0, 2)}" if op == SOP1Op.S_SWAPPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}, {_fmt_src(inst.ssrc0, 2)}" - if op in (SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64): return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, sendmsg({MSG.get(inst.ssrc0, str(inst.ssrc0))})" - return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {inst.lit(inst.ssrc0) if inst.src_regs(0) == 1 else _fmt_src(inst.ssrc0, inst.src_regs(0))}" + if op in (SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64): return f"{name} {_fmt_sdst(inst.sdst, 2 if 'b64' in name else 1)}, sendmsg({MSG.get(inst.ssrc0, str(inst.ssrc0))})" + dn, s0n, _ = _sop_widths(name) + return f"{name} {_fmt_sdst(inst.sdst, dn)}, {inst.lit(inst.ssrc0) if s0n == 1 else _fmt_src(inst.ssrc0, s0n)}" def _disasm_sop2(inst: SOP2) -> str: - return f"{inst.op_name.lower()} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0))}, {inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1))}" + name = SOP2Op(inst.op).name.lower() + dn, s0n, s1n = _sop_widths(name) + return f"{name} {_fmt_sdst(inst.sdst, dn)}, {inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, s0n)}, {inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, s1n)}" def _disasm_sopc(inst: SOPC) -> str: - return f"{inst.op_name.lower()} {_fmt_src(inst.ssrc0, inst.src_regs(0))}, {_fmt_src(inst.ssrc1, inst.src_regs(1))}" + name = SOPCOp(inst.op).name.lower() + _, s0n, s1n = _sop_widths(name) + return f"{name} {_fmt_src(inst.ssrc0, s0n)}, {_fmt_src(inst.ssrc1, s1n)}" def _disasm_sopk(inst: SOPK) -> str: - op, name = inst.op, inst.op_name.lower() - if not name: raise ValueError(f"undefined SOPK op: {inst.op}") - if op == SOPKOp.S_VERSION: return f"{name} 0x{inst.simm16:x}" - if op in (SOPKOp.S_SETREG_B32, SOPKOp.S_GETREG_B32): + # Use architecture-specific SOPK op enum + if 'rdna4' in inst.__class__.__module__: + from extra.assembly.amd.autogen.rdna4.enum import SOPKOp as OpEnum + hwreg_map = HWREG_GFX12 + else: + OpEnum = SOPKOp + hwreg_map = HWREG + op, name = OpEnum(inst.op), OpEnum(inst.op).name.lower() + if name == 's_version': return f"{name} 0x{inst.simm16:x}" + if name in ('s_setreg_b32', 's_getreg_b32'): hid, hoff, hsz = inst.simm16 & 0x3f, (inst.simm16 >> 6) & 0x1f, ((inst.simm16 >> 11) & 0x1f) + 1 - is_rdna4 = 'rdna4' in inst.__class__.__module__ - hwreg_map = HWREG_RDNA4 if is_rdna4 else HWREG - if hid in (16, 17) or (is_rdna4 and hid not in hwreg_map): hs = f"0x{inst.simm16:x}" - else: hs = f"hwreg({hwreg_map.get(hid, str(hid))}, {hoff}, {hsz})" - return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1)}" if op == SOPKOp.S_SETREG_B32 else f"{name} {_fmt_sdst(inst.sdst, 1)}, {hs}" - return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, 0x{inst.simm16:x}" + hreg_name = hwreg_map.get(hid, str(hid)) + # If offset=0 and size=32, use short form hwreg(NAME), otherwise hwreg(NAME, off, sz) + if hid in (16, 17): hs = f"0x{inst.simm16:x}" + elif hoff == 0 and hsz == 32: hs = f"hwreg({hreg_name})" + else: hs = f"hwreg({hreg_name}, {hoff}, {hsz})" + return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1)}" if name == 's_setreg_b32' else f"{name} {_fmt_sdst(inst.sdst, 1)}, {hs}" + dn, _, _ = _sop_widths(name) + return f"{name} {_fmt_sdst(inst.sdst, dn)}, 0x{inst.simm16:x}" def _disasm_vinterp(inst: VINTERP) -> str: + name = VINTERPOp(inst.op).name.lower() + src0 = f"-{inst.lit(inst.src0)}" if inst.neg & 1 else inst.lit(inst.src0) + src1 = f"-{inst.lit(inst.src1)}" if inst.neg & 2 else inst.lit(inst.src1) + src2 = f"-{inst.lit(inst.src2)}" if inst.neg & 4 else inst.lit(inst.src2) mods = _mods((inst.waitexp, f"wait_exp:{inst.waitexp}"), (inst.clmp, "clamp")) - return f"{inst.op_name.lower()} v{inst.vdst}, {inst.lit(inst.src0, inst.neg & 1)}, {inst.lit(inst.src1, inst.neg & 2)}, {inst.lit(inst.src2, inst.neg & 4)}" + (" " + mods if mods else "") + return f"{name} v{inst.vdst}, {src0}, {src1}, {src2}" + (" " + mods if mods else "") -def _disasm_ldsdir(inst) -> str: - wait = f" wait_vdst:{inst.wait_va}" if inst.wait_va != 0 else "" - if inst.op == 1: return f"lds_direct_load v{inst.vdst}{wait}" - if inst.op == 0: return f"lds_param_load v{inst.vdst}, attr{inst.attr}.{['x','y','z','w'][inst.attr_chan]}{wait}" - raise ValueError(f"unknown LDSDIR op: {inst.op}") +# Export targets: mrt0-7, mrtz, pos0-4, prim, dual_src_blend0/1 +_EXP_TARGETS = {**{i: f'mrt{i}' for i in range(8)}, 8: 'mrtz', **{i+12: f'pos{i}' for i in range(5)}, 20: 'prim', 21: 'dual_src_blend0', 22: 'dual_src_blend1'} -def _disasm_vdsdir(inst) -> str: - wait_va = f" wait_va_vdst:{inst.wait_va}" if inst.wait_va != 0 else "" - wait_vm = f" wait_vm_vsrc:{inst.wait_vm}" if inst.wait_vm != 0 else "" - if inst.op == 1: return f"ds_direct_load v{inst.vdst}{wait_va}{wait_vm}" - if inst.op == 0: return f"ds_param_load v{inst.vdst}, attr{inst.attr}.{['x','y','z','w'][inst.attr_chan]}{wait_va}{wait_vm}" - raise ValueError(f"unknown VDSDIR op: {inst.op}") - -def _disasm_vexport(inst) -> str: +def _disasm_exp(inst) -> str: target = _EXP_TARGETS.get(inst.target, f"invalid_target_{inst.target}") en = inst.en vsrc = lambda i, v: f"v{v}" if (en >> i) & 1 else "off" @@ -525,80 +579,239 @@ def _disasm_vexport(inst) -> str: prefix = "export" if 'rdna4' in inst.__class__.__module__ else "exp" return f"{prefix} {target} {srcs}" + (" " + mods if mods else "") -def _disasm_vbuffer(inst) -> str: - name = inst.op_name.lower() - suffix = name.split('_')[-1] - base_w = {'b32':1,'b64':2,'b96':3,'b128':4,'b16':1,'b8':1,'x':1,'xy':2,'xyz':3,'xyzw':4,'u32':1,'u64':2,'i32':1,'i64':2,'f32':1,'f64':2,'f16':1,'bf16':1}.get(suffix, 1) - w = (base_w + 1) // 2 if 'd16' in name else base_w - if 'cmpswap' in name: w *= 2 - if inst.tfe: w += 1 - vaddr = _vreg(inst.vaddr, 2) if inst.offen and inst.idxen else f"v{inst.vaddr}" if inst.offen or inst.idxen else "off" - rsrc, soffset = _sreg_or_ttmp(inst.rsrc, 4), decode_src(inst.soffset) - th_load = ['','TH_LOAD_RT','TH_LOAD_NT','TH_LOAD_HT','TH_LOAD_LU','TH_LOAD_NT_RT','TH_LOAD_NT_HT','TH_LOAD_BYPASS'] - th_store = ['','TH_STORE_RT','TH_STORE_NT','TH_STORE_HT','','TH_STORE_NT_RT','TH_STORE_NT_HT','TH_STORE_BYPASS'] - th_atomic = ['','TH_ATOMIC_NT','','','','TH_ATOMIC_RETURN','TH_ATOMIC_RT_RETURN','TH_ATOMIC_CASCADE_NT'] - scope_names = ['','SCOPE_SE','SCOPE_DEV','SCOPE_SYS'] - is_atomic, is_store = 'atomic' in name, 'store' in name and 'atomic' not in name - th_names = th_atomic if is_atomic else th_store if is_store else th_load - mods = [] - if inst.idxen: mods.append("idxen") - if inst.offen: mods.append("offen") - if inst.ioffset: mods.append(f"offset:{inst.ioffset}") - if inst.th and inst.th < len(th_names) and th_names[inst.th]: mods.append(f"th:{th_names[inst.th]}") - if inst.scope and inst.scope < len(scope_names) and scope_names[inst.scope]: mods.append(f"scope:{scope_names[inst.scope]}") - if inst.tfe: mods.append("tfe") - return f"{name} {_vreg(inst.vdata, w)}, {vaddr}, {rsrc}, {soffset}" + (" " + " ".join(mods) if mods else "") +def _disasm_ldsdir(inst) -> str: + is_rdna4 = 'rdna4' in inst.__class__.__module__ + if is_rdna4: + # RDNA4 uses ds_* prefix and wait_va_vdst/wait_vm_vsrc modifiers + wait = f" wait_va_vdst:{inst.wait_va} wait_vm_vsrc:{inst.wait_vm}" + if inst.op == 1: return f"ds_direct_load v{inst.vdst}{wait}" + if inst.op == 0: return f"ds_param_load v{inst.vdst}, attr{inst.attr}.{['x','y','z','w'][inst.attr_chan]}{wait}" + else: + # RDNA3 uses lds_* prefix and wait_vdst modifier + wait = f" wait_vdst:{inst.wait_va}" if inst.wait_va != 0 else "" + if inst.op == 1: return f"lds_direct_load v{inst.vdst}{wait}" + if inst.op == 0: return f"lds_param_load v{inst.vdst}, attr{inst.attr}.{['x','y','z','w'][inst.attr_chan]}{wait}" + raise ValueError(f"unknown LDSDIR op: {inst.op}") + +# ═══════════════════════════════════════════════════════════════════════════════ +# RDNA4-specific disassemblers (GFX12) +# ═══════════════════════════════════════════════════════════════════════════════ + +# th values for RDNA4 memory instructions (based on AMDGPU ISA docs and LLVM SIDefines.h) +# Load: 0=RT(default), 1=NT, 2=HT, 3=LU, 4=NT_RT, 5=RT_NT, 6=NT_HT, 7=BYPASS(only with scope=SYS) +_TH_LOAD = {0: '', 1: 'th:TH_LOAD_NT', 2: 'th:TH_LOAD_HT', 3: 'th:TH_LOAD_LU', 4: 'th:TH_LOAD_NT_RT', 5: 'th:TH_LOAD_RT_NT', 6: 'th:TH_LOAD_NT_HT', 7: 'th:TH_LOAD_BYPASS'} +# Store: 0=RT(default), 1=NT, 2=HT, 3=WB, 4=NT_RT, 5=RT_NT, 6=NT_HT, 7=NT_WB (BYPASS is th=7 + scope=SYS) +_TH_STORE = {0: '', 1: 'th:TH_STORE_NT', 2: 'th:TH_STORE_HT', 3: 'th:TH_STORE_WB', 4: 'th:TH_STORE_NT_RT', 5: 'th:TH_STORE_RT_NT', 6: 'th:TH_STORE_NT_HT', 7: 'th:TH_STORE_NT_WB'} +# Atomic: bit0=RETURN, bit1=NT, bit2=CASCADE -> 0=none, 1=RETURN, 2=NT, 3=NT_RETURN, 4=CASCADE_RT, 5=RT_RETURN(N/A), 6=CASCADE_NT, 7=N/A +_TH_ATOMIC = {0: '', 1: 'th:TH_ATOMIC_RETURN', 2: 'th:TH_ATOMIC_NT', 3: 'th:TH_ATOMIC_NT_RETURN', 4: 'th:TH_ATOMIC_CASCADE_RT', 5: 'th:TH_ATOMIC_RT_RETURN', 6: 'th:TH_ATOMIC_CASCADE_NT', 7: 'th:TH_ATOMIC_CASCADE_NT'} +_SCOPE = {0: '', 1: 'scope:SCOPE_SE', 2: 'scope:SCOPE_DEV', 3: 'scope:SCOPE_SYS'} + +def _rdna4_mem_mods(th: int, scope: int, is_store: bool, is_atomic: bool) -> str: + th_map = _TH_ATOMIC if is_atomic else _TH_STORE if is_store else _TH_LOAD + # Special case: th=3 with scope=SYS means BYPASS for load/store (otherwise th=3 means LU/WB) + if th == 3 and scope == 3 and not is_atomic: + th_s = 'th:TH_STORE_BYPASS' if is_store else 'th:TH_LOAD_BYPASS' + else: + th_s = th_map.get(th, f'th:{th}' if th else '') + scope_s = _SCOPE.get(scope, f'scope:{scope}' if scope else '') + return ' '.join(x for x in [th_s, scope_s] if x) def _disasm_vflat(inst) -> str: - name = inst.op_name.lower() + """Disassemble RDNA4 VFLAT/VGLOBAL/VSCRATCH instructions.""" + from extra.assembly.amd.autogen.rdna4.enum import VFLATOp, VGLOBALOp, VSCRATCHOp cls_name = type(inst).__name__ - seg = 'flat' if cls_name == 'VFLAT' else 'global' if cls_name == 'VGLOBAL' else 'scratch' - parts = name.split('_', 1) - instr = f"{seg}_{parts[1]}" if len(parts) > 1 else name - suffix = name.split('_')[-1] - w = {'b32':1,'b64':2,'b96':3,'b128':4,'b16':1,'b8':1,'u32':1,'u64':2,'i32':1,'i64':2,'f32':1,'f64':2,'f16':1,'bf16':1}.get(suffix, 1) - if 'cmpswap' in name: w *= 2 - off_val = inst.ioffset if inst.ioffset < 0x800000 else inst.ioffset - 0x1000000 - if seg == 'flat': saddr_s, addr_width = "", 2 - elif inst.saddr == 0x7F or (hasattr(inst, 'sve') and inst.sve == 0 and seg == 'scratch'): saddr_s, addr_width = ", off", 2 - elif inst.saddr == 124: saddr_s, addr_width = ", off", 2 - else: saddr_s, addr_width = f", {_fmt_src(inst.saddr, 2) if inst.saddr <= 105 else decode_src(inst.saddr)}", 1 - vaddr = f"v{inst.vaddr}" if addr_width == 1 else _vreg(inst.vaddr, 2) - th_load = ['','TH_LOAD_RT','TH_LOAD_NT','TH_LOAD_HT','TH_LOAD_LU','TH_LOAD_NT_RT','TH_LOAD_NT_HT','TH_LOAD_BYPASS'] - th_store = ['','TH_STORE_RT','TH_STORE_NT','TH_STORE_HT','','TH_STORE_NT_RT','TH_STORE_NT_HT','TH_STORE_BYPASS'] - th_atomic = ['','TH_ATOMIC_NT','','','','TH_ATOMIC_RETURN','TH_ATOMIC_RT_RETURN','TH_ATOMIC_CASCADE_NT'] - scope_names = ['','SCOPE_SE','SCOPE_DEV','SCOPE_SYS'] - is_atomic, is_store = 'atomic' in name, 'store' in name and 'atomic' not in name - th_names = th_atomic if is_atomic else th_store if is_store else th_load - mods = [] - if off_val: mods.append(f"offset:{off_val}") - if inst.th and inst.th < len(th_names) and th_names[inst.th]: mods.append(f"th:{th_names[inst.th]}") - if inst.scope and inst.scope < len(scope_names) and scope_names[inst.scope]: mods.append(f"scope:{scope_names[inst.scope]}") - mod_str = " " + " ".join(mods) if mods else "" - if 'store' in name and 'atomic' not in name: return f"{instr} {vaddr}, {_vreg(inst.vsrc, w)}{saddr_s}{mod_str}" - if 'atomic' in name: - if inst.th and inst.th >= 5: return f"{instr} {_vreg(inst.vdst, w)}, {vaddr}, {_vreg(inst.vsrc, w)}{saddr_s}{mod_str}" - return f"{instr} {vaddr}, {_vreg(inst.vsrc, w)}{saddr_s}{mod_str}" - return f"{instr} {_vreg(inst.vdst, w)}, {vaddr}{saddr_s}{mod_str}" + if cls_name == 'VGLOBAL': op_enum, prefix = VGLOBALOp, 'global' + elif cls_name == 'VSCRATCH': op_enum, prefix = VSCRATCHOp, 'scratch' + else: op_enum, prefix = VFLATOp, 'flat' + name = op_enum(inst.op).name.lower() + + # Data width based on instruction name suffix + suffix = name.split('_')[-1] + base_w = {'b32':1,'b64':2,'b96':3,'b128':4,'u8':1,'i8':1,'u16':1,'i16':1,'u32':1,'i32':1,'u64':2,'i64':2,'f32':1,'f64':2}.get(suffix, 1) + # For cmpswap: vsrc holds cmp+data pairs (2x base), vdst is base width + vsrc_w = base_w * 2 if 'cmpswap' in name else base_w + vdst_w = base_w + + # Offset: signed 24-bit (stored as unsigned, needs sign extension) + off = inst.ioffset if inst.ioffset < 0x800000 else inst.ioffset - 0x1000000 + off_s = f" offset:{off}" if off else "" + + # Memory modifiers + is_store, is_atomic = 'store' in name, 'atomic' in name + mods = _rdna4_mem_mods(inst.th, inst.scope, is_store, is_atomic) + + # saddr handling - VGLOBAL needs explicit "off" when saddr=124 + if inst.saddr == 124: saddr_s = ", off" if prefix == 'global' else "" + elif inst.saddr in SPECIAL_PAIRS: saddr_s = f", {SPECIAL_PAIRS[inst.saddr]}" + else: saddr_s = f", {_sreg(inst.saddr, 2)}" + + # Address width: 1 for scratch with saddr, 2 otherwise + addr_w = 1 if (prefix == 'scratch' or (inst.saddr != 124 and prefix != 'flat')) else 2 + vaddr_s = _vreg(inst.vaddr, addr_w) + vsrc_s = _vreg(inst.vsrc, vsrc_w) + vdst_s = _vreg(inst.vdst, vdst_w) + + if is_atomic: + if inst.th == 1: # TH_ATOMIC_RETURN + return f"{name} {vdst_s}, {vaddr_s}, {vsrc_s}{saddr_s}{off_s}" + (f" {mods}" if mods else "") + return f"{name} {vaddr_s}, {vsrc_s}{saddr_s}{off_s}" + (f" {mods}" if mods else "") + if is_store: return f"{name} {vaddr_s}, {vsrc_s}{saddr_s}{off_s}" + (f" {mods}" if mods else "") + return f"{name} {vdst_s}, {vaddr_s}{saddr_s}{off_s}" + (f" {mods}" if mods else "") + +def _disasm_vbuffer(inst) -> str: + """Disassemble RDNA4 VBUFFER instructions.""" + from extra.assembly.amd.autogen.rdna4.enum import VBUFFEROp + name = VBUFFEROp(inst.op).name.lower() + + # Determine if this is a typed buffer instruction (MTBUF format) + is_format = 'format' in name + + # Data width based on instruction name + if is_format: + w = {'x': 1, 'xy': 2, 'xyz': 3, 'xyzw': 4}.get(name.split('_')[-1], 1) + if 'd16' in name: w = (w + 1) // 2 + else: + suffix = name.split('_')[-1] + w = {'b32':1,'b64':2,'b96':3,'b128':4,'u8':1,'i8':1,'u16':1,'i16':1,'u32':1,'i32':1,'u64':2,'i64':2,'b8':1,'b16':1,'f16':1,'f32':1,'bf16':1}.get(suffix, 1) + if 'cmpswap' in name: w *= 2 + if inst.tfe: w += 1 + + vdata_s = _vreg(inst.vdata, w) + vaddr_s = _vreg(inst.vaddr, 2) if inst.offen and inst.idxen else (f"v{inst.vaddr}" if inst.offen or inst.idxen else "off") + # RDNA4 VBUFFER rsrc field stores the SGPR index directly (not /4 like RDNA3) + srsrc_s = _sreg_or_ttmp(inst.rsrc, 4) + soffset_s = decode_src(inst.soffset) + + off = inst.ioffset if inst.ioffset < 0x800000 else inst.ioffset - 0x1000000 + is_store, is_atomic = 'store' in name, 'atomic' in name + mods = _rdna4_mem_mods(inst.th, inst.scope, is_store, is_atomic) + + # Format field is only for MTBUF (tbuffer_*) instructions, not buffer_*_format_* instructions + # We don't output format for buffer instructions since they use implicit format + parts = [] + if inst.idxen: parts.append("idxen") + if inst.offen: parts.append("offen") + if off: parts.append(f"offset:{off}") + if mods: parts.append(mods) + if inst.tfe: parts.append("tfe") + + return f"{name} {vdata_s}, {vaddr_s}, {srsrc_s}, {soffset_s}" + (f" {' '.join(parts)}" if parts else "") + +def _disasm_vimage(inst) -> str: + """Disassemble RDNA4 VIMAGE instructions.""" + from extra.assembly.amd.autogen.rdna4.enum import VIMAGEOp + name = VIMAGEOp(inst.op).name.lower() + + # RDNA4 VIMAGE rsrc field stores the SGPR index directly (not /4 like RDNA3) + if 'bvh' in name: + # BVH intersect ray: special format with individual/range vaddr components + # Format: [node_ptr, ray_extent, ray_origin(3), ray_dir(3), ray_inv_dir(3)] + # bvh64 has 2-VGPR node_ptr, a16 removes ray_inv_dir + if 'dual' in name or 'bvh8' in name: + # dual/bvh8: [node_ptr(2), ray_extent(2), ray_origin(3), ray_dir(3), ...] + parts = [_vreg(inst.vaddr0, 2), _vreg(inst.vaddr1, 2), _vreg(inst.vaddr2, 3), _vreg(inst.vaddr3, 3)] + if not inst.a16: parts.append(_vreg(inst.vaddr4, 1 if 'bvh8' in name else 2)) + dst_w = 10 + elif '64' in name: + parts = [_vreg(inst.vaddr0, 2), f"v{inst.vaddr1}", _vreg(inst.vaddr2, 3), _vreg(inst.vaddr3, 3)] + if not inst.a16: parts.append(_vreg(inst.vaddr4, 3)) + dst_w = 4 + else: + parts = [f"v{inst.vaddr0}", f"v{inst.vaddr1}", _vreg(inst.vaddr2, 3), _vreg(inst.vaddr3, 3)] + if not inst.a16: parts.append(_vreg(inst.vaddr4, 3)) + dst_w = 4 + return f"{name} {_vreg(inst.vdata, dst_w)}, [{', '.join(parts)}], {_sreg_or_ttmp(inst.rsrc, 4)}{' a16' if inst.a16 else ''}" + + # vdata width - msaa_load always uses 4 VGPRs per channel + vdata = 4 if 'gather4' in name or 'msaa_load' in name else (bin(inst.dmask).count('1') or 1) + if inst.d16: vdata = (vdata + 1) // 2 + if inst.tfe: vdata += 1 + + # dim names + dim_names = ['1d', '2d', '3d', 'cube', '1d_array', '2d_array', '2d_msaa', '2d_msaa_array'] + dim = dim_names[inst.dim] if inst.dim < len(dim_names) else f"{inst.dim}" + + # vaddr width calculation (RDNA4 uses vaddr0-4 for address components) + vaddr_w = _mimg_vaddr_width(name, inst.dim, inst.a16) + if vaddr_w == 1: vaddr_s = f"v{inst.vaddr0}" + elif vaddr_w == 2: vaddr_s = f"[v{inst.vaddr0}, v{inst.vaddr1}]" + elif vaddr_w == 3: vaddr_s = f"[v{inst.vaddr0}, v{inst.vaddr1}, v{inst.vaddr2}]" + elif vaddr_w == 4: vaddr_s = f"[v{inst.vaddr0}, v{inst.vaddr1}, v{inst.vaddr2}, v{inst.vaddr3}]" + else: vaddr_s = f"[v{inst.vaddr0}, v{inst.vaddr1}, v{inst.vaddr2}, v{inst.vaddr3}, v{inst.vaddr4}]" + + srsrc_s = _sreg_or_ttmp(inst.rsrc, 8) + # RDNA4 always requires dmask for size calculation (even if 0xf) + mods = [f"dmask:0x{inst.dmask:x}"] + mods.append(f"dim:SQ_RSRC_IMG_{dim.upper()}") + # Add th/scope before other modifiers, then r128, then a16/tfe/d16 (LLVM expects this order) + if inst.th or inst.scope: + is_store, is_atomic = 'store' in name, 'atomic' in name + mem_mods = _rdna4_mem_mods(inst.th, inst.scope, is_store, is_atomic) + if mem_mods: mods.append(mem_mods) + if inst.r128: mods.append("r128") + mods.extend([m for c, m in [(inst.a16, "a16"), (inst.tfe, "tfe"), (inst.d16, "d16")] if c]) + + return f"{name} {_vreg(inst.vdata, vdata)}, {vaddr_s}, {srsrc_s} {' '.join(mods)}" + +def _disasm_vsample(inst) -> str: + """Disassemble RDNA4 VSAMPLE instructions.""" + from extra.assembly.amd.autogen.rdna4.enum import VSAMPLEOp + name = VSAMPLEOp(inst.op).name.lower() + + # vdata width + vdata = 4 if 'gather4' in name or 'msaa_load' in name else (bin(inst.dmask).count('1') or 1) + if inst.d16: vdata = (vdata + 1) // 2 + if inst.tfe: vdata += 1 + + dim_names = ['1d', '2d', '3d', 'cube', '1d_array', '2d_array', '2d_msaa', '2d_msaa_array'] + dim = dim_names[inst.dim] if inst.dim < len(dim_names) else f"{inst.dim}" + + vaddr = _mimg_vaddr_width(name, inst.dim, inst.a16) + if vaddr == 1: vaddr_s = f"v{inst.vaddr0}" + elif vaddr == 2: vaddr_s = f"[v{inst.vaddr0}, v{inst.vaddr1}]" + elif vaddr == 3: vaddr_s = f"[v{inst.vaddr0}, v{inst.vaddr1}, v{inst.vaddr2}]" + elif vaddr == 4: vaddr_s = f"[v{inst.vaddr0}, v{inst.vaddr1}, v{inst.vaddr2}, v{inst.vaddr3}]" + else: + # More than 4 vaddrs: vaddr3 becomes start of a contiguous range for the remaining coords + extra = vaddr - 3 # vaddr0-2 are individual, vaddr3 starts range of remaining + vaddr_s = f"[v{inst.vaddr0}, v{inst.vaddr1}, v{inst.vaddr2}, {_vreg(inst.vaddr3, extra)}]" + + # RDNA4 VSAMPLE rsrc/samp fields store the SGPR index directly (not /4 like RDNA3) + srsrc_s = _sreg_or_ttmp(inst.rsrc, 8) + + # msaa_load doesn't use a sampler (it's a load, not a sample), but gather4h does + uses_sampler = 'msaa_load' not in name + ssamp_s = f", {_sreg_or_ttmp(inst.samp, 4)}" if uses_sampler else "" + + mods = [f"dmask:0x{inst.dmask:x}"] if inst.dmask else [] + mods.append(f"dim:SQ_RSRC_IMG_{dim.upper()}") + if inst.unrm: mods.append("unorm") + # th/scope must come before r128, a16, tfe, lwe, d16 + if inst.th or inst.scope: + mem_mods = _rdna4_mem_mods(inst.th, inst.scope, False, False) + if mem_mods: mods.append(mem_mods) + mods.extend([m for c, m in [(inst.r128, "r128"), (inst.a16, "a16"), (inst.tfe, "tfe"), (inst.lwe, "lwe"), (inst.d16, "d16")] if c]) + + return f"{name} {_vreg(inst.vdata, vdata)}, {vaddr_s}, {srsrc_s}{ssamp_s} {' '.join(mods)}" -# Handler mappings DISASM_HANDLERS = {VOP1: _disasm_vop1, VOP2: _disasm_vop2, VOPC: _disasm_vopc, VOP3: _disasm_vop3, VOP3SD: _disasm_vop3sd, VOPD: _disasm_vopd, VOP3P: _disasm_vop3p, VINTERP: _disasm_vinterp, SOPP: _disasm_sopp, SMEM: _disasm_smem, DS: _disasm_ds, FLAT: _disasm_flat, MUBUF: _disasm_buf, MTBUF: _disasm_buf, - MIMG: _disasm_mimg, SOP1: _disasm_sop1, SOP2: _disasm_sop2, SOPC: _disasm_sopc, SOPK: _disasm_sopk} + MIMG: _disasm_mimg, SOP1: _disasm_sop1, SOP2: _disasm_sop2, SOPC: _disasm_sopc, SOPK: _disasm_sopk, EXP: _disasm_exp, LDSDIR: _disasm_ldsdir} -_DISASM_BY_NAME = { - 'VOP1': _disasm_vop1, 'VOP2': _disasm_vop2, 'VOPC': _disasm_vopc, 'VOP3': _disasm_vop3, 'VOP3SD': _disasm_vop3sd, - 'VOPD': _disasm_vopd, 'VOP3P': _disasm_vop3p, 'VINTERP': _disasm_vinterp, 'SOPP': _disasm_sopp, 'SMEM': _disasm_smem, - 'DS': _disasm_ds, 'VDS': _disasm_ds, 'FLAT': _disasm_flat, 'MUBUF': _disasm_buf, 'MTBUF': _disasm_buf, 'MIMG': _disasm_mimg, - 'SOP1': _disasm_sop1, 'SOP2': _disasm_sop2, 'SOPC': _disasm_sopc, 'SOPK': _disasm_sopk, - 'VEXPORT': _disasm_vexport, 'EXP': _disasm_vexport, 'LDSDIR': _disasm_ldsdir, 'VDSDIR': _disasm_vdsdir, - 'VBUFFER': _disasm_vbuffer, 'VFLAT': _disasm_vflat, 'VGLOBAL': _disasm_vflat, 'VSCRATCH': _disasm_vflat, - 'VSAMPLE': _disasm_vsample, 'VIMAGE': _disasm_vimage, -} +# RDNA4 uses different class names, dispatch by name for cross-arch support +_DISASM_BY_NAME = {'VOP1': _disasm_vop1, 'VOP2': _disasm_vop2, 'VOPC': _disasm_vopc, 'VOP3': _disasm_vop3, 'VOP3SD': _disasm_vop3sd, + 'VOPD': _disasm_vopd, 'VOP3P': _disasm_vop3p, 'VINTERP': _disasm_vinterp, 'SOPP': _disasm_sopp, 'SMEM': _disasm_smem, + 'DS': _disasm_ds, 'VDS': _disasm_ds, 'FLAT': _disasm_flat, 'MUBUF': _disasm_buf, 'MTBUF': _disasm_buf, 'MIMG': _disasm_mimg, + 'SOP1': _disasm_sop1, 'SOP2': _disasm_sop2, 'SOPC': _disasm_sopc, 'SOPK': _disasm_sopk, + 'EXP': _disasm_exp, 'LDSDIR': _disasm_ldsdir, 'VEXPORT': _disasm_exp, 'VDSDIR': _disasm_ldsdir, + 'VFLAT': _disasm_vflat, 'VGLOBAL': _disasm_vflat, 'VSCRATCH': _disasm_vflat, + 'VBUFFER': _disasm_vbuffer, 'VIMAGE': _disasm_vimage, 'VSAMPLE': _disasm_vsample} -def disasm(inst: Inst) -> str: +def disasm(inst: Inst, wave_size: int = 32) -> str: handler = DISASM_HANDLERS.get(type(inst)) or _DISASM_BY_NAME.get(type(inst).__name__) - if handler is None: raise KeyError(f"No disasm handler for {type(inst).__name__}") + if handler is None: raise KeyError(f"no disasm handler for {type(inst).__name__}") + # For VOP3P (includes WMMA), pass wave_size if handler supports it + if handler == _disasm_vop3p: return _disasm_vop3p(inst, wave_size) return handler(inst) # ═══════════════════════════════════════════════════════════════════════════════ @@ -607,7 +820,7 @@ def disasm(inst: Inst) -> str: SPEC_REGS = {'vcc_lo': RawImm(106), 'vcc_hi': RawImm(107), 'vcc': RawImm(106), 'null': RawImm(124), 'off': RawImm(124), 'm0': RawImm(125), 'exec_lo': RawImm(126), 'exec_hi': RawImm(127), 'exec': RawImm(126), 'scc': RawImm(253), 'src_scc': RawImm(253)} -FLOATS = {str(k): k for k in FLOAT_ENC} +FLOATS = {str(k): k for k in FLOAT_ENC} # Valid float literal strings: '0.5', '-0.5', '1.0', etc. REG_MAP: dict[str, _RegFactory] = {'s': s, 'v': v, 't': ttmp, 'ttmp': ttmp} SMEM_OPS = {'s_load_b32', 's_load_b64', 's_load_b128', 's_load_b256', 's_load_b512', 's_buffer_load_b32', 's_buffer_load_b64', 's_buffer_load_b128', 's_buffer_load_b256', 's_buffer_load_b512'} @@ -649,6 +862,7 @@ def _extract(text: str, pat: str, flags=re.I): def get_dsl(text: str) -> str: text, kw = text.strip(), [] + # Extract modifiers for pat, val in [(r'\s+mul:2(?:\s|$)', 1), (r'\s+mul:4(?:\s|$)', 2), (r'\s+div:2(?:\s|$)', 3)]: if (m := _extract(text, pat))[0]: kw.append(f'omod={val}'); text = m[1]; break if (m := _extract(text, r'\s+clamp(?:\s|$)'))[0]: kw.append('clmp=1'); text = m[1] @@ -666,11 +880,13 @@ def get_dsl(text: str) -> str: m, text = _extract(text, r'\s+neg_lo:\[([^\]]+)\]'); neg_lo = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None m, text = _extract(text, r'\s+neg_hi:\[([^\]]+)\]'); neg_hi = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None if waitexp: kw.append(f'waitexp={waitexp}') + parts = text.replace(',', ' ').split() if not parts: raise ValueError("empty instruction") mn, op_str = parts[0].lower(), text[len(parts[0]):].strip() ops, args = _parse_ops(op_str), [_op2dsl(o) for o in _parse_ops(op_str)] + # s_waitcnt if mn == 's_waitcnt': vm, exp, lgkm = 0x3f, 0x7, 0x3f for p in op_str.replace(',', ' ').split(): @@ -680,6 +896,7 @@ def get_dsl(text: str) -> str: elif re.match(r'^0x[0-9a-f]+$|^\d+$', p): return f"s_waitcnt(simm16={int(p, 0)})" return f"s_waitcnt(simm16={waitcnt(vm, exp, lgkm)})" + # VOPD if '::' in text: xp, yp = text.split('::') xps, yps = xp.strip().replace(',', ' ').split(), yp.strip().replace(',', ' ').split() @@ -691,12 +908,14 @@ def get_dsl(text: str) -> str: elif 'fmamk' in yps[0].lower() and len(yo) > 3: lit, vsy1 = yo[2], yo[3] return f"VOPD(VOPDOp.{xps[0].upper()}, VOPDOp.{yps[0].upper()}, vdstx={vdx}, vdsty={vdy}, srcx0={sx0}, vsrcx1={vsx1}, srcy0={sy0}, vsrcy1={vsy1}{f', literal={lit}' if lit else ''})" + # Special instructions if mn == 's_setreg_imm32_b32': raise ValueError(f"unsupported: {mn}") if mn in ('s_setpc_b64', 's_rfe_b64'): return f"{mn}(ssrc0={args[0]})" if mn in ('s_sendmsg_rtn_b32', 's_sendmsg_rtn_b64'): return f"{mn}(sdst={args[0]}, ssrc0=RawImm({args[1].strip()}))" if mn == 's_version': return f"{mn}(simm16={args[0]})" if mn == 's_setreg_b32': return f"{mn}(simm16={args[0]}, sdst={args[1]})" + # SMEM if mn in SMEM_OPS: gs, ds = ", glc=1" if glc else "", ", dlc=1" if dlc else "" if len(ops) >= 3 and re.match(r'^-?[0-9]|^-?0x', ops[2].strip().lower()): @@ -704,9 +923,11 @@ def get_dsl(text: str) -> str: if off_val and len(ops) >= 3: return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={off_val}, soffset={args[2]}{gs}{ds})" if len(ops) >= 3: return f"{mn}(sdata={args[0]}, sbase={args[1]}, soffset={args[2]}{gs}{ds})" + # Buffer if mn.startswith('buffer_') and len(ops) >= 2 and ops[1].strip().lower() == 'off': return f"{mn}(vdata={args[0]}, vaddr=0, srsrc={args[2]}, soffset={f'RawImm({args[3].strip()})' if len(args) > 3 else 'RawImm(0)'})" + # FLAT/GLOBAL/SCRATCH load/store/atomic - saddr needs RawImm(124) for off/null def _saddr(a): return 'RawImm(124)' if a in ('OFF', 'NULL') else a flat_mods = f"{f', offset={off_val}' if off_val else ''}{', glc=1' if glc else ''}{', slc=1' if slc else ''}{', dlc=1' if dlc else ''}" for pre, flds in [('flat_load','vdst,addr,saddr'), ('global_load','vdst,addr,saddr'), ('scratch_load','vdst,addr,saddr'), @@ -719,6 +940,7 @@ def get_dsl(text: str) -> str: if glc and len(args) >= 3: return f"{mn}(vdst={args[0]}, addr={args[1]}, data={args[2]}{f', saddr={_saddr(args[3])}' if len(args) >= 4 else ', saddr=RawImm(124)'}{flat_mods})" if len(args) >= 2: return f"{mn}(addr={args[0]}, data={args[1]}{f', saddr={_saddr(args[2])}' if len(args) >= 3 else ', saddr=RawImm(124)'}{flat_mods})" + # DS instructions if mn.startswith('ds_'): off0, off1 = (str(int(off_val, 0) & 0xff), str((int(off_val, 0) >> 8) & 0xff)) if off_val else ("0", "0") gds_s = ", gds=1" if 'gds' in text.lower().split()[-1:] else "" @@ -742,10 +964,12 @@ def get_dsl(text: str) -> str: return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}, data1={args[3]}{off_kw})" if '_rtn' in mn else f"{mn}(addr={args[0]}, data0={args[1]}, data1={args[2]}{off_kw})" return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}{off_kw})" if '_rtn' in mn else f"{mn}(addr={args[0]}, data0={args[1]}{off_kw})" + # v_fmaak/v_fmamk literal extraction lit_s = "" if mn in ('v_fmaak_f32', 'v_fmaak_f16') and len(args) == 4: lit_s, args = f", literal={args[3].strip()}", args[:3] elif mn in ('v_fmamk_f32', 'v_fmamk_f16') and len(args) == 4: lit_s, args = f", literal={args[2].strip()}", [args[0], args[1], args[3]] + # VCC ops cleanup vcc_ops = {'v_add_co_ci_u32', 'v_sub_co_ci_u32', 'v_subrev_co_ci_u32'} if mn.replace('_e32', '') in vcc_ops and len(args) >= 5: mn, args = mn.replace('_e32', '') + '_e32', [args[0], args[2], args[3]] if mn.replace('_e64', '') in vcc_ops and mn.endswith('_e64'): mn = mn.replace('_e64', '') @@ -755,6 +979,7 @@ def get_dsl(text: str) -> str: fn = mn.replace('.', '_') if opsel is not None: args = [re.sub(r'\.[hl]$', '', a) for a in args] + # v_fma_mix*: extract inline neg/abs modifiers if 'fma_mix' in mn and neg_lo is None and neg_hi is None: inline_neg, inline_abs, clean_args = 0, 0, [args[0]] for i, op in enumerate(ops[1:4]): diff --git a/extra/assembly/amd/dsl.py b/extra/assembly/amd/dsl.py index bc8dd5b8d4..2df46af183 100644 --- a/extra/assembly/amd/dsl.py +++ b/extra/assembly/amd/dsl.py @@ -497,9 +497,9 @@ class Inst: def __hash__(self): return hash((self.__class__.__name__, tuple(sorted((k, repr(v)) for k, v in self._values.items())), self._literal)) - def disasm(self) -> str: + def disasm(self, wave_size: int = 32) -> str: from extra.assembly.amd.asm import disasm - return disasm(self) + return disasm(self, wave_size) _enum_map = {'VOP1': VOP1Op, 'VOP2': VOP2Op, 'VOP3': VOP3Op, 'VOP3SD': VOP3SDOp, 'VOP3P': VOP3POp, 'VOPC': VOPCOp, 'SOP1': SOP1Op, 'SOP2': SOP2Op, 'SOPC': SOPCOp, 'SOPK': SOPKOp, 'SOPP': SOPPOp, diff --git a/extra/assembly/amd/test/test_llvm.py b/extra/assembly/amd/test/test_llvm.py index 1685e9ec36..37b030275a 100644 --- a/extra/assembly/amd/test/test_llvm.py +++ b/extra/assembly/amd/test/test_llvm.py @@ -32,12 +32,14 @@ RDNA4_TEST_FILES = { 'vop3_from_vop1': 'gfx12_asm_vop3_from_vop1.s', 'vop3_from_vop2': 'gfx12_asm_vop3_from_vop2.s', 'ds': 'gfx12_asm_ds.s', 'ds_alias': 'gfx12_asm_ds_alias.s', 'smem': 'gfx12_asm_smem.s', 'vflat': 'gfx12_asm_vflat.s', 'vflat_alias': 'gfx12_asm_vflat_alias.s', + 'vscratch': 'gfx12_asm_vflat.s', 'vscratch_alias': 'gfx12_asm_vflat_alias.s', # scratch instructions in vflat files 'vbuffer_mubuf': 'gfx12_asm_vbuffer_mubuf.s', 'vbuffer_mubuf_alias': 'gfx12_asm_vbuffer_mubuf_alias.s', 'vbuffer_mtbuf': 'gfx12_asm_vbuffer_mtbuf.s', 'vbuffer_mtbuf_alias': 'gfx12_asm_vbuffer_mtbuf_alias.s', 'vimage': 'gfx12_asm_vimage.s', 'vimage_alias': 'gfx12_asm_vimage_alias.s', 'vsample': 'gfx12_asm_vsample.s', 'vdsdir': 'gfx12_asm_vdsdir.s', 'vdsdir_alias': 'gfx12_asm_vdsdir_alias.s', 'exp': 'gfx12_asm_exp.s', 'wmma_w32': 'gfx12_asm_wmma_w32.s', 'wmma_w64': 'gfx12_asm_wmma_w64.s', - 'features': 'gfx12_asm_features.s', 'global_load_tr': 'gfx12_asm_global_load_tr.s', + 'global_load_tr': 'gfx12_asm_global_load_tr.s', + # NOTE: 'features' (gfx12_asm_features.s) tests DPP instruction variants which require separate format decoders } def parse_llvm_tests(text: str, gfx_prefix: str) -> list[tuple[str, bytes]]: @@ -108,6 +110,10 @@ class TestLLVMBase(unittest.TestCase): fmt_cls = self.formats.get(name) if fmt_cls is None: self.skipTest(f"No format class for {name}") + # Determine wave size from test name (w64 = wave64, otherwise wave32) + wave_size = 64 if 'w64' in name else 32 + mattr = f'+real-true16,+wavefrontsize{wave_size}' + to_test: list[tuple[str, bytes, str | None, str | None]] = [] for asm_text, data in self.tests.get(name, []): if len(data) > fmt_cls._size(): continue @@ -117,14 +123,14 @@ class TestLLVMBase(unittest.TestCase): if decoded.to_bytes()[:len(data)] != data: to_test.append((asm_text, data, None, "decode roundtrip failed")) continue - to_test.append((asm_text, data, decoded.disasm(), None)) + to_test.append((asm_text, data, decoded.disasm(wave_size), None)) except Exception as e: to_test.append((asm_text, data, None, f"exception: {e}")) disasm_strs = [(i, t[2]) for i, t in enumerate(to_test) if t[2] is not None] llvm_map = {} if disasm_strs: - llvm_results = compile_asm_batch([s for _, s in disasm_strs], self.mcpu) + llvm_results = compile_asm_batch([s for _, s in disasm_strs], self.mcpu, mattr) llvm_map = {i: llvm_results[j] for j, (i, _) in enumerate(disasm_strs)} passed, failed, failures = 0, 0, [] @@ -180,9 +186,10 @@ class TestLLVMRDNA4(TestLLVMBase): 'vbuffer_mtbuf': get('VBUFFER'), 'vbuffer_mtbuf_alias': get('VBUFFER'), 'vdsdir': get('VDSDIR'), 'vdsdir_alias': get('VDSDIR'), 'vflat': get('VFLAT'), 'vflat_alias': get('VFLAT'), + 'vscratch': get('VSCRATCH'), 'vscratch_alias': get('VSCRATCH'), 'vimage': get('VIMAGE'), 'vimage_alias': get('VIMAGE'), 'vsample': get('VSAMPLE'), 'wmma_w32': get('VOP3P'), 'wmma_w64': get('VOP3P'), - 'features': None, 'global_load_tr': get('VGLOBAL'), + 'global_load_tr': get('VGLOBAL'), } cls._load_tests(RDNA4_TEST_FILES) diff --git a/extra/assembly/amd/test/test_roundtrip.py b/extra/assembly/amd/test/test_roundtrip.py index cc9ff9c66b..89e32f1260 100644 --- a/extra/assembly/amd/test/test_roundtrip.py +++ b/extra/assembly/amd/test/test_roundtrip.py @@ -1,90 +1,37 @@ #!/usr/bin/env python3 """Roundtrip tests: generate tinygrad kernels, decode instructions, re-encode, verify match.""" import unittest, io, sys, re, subprocess, os -from extra.assembly.amd.autogen.rdna3.ins import * from extra.assembly.amd.dsl import Inst -from extra.assembly.amd.asm import asm -from extra.assembly.amd.asm import detect_format from extra.assembly.amd.test.helpers import get_llvm_mc, get_llvm_objdump -def disassemble_lib(lib: bytes, compiler) -> list[tuple[str, bytes]]: - """Disassemble ELF binary and return list of (instruction_text, machine_code_bytes).""" - old_stdout = sys.stdout - sys.stdout = io.StringIO() - compiler.disassemble(lib) - output = sys.stdout.getvalue() - sys.stdout = old_stdout - - results = [] - for line in output.splitlines(): - if '//' not in line: continue - instr = line.split('//')[0].strip() - if not instr: continue - comment = line.split('//')[1].strip() - if ':' not in comment: continue - hex_str = comment.split(':')[1].strip().split()[0] - try: - machine_bytes = bytes.fromhex(hex_str)[::-1] # big-endian to little-endian - results.append((instr, machine_bytes)) - except ValueError: - continue - return results - -def compile_asm(instr: str, compiler=None) -> bytes: - """Compile a single instruction with llvm-mc and return the machine code bytes.""" - llvm_mc = get_llvm_mc() - result = subprocess.run( - [llvm_mc, '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'], - input=f".text\n{instr}\n", capture_output=True, text=True) - if result.returncode != 0: raise RuntimeError(f"llvm-mc failed for '{instr}': {result.stderr.strip()}") - # Parse encoding: [0x01,0x39,0x0a,0x7e] - for line in result.stdout.split('\n'): - if 'encoding:' in line: - enc = line.split('encoding:')[1].strip() - if enc.startswith('[') and enc.endswith(']'): - hex_vals = enc[1:-1].replace('0x', '').replace(',', '').replace(' ', '') - return bytes.fromhex(hex_vals) - raise RuntimeError(f"no encoding found in llvm-mc output for: {instr}") - -def compile_asm_batch(instrs: list[str]) -> list[bytes]: +def compile_asm_batch(instrs: list[str], mcpu: str = 'gfx1100') -> list[bytes]: """Compile multiple instructions with a single llvm-mc call.""" if not instrs: return [] - llvm_mc = get_llvm_mc() - src = ".text\n" + "\n".join(instrs) + "\n" - result = subprocess.run( - [llvm_mc, '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'], - input=src, capture_output=True, text=True) + result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', f'-mcpu={mcpu}', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'], + input=".text\n" + "\n".join(instrs) + "\n", capture_output=True, text=True) if result.returncode != 0: raise RuntimeError(f"llvm-mc batch failed: {result.stderr.strip()}") - # Parse all encodings in order encodings = [] for line in result.stdout.split('\n'): if 'encoding:' in line: enc = line.split('encoding:')[1].strip() if enc.startswith('[') and enc.endswith(']'): - hex_vals = enc[1:-1].replace('0x', '').replace(',', '').replace(' ', '') - encodings.append(bytes.fromhex(hex_vals)) + encodings.append(bytes.fromhex(enc[1:-1].replace('0x', '').replace(',', '').replace(' ', ''))) if len(encodings) != len(instrs): raise RuntimeError(f"expected {len(instrs)} encodings, got {len(encodings)}") return encodings -def compile_and_disasm_batch(instrs: list[str], compiler) -> list[str]: +def compile_and_disasm_batch(instrs: list[str], mcpu: str = 'gfx1100') -> list[str]: """Compile instructions with LLVM and get LLVM's disassembly.""" - import tempfile, os + import tempfile if not instrs: return [] - # Build assembly source with all instructions - src = ".text\n.globl test\n.p2align 8\n.type test,@function\ntest:\n" - src += "\n".join(f" {instr}" for instr in instrs) + "\n" - # Use llvm-mc to assemble to object file + src = ".text\n.globl test\n.p2align 8\n.type test,@function\ntest:\n" + "\n".join(f" {instr}" for instr in instrs) + "\n" with tempfile.NamedTemporaryFile(suffix='.o', delete=False) as f: obj_path = f.name try: - result = subprocess.run( - [get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-filetype=obj', '-o', obj_path], - input=src, capture_output=True, text=True) + result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', f'-mcpu={mcpu}', '-mattr=+real-true16,+wavefrontsize32', '-filetype=obj', '-o', obj_path], + input=src, capture_output=True, text=True) if result.returncode != 0: raise RuntimeError(f"llvm-mc failed: {result.stderr.strip()}") - # Disassemble with llvm-objdump - result = subprocess.run([get_llvm_objdump(), '-d', '--mcpu=gfx1100', obj_path], capture_output=True, text=True) + result = subprocess.run([get_llvm_objdump(), '-d', f'--mcpu={mcpu}', obj_path], capture_output=True, text=True) if result.returncode != 0: raise RuntimeError(f"llvm-objdump failed: {result.stderr.strip()}") - # Parse disassembly output results: list[str] = [] for line in result.stdout.splitlines(): if '//' not in line: continue @@ -94,127 +41,143 @@ def compile_and_disasm_batch(instrs: list[str], compiler) -> list[str]: finally: os.unlink(obj_path) -class TestTinygradKernelRoundtrip(unittest.TestCase): - """Test roundtrip on real tinygrad-generated kernels using get_kernels_from_tinygrad pattern.""" +class TestRoundtripBase(unittest.TestCase): + """Base class for roundtrip tests.""" + mcpu: str = 'gfx1100' + arch: str = 'rdna3' + + @classmethod + def _get_modules(cls): + if cls.arch == 'rdna3': + from extra.assembly.amd.autogen.rdna3 import ins + from extra.assembly.amd.asm import detect_format, asm + else: + import extra.assembly.amd.autogen.rdna4.ins as ins + from extra.assembly.amd.asm import asm + detect_format = None # RDNA4 uses different detection + return ins, detect_format, asm def _test_kernel_roundtrip(self, op_fn): - """Generate kernel from op_fn, test: - 1. decode -> reencode matches original bytes - 2. asm(disasm()) matches LLVM output - 3. our disasm() matches LLVM's disassembly string exactly - """ + """Generate kernel from op_fn, test decode -> reencode and asm(disasm()) matches LLVM.""" from extra.assembly.amd.test.test_compare_emulators import get_kernels_from_tinygrad from tinygrad.runtime.support.compiler_amd import HIPCompiler + ins, detect_format, asm = self._get_modules() kernels, _, _ = get_kernels_from_tinygrad(op_fn) - compiler = HIPCompiler('gfx1100') + compiler = HIPCompiler(self.mcpu) - # First pass: decode all instructions and collect info - decoded_instrs: list[tuple] = [] # list of (ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err) + # First pass: decode all instructions + decoded_instrs: list[tuple] = [] for ki, kernel in enumerate(kernels): offset = 0 while offset < len(kernel.code): remaining = kernel.code[offset:] - fmt = detect_format(remaining) - if fmt is None: - decoded_instrs.append((ki, offset, None, None, None, False, "no format")) - offset += 4 - continue + if len(remaining) < 4: break + + # Try to detect format + if detect_format is not None: + try: + fmt = detect_format(remaining) + except ValueError: + decoded_instrs.append((ki, offset, None, None, None, False, "no format")) + offset += 4 + continue + else: + # For RDNA4, try formats in order + fmt = None + from extra.assembly.amd.autogen.rdna4.ins import SOP1, SOP2, SOPC, SOPK, SOPP, VOP1, VOP2, VOP3, VOP3P, VOPC, VOPD, VDS, SMEM, VFLAT, VBUFFER, VIMAGE, VSAMPLE, VEXPORT, VDSDIR + word = int.from_bytes(remaining[:4], 'little') + for cls in [VOPD, VOP3P, VOP3, VDS, VFLAT, VBUFFER, VIMAGE, VSAMPLE, SMEM, VEXPORT, SOP1, SOPC, SOPP, SOPK, VOPC, VOP1, SOP2, VOP2, VDSDIR]: + if cls._encoding is not None: + bf, val = cls._encoding + if ((word >> bf.lo) & bf.mask()) == val: + fmt = cls + break + if fmt is None: + decoded_instrs.append((ki, offset, None, None, None, False, "no format")) + offset += 4 + continue base_size = fmt._size() - if len(remaining) < base_size: - break + if len(remaining) < base_size: break try: - decoded = fmt.from_bytes(remaining) # pass all remaining bytes so from_bytes can read literal - size = decoded.size() # actual size including literal + decoded = fmt.from_bytes(remaining) + size = decoded.size() orig_bytes = remaining[:size] reencoded = decoded.to_bytes() our_disasm = decoded.disasm() decode_ok = reencoded == orig_bytes - decode_err: str | None = None if decode_ok else f"orig={orig_bytes.hex()} reenc={reencoded.hex()}" + decode_err = None if decode_ok else f"orig={orig_bytes.hex()} reenc={reencoded.hex()}" decoded_instrs.append((ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err)) except Exception as e: decoded_instrs.append((ki, offset, remaining[:base_size], None, None, False, str(e))) size = base_size - offset += size - # Collect disasm strings for batched LLVM calls - skip unknown opcodes (op_X) that LLVM can't compile - asm_test_instrs: list[tuple[int, str]] = [] # (idx, our_disasm) for asm test - disasm_test_instrs: list[tuple[int, str]] = [] # (idx, our_disasm) for disasm comparison test - + # Collect disasm strings for batched LLVM calls + asm_test_instrs: list[tuple[int, str]] = [] for idx, (ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err) in enumerate(decoded_instrs): if our_disasm is None: continue - # Skip unknown opcodes and malformed instructions for both tests if our_disasm.startswith('op_') or re.search(r', \d+, \d+, \d+,', our_disasm): continue asm_test_instrs.append((idx, our_disasm)) - disasm_test_instrs.append((idx, our_disasm)) # Batch compile for asm test - asm_llvm_results = compile_asm_batch([d for _, d in asm_test_instrs]) + asm_llvm_results = compile_asm_batch([d for _, d in asm_test_instrs], self.mcpu) asm_llvm_map = {idx: result for (idx, _), result in zip(asm_test_instrs, asm_llvm_results)} # Batch compile+disasm for disasm comparison test - disasm_llvm_results = compile_and_disasm_batch([d for _, d in disasm_test_instrs], compiler) - disasm_llvm_map = {idx: result for (idx, _), result in zip(disasm_test_instrs, disasm_llvm_results)} + disasm_llvm_results = compile_and_disasm_batch([d for _, d in asm_test_instrs], self.mcpu) + disasm_llvm_map = {idx: result for (idx, _), result in zip(asm_test_instrs, disasm_llvm_results)} - # Now evaluate results + # Evaluate results decode_passed, decode_failed, decode_skipped = 0, 0, 0 asm_passed, asm_failed, asm_skipped = 0, 0, 0 disasm_passed, disasm_failed, disasm_skipped = 0, 0, 0 - decode_failures: list[str] = [] - asm_failures: list[str] = [] - disasm_failures: list[str] = [] + decode_failures, asm_failures, disasm_failures = [], [], [] for idx, (ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err) in enumerate(decoded_instrs): - # Decode test - if decode_ok: - decode_passed += 1 - elif decode_err == "no format": - decode_skipped += 1 + if decode_ok: decode_passed += 1 + elif decode_err == "no format": decode_skipped += 1 else: decode_failed += 1 decode_failures.append(f"K{ki}@{offset}: {our_disasm}: {decode_err}") - # Asm test if our_disasm is None: asm_skipped += 1 + disasm_skipped += 1 elif idx in asm_llvm_map: llvm_bytes = asm_llvm_map[idx] try: our_bytes = asm(our_disasm).to_bytes() - if our_bytes[:len(llvm_bytes)] == llvm_bytes: - asm_passed += 1 + if our_bytes[:len(llvm_bytes)] == llvm_bytes: asm_passed += 1 else: asm_failed += 1 asm_failures.append(f"K{ki}@{offset}: '{our_disasm}': ours={our_bytes[:len(llvm_bytes)].hex()} llvm={llvm_bytes.hex()}") except Exception: asm_skipped += 1 + + if idx in disasm_llvm_map: + if our_disasm == disasm_llvm_map[idx]: disasm_passed += 1 + else: + disasm_failed += 1 + disasm_failures.append(f"K{ki}@{offset}: ours='{our_disasm}' llvm='{disasm_llvm_map[idx]}'") + else: + disasm_skipped += 1 else: asm_skipped += 1 - - # Disasm comparison test - if our_disasm is None: - disasm_skipped += 1 - elif idx in disasm_llvm_map: - llvm_disasm = disasm_llvm_map[idx] - if our_disasm == llvm_disasm: - disasm_passed += 1 - else: - disasm_failed += 1 - disasm_failures.append(f"K{ki}@{offset}: ours='{our_disasm}' llvm='{llvm_disasm}'") - else: disasm_skipped += 1 - print(f"decode roundtrip: {decode_passed} passed, {decode_failed} failed, {decode_skipped} skipped") - print(f"asm vs llvm: {asm_passed} passed, {asm_failed} failed, {asm_skipped} skipped") - print(f"disasm vs llvm: {disasm_passed} passed, {disasm_failed} failed, {disasm_skipped} skipped") + print(f"{self.arch.upper()} decode roundtrip: {decode_passed} passed, {decode_failed} failed, {decode_skipped} skipped") + print(f"{self.arch.upper()} asm vs llvm: {asm_passed} passed, {asm_failed} failed, {asm_skipped} skipped") + print(f"{self.arch.upper()} disasm vs llvm: {disasm_passed} passed, {disasm_failed} failed, {disasm_skipped} skipped") self.assertEqual(decode_failed, 0, f"Decode failures:\n" + "\n".join(decode_failures[:20])) self.assertEqual(asm_failed, 0, f"Asm failures:\n" + "\n".join(asm_failures[:20])) - # Note: disasm string comparison is informational only - formatting differences between LLVM versions are expected - # Basic unary ops +class TestRoundtripRDNA3(TestRoundtripBase): + """Roundtrip tests for RDNA3 (gfx1100).""" + mcpu, arch = 'gfx1100', 'rdna3' + def test_neg(self): self._test_kernel_roundtrip(lambda T: -T([1.0, -2.0, 3.0, -4.0])) def test_relu(self): self._test_kernel_roundtrip(lambda T: T([-1.0, 0.0, 1.0, 2.0]).relu()) def test_exp(self): self._test_kernel_roundtrip(lambda T: T([0.0, 1.0, 2.0]).exp()) @@ -222,42 +185,62 @@ class TestTinygradKernelRoundtrip(unittest.TestCase): def test_sin(self): self._test_kernel_roundtrip(lambda T: T([0.0, 1.0, 2.0]).sin()) def test_sqrt(self): self._test_kernel_roundtrip(lambda T: T([1.0, 4.0, 9.0]).sqrt()) def test_recip(self): self._test_kernel_roundtrip(lambda T: T([1.0, 2.0, 4.0]).reciprocal()) - - # Binary ops def test_add(self): self._test_kernel_roundtrip(lambda T: T([1.0, 2.0]) + T([3.0, 4.0])) def test_sub(self): self._test_kernel_roundtrip(lambda T: T([5.0, 6.0]) - T([1.0, 2.0])) def test_mul(self): self._test_kernel_roundtrip(lambda T: T([2.0, 3.0]) * T([4.0, 5.0])) def test_div(self): self._test_kernel_roundtrip(lambda T: T([10.0, 20.0]) / T([2.0, 4.0])) def test_max_binary(self): self._test_kernel_roundtrip(lambda T: T([1.0, 5.0]).maximum(T([3.0, 2.0]))) - - # Reductions def test_sum_reduce(self): self._test_kernel_roundtrip(lambda T: T.empty(64).sum()) def test_max_reduce(self): self._test_kernel_roundtrip(lambda T: T.empty(64).max()) def test_mean_reduce(self): self._test_kernel_roundtrip(lambda T: T.empty(32).mean()) - - # Matmul def test_gemm_4x4(self): self._test_kernel_roundtrip(lambda T: T.empty(4, 4) @ T.empty(4, 4)) def test_gemv(self): self._test_kernel_roundtrip(lambda T: T.empty(1, 16) @ T.empty(16, 16)) - - # Complex ops def test_softmax(self): self._test_kernel_roundtrip(lambda T: T.empty(16).softmax()) def test_layernorm(self): self._test_kernel_roundtrip(lambda T: T.empty(8, 8).layernorm()) - - # Memory patterns def test_contiguous(self): self._test_kernel_roundtrip(lambda T: T.empty(4, 4).permute(1, 0).contiguous()) def test_reshape(self): self._test_kernel_roundtrip(lambda T: (T.empty(16) + 1).reshape(4, 4).contiguous()) def test_expand(self): self._test_kernel_roundtrip(lambda T: T.empty(4, 1).expand(4, 4).contiguous()) - - # Cast ops def test_cast_int(self): self._test_kernel_roundtrip(lambda T: T.empty(16).int().float()) def test_cast_half(self): self._test_kernel_roundtrip(lambda T: T.empty(16).half().float()) - - # Comparison ops def test_cmp_lt(self): self._test_kernel_roundtrip(lambda T: (T.empty(64) < T.empty(64)).where(T.empty(64), T.empty(64))) def test_where(self): self._test_kernel_roundtrip(lambda T: (T.empty(64) > 0).where(T.empty(64), T.empty(64))) - - # Fused ops def test_fma(self): self._test_kernel_roundtrip(lambda T: (T([1.0, 2.0]) * T([3.0, 4.0]) + T([5.0, 6.0]))) +@unittest.skipUnless(os.environ.get("TEST_RDNA4"), "RDNA4 roundtrip tests require TEST_RDNA4=1 and gfx1200 hardware") +class TestRoundtripRDNA4(TestRoundtripBase): + """Roundtrip tests for RDNA4 (gfx1200).""" + mcpu, arch = 'gfx1200', 'rdna4' + + def test_neg(self): self._test_kernel_roundtrip(lambda T: -T([1.0, -2.0, 3.0, -4.0])) + def test_relu(self): self._test_kernel_roundtrip(lambda T: T([-1.0, 0.0, 1.0, 2.0]).relu()) + def test_exp(self): self._test_kernel_roundtrip(lambda T: T([0.0, 1.0, 2.0]).exp()) + def test_log(self): self._test_kernel_roundtrip(lambda T: T([1.0, 2.0, 3.0]).log()) + def test_sin(self): self._test_kernel_roundtrip(lambda T: T([0.0, 1.0, 2.0]).sin()) + def test_sqrt(self): self._test_kernel_roundtrip(lambda T: T([1.0, 4.0, 9.0]).sqrt()) + def test_recip(self): self._test_kernel_roundtrip(lambda T: T([1.0, 2.0, 4.0]).reciprocal()) + def test_add(self): self._test_kernel_roundtrip(lambda T: T([1.0, 2.0]) + T([3.0, 4.0])) + def test_sub(self): self._test_kernel_roundtrip(lambda T: T([5.0, 6.0]) - T([1.0, 2.0])) + def test_mul(self): self._test_kernel_roundtrip(lambda T: T([2.0, 3.0]) * T([4.0, 5.0])) + def test_div(self): self._test_kernel_roundtrip(lambda T: T([10.0, 20.0]) / T([2.0, 4.0])) + def test_max_binary(self): self._test_kernel_roundtrip(lambda T: T([1.0, 5.0]).maximum(T([3.0, 2.0]))) + def test_sum_reduce(self): self._test_kernel_roundtrip(lambda T: T.empty(64).sum()) + def test_max_reduce(self): self._test_kernel_roundtrip(lambda T: T.empty(64).max()) + def test_mean_reduce(self): self._test_kernel_roundtrip(lambda T: T.empty(32).mean()) + def test_gemm_4x4(self): self._test_kernel_roundtrip(lambda T: T.empty(4, 4) @ T.empty(4, 4)) + def test_gemv(self): self._test_kernel_roundtrip(lambda T: T.empty(1, 16) @ T.empty(16, 16)) + def test_softmax(self): self._test_kernel_roundtrip(lambda T: T.empty(16).softmax()) + def test_layernorm(self): self._test_kernel_roundtrip(lambda T: T.empty(8, 8).layernorm()) + def test_contiguous(self): self._test_kernel_roundtrip(lambda T: T.empty(4, 4).permute(1, 0).contiguous()) + def test_reshape(self): self._test_kernel_roundtrip(lambda T: (T.empty(16) + 1).reshape(4, 4).contiguous()) + def test_expand(self): self._test_kernel_roundtrip(lambda T: T.empty(4, 1).expand(4, 4).contiguous()) + def test_cast_int(self): self._test_kernel_roundtrip(lambda T: T.empty(16).int().float()) + def test_cast_half(self): self._test_kernel_roundtrip(lambda T: T.empty(16).half().float()) + def test_cmp_lt(self): self._test_kernel_roundtrip(lambda T: (T.empty(64) < T.empty(64)).where(T.empty(64), T.empty(64))) + def test_where(self): self._test_kernel_roundtrip(lambda T: (T.empty(64) > 0).where(T.empty(64), T.empty(64))) + def test_fma(self): self._test_kernel_roundtrip(lambda T: (T([1.0, 2.0]) * T([3.0, 4.0]) + T([5.0, 6.0]))) + +# Keep old class name for backwards compatibility +TestTinygradKernelRoundtrip = TestRoundtripRDNA3 + if __name__ == "__main__": unittest.main()