From e7b5d8a4349dbe650d24f6d14e7120c932842fe7 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 8 Jan 2026 05:09:37 -0800 Subject: [PATCH] assembly/amd: more RDNA4 asm (#14062) * rdna4 more * asm * fixes * assembly/amd: handwritten wmma failing test * passes * wmma default hacks * space * 0 skips in rdna3/rdna4 disasm * more RDNA4 tests --------- Co-authored-by: qazal --- extra/assembly/amd/asm.py | 460 +++++++++++++++++--- extra/assembly/amd/autogen/rdna4/enum.py | 7 + extra/assembly/amd/autogen/rdna4/ins.py | 7 + extra/assembly/amd/dsl.py | 45 +- extra/assembly/amd/pdf.py | 5 + extra/assembly/amd/test/test_handwritten.py | 3 + extra/assembly/amd/test/test_llvm.py | 19 +- 7 files changed, 470 insertions(+), 76 deletions(-) diff --git a/extra/assembly/amd/asm.py b/extra/assembly/amd/asm.py index 66e2ca11ba..9da9488e8b 100644 --- a/extra/assembly/amd/asm.py +++ b/extra/assembly/amd/asm.py @@ -11,7 +11,7 @@ from extra.assembly.amd.autogen.rdna3.enum import BufFmt from extra.assembly.amd.autogen.rdna4 import ins as rdna4_ins from extra.assembly.amd.autogen.rdna4.ins import (VOP1 as R4_VOP1, VOP2 as R4_VOP2, VOP3 as R4_VOP3, VOP3SD as R4_VOP3SD, VOP3P as R4_VOP3P, VOPC as R4_VOPC, VOPD as R4_VOPD, VINTERP as R4_VINTERP, SOP1 as R4_SOP1, SOP2 as R4_SOP2, SOPC as R4_SOPC, SOPK as R4_SOPK, SOPP as R4_SOPP, - SMEM as R4_SMEM, DS as R4_DS, VBUFFER as R4_VBUFFER, VEXPORT as R4_VEXPORT) + SMEM as R4_SMEM, DS as R4_DS, VBUFFER as R4_VBUFFER, VEXPORT as R4_VEXPORT, VOPDOp as R4_VOPDOp) def _is_cdna(inst: Inst) -> bool: return 'cdna' in inst.__class__.__module__ @@ -80,13 +80,25 @@ HWREG = {1: 'HW_REG_MODE', 2: 'HW_REG_STATUS', 3: 'HW_REG_TRAPSTS', 4: 'HW_REG_H 19: 'HW_REG_PERF_SNAPSHOT_PC_HI', 20: 'HW_REG_FLAT_SCR_LO', 21: 'HW_REG_FLAT_SCR_HI', 22: 'HW_REG_XNACK_MASK', 23: 'HW_REG_HW_ID1', 24: 'HW_REG_HW_ID2', 25: 'HW_REG_POPS_PACKER', 28: 'HW_REG_IB_STS2'} HWREG_IDS = {v.lower(): k for k, v in HWREG.items()} +# RDNA4 uses different hwreg names (WAVE_ prefix) +HWREG_RDNA4 = {1: 'HW_REG_WAVE_MODE', 2: 'HW_REG_WAVE_STATUS', 4: 'HW_REG_WAVE_STATE_PRIV', 5: 'HW_REG_WAVE_GPR_ALLOC', + 6: 'HW_REG_WAVE_LDS_ALLOC', 7: 'HW_REG_IB_STS', 10: 'HW_REG_PERF_SNAPSHOT_DATA', 11: 'HW_REG_PERF_SNAPSHOT_PC_LO', + 12: 'HW_REG_PERF_SNAPSHOT_PC_HI', 15: 'HW_REG_PERF_SNAPSHOT_DATA1', 16: 'HW_REG_PERF_SNAPSHOT_DATA2', + 17: 'HW_REG_WAVE_EXCP_FLAG_PRIV', 18: 'HW_REG_WAVE_EXCP_FLAG_USER', 19: 'HW_REG_WAVE_TRAP_CTRL', + 20: 'HW_REG_WAVE_SCRATCH_BASE_LO', 21: 'HW_REG_WAVE_SCRATCH_BASE_HI', 23: 'HW_REG_WAVE_HW_ID1', + 24: 'HW_REG_WAVE_HW_ID2', 26: 'HW_REG_WAVE_SCHED_MODE', 29: 'HW_REG_SHADER_CYCLES_LO', + 30: 'HW_REG_SHADER_CYCLES_HI', 31: 'HW_REG_WAVE_DVGPR_ALLOC_LO', 32: 'HW_REG_WAVE_DVGPR_ALLOC_HI'} # RDNA unified buffer format - extracted from PDF, use enum for name->value lookup BUF_FMT = {e.name: e.value for e in BufFmt} +# Extended format map for formats missing from enum (computed from observed patterns) +_BUF_FMT_EXT = {'BUF_FMT_32_32_32_32_SINT': 62, 'BUF_FMT_32_32_32_32_FLOAT': 63, 'BUF_FMT_8_FLOAT': 108} +BUF_FMT.update(_BUF_FMT_EXT) def _parse_buf_fmt_combo(s: str) -> int: # parse format:[BUF_DATA_FORMAT_X, BUF_NUM_FORMAT_Y] parts = [p.strip().replace('BUF_DATA_FORMAT_', '').replace('BUF_NUM_FORMAT_', '') for p in s.split(',')] return BUF_FMT.get(f'BUF_FMT_{parts[0]}_{parts[1]}') if len(parts) == 2 else None MSG = {128: 'MSG_RTN_GET_DOORBELL', 129: 'MSG_RTN_GET_DDID', 130: 'MSG_RTN_GET_TMA', - 131: 'MSG_RTN_GET_REALTIME', 132: 'MSG_RTN_SAVE_WAVE', 133: 'MSG_RTN_GET_TBA'} + 131: 'MSG_RTN_GET_REALTIME', 132: 'MSG_RTN_SAVE_WAVE', 133: 'MSG_RTN_GET_TBA', + 134: 'MSG_RTN_GET_TBA_TO_PC', 135: 'MSG_RTN_GET_SE_AID_ID'} # ═══════════════════════════════════════════════════════════════════════════════ # HELPERS @@ -135,6 +147,7 @@ def _vop3_src(inst, v: int, neg: int, abs_: int, hi: int, n: int, f16: bool) -> if v == 255: s = inst.lit(v) # literal constant takes priority elif n > 1: s = _fmt_src(v, n) elif f16 and v >= 256: s = f"v{v - 256}.h" if hi else f"v{v - 256}.l" + elif v == 253: s = "src_scc" # VOP3 sources use src_scc not scc else: s = inst.lit(v) if abs_: s = f"|{s}|" return f"-{s}" if neg else s @@ -162,12 +175,16 @@ def _disasm_vop1(inst: VOP1) -> str: # 16-bit dst: uses .h/.l suffix for RDNA (CDNA uses plain vN) parts = name.split('_') is_16d = not cdna and (any(p in ('f16','i16','u16','b16') for p in parts[-2:-1]) or (len(parts) >= 2 and parts[-1] in ('f16','i16','u16','b16') and 'cvt' not in name)) - dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else _fmt_v16(inst.vdst, 0, 128) if is_16d else f"v{inst.vdst}" - src = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0), cdna) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if not cdna and inst.is_src_16(0) and 'sat_pk' not in name else inst.lit(inst.src0) + # v_cvt_pk_f32_fp8 and v_cvt_pk_f32_bf8 output to 2 VGPRs, and take 16-bit src + is_pk_fp8 = 'cvt_pk_f32_fp8' in name or 'cvt_pk_f32_bf8' in name + dregs = 2 if is_pk_fp8 else inst.dst_regs() + dst = _vreg(inst.vdst, dregs) if dregs > 1 else _fmt_v16(inst.vdst, 0, 128) if is_16d else f"v{inst.vdst}" + src = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0), cdna) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if not cdna and (inst.is_src_16(0) or is_pk_fp8) and 'sat_pk' not in name else inst.lit(inst.src0) return f"{name}{suf} {dst}, {src}" _VOP2_CARRY_OUT = {'v_add_co_u32', 'v_sub_co_u32', 'v_subrev_co_u32'} # carry out only -_VOP2_CARRY_INOUT = {'v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'} # carry in and out +_VOP2_CARRY_INOUT = {'v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'} # carry in and out (CDNA) +_VOP2_CARRY_INOUT_RDNA = {'v_add_co_ci_u32', 'v_sub_co_ci_u32', 'v_subrev_co_ci_u32'} # carry in and out (RDNA) def _disasm_vop2(inst: VOP2) -> str: name, cdna = inst.op_name.lower(), _is_cdna(inst) if cdna: name = _CDNA_DISASM_ALIASES.get(name, name) # apply CDNA aliases @@ -186,6 +203,15 @@ def _disasm_vop2(inst: VOP2) -> str: # CDNA carry ops output vcc after vdst if cdna and name in _VOP2_CARRY_OUT: return f"{name}{suf} v{inst.vdst}, {vcc}, {inst.lit(inst.src0)}, v{inst.vsrc1}" if cdna and name in _VOP2_CARRY_INOUT: return f"{name}{suf} v{inst.vdst}, {vcc}, {inst.lit(inst.src0)}, v{inst.vsrc1}, {vcc}" + # RDNA carry-in/out ops: v_add_co_ci_u32, etc. - format: vdst, vcc_lo, src0, vsrc1, vcc_lo + if not cdna and name in _VOP2_CARRY_INOUT_RDNA: return f"{name}{suf} v{inst.vdst}, {vcc}, {inst.lit(inst.src0)}, v{inst.vsrc1}, {vcc}" + # Handle 64-bit register operands (v_add_f64, v_mul_f64, etc.) + dn, sn0, sn1 = inst.dst_regs(), inst.src_regs(0), inst.src_regs(1) + if dn > 1 or sn0 > 1 or sn1 > 1: + dst = _vreg(inst.vdst, dn) if dn > 1 else f"v{inst.vdst}" + src0 = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, sn0, cdna) + src1 = _vreg(inst.vsrc1, sn1) if sn1 > 1 else f"v{inst.vsrc1}" + return f"{name} {dst}, {src0}, {src1}" return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}" + (f", {vcc}" if name == 'v_cndmask_b32' else "") def _disasm_vopc(inst: VOPC) -> str: @@ -241,19 +267,21 @@ def _disasm_smem(inst: SMEM) -> str: # soe=0, imm=1: offset is immediate # soe=0, imm=0: offset field is SGPR encoding (0-255) soe, imm = getattr(inst, 'soe', 0), getattr(inst, 'imm', 1) + is_rdna4 = 'rdna4' in inst.__class__.__module__ + offset = inst.ioffset if is_rdna4 else getattr(inst, 'offset', 0) # RDNA4 uses ioffset, others use offset if cdna: if soe and imm: - off_s = f"{decode_src(inst.soffset, cdna)} offset:0x{inst.offset:x}" # SGPR + immediate + off_s = f"{decode_src(inst.soffset, cdna)} offset:0x{offset:x}" # SGPR + immediate elif imm: - off_s = f"0x{inst.offset:x}" # Immediate offset only - elif inst.offset < 256: - off_s = decode_src(inst.offset, cdna) # SGPR encoding in offset field + off_s = f"0x{offset:x}" # Immediate offset only + elif offset < 256: + off_s = decode_src(offset, cdna) # SGPR encoding in offset field else: off_s = decode_src(inst.soffset, cdna) - elif inst.offset and inst.soffset != 124: - off_s = f"{decode_src(inst.soffset, cdna)} offset:0x{inst.offset:x}" - elif inst.offset: - off_s = f"0x{inst.offset:x}" + elif offset and inst.soffset != 124: + off_s = f"{decode_src(inst.soffset, cdna)} offset:0x{offset:x}" + elif offset: + off_s = f"0x{offset:x}" else: off_s = decode_src(inst.soffset, cdna) op_val = inst.op.value if hasattr(inst.op, 'value') else inst.op @@ -262,6 +290,23 @@ def _disasm_smem(inst: SMEM) -> str: sbase_idx, sbase_count = inst.sbase * 2, 4 if is_buffer else 2 sbase_str = _fmt_src(sbase_idx, sbase_count, cdna) if sbase_count == 2 else _sreg(sbase_idx, sbase_count) if sbase_idx <= 105 else _reg("ttmp", sbase_idx - 108, sbase_count) if name in ('s_atc_probe', 's_atc_probe_buffer'): return f"{name} {inst.sdata}, {sbase_str}, {off_s}" + # RDNA4 prefetch instructions: sbase, offset, soffset, sdata (cache hint) + if 'prefetch' in name: + off = getattr(inst, 'ioffset', inst.offset) # RDNA4 uses ioffset + # Handle 24-bit signed offset: 0xffffff = -1, 0x800000+ are negative + if off >= 0x800000: off = off - 0x1000000 # convert to signed + off_s = f"0x{off:x}" if off > 255 else str(off) + soff_s = decode_src(inst.soffset, cdna) if inst.soffset != 124 else "null" + if 'pc_rel' in name: + return f"{name} {off_s}, {soff_s}, {inst.sdata}" + return f"{name} {sbase_str}, {off_s}, {soff_s}, {inst.sdata}" + # RDNA4 uses th (temporal hint) and scope instead of glc/dlc + th, scope = getattr(inst, 'th', 0), getattr(inst, 'scope', 0) + if th or scope: + th_names = ['TH_LOAD_RT', 'TH_LOAD_NT', 'TH_LOAD_HT', 'TH_LOAD_LU'] + scope_names = ['SCOPE_CU', 'SCOPE_SE', 'SCOPE_DEV', 'SCOPE_SYS'] + mods = (f" th:{th_names[th]}" if th else "") + (f" scope:{scope_names[scope]}" if scope else "") + return f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs(), cdna)}, {sbase_str}, {off_s}{mods}" return f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs(), cdna)}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (getattr(inst, 'dlc', 0), " dlc")) def _disasm_flat(inst: FLAT) -> str: @@ -310,6 +355,12 @@ def _disasm_ds(inst: DS) -> str: if op == DSOp.DS_NOP: return name if op == DSOp.DS_BVH_STACK_RTN_B32: return f"{name} v{inst.vdst}, {addr}, v{inst.data0}, {_vreg(inst.data1, 4)}{off}{gds}" + # RDNA4 BVH stack instructions: push4=4 VGPRs, push8=8 VGPRs for data1 + if 'bvh_stack_push' in name: + d1_regs = 8 if 'push8' in name else 4 + vdst_regs = 2 if 'pop2' in name else 1 + vdst_s = _vreg(inst.vdst, vdst_regs) if vdst_regs > 1 else f"v{inst.vdst}" + return f"{name} {vdst_s}, {addr}, v{inst.data0}, {_vreg(inst.data1, d1_regs)}{off}{gds}" if 'gws_sema' in name and op != DSOp.DS_GWS_SEMA_BR: return f"{name}{off}{gds}" if 'gws_' in name: return f"{name} {addr}{off}{gds}" if op in (DSOp.DS_CONSUME, DSOp.DS_APPEND): return f"{name} {rp}{inst.vdst}{off}{gds}" @@ -333,11 +384,19 @@ def _disasm_ds(inst: DS) -> str: def _disasm_vop3(inst: VOP3) -> str: op, name = inst.op, inst.op_name.lower() + # RDNA4 v_s_* scalar VOP3 instructions (output to SGPR, not VGPR) + if name.startswith('v_s_'): + src = inst.lit(inst.src0) if inst.src0 == 255 else ("src_scc" if inst.src0 == 253 else _fmt_src(inst.src0, inst.src_regs(0))) + if inst.neg & 1: src = f"-{src}" + if inst.abs & 1: src = f"|{src}|" + clamp = inst.cm if 'cm' in inst._fields else getattr(inst, 'clmp', 0) # RDNA4 uses 'cm', RDNA3 uses 'clmp' + return f"{name} s{inst.vdst}, {src}" + (" clamp" if clamp else "") + _omod(inst.omod) + # VOP3SD (shared encoding) if isinstance(op, VOP3SDOp): sdst = (inst.clmp << 7) | (inst.opsel << 3) | inst.abs def src(v, neg, n): - s = inst.lit(v) if v == 255 else (_fmt_src(v, n) if n > 1 else inst.lit(v)) + s = inst.lit(v) if v == 255 else ("src_scc" if v == 253 else (_fmt_src(v, n) if n > 1 else inst.lit(v))) return f"neg({s})" if neg and v == 255 else (f"-{s}" if neg else s) s0, s1, s2 = src(inst.src0, inst.neg & 1, inst.src_regs(0)), src(inst.src1, inst.neg & 2, inst.src_regs(1)), src(inst.src2, inst.neg & 4, inst.src_regs(2)) dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}" @@ -366,7 +425,8 @@ def _disasm_vop3(inst: VOP3) -> str: elif is16_d: dst = f"v{inst.vdst}.h" if (inst.opsel & 8) else f"v{inst.vdst}.l" else: dst = f"v{inst.vdst}" - cl, om = " clamp" if inst.clmp else "", _omod(inst.omod) + clamp = inst.cm if 'cm' in inst._fields else getattr(inst, 'clmp', 0) # RDNA4 uses 'cm', RDNA3 uses 'clmp' + cl, om = " clamp" if clamp else "", _omod(inst.omod) nonvgpr_opsel = (inst.src0 < 256 and (inst.opsel & 1)) or (inst.src1 < 256 and (inst.opsel & 2)) or (inst.src2 < 256 and (inst.opsel & 4)) need_opsel = nonvgpr_opsel or (inst.opsel and not is16_s) @@ -377,26 +437,40 @@ def _disasm_vop3(inst: VOP3) -> str: os = _opsel_str(inst.opsel, n, need_opsel, is16_d) return f"{name}_e64 {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if n == 3 else f"{name}_e64 {dst}, {s0}, {s1}{os}{cl}{om}" if inst.op < 512: # VOP1 - return f"{name}_e64" if op in (VOP3Op.V_NOP, VOP3Op.V_PIPEFLUSH) else f"{name}_e64 {dst}, {s0}{_opsel_str(inst.opsel, 1, need_opsel, is16_d)}{cl}{om}" + # v_cvt_f32_bf8/fp8 use byte_sel instead of op_sel (opsel bits [1:0] map to byte_sel [0],[1] swapped) + if re.match(r'v_cvt_f32_(bf|fp)8', name) and inst.opsel: + os = f" byte_sel:{((inst.opsel & 1) << 1) | ((inst.opsel >> 1) & 1)}" + else: + os = _opsel_str(inst.opsel, 1, need_opsel, is16_d) + return f"{name}_e64" if op in (VOP3Op.V_NOP, VOP3Op.V_PIPEFLUSH) else f"{name}_e64 {dst}, {s0}{os}{cl}{om}" # Native VOP3 n = inst.num_srcs() - os = _opsel_str(inst.opsel, n, need_opsel, is16_d) + # v_cvt_sr_*_f32 uses byte_sel instead of op_sel + if 'cvt_sr' in name and inst.opsel: + os = f" byte_sel:{inst.opsel >> 2}" + else: + os = _opsel_str(inst.opsel, n, need_opsel, is16_d) return f"{name} {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if n == 3 else f"{name} {dst}, {s0}, {s1}{os}{cl}{om}" def _disasm_vop3sd(inst: VOP3SD) -> str: name = inst.op_name.lower() + # For 64-bit carry instructions (mad_co_*64*), src2 is a carry-in pair + src2_n = 2 if '_co_' in name and '64' in name else inst.src_regs(2) def src(v, neg, n): - s = inst.lit(v) if v == 255 else (_fmt_src(v, n) if n > 1 else inst.lit(v)) + s = inst.lit(v) if v == 255 else ("src_scc" if v == 253 else (_fmt_src(v, n) if n > 1 else inst.lit(v))) return f"neg({s})" if neg and v == 255 else (f"-{s}" if neg else s) - s0, s1, s2 = src(inst.src0, inst.neg & 1, inst.src_regs(0)), src(inst.src1, inst.neg & 2, inst.src_regs(1)), src(inst.src2, inst.neg & 4, inst.src_regs(2)) + s0, s1, s2 = src(inst.src0, inst.neg & 1, inst.src_regs(0)), src(inst.src1, inst.neg & 2, inst.src_regs(1)), src(inst.src2, inst.neg & 4, src2_n) dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}" srcs = f"{s0}, {s1}, {s2}" if inst.num_srcs() == 3 else f"{s0}, {s1}" - suffix = "_e64" if name.startswith('v_') and 'co_' in name else "" - return f"{name}{suffix} {dst}, {_fmt_sdst(inst.sdst, 1)}, {srcs}{' clamp' if inst.clmp else ''}{_omod(inst.omod)}" + # VOP3SD sdst: always single register for RDNA3/4 + clamp = inst.cm if 'cm' in inst._fields else getattr(inst, 'clmp', 0) + return f"{name} {dst}, {_fmt_sdst(inst.sdst, 1)}, {srcs}{' clamp' if clamp else ''}{_omod(inst.omod)}" def _disasm_vopd(inst: VOPD) -> str: lit = inst._literal or inst.literal - vdst_y, nx, ny = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1), VOPDOp(inst.opx).name.lower(), VOPDOp(inst.opy).name.lower() + is_rdna4 = 'rdna4' in inst.__class__.__module__ + op_enum = R4_VOPDOp if is_rdna4 else VOPDOp + vdst_y, nx, ny = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1), op_enum(inst.opx).name.lower(), op_enum(inst.opy).name.lower() def half(n, vd, s0, vs1): if 'mov' in n: return f"{n} v{vd}, {inst.lit(s0)}" # fmamk: dst = src0 * K + vsrc1, fmaak: dst = src0 * vsrc1 + K @@ -405,22 +479,49 @@ def _disasm_vopd(inst: VOPD) -> str: return f"{n} v{vd}, {inst.lit(s0)}, v{vs1}" return f"{half(nx, inst.vdstx, inst.srcx0, inst.vsrcx1)} :: {half(ny, vdst_y, inst.srcy0, inst.vsrcy1)}" +def _swmmac_regs(name: str) -> tuple[int, int, int, int]: + """Return (dst, src0, src1, src2) register counts for SWMMAC instructions.""" + # v_swmmac_DTYPE_MxNxK_ATYPE[_BTYPE] + if 'f16_16x16x32' in name or 'bf16_16x16x32' in name: return (4, 4, 8, 1) # f16/bf16 output + if 'f32_16x16x32_f16' in name or 'f32_16x16x32_bf16' in name: return (8, 4, 8, 1) # f32 from f16/bf16 + if 'i32_16x16x32_iu4' in name: return (8, 1, 2, 1) + if 'i32_16x16x64_iu4' in name: return (8, 2, 4, 1) + if 'i32_16x16x32_iu8' in name or 'f32_16x16x32_fp8' in name or 'f32_16x16x32_bf8' in name: return (8, 2, 4, 1) + return (8, 8, 8, 8) # default + def _disasm_vop3p(inst: VOP3P) -> str: name = inst.op_name.lower() - is_wmma, n, is_fma_mix = 'wmma' in name, inst.num_srcs(), 'fma_mix' in name + is_wmma, is_swmmac, n, is_fma_mix = 'wmma' in name, 'swmmac' in name, inst.num_srcs(), 'fma_mix' in name def get_src(v, sc): return inst.lit(v) if v == 255 else _fmt_src(v, sc) - if is_wmma: - sc = 2 if 'iu4' in name else 4 if 'iu8' in name else 8 - src0, src1, src2, dst = get_src(inst.src0, sc), get_src(inst.src1, sc), get_src(inst.src2, 8), _vreg(inst.vdst, 8) + if is_swmmac: + dn, s0n, s1n, s2n = _swmmac_regs(name) + src0, src1, src2, dst = get_src(inst.src0, s0n), get_src(inst.src1, s1n), get_src(inst.src2, s2n), _vreg(inst.vdst, dn) + elif is_wmma: + # Regular WMMA src0/src1 sizes based on type and dimensions + # RDNA4: 16x16x16_iu4=1, 16x16x32_iu4=2, 16x16x16_iu8=2, fp8/bf8=2, f16/bf16=4 + # RDNA3: 16x16x16_iu4=2, 16x16x16_iu8=4, f16/bf16=8 (2x RDNA4 for non-SWMMAC), dst always 8 + is_rdna4_wmma = 'rdna4' in inst.__class__.__module__ + sc = 1 if '16x16x16_iu4' in name else 2 if ('iu4' in name or 'iu8' in name or 'fp8' in name or 'bf8' in name) else 4 + if not is_rdna4_wmma: sc *= 2 # RDNA3 uses 2x register count + dc = 8 if not is_rdna4_wmma else (4 if ('f16_16x16' in name or 'bf16_16x16' in name) and 'f32' not in name else 8) # RDNA3 always 8 + src0, src1, src2, dst = get_src(inst.src0, sc), get_src(inst.src1, sc), get_src(inst.src2, dc), _vreg(inst.vdst, dc) else: src0, src1, src2, dst = get_src(inst.src0, 1), get_src(inst.src1, 1), get_src(inst.src2, 1), f"v{inst.vdst}" opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2) + clamp = inst.cm if 'cm' in inst._fields else getattr(inst, 'clmp', 0) # RDNA4 uses 'cm', RDNA3 uses 'clmp' if is_fma_mix: def m(s, neg, abs_): return f"-{f'|{s}|' if abs_ else s}" if neg else (f"|{s}|" if abs_ else s) src0, src1, src2 = m(src0, inst.neg & 1, inst.neg_hi & 1), m(src1, inst.neg & 2, inst.neg_hi & 2), m(src2, inst.neg & 4, inst.neg_hi & 4) - mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi else []) + (["clamp"] if inst.clmp else []) + mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi else []) + (["clamp"] if clamp else []) + elif is_swmmac: + # SWMMAC uses index_key instead of op_sel + mods = ([f"index_key:{inst.opsel}"] if inst.opsel else []) + ([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + \ + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if clamp else []) else: - mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != (7 if n == 3 else 3) else []) + \ - ([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if inst.clmp else []) + # VOP3P default opsel_hi is 7 (all high halves) for all ops including WMMA + # Note: LLVM doesn't accept op_sel_hi modifier for WMMA, so we can only roundtrip WMMA with default opsel_hi=7 + opsel_hi_default = 7 if n == 3 else 3 + mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != opsel_hi_default else []) + \ + ([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if clamp else []) return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if n == 3 else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}" def _disasm_buf(inst: MUBUF | MTBUF) -> str: @@ -511,11 +612,21 @@ def _disasm_sop1(inst: SOP1) -> str: if op in (SOP1Op.S_SETPC_B64, SOP1Op.S_RFE_B64): return f"{name} {src}" if op == SOP1Op.S_SWAPPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}, {src}" if op in (SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64): return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, sendmsg({MSG.get(inst.ssrc0, str(inst.ssrc0))})" + # RDNA4 source-only ops (sdst=NULL) + sop1_src_only = ('S_ALLOC_VGPR', 'S_SLEEP_VAR', 'S_BARRIER_SIGNAL', 'S_BARRIER_SIGNAL_ISFIRST', 'S_BARRIER_INIT', 'S_BARRIER_JOIN') + if inst.op_name in sop1_src_only: return f"{name} {src}" return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs(), cdna)}, {src}" def _disasm_sop2(inst: SOP2) -> str: - cdna = _is_cdna(inst) - return f"{inst.op_name.lower()} {_fmt_sdst(inst.sdst, inst.dst_regs(), cdna)}, {inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0), cdna)}, {inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1), cdna)}" + cdna, name = _is_cdna(inst), inst.op_name.lower() + lit = getattr(inst, '_literal', None) + s0 = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0), cdna) + s1 = inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1), cdna) + dst = _fmt_sdst(inst.sdst, inst.dst_regs(), cdna) + # s_fmamk: dst = src0 * K + src1, s_fmaak: dst = src0 * src1 + K + if 'fmamk' in name and lit is not None: return f"{name} {dst}, {s0}, 0x{lit:x}, {s1}" + if 'fmaak' in name and lit is not None: return f"{name} {dst}, {s0}, {s1}, 0x{lit:x}" + return f"{name} {dst}, {s0}, {s1}" def _disasm_sopc(inst: SOPC) -> str: cdna = _is_cdna(inst) @@ -525,15 +636,21 @@ def _disasm_sopc(inst: SOPC) -> str: def _disasm_sopk(inst: SOPK) -> str: op, name, cdna = inst.op, inst.op_name.lower(), _is_cdna(inst) + is_rdna4 = 'rdna4' in inst.__class__.__module__ + hw = HWREG # For RDNA4, just use numeric ID since LLVM-18 doesn't know WAVE_ prefixed names + def fmt_hwreg(hid, hoff, hsz): + if hid not in hw: return f"0x{inst.simm16:x}" # unknown hwreg ID, output raw hex + # For RDNA4, use numeric ID instead of name (LLVM-18 doesn't support WAVE_ prefixed names) + hr_name = str(hid) if is_rdna4 else hw[hid] + return f"hwreg({hr_name})" if hoff == 0 and hsz == 32 else f"hwreg({hr_name}, {hoff}, {hsz})" # s_setreg_imm32_b32 has a 32-bit literal value if name == 's_setreg_imm32_b32' or (not cdna and op == SOPKOp.S_SETREG_IMM32_B32): hid, hoff, hsz = inst.simm16 & 0x3f, (inst.simm16 >> 6) & 0x1f, ((inst.simm16 >> 11) & 0x1f) + 1 - hs = f"0x{inst.simm16:x}" if hid in (16, 17) else f"hwreg({HWREG.get(hid, str(hid))}, {hoff}, {hsz})" - return f"{name} {hs}, 0x{inst._literal:x}" + return f"{name} {fmt_hwreg(hid, hoff, hsz)}, 0x{inst._literal:x}" if not cdna and op == SOPKOp.S_VERSION: return f"{name} 0x{inst.simm16:x}" if (not cdna and op in (SOPKOp.S_SETREG_B32, SOPKOp.S_GETREG_B32)) or (cdna and name in ('s_setreg_b32', 's_getreg_b32')): hid, hoff, hsz = inst.simm16 & 0x3f, (inst.simm16 >> 6) & 0x1f, ((inst.simm16 >> 11) & 0x1f) + 1 - hs = f"0x{inst.simm16:x}" if hid in (16, 17) else f"hwreg({HWREG.get(hid, str(hid))}, {hoff}, {hsz})" + hs = fmt_hwreg(hid, hoff, hsz) return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1, cdna)}" if 'setreg' in name else f"{name} {_fmt_sdst(inst.sdst, 1, cdna)}, {hs}" if not cdna and op in (SOPKOp.S_SUBVECTOR_LOOP_BEGIN, SOPKOp.S_SUBVECTOR_LOOP_END): return f"{name} {_fmt_sdst(inst.sdst, 1)}, 0x{inst.simm16:x}" @@ -543,9 +660,49 @@ def _disasm_vinterp(inst: VINTERP) -> str: mods = _mods((inst.waitexp, f"wait_exp:{inst.waitexp}"), (inst.clmp, "clamp")) return f"{inst.op_name.lower()} v{inst.vdst}, {inst.lit(inst.src0, inst.neg & 1)}, {inst.lit(inst.src1, inst.neg & 2)}, {inst.lit(inst.src2, inst.neg & 4)}" + (" " + mods if mods else "") +EXP_TARGETS = {0: 'mrt0', 1: 'mrt1', 2: 'mrt2', 3: 'mrt3', 4: 'mrt4', 5: 'mrt5', 6: 'mrt6', 7: 'mrt7', + 8: 'mrtz', 9: 'null', 12: 'pos0', 13: 'pos1', 14: 'pos2', 15: 'pos3', 16: 'pos4', + 32: 'param0', 33: 'param1', 34: 'param2', 35: 'param3', 36: 'param4', 37: 'param5'} +def _disasm_vexport(inst) -> str: + tgt = EXP_TARGETS.get(inst.target, f'{inst.target}') + srcs = [f'v{getattr(inst, f"vsrc{i}")}' if inst.en & (1 << i) else 'off' for i in range(4)] + mods = _mods((inst.done, "done"), (inst.row, "row_en")) + return f"export {tgt} {', '.join(srcs)}" + (" " + mods if mods else "") + +def _disasm_vbuffer(inst) -> str: + name = inst.op_name.lower().replace('buffer_', 'buffer_').replace('tbuffer_', 'tbuffer_') + # Calculate vdata register count like MUBUF: xyzw=4, xyz=3, xy=2, atomic cmpswap needs double + w = (2 if _has(name, 'xyz', 'xyzw') else 1) if 'd16' in name else \ + ((2 if _has(name, 'b64', 'u64', 'i64') else 1) * (2 if 'cmpswap' in name else 1)) if 'atomic' in name else \ + {'b32':1,'b64':2,'b96':3,'b128':4,'b16':1,'x':1,'xy':2,'xyz':3,'xyzw':4}.get(name.split('_')[-1], inst.dst_regs()) + if getattr(inst, 'tfe', 0): w += 1 + vdata = _vreg(inst.vdata, w) if w else f'v{inst.vdata}' + vaddr = _vreg(inst.vaddr, 2) if inst.offen and inst.idxen else (f'v{inst.vaddr}' if inst.offen or inst.idxen else 'off') + # RDNA4 VBUFFER: rsrc is direct SGPR index (not divided by 4), ttmp for 108+ + srsrc = f'ttmp[{inst.rsrc - 108}:{inst.rsrc - 108 + 3}]' if inst.rsrc >= 108 else f's[{inst.rsrc}:{inst.rsrc + 3}]' + soff = decode_src(inst.soffset) if inst.soffset >= 106 else f's{inst.soffset}' + # TBUFFER format handling - format:1 is default and omitted, valid formats use format:[BUF_FMT_*], invalid use format:N + fmt = getattr(inst, 'format', 0) + fmt_names = {e.value: e.name for e in BufFmt} + fmt_s = f" format:[{fmt_names[fmt]}]" if fmt > 1 and fmt in fmt_names else (f" format:{fmt}" if fmt > 1 else "") + # RDNA4 th (temporal hint) and scope modifiers - different mappings for load/store/atomic + if 'atomic' in name: th_names = {1: 'TH_ATOMIC_RETURN', 6: 'TH_ATOMIC_CASCADE_NT'} + elif 'store' in name: th_names = {3: 'TH_STORE_BYPASS', 6: 'TH_STORE_NT_HT'} + else: th_names = {3: 'TH_LOAD_BYPASS', 6: 'TH_LOAD_NT_HT'} + scope_names = {1: 'SCOPE_SE', 2: 'SCOPE_DEV', 3: 'SCOPE_SYS'} + # Modifier order: format idxen offen offset th scope + mods = _mods((inst.idxen, "idxen"), (inst.offen, "offen"), (inst.ioffset, f"offset:{inst.ioffset}"), + (inst.th in th_names, f"th:{th_names.get(inst.th, '')}"), (inst.scope in scope_names, f"scope:{scope_names.get(inst.scope, '')}")) + return f"{name} {vdata}, {vaddr}, {srsrc}, {soff}{fmt_s}" + (" " + mods if mods else "") + DISASM_HANDLERS = {VOP1: _disasm_vop1, VOP2: _disasm_vop2, VOPC: _disasm_vopc, VOP3: _disasm_vop3, VOP3SD: _disasm_vop3sd, VOPD: _disasm_vopd, VOP3P: _disasm_vop3p, VINTERP: _disasm_vinterp, SOPP: _disasm_sopp, SMEM: _disasm_smem, DS: _disasm_ds, FLAT: _disasm_flat, MUBUF: _disasm_buf, MTBUF: _disasm_buf, - MIMG: _disasm_mimg, SOP1: _disasm_sop1, SOP2: _disasm_sop2, SOPC: _disasm_sopc, SOPK: _disasm_sopk} + MIMG: _disasm_mimg, SOP1: _disasm_sop1, SOP2: _disasm_sop2, SOPC: _disasm_sopc, SOPK: _disasm_sopk, + # RDNA4 + R4_VOP1: _disasm_vop1, R4_VOP2: _disasm_vop2, R4_VOPC: _disasm_vopc, R4_VOP3: _disasm_vop3, R4_VOP3SD: _disasm_vop3sd, + R4_VOPD: _disasm_vopd, R4_VOP3P: _disasm_vop3p, R4_VINTERP: _disasm_vinterp, R4_SOPP: _disasm_sopp, R4_SMEM: _disasm_smem, + R4_DS: _disasm_ds, R4_SOP1: _disasm_sop1, R4_SOP2: _disasm_sop2, R4_SOPC: _disasm_sopc, R4_SOPK: _disasm_sopk, + R4_VEXPORT: _disasm_vexport, R4_VBUFFER: _disasm_vbuffer} def disasm(inst: Inst) -> str: return DISASM_HANDLERS[type(inst)](inst) @@ -557,8 +714,10 @@ SPEC_REGS = {'vcc_lo': RawImm(106), 'vcc_hi': RawImm(107), 'vcc': RawImm(106), ' 'exec_lo': RawImm(126), 'exec_hi': RawImm(127), 'exec': RawImm(126), 'scc': RawImm(253), 'src_scc': RawImm(253)} FLOATS = {str(k): k for k in FLOAT_ENC} # Valid float literal strings: '0.5', '-0.5', '1.0', etc. REG_MAP: dict[str, _RegFactory] = {'s': s, 'v': v, 't': ttmp, 'ttmp': ttmp} -SMEM_OPS = {'s_load_b32', 's_load_b64', 's_load_b128', 's_load_b256', 's_load_b512', - 's_buffer_load_b32', 's_buffer_load_b64', 's_buffer_load_b128', 's_buffer_load_b256', 's_buffer_load_b512', +SMEM_OPS = {'s_load_b32', 's_load_b64', 's_load_b96', 's_load_b128', 's_load_b256', 's_load_b512', + 's_load_i8', 's_load_u8', 's_load_i16', 's_load_u16', + 's_buffer_load_b32', 's_buffer_load_b64', 's_buffer_load_b96', 's_buffer_load_b128', 's_buffer_load_b256', 's_buffer_load_b512', + 's_buffer_load_i8', 's_buffer_load_u8', 's_buffer_load_i16', 's_buffer_load_u16', 's_atc_probe', 's_atc_probe_buffer'} SPEC_DSL = {'vcc_lo': 'VCC_LO', 'vcc_hi': 'VCC_HI', 'vcc': 'VCC_LO', 'null': 'NULL', 'off': 'OFF', 'm0': 'M0', 'exec_lo': 'EXEC_LO', 'exec_hi': 'EXEC_HI', 'exec': 'EXEC_LO', 'scc': 'SCC', 'src_scc': 'SCC'} @@ -567,6 +726,8 @@ def _op2dsl(op: str) -> str: op = op.strip() neg = op.startswith('-') and not (op[1:2].isdigit() or (len(op) > 2 and op[1] == '0' and op[2] in 'xX')) if neg: op = op[1:] + # Handle neg(value) syntax + if op.startswith('neg(') and op.endswith(')'): neg = True; op = op[4:-1] abs_ = (op.startswith('|') and op.endswith('|')) or (op.startswith('abs(') and op.endswith(')')) if abs_: op = op[1:-1] if op.startswith('|') else op[4:-1] hi = ".h" if op.endswith('.h') else ".l" if op.endswith('.l') else "" @@ -630,6 +791,23 @@ _ALIASES = { 'v_dot2c_f32_f16': 'v_dot2acc_f32_f16', # More VOP3 aliases 'v_fma_legacy_f32': 'v_fma_dx9_zero_f32', + # DS aliases (RDNA4: ds_read_* -> ds_load_*, ds_write_* -> ds_store_*) + 'ds_read_b32': 'ds_load_b32', 'ds_read_b64': 'ds_load_b64', 'ds_read_b96': 'ds_load_b96', 'ds_read_b128': 'ds_load_b128', + 'ds_read_i8': 'ds_load_i8', 'ds_read_u8': 'ds_load_u8', 'ds_read_i16': 'ds_load_i16', 'ds_read_u16': 'ds_load_u16', + 'ds_read_i8_d16': 'ds_load_i8_d16', 'ds_read_u8_d16': 'ds_load_u8_d16', 'ds_read_i8_d16_hi': 'ds_load_i8_d16_hi', 'ds_read_u8_d16_hi': 'ds_load_u8_d16_hi', + 'ds_read_u16_d16': 'ds_load_u16_d16', 'ds_read_u16_d16_hi': 'ds_load_u16_d16_hi', + 'ds_read2_b32': 'ds_load_2addr_b32', 'ds_read2_b64': 'ds_load_2addr_b64', + 'ds_read2st64_b32': 'ds_load_2addr_stride64_b32', 'ds_read2st64_b64': 'ds_load_2addr_stride64_b64', + 'ds_read_addtid_b32': 'ds_load_addtid_b32', 'ds_write_addtid_b32': 'ds_store_addtid_b32', + 'ds_write_b32': 'ds_store_b32', 'ds_write_b64': 'ds_store_b64', 'ds_write_b96': 'ds_store_b96', 'ds_write_b128': 'ds_store_b128', + 'ds_write_b8': 'ds_store_b8', 'ds_write_b16': 'ds_store_b16', + 'ds_write_b8_d16_hi': 'ds_store_b8_d16_hi', 'ds_write_b16_d16_hi': 'ds_store_b16_d16_hi', + 'ds_write2_b32': 'ds_store_2addr_b32', 'ds_write2_b64': 'ds_store_2addr_b64', + 'ds_write2st64_b32': 'ds_store_2addr_stride64_b32', 'ds_write2st64_b64': 'ds_store_2addr_stride64_b64', + # DS wrxchg aliases (RDNA4: ds_wrxchg* -> ds_storexchg*) + 'ds_wrxchg_rtn_b32': 'ds_storexchg_rtn_b32', 'ds_wrxchg_rtn_b64': 'ds_storexchg_rtn_b64', + 'ds_wrxchg2_rtn_b32': 'ds_storexchg_2addr_rtn_b32', 'ds_wrxchg2_rtn_b64': 'ds_storexchg_2addr_rtn_b64', + 'ds_wrxchg2st64_rtn_b32': 'ds_storexchg_2addr_stride64_rtn_b32', 'ds_wrxchg2st64_rtn_b64': 'ds_storexchg_2addr_stride64_rtn_b64', } def _apply_alias(text: str) -> str: @@ -639,18 +817,23 @@ def _apply_alias(text: str) -> str: if m in _ALIASES: return _ALIASES[m] + text[len(m):] return text -def get_dsl(text: str) -> str: +def get_dsl(text: str, arch: str = "rdna3") -> str: text, kw = _apply_alias(text.strip()), [] # Extract modifiers for pat, val in [(r'\s+mul:2(?:\s|$)', 1), (r'\s+mul:4(?:\s|$)', 2), (r'\s+div:2(?:\s|$)', 3)]: if (m := _extract(text, pat))[0]: kw.append(f'omod={val}'); text = m[1]; break - if (m := _extract(text, r'\s+clamp(?:\s|$)'))[0]: kw.append('clmp=1'); text = m[1] + # For RDNA4, use 'cm' for clamp in VOP3P, otherwise use 'clmp' + clamp_found = False + if (m := _extract(text, r'\s+clamp(?:\s|$)'))[0]: clamp_found = True; text = m[1] opsel, m, text = None, *_extract(text, r'\s+op_sel:\[([^\]]+)\]') if m: bits, mn = [int(x.strip()) for x in m.group(1).split(',')], text.split()[0].lower() - is3p = mn.startswith(('v_pk_', 'v_wmma_', 'v_dot')) + is3p = mn.startswith(('v_pk_', 'v_wmma_', 'v_dot', 'v_fma_mix')) opsel = (bits[0] | (bits[1] << 1) | (bits[2] << 2)) if len(bits) == 3 and is3p else \ (bits[0] | (bits[1] << 1) | (bits[2] << 3)) if len(bits) == 3 else sum(b << i for i, b in enumerate(bits)) + # Extract op_sel_hi for VOP3P - encodes to opsel_hi (2 bits) and opsel_hi2 (1 bit) + opsel_hi_val, m, text = None, *_extract(text, r'\s+op_sel_hi:\[([^\]]+)\]') + if m: opsel_hi_val = [int(x.strip()) for x in m.group(1).split(',')] m, text = _extract(text, r'\s+wait_exp:(\d+)'); waitexp = m.group(1) if m else None m, text = _extract(text, r'\s+offset:(0x[0-9a-fA-F]+|-?\d+)'); off_val = m.group(1) if m else None m, text = _extract(text, r'\s+dlc(?:\s|$)'); dlc = 1 if m else None @@ -663,7 +846,20 @@ def get_dsl(text: str) -> str: m, text = _extract(text, r'\s+format:(\d+)'); fmt_val = m.group(1) if m and not fmt_val else fmt_val m, text = _extract(text, r'\s+neg_lo:\[([^\]]+)\]'); neg_lo = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None m, text = _extract(text, r'\s+neg_hi:\[([^\]]+)\]'); neg_hi = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None + m, text = _extract(text, r'\s+byte_sel:(\d+)'); byte_sel = int(m.group(1)) if m else None + # DS instruction offsets + m, text = _extract(text, r'\s+offset0:(\d+)'); ds_off0 = int(m.group(1)) if m else None + m, text = _extract(text, r'\s+offset1:(\d+)'); ds_off1 = int(m.group(1)) if m else None + # WMMA/SWMMAC modifiers + m, text = _extract(text, r'\s+index_key:(\d+)'); index_key = int(m.group(1)) if m else None if waitexp: kw.append(f'waitexp={waitexp}') + # byte_sel encodes to opsel bits [13:12] for cvt_sr/cvt_pk instructions + if byte_sel is not None: + if opsel is None: opsel = 0 + opsel |= (byte_sel << 2) # byte_sel goes to opsel[3:2] + if ds_off0 is not None: kw.append(f'offset0={ds_off0}') + if ds_off1 is not None: kw.append(f'offset1={ds_off1}') + if index_key is not None: kw.append(f'opsel={index_key}') # SWMMAC index_key is encoded in opsel field parts = text.replace(',', ' ').split() if not parts: raise ValueError("empty instruction") @@ -694,20 +890,60 @@ def get_dsl(text: str) -> str: # Special instructions if mn == 's_setreg_imm32_b32': raise ValueError(f"unsupported: {mn}") + # SOP1 instructions with no dest (sdst=NULL=0x80): s_alloc_vgpr, s_barrier_*, s_sleep_var + sop1_no_dest = ('s_alloc_vgpr', 's_barrier_init', 's_barrier_join', 's_barrier_signal', 's_barrier_signal_isfirst', 's_sleep_var') + if mn in sop1_no_dest: + return f"{mn}(sdst=RawImm(128), ssrc0={args[0]})" if mn in ('s_setpc_b64', 's_rfe_b64'): return f"{mn}(ssrc0={args[0]})" if mn in ('s_sendmsg_rtn_b32', 's_sendmsg_rtn_b64'): return f"{mn}(sdst={args[0]}, ssrc0=RawImm({args[1].strip()}))" if mn == 's_version': return f"{mn}(simm16={args[0]})" if mn == 's_setreg_b32': return f"{mn}(simm16={args[0]}, sdst={args[1]})" + # Export instructions (RDNA4 VEXPORT) + if mn == 'export': + # Target names: mrt0-7 (0-7), mrtz (8), pos0-3 (12-15) + target_map = {**{f'mrt{i}': i for i in range(8)}, 'mrtz': 8, **{f'pos{i}': 12+i for i in range(4)}} + # Extract done modifier first + m, exp_str = _extract(op_str, r'\s+done(?:\s|$)') + done_val = 1 if m else 0 + # Parse: target vsrc0, vsrc1, vsrc2, vsrc3 + exp_parts = exp_str.replace(',', ' ').split() + target_name = exp_parts[0].lower().strip() + target = target_map.get(target_name, 0) + # Parse vsrc0-3, "off" means disabled (use v0 but don't set en bit) + vsrcs, en = [], 0 + for i, o in enumerate(exp_parts[1:5]): + o = o.strip().lower() + if o == 'off': vsrcs.append('v[0]') + else: vsrcs.append(_op2dsl(o)); en |= (1 << i) + return f"VEXPORT(target={target}, en={en}, vsrc0={vsrcs[0]}, vsrc1={vsrcs[1]}, vsrc2={vsrcs[2]}, vsrc3={vsrcs[3]}, done={done_val})" + # SMEM if mn in SMEM_OPS: gs, ds = ", glc=1" if glc else "", ", dlc=1" if dlc else "" - if len(ops) >= 3 and re.match(r'^-?[0-9]|^-?0x', ops[2].strip().lower()): - return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={args[2]}, soffset=RawImm(124){gs}{ds})" - if off_val and len(ops) >= 3: return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={off_val}, soffset={args[2]}{gs}{ds})" - if len(ops) >= 3: return f"{mn}(sdata={args[0]}, sbase={args[1]}, soffset={args[2]}{gs}{ds})" + # RDNA4 uses ioffset/th/scope, RDNA3 uses offset/glc/dlc + off_field = "ioffset" if arch == "rdna4" else "offset" + th_s, scope_s, smem_str = "", "", op_str + if arch == "rdna4": + # Extract th (temporal hint) and scope modifiers for RDNA4 SMEM + m, smem_str = _extract(op_str, r'\s+th:TH_(\w+)') + th_val = {'LOAD_RT': 0, 'LOAD_NT': 1, 'LOAD_HT': 2, 'LOAD_LU': 3, 'STORE_RT': 0, 'STORE_NT': 1, 'STORE_HT': 2, 'STORE_LU': 3}.get(m.group(1), 0) if m else None + m, smem_str = _extract(smem_str, r'\s+scope:SCOPE_(\w+)') + scope_val = {'CU': 0, 'SE': 1, 'DEV': 2, 'SYS': 3}.get(m.group(1), 0) if m else None + if scope_val is None: # Try numeric scope format + m, smem_str = _extract(smem_str, r'\s+scope:(0?x?[0-9a-fA-F]+)') + scope_val = int(m.group(1), 0) if m else None + th_s = f", th={th_val}" if th_val else "" + scope_s = f", scope={scope_val}" if scope_val else "" + # Re-parse operands after extracting modifiers + smem_ops = _parse_ops(smem_str) + smem_args = [_op2dsl(o) for o in smem_ops] + if len(smem_ops) >= 3 and re.match(r'^-?[0-9]|^-?0x', smem_ops[2].strip().lower()): + return f"{mn}(sdata={smem_args[0]}, sbase={smem_args[1]}, {off_field}={smem_ops[2].strip()}, soffset=RawImm(124){gs}{ds}{th_s}{scope_s})" + if off_val and len(smem_ops) >= 3: return f"{mn}(sdata={smem_args[0]}, sbase={smem_args[1]}, {off_field}={off_val}, soffset={smem_args[2]}{gs}{ds}{th_s}{scope_s})" + if len(smem_ops) >= 3: return f"{mn}(sdata={smem_args[0]}, sbase={smem_args[1]}, soffset={smem_args[2]}{gs}{ds}{th_s}{scope_s})" - # Buffer (MUBUF/MTBUF) instructions + # Buffer (MUBUF/MTBUF/VBUFFER) instructions if mn.startswith(('buffer_', 'tbuffer_')): is_tbuf = mn.startswith('tbuffer_') # Parse format value for tbuffer @@ -717,7 +953,50 @@ def get_dsl(text: str) -> str: else: fmt_num = BUF_FMT.get(fmt_val.replace(' ', '')) or _parse_buf_fmt_combo(fmt_val) # Handle special no-arg buffer ops if mn in ('buffer_gl0_inv', 'buffer_gl1_inv', 'buffer_wbl2', 'buffer_inv'): return f"{mn}()" - # Build modifiers string + # RDNA4 uses VBUFFER with different field names and th/scope instead of glc/dlc/slc + if arch == "rdna4": + # Extract th and scope modifiers - RDNA4 temporal hints (from ISA docs) + # Load: RT=0, NT=1, HT=2, BYPASS=3, LU=4, RT_NT=5, NT_HT=6, RT_WB=7 + # Store: RT=0, NT=1, HT=2, BYPASS=3, LU=4, RT_NT=5, NT_HT=6 + # Atomic: RT=0, NT=1, RETURN=1, NT_RETURN=3, RT_RETURN=1, CASCADE_RT=6, CASCADE_NT=6 + m, buf_text = _extract(op_str, r'\s+th:TH_(\w+)') + th_val = {'LOAD_RT': 0, 'LOAD_NT': 1, 'LOAD_HT': 2, 'LOAD_BYPASS': 3, 'LOAD_LU': 4, 'LOAD_RT_NT': 5, 'LOAD_NT_HT': 6, 'LOAD_RT_WB': 7, + 'STORE_RT': 0, 'STORE_NT': 1, 'STORE_HT': 2, 'STORE_BYPASS': 3, 'STORE_LU': 4, 'STORE_RT_NT': 5, 'STORE_NT_HT': 6, + 'ATOMIC_RT': 0, 'ATOMIC_NT': 1, 'ATOMIC_RETURN': 1, 'ATOMIC_RT_RETURN': 1, 'ATOMIC_NT_RETURN': 3, 'ATOMIC_CASCADE_RT': 6, 'ATOMIC_CASCADE_NT': 6}.get(m.group(1), 0) if m else 0 + m, buf_text = _extract(buf_text, r'\s+scope:SCOPE_(\w+)') + scope_val = {'CU': 0, 'SE': 1, 'DEV': 2, 'SYS': 3}.get(m.group(1), 0) if m else 0 + # Re-parse operands from cleaned text + buf_ops = _parse_ops(buf_text) + buf_args = [_op2dsl(o) for o in buf_ops] + # Build VBUFFER modifier string + vbuf_mods = "".join([f", ioffset={off_val}" if off_val else "", ", offen=1" if offen else "", ", idxen=1" if idxen else "", + f", th={th_val}" if th_val else "", f", scope={scope_val}" if scope_val else "", + ", tfe=1" if tfe else ""]) + if is_tbuf and fmt_num is not None: vbuf_mods = f", format={fmt_num}" + vbuf_mods + elif is_tbuf: vbuf_mods = ", format=1" + vbuf_mods # default format for tbuffer + else: vbuf_mods = ", format=1" + vbuf_mods # VBUFFER needs format=1 by default + # Determine vaddr value (v[0] for 'off', actual register otherwise) + vaddr_idx = 1 + if len(buf_ops) > vaddr_idx and buf_ops[vaddr_idx].strip().lower() == 'off': vaddr_val = "v[0]" + else: vaddr_val = buf_args[vaddr_idx] if len(buf_args) > vaddr_idx else "v[0]" + # rsrc and soffset indices + rsrc_idx, soff_idx = (2, 3) if len(buf_ops) > 1 else (1, 2) + # RDNA4 VBUFFER rsrc is raw SGPR index (not divided by 4), extract base index from s[N:N+3] or ttmp[N:N+3] + rsrc_raw = buf_ops[rsrc_idx].strip() if len(buf_ops) > rsrc_idx else "s[0:3]" + if m := re.match(r's\[(\d+):\d+\]', rsrc_raw.lower()): rsrc_val = m.group(1) + elif m := re.match(r's(\d+)', rsrc_raw.lower()): rsrc_val = m.group(1) + elif m := re.match(r'ttmp\[(\d+):\d+\]', rsrc_raw.lower()): rsrc_val = str(108 + int(m.group(1))) + elif m := re.match(r'ttmp(\d+)', rsrc_raw.lower()): rsrc_val = str(108 + int(m.group(1))) + else: rsrc_val = "0" + # soffset: RDNA4 VBUFFER uses raw SGPR index (0-127), wrap in RawImm to bypass encode_src + soff_raw = buf_ops[soff_idx].strip() if len(buf_ops) > soff_idx else "0" + soff_lower = soff_raw.lower() + if soff_lower == 'm0': soff_val = "RawImm(125)" + elif soff_lower in ('null', 'off'): soff_val = "RawImm(124)" + elif m := re.match(r's(\d+)', soff_lower): soff_val = f"RawImm({m.group(1)})" + else: soff_val = f"RawImm({soff_raw})" + return f"{mn}(vdata={buf_args[0]}, vaddr={vaddr_val}, rsrc={rsrc_val}, soffset={soff_val}{vbuf_mods})" + # RDNA3 MUBUF/MTBUF handling buf_mods = "".join([f", offset={off_val}" if off_val else "", ", glc=1" if glc else "", ", dlc=1" if dlc else "", ", slc=1" if slc else "", ", tfe=1" if tfe else "", ", offen=1" if offen else "", ", idxen=1" if idxen else ""]) if is_tbuf and fmt_num is not None: buf_mods = f", format={fmt_num}" + buf_mods @@ -747,7 +1026,13 @@ def get_dsl(text: str) -> str: # DS instructions if mn.startswith('ds_'): - off0, off1 = (str(int(off_val, 0) & 0xff), str((int(off_val, 0) >> 8) & 0xff)) if off_val else ("0", "0") + # Use ds_off0/ds_off1 if extracted, otherwise parse from combined offset:N + if ds_off0 is not None or ds_off1 is not None: + off0, off1 = str(ds_off0 or 0), str(ds_off1 or 0) + elif off_val: + off0, off1 = str(int(off_val, 0) & 0xff), str((int(off_val, 0) >> 8) & 0xff) + else: + off0, off1 = "0", "0" gds_s = ", gds=1" if 'gds' in text.lower().split()[-1:] else "" off_kw = f", offset0={off0}, offset1={off1}{gds_s}" if mn == 'ds_nop' or mn in ('ds_gws_sema_v', 'ds_gws_sema_p', 'ds_gws_sema_release_all'): return f"{mn}({off_kw.lstrip(', ')})" @@ -773,16 +1058,35 @@ def get_dsl(text: str) -> str: lit_s = "" if mn in ('v_fmaak_f32', 'v_fmaak_f16') and len(args) == 4: lit_s, args = f", literal={args[3].strip()}", args[:3] elif mn in ('v_fmamk_f32', 'v_fmamk_f16') and len(args) == 4: lit_s, args = f", literal={args[2].strip()}", [args[0], args[1], args[3]] + # s_fmaak/s_fmamk literal extraction (SOP2) + elif mn in ('s_fmaak_f32',) and len(args) == 4: lit_s, args = f", literal={args[3].strip()}", args[:3] + elif mn in ('s_fmamk_f32',) and len(args) == 4: lit_s, args = f", literal={args[2].strip()}", [args[0], args[1], args[3]] + # v_cndmask_b32 with vcc_lo: strip the vcc_lo operand (implicit for VOP2) + elif mn in ('v_cndmask_b32', 'v_cndmask_b32_e32') and len(args) == 4 and ops[3].strip().lower() in ('vcc_lo', 'vcc'): + mn, args = 'v_cndmask_b32_e32', args[:3] - # VCC ops cleanup + # Special register name to encoding map (used for carry ops and v_cmp) + _SGPR_NAMES = {'vcc_lo': 106, 'vcc_hi': 107, 'vcc': 106, 'null': 124, 'm0': 125, 'exec_lo': 126, 'exec_hi': 127} + # VCC ops cleanup - v_add_co_ci_u32 etc. with carry-in/out vcc_ops = {'v_add_co_ci_u32', 'v_sub_co_ci_u32', 'v_subrev_co_ci_u32'} - if mn.replace('_e32', '') in vcc_ops and len(args) >= 5: mn, args = mn.replace('_e32', '') + '_e32', [args[0], args[2], args[3]] + if mn.replace('_e32', '') in vcc_ops and len(args) >= 5: + # Check if carry-in is vcc_lo - if so, use VOP2, otherwise use VOP3SD + carry_in = ops[4].strip().lower() if len(ops) > 4 else 'vcc_lo' + carry_out = ops[1].strip().lower() if len(ops) > 1 else 'vcc_lo' + if carry_in in ('vcc_lo', 'vcc') and carry_out in ('vcc_lo', 'vcc'): + mn, args = mn.replace('_e32', '') + '_e32', [args[0], args[2], args[3]] + else: + # Need VOP3SD format for non-vcc carry operands + mn_base = mn.replace('_e32', '').replace('_e64', '') + # sdst = carry-out, src2 = carry-in + sdst = _SGPR_NAMES.get(carry_out, 124) if carry_out in _SGPR_NAMES else (int(carry_out[1:]) if carry_out.startswith('s') and carry_out[1:].isdigit() else 124) + src2 = _SGPR_NAMES.get(carry_in, 0) if carry_in in _SGPR_NAMES else (int(carry_in[1:]) if carry_in.startswith('s') and carry_in[1:].isdigit() else 0) + return f"{mn_base}(vdst={args[0]}, sdst=RawImm({sdst}), src0={args[2]}, src1={args[3]}, src2=RawImm({src2}))" if mn.replace('_e64', '') in vcc_ops and mn.endswith('_e64'): mn = mn.replace('_e64', '') if mn.startswith('v_cmp') and not mn.endswith('_e64') and len(args) >= 3 and ops[0].strip().lower() in ('vcc_lo', 'vcc_hi', 'vcc'): args = args[1:] if 'cmpx' in mn and mn.endswith('_e64') and len(args) == 2: args = ['RawImm(126)'] + args - # v_cmp_*_e64 has SGPR destination in vdst field - encode as RawImm - _SGPR_NAMES = {'vcc_lo': 106, 'vcc_hi': 107, 'vcc': 106, 'null': 124, 'm0': 125, 'exec_lo': 126, 'exec_hi': 127} - if mn.startswith('v_cmp') and 'cmpx' not in mn and mn.endswith('_e64') and len(args) >= 1: + # v_cmp_*_e64, v_s_*, v_readlane_b32, v_readfirstlane_b32 have SGPR destination in vdst field - encode as RawImm + if ((mn.startswith('v_cmp') and 'cmpx' not in mn and mn.endswith('_e64')) or mn.startswith('v_s_') or mn in ('v_readlane_b32', 'v_readfirstlane_b32')) and len(args) >= 1: dst = ops[0].strip().lower() if dst.startswith('s') and dst[1:].isdigit(): args[0] = f'RawImm({int(dst[1:])})' elif dst.startswith('s[') and ':' in dst: args[0] = f'RawImm({int(dst[2:].split(":")[0])})' @@ -815,15 +1119,59 @@ def get_dsl(text: str) -> str: if neg_lo is not None: all_kw.append(f'neg={neg_lo}') if neg_hi is not None: all_kw.append(f'neg_hi={neg_hi}') if 'bvh' in mn and 'intersect_ray' in mn: all_kw.extend(['dmask=15', 'unrm=1', 'r128=1']) + # VOP3P packed ops: handle op_sel_hi explicitly or use defaults (7 = all high halves), except fma_mix ops which default to 0 + vop3p_ops = {'v_pk_', 'v_dot2', 'v_dot4', 'v_dot8', 'v_wmma', 'v_swmmac'} + is_vop3p = any(mn.startswith(p) for p in vop3p_ops) + is_fma_mix = 'fma_mix' in mn + if opsel_hi_val is not None: + # Explicit op_sel_hi: encode bits 0,1 to opsel_hi, bit 2 to opsel_hi2 + # For 2-element op_sel_hi (2-op instructions), opsel_hi2 defaults to 1 unless fma_mix + opsel_hi_enc = opsel_hi_val[0] | (opsel_hi_val[1] << 1) if len(opsel_hi_val) >= 2 else opsel_hi_val[0] + opsel_hi2_enc = opsel_hi_val[2] if len(opsel_hi_val) >= 3 else (0 if is_fma_mix else 1) + all_kw.extend([f'opsel_hi={opsel_hi_enc}', f'opsel_hi2={opsel_hi2_enc}']) + elif is_vop3p and not is_fma_mix: + all_kw.extend(['opsel_hi=3', 'opsel_hi2=1']) + # Add clamp keyword - use 'cm' for RDNA4 (VOP3/VOP3P use cm), otherwise 'clmp' + if clamp_found: + if arch == 'rdna4': all_kw.append('cm=1') + else: all_kw.append('clmp=1') a_str, kw_str = ', '.join(args), ', '.join(all_kw) return f"{fn}({a_str}, {kw_str})" if kw_str and a_str else f"{fn}({kw_str})" if kw_str else f"{fn}({a_str})" -def asm(text: str) -> Inst: - dsl = get_dsl(text) - ns = {n: getattr(ins, n) for n in dir(ins) if not n.startswith('_')} +def _hwreg(id_, offset=0, size=32): return id_ | (offset << 6) | ((size - 1) << 11) +def _sendmsg(id_, op=0, stream=0): return id_ | (op << 4) | (stream << 8) + +# Hardware register name to ID mapping (RDNA3/generic) +_HWREG_NAMES = {'HW_REG_MODE': 1, 'HW_REG_STATUS': 2, 'HW_REG_TRAPSTS': 3, 'HW_REG_HW_ID': 4, 'HW_REG_GPR_ALLOC': 5, + 'HW_REG_LDS_ALLOC': 6, 'HW_REG_IB_STS': 7, 'HW_REG_PC_LO': 8, 'HW_REG_PC_HI': 9, 'HW_REG_INST_DW0': 10, 'HW_REG_INST_DW1': 11, + 'HW_REG_IB_DBG0': 12, 'HW_REG_IB_DBG1': 13, 'HW_REG_FLUSH_IB': 14, 'HW_REG_SH_MEM_BASES': 15, 'HW_REG_SQ_SHADER_TBA_LO': 16, + 'HW_REG_SQ_SHADER_TBA_HI': 17, 'HW_REG_SQ_SHADER_TMA_LO': 18, 'HW_REG_SQ_SHADER_TMA_HI': 19, 'HW_REG_FLAT_SCR_LO': 20, + 'HW_REG_FLAT_SCR_HI': 21, 'HW_REG_XNACK_MASK': 22, 'HW_REG_HW_ID1': 23, 'HW_REG_HW_ID2': 24, 'HW_REG_POPS_PACKER': 25, + 'HW_REG_PERF_SNAPSHOT_DATA': 26, 'HW_REG_PERF_SNAPSHOT_PC_LO': 27, 'HW_REG_PERF_SNAPSHOT_PC_HI': 28, 'HW_REG_SHADER_CYCLES': 29, + 'HW_REG_SHADER_CYCLES_HI': 30, 'HW_REG_WAVE_MODE': 31, 'HW_REG_WAVE_SCRATCH_BASE': 32} +# RDNA4 hwreg mappings (derived from HWREG_RDNA4) +_HWREG_NAMES_RDNA4 = {v: k for k, v in HWREG_RDNA4.items()} +_SENDMSG_NAMES = {'MSG_INTERRUPT': 1, 'MSG_GS': 2, 'MSG_GS_DONE': 3, 'MSG_SAVEWAVE': 4, 'MSG_STALL_WAVE_GEN': 5, + 'MSG_HALT_WAVES': 6, 'MSG_ORDERED_PS_DONE': 7, 'MSG_EARLY_PRIM_DEALLOC': 8, 'MSG_GS_ALLOC_REQ': 9, 'MSG_GET_DOORBELL': 10, + 'MSG_GET_DDID': 11, 'MSG_HS_TESSFACTOR': 2, 'MSG_DEALLOC_VGPRS': 10, 'MSG_RTN_GET_DOORBELL': 128, 'MSG_RTN_GET_DDID': 129, + 'MSG_RTN_GET_TMA': 130, 'MSG_RTN_GET_REALTIME': 131, 'MSG_RTN_SAVE_WAVE': 132, 'MSG_RTN_GET_TBA': 133, + 'MSG_RTN_GET_TBA_TO_PC': 134, 'MSG_RTN_GET_SE_AID_ID': 135} + +def asm(text: str, arch: str = "rdna3") -> Inst: + dsl = get_dsl(text, arch) + if arch == "rdna4": + ns = {n: getattr(rdna4_ins, n) for n in dir(rdna4_ins) if not n.startswith('_')} + hwreg_names = _HWREG_NAMES_RDNA4 + else: + ns = {n: getattr(ins, n) for n in dir(ins) if not n.startswith('_')} + hwreg_names = _HWREG_NAMES + # Helper for hwreg() that handles both numeric and named IDs + def hwreg(id_, offset=0, size=32): return _hwreg(hwreg_names.get(id_, id_) if isinstance(id_, str) else id_, offset, size) + def sendmsg(id_, op=0, stream=0): return _sendmsg(_SENDMSG_NAMES.get(id_, id_) if isinstance(id_, str) else id_, op, stream) ns.update({'s': s, 'v': v, 'ttmp': ttmp, 'abs': abs, 'RawImm': RawImm, 'SrcMod': SrcMod, 'VGPR': VGPR, 'SGPR': SGPR, 'TTMP': TTMP, - 'VCC_LO': VCC_LO, 'VCC_HI': VCC_HI, 'VCC': VCC, 'EXEC_LO': EXEC_LO, 'EXEC_HI': EXEC_HI, 'EXEC': EXEC, 'SCC': SCC, 'M0': M0, 'NULL': NULL, 'OFF': OFF}) + 'VCC_LO': VCC_LO, 'VCC_HI': VCC_HI, 'VCC': VCC, 'EXEC_LO': EXEC_LO, 'EXEC_HI': EXEC_HI, 'EXEC': EXEC, 'SCC': SCC, 'M0': M0, 'NULL': NULL, 'OFF': OFF, + 'hwreg': hwreg, 'sendmsg': sendmsg, **{k: k for k in hwreg_names}, **{k: k for k in _SENDMSG_NAMES}}) try: return eval(dsl, ns) except NameError: if m := re.match(r'^(v_\w+)(\(.*\))$', dsl): return eval(f"{m.group(1)}_e32{m.group(2)}", ns) diff --git a/extra/assembly/amd/autogen/rdna4/enum.py b/extra/assembly/amd/autogen/rdna4/enum.py index 1c5e921131..d8fd760ca2 100644 --- a/extra/assembly/amd/autogen/rdna4/enum.py +++ b/extra/assembly/amd/autogen/rdna4/enum.py @@ -243,6 +243,8 @@ class SMEMOp(IntEnum): S_BUFFER_LOAD_I16 = 26 S_BUFFER_LOAD_U16 = 27 S_DCACHE_INV = 33 + S_ATC_PROBE = 34 + S_ATC_PROBE_BUFFER = 35 S_PREFETCH_INST = 36 S_PREFETCH_INST_PC_REL = 37 S_PREFETCH_DATA = 38 @@ -318,6 +320,8 @@ class SOP1Op(IntEnum): S_BARRIER_SIGNAL = 78 S_BARRIER_SIGNAL_ISFIRST = 79 S_GET_BARRIER_STATE = 80 + S_BARRIER_INIT = 81 + S_BARRIER_JOIN = 82 S_ALLOC_VGPR = 83 S_SLEEP_VAR = 88 S_CEIL_F32 = 96 @@ -486,6 +490,7 @@ class SOPPOp(IntEnum): S_ROUND_MODE = 17 S_DENORM_MODE = 18 S_BARRIER_WAIT = 20 + S_BARRIER_LEAVE = 21 S_CODE_END = 31 S_BRANCH = 32 S_CBRANCH_SCC0 = 33 @@ -502,6 +507,8 @@ class SOPPOp(IntEnum): S_SENDMSGHALT = 55 S_INCPERFLEVEL = 56 S_DECPERFLEVEL = 57 + S_TTRACEDATA = 58 + S_TTRACEDATA_IMM = 59 S_ICACHE_INV = 60 S_WAIT_LOADCNT = 64 S_WAIT_STORECNT = 65 diff --git a/extra/assembly/amd/autogen/rdna4/ins.py b/extra/assembly/amd/autogen/rdna4/ins.py index 99da6270f1..5bfcb8db03 100644 --- a/extra/assembly/amd/autogen/rdna4/ins.py +++ b/extra/assembly/amd/autogen/rdna4/ins.py @@ -379,6 +379,8 @@ s_buffer_load_u8 = functools.partial(SMEM, SMEMOp.S_BUFFER_LOAD_U8) s_buffer_load_i16 = functools.partial(SMEM, SMEMOp.S_BUFFER_LOAD_I16) s_buffer_load_u16 = functools.partial(SMEM, SMEMOp.S_BUFFER_LOAD_U16) s_dcache_inv = functools.partial(SMEM, SMEMOp.S_DCACHE_INV) +s_atc_probe = functools.partial(SMEM, SMEMOp.S_ATC_PROBE) +s_atc_probe_buffer = functools.partial(SMEM, SMEMOp.S_ATC_PROBE_BUFFER) s_prefetch_inst = functools.partial(SMEM, SMEMOp.S_PREFETCH_INST) s_prefetch_inst_pc_rel = functools.partial(SMEM, SMEMOp.S_PREFETCH_INST_PC_REL) s_prefetch_data = functools.partial(SMEM, SMEMOp.S_PREFETCH_DATA) @@ -452,6 +454,8 @@ s_sendmsg_rtn_b64 = functools.partial(SOP1, SOP1Op.S_SENDMSG_RTN_B64) s_barrier_signal = functools.partial(SOP1, SOP1Op.S_BARRIER_SIGNAL) s_barrier_signal_isfirst = functools.partial(SOP1, SOP1Op.S_BARRIER_SIGNAL_ISFIRST) s_get_barrier_state = functools.partial(SOP1, SOP1Op.S_GET_BARRIER_STATE) +s_barrier_init = functools.partial(SOP1, SOP1Op.S_BARRIER_INIT) +s_barrier_join = functools.partial(SOP1, SOP1Op.S_BARRIER_JOIN) s_alloc_vgpr = functools.partial(SOP1, SOP1Op.S_ALLOC_VGPR) s_sleep_var = functools.partial(SOP1, SOP1Op.S_SLEEP_VAR) s_ceil_f32 = functools.partial(SOP1, SOP1Op.S_CEIL_F32) @@ -612,6 +616,7 @@ s_trap = functools.partial(SOPP, SOPPOp.S_TRAP) s_round_mode = functools.partial(SOPP, SOPPOp.S_ROUND_MODE) s_denorm_mode = functools.partial(SOPP, SOPPOp.S_DENORM_MODE) s_barrier_wait = functools.partial(SOPP, SOPPOp.S_BARRIER_WAIT) +s_barrier_leave = functools.partial(SOPP, SOPPOp.S_BARRIER_LEAVE) s_code_end = functools.partial(SOPP, SOPPOp.S_CODE_END) s_branch = functools.partial(SOPP, SOPPOp.S_BRANCH) s_cbranch_scc0 = functools.partial(SOPP, SOPPOp.S_CBRANCH_SCC0) @@ -628,6 +633,8 @@ s_sendmsg = functools.partial(SOPP, SOPPOp.S_SENDMSG) s_sendmsghalt = functools.partial(SOPP, SOPPOp.S_SENDMSGHALT) s_incperflevel = functools.partial(SOPP, SOPPOp.S_INCPERFLEVEL) s_decperflevel = functools.partial(SOPP, SOPPOp.S_DECPERFLEVEL) +s_ttracedata = functools.partial(SOPP, SOPPOp.S_TTRACEDATA) +s_ttracedata_imm = functools.partial(SOPP, SOPPOp.S_TTRACEDATA_IMM) s_icache_inv = functools.partial(SOPP, SOPPOp.S_ICACHE_INV) s_wait_loadcnt = functools.partial(SOPP, SOPPOp.S_WAIT_LOADCNT) s_wait_storecnt = functools.partial(SOPP, SOPPOp.S_WAIT_STORECNT) diff --git a/extra/assembly/amd/dsl.py b/extra/assembly/amd/dsl.py index af2af3b393..d714398f7b 100644 --- a/extra/assembly/amd/dsl.py +++ b/extra/assembly/amd/dsl.py @@ -92,6 +92,7 @@ _SPECIAL_REGS = { 'V_CMP_CLASS_F16': (1, 1, 1, 1), 'V_CMPX_CLASS_F16': (1, 1, 1, 1), 'V_MAD_U64_U32': (2, 1, 1, 2), 'V_MAD_I64_I32': (2, 1, 1, 2), 'V_QSAD_PK_U16_U8': (2, 2, 1, 2), 'V_MQSAD_PK_U16_U8': (2, 2, 1, 2), 'V_MQSAD_U32_U8': (4, 2, 1, 4), + 'V_CVT_PK_F32_BF8': (2, 1, 1, 1), 'V_CVT_PK_F32_FP8': (2, 1, 1, 1), } _SPECIAL_DTYPE = { 'V_LSHLREV_B64': ('B64', 'U32', 'B64', None), 'V_LSHRREV_B64': ('B64', 'U32', 'B64', None), 'V_ASHRREV_I64': ('I64', 'U32', 'I64', None), @@ -136,8 +137,8 @@ def spec_is_16bit(name: str) -> bool: def spec_is_64bit(name: str) -> bool: return bool(_F64_RE.search(name.upper())) _3SRC = {'FMA', 'MAD', 'MIN3', 'MAX3', 'MED3', 'DIV_FIX', 'DIV_FMAS', 'DIV_SCALE', 'SAD', 'LERP', 'ALIGN', 'CUBE', 'BFE', 'BFI', 'PERM_B32', 'PERMLANE', 'CNDMASK', 'XOR3', 'OR3', 'ADD3', 'LSHL_OR', 'AND_OR', 'LSHL_ADD', 'ADD_LSHL', 'XAD', 'MAXMIN', - 'MINMAX', 'DOT2', 'DOT4', 'DOT8', 'WMMA', 'CVT_PK_U8', 'MULLIT', 'CO_CI'} -_2SRC = {'FMAC'} # FMAC uses dst as implicit accumulator, so only 2 explicit sources + 'MINMAX', 'MINIMUMMAXIMUM', 'MAXIMUMMINIMUM', 'MINIMUM3', 'MAXIMUM3', 'DOT2', 'DOT4', 'DOT8', 'WMMA', 'CVT_PK_U8', 'MULLIT', 'CO_CI'} +_2SRC = {'FMAC', 'PERMLANE16_VAR', 'PERMLANEX16_VAR'} # FMAC uses dst as implicit accumulator, _VAR permlane only 2 sources def spec_num_srcs(name: str) -> int: name = name.upper() if any(k in name for k in _2SRC): return 2 @@ -271,12 +272,17 @@ RAW_FIELDS = {'vdata', 'vdst', 'vaddr', 'addr', 'data', 'data0', 'data1', 'sdst' def _encode_reg(val: Reg) -> int: return (108 if isinstance(val, TTMP) else 0) + val.idx -def _is_inline_const(v: int) -> bool: return 0 <= v <= 127 or 128 <= v <= 208 or 240 <= v <= 255 +def _is_encoded_src(v: int) -> bool: return 106 <= v <= 127 or 128 <= v <= 208 or 240 <= v <= 255 # Special regs (106-127) or inline const def encode_src(val) -> int: if isinstance(val, VGPR): return 256 + _encode_reg(val) if isinstance(val, Reg): return _encode_reg(val) - if isinstance(val, SrcMod) and not isinstance(val, Reg): return val.val if _is_inline_const(val.val) else 255 + if isinstance(val, SrcMod) and not isinstance(val, Reg): + v = val.val + if _is_encoded_src(v): return v # Already encoded (special reg 106-127 or inline const 128-208 or float 240-255) + if isinstance(v, int) and 0 <= v <= 64: return 128 + v # Encode as inline constant + if isinstance(v, int) and -16 <= v <= -1: return 192 - v + return 255 # Literal if hasattr(val, 'value'): return val.value # IntEnum if isinstance(val, float): return 128 if val == 0.0 else FLOAT_ENC.get(val, 255) if isinstance(val, int): return 128 + val if 0 <= val <= 64 else 192 - val if -16 <= val <= -1 else 255 @@ -359,19 +365,30 @@ class Inst: def _validate(self, orig_args: dict): """Format-specific validation. Override in subclass or check by class name.""" cls_name, op = self.__class__.__name__, orig_args.get('op') - if hasattr(op, 'value'): op = op.value - # SMEM: register count must match opcode - if cls_name == 'SMEM' and op is not None: - expected = {0:1, 1:2, 2:4, 3:8, 4:16, 8:1, 9:2, 10:4, 11:8, 12:16}.get(op) + op_val = op.value if hasattr(op, 'value') else op + op_name = op.name if hasattr(op, 'name') else None + # SMEM: register count must match opcode (derive from name: b32=1, b64=2, b96=3, b128=4, b256=8, b512=16, i8/u8/i16/u16=1) + if cls_name == 'SMEM' and op_name: + expected = {'B32': 1, 'B64': 2, 'B96': 3, 'B128': 4, 'B256': 8, 'B512': 16, 'I8': 1, 'U8': 1, 'I16': 1, 'U16': 1}.get(op_name.split('_')[-1]) sdata = orig_args.get('sdata') if expected and isinstance(sdata, Reg) and sdata.count != expected: - raise ValueError(f"SMEM op {op} expects {expected} registers, got {sdata.count}") - # SOP1: b32=1 reg, b64=2 regs + raise ValueError(f"SMEM op {op_name} expects {expected} registers, got {sdata.count}") + # SOP1: derive expected register sizes from op name (e.g., S_MOV_B64 -> dst=2, src=2; S_CTZ_I32_B64 -> dst=1, src=2) if cls_name == 'SOP1' and hasattr(orig_args.get('op'), 'name'): - expected = 2 if orig_args['op'].name.endswith('_B64') else 1 - for fld in ('sdst', 'ssrc0'): + op_name = orig_args['op'].name + # Special cases: BITSET takes bit index (1 reg) regardless of dst size + if 'BITSET' in op_name: + dst_size = 2 if op_name.endswith('_B64') else 1 + src_size = 1 # bit index is always 1 reg + else: + # Extract sizes from name: last suffix is src type, second-to-last (if exists) is dst type + sizes = {'B32': 1, 'I32': 1, 'U32': 1, 'B64': 2, 'I64': 2, 'U64': 2, 'B128': 4, 'B256': 8, 'B512': 16} + parts = op_name.split('_') + src_size = sizes.get(parts[-1], 1) if parts[-1] in sizes else 1 + dst_size = sizes.get(parts[-2], src_size) if len(parts) >= 2 and parts[-2] in sizes else src_size + for fld, expected in [('sdst', dst_size), ('ssrc0', src_size)]: if isinstance(orig_args.get(fld), Reg) and orig_args[fld].count != expected: - raise ValueError(f"SOP1 {orig_args['op'].name} expects {expected} register(s) for {fld}, got {orig_args[fld].count}") + raise ValueError(f"SOP1 {op_name} expects {expected} register(s) for {fld}, got {orig_args[fld].count}") def __init__(self, *args, literal: int | None = None, **kwargs): self._values, self._literal = dict(self._defaults), None @@ -411,7 +428,9 @@ class Inst: if cls_name == 'VOP3P': op = orig_args.get('op') if hasattr(op, 'value'): op = op.value + # fma_mix ops (32-34) default to opsel_hi=0, WMMA ops (64-69) default to opsel_hi=7 to match LLVM if op in (32, 33, 34) and 'opsel_hi' not in orig_args: self._values['opsel_hi'] = self._values['opsel_hi2'] = 0 + if op in range(64, 70) and 'opsel_hi' not in orig_args: self._values['opsel_hi'], self._values['opsel_hi2'] = 3, 1 # Encode all fields for name, val in list(self._values.items()): diff --git a/extra/assembly/amd/pdf.py b/extra/assembly/amd/pdf.py index 0d8cabd629..85d611b44b 100644 --- a/extra/assembly/amd/pdf.py +++ b/extra/assembly/amd/pdf.py @@ -290,6 +290,11 @@ if __name__ == "__main__": 'DS': {24: 'DS_GWS_SEMA_RELEASE_ALL', 25: 'DS_GWS_INIT', 26: 'DS_GWS_SEMA_V', 27: 'DS_GWS_SEMA_BR', 28: 'DS_GWS_SEMA_P', 29: 'DS_GWS_BARRIER'}, 'FLAT': {40: 'GLOBAL_LOAD_ADDTID_B32', 41: 'GLOBAL_STORE_ADDTID_B32', 55: 'FLAT_ATOMIC_CSUB_U32'}} for fmt, ops in fixes.items(): enums[fmt] = merge_dicts([enums[fmt], ops]) + if arch == 'rdna4': + fixes = {'SMEM': {34: 'S_ATC_PROBE', 35: 'S_ATC_PROBE_BUFFER'}, + 'SOP1': {81: 'S_BARRIER_INIT', 82: 'S_BARRIER_JOIN'}, + 'SOPP': {21: 'S_BARRIER_LEAVE', 58: 'S_TTRACEDATA', 59: 'S_TTRACEDATA_IMM'}} + for fmt, ops in fixes.items(): enums[fmt] = merge_dicts([enums[fmt], ops]) if arch in ('rdna3', 'rdna4'): # RDNA SMEM: PDF says DLC=[14], GLC=[16] but hardware uses DLC=[13], GLC=[14] if 'SMEM' in formats: diff --git a/extra/assembly/amd/test/test_handwritten.py b/extra/assembly/amd/test/test_handwritten.py index 1348764c51..0eaadac97e 100644 --- a/extra/assembly/amd/test/test_handwritten.py +++ b/extra/assembly/amd/test/test_handwritten.py @@ -21,6 +21,9 @@ class TestIntegration(unittest.TestCase): self.assertEqual(repr(self.inst), repr(reasm)) print(desc) + def test_wmma(self): + self.inst = v_wmma_f32_16x16x16_f16(v[0:7], v[189:192], v[140:143], v[0:7]) + def test_load_b128(self): self.inst = s_load_b128(s[4:7], s[0:1], NULL, 0) diff --git a/extra/assembly/amd/test/test_llvm.py b/extra/assembly/amd/test/test_llvm.py index 7139060a57..db0e273b4e 100644 --- a/extra/assembly/amd/test/test_llvm.py +++ b/extra/assembly/amd/test/test_llvm.py @@ -5,7 +5,7 @@ from tinygrad.helpers import fetch from extra.assembly.amd.asm import asm, disasm, detect_format from extra.assembly.amd.test.helpers import get_llvm_mc -LLVM_BASE = "https://raw.githubusercontent.com/llvm/llvm-project/main/llvm/test/MC/AMDGPU" +LLVM_BASE = "https://raw.githubusercontent.com/llvm/llvm-project/llvmorg-21.1.0/llvm/test/MC/AMDGPU" RDNA_FILES = ['gfx11_asm_sop1.s', 'gfx11_asm_sop2.s', 'gfx11_asm_sopp.s', 'gfx11_asm_sopk.s', 'gfx11_asm_sopc.s', 'gfx11_asm_vop1.s', 'gfx11_asm_vop2.s', 'gfx11_asm_vopc.s', 'gfx11_asm_vop3.s', 'gfx11_asm_vop3p.s', 'gfx11_asm_vinterp.s', @@ -22,10 +22,12 @@ CDNA_FILES = ['gfx9_asm_sop1.s', 'gfx9_asm_sop2.s', 'gfx9_asm_sopp.s', 'gfx9_asm 'gfx90a_ldst_acc.s', 'gfx90a_asm_features.s', 'flat-scratch-gfx942.s', 'gfx942_asm_features.s', 'mai-gfx90a.s', 'mai-gfx942.s'] # RDNA4 (gfx12) test files - excludes alias/err/fake16/dpp files, and vimage/vsample (not supported) -# NOTE: vflat excluded - PDF has wrong OP field bits [20:13] vs hardware [21:14], and missing seg field +# NOTE: vflat/vdsdir excluded - not implemented; features.s has mixed formats RDNA4_FILES = ['gfx12_asm_sop1.s', 'gfx12_asm_sop2.s', 'gfx12_asm_sopp.s', 'gfx12_asm_sopk.s', 'gfx12_asm_sopc.s', - 'gfx12_asm_vop1.s', 'gfx12_asm_vop2.s', 'gfx12_asm_vopc.s', 'gfx12_asm_vop3.s', 'gfx12_asm_vop3p.s', - 'gfx12_asm_vopd.s', 'gfx12_asm_ds.s', 'gfx12_asm_smem.s', + 'gfx12_asm_vop1.s', 'gfx12_asm_vop2.s', 'gfx12_asm_vopc.s', 'gfx12_asm_vopcx.s', 'gfx12_asm_vop3.s', 'gfx12_asm_vop3c.s', + 'gfx12_asm_vop3cx.s', 'gfx12_asm_vop3p.s', 'gfx12_asm_vop3_from_vop1.s', 'gfx12_asm_vop3_from_vop2.s', + 'gfx12_asm_vop3p_features.s', 'gfx12_asm_vopd.s', 'gfx12_asm_vopd_features.s', + 'gfx12_asm_ds.s', 'gfx12_asm_smem.s', 'gfx12_asm_vbuffer_mubuf.s', 'gfx12_asm_vbuffer_mtbuf.s', 'gfx12_asm_wmma_w32.s', 'gfx12_asm_exp.s'] def _is_mimg(data: bytes) -> bool: return (int.from_bytes(data[:4], 'little') >> 26) & 0x3f == 0b111100 @@ -82,7 +84,7 @@ def _make_test(f: str, arch: str, test_type: str): passed, skipped = 0, 0 for asm_text, expected in tests: try: - self.assertEqual(asm(asm_text).to_bytes(), expected) + self.assertEqual(asm(asm_text, arch).to_bytes(), expected) passed += 1 except: skipped += 1 print(f"{name}: {passed} passed, {skipped} skipped") @@ -91,10 +93,13 @@ def _make_test(f: str, arch: str, test_type: str): for _, data in tests: try: decoded = detect_format(data, arch).from_bytes(data) - if decoded.to_bytes()[:len(data)] == data and (d := disasm(decoded)): to_test.append((data, d)) + # Skip if roundtrip fails, disasm fails, or op_name is missing (disasm starts with space) + if decoded.to_bytes()[:len(data)] == data and (d := disasm(decoded)) and not d.startswith(' '): to_test.append((data, d)) except: pass - print(f"{name}: {len(to_test)} passed, {len(tests) - len(to_test)} skipped") + skipped = len(tests) - len(to_test) + print(f"{name}: {len(to_test)} passed, {skipped} skipped") if arch in ("rdna3", "rdna4"): + self.assertEqual(skipped, 0, f"{name}: {skipped} tests skipped, expected 0") for (data, _), llvm in zip(to_test, _compile_asm_batch([t[1] for t in to_test], arch)): self.assertEqual(llvm, data) return test