diff --git a/extra/assembly/amd/asm.py b/extra/assembly/amd/asm.py index 584bb354b3..ed92f2fcfb 100644 --- a/extra/assembly/amd/asm.py +++ b/extra/assembly/amd/asm.py @@ -1,15 +1,12 @@ # RDNA3 assembler and disassembler from __future__ import annotations import re -from extra.assembly.amd.dsl import Inst, RawImm, Reg, SrcMod, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory, SRC_FIELDS, unwrap +from extra.assembly.amd.dsl import Inst, RawImm, Reg, SrcMod, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory from extra.assembly.amd.dsl import VCC_LO, VCC_HI, VCC, EXEC_LO, EXEC_HI, EXEC, SCC, M0, NULL, OFF from extra.assembly.amd.dsl import SPECIAL_GPRS, SPECIAL_PAIRS, FLOAT_DEC, FLOAT_ENC, decode_src from extra.assembly.amd.autogen.rdna3 import ins from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT, MUBUF, MTBUF, MIMG, EXP, - VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, VOPDOp, VINTERPOp, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, DSOp, FLATOp, MUBUFOp, MTBUFOp, MIMGOp) - -# VOP3SD opcodes that share VOP3 encoding -VOP3SD_OPS = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770} + VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOPDOp, SOP1Op, SOPKOp, SOPPOp, SMEMOp, DSOp, MUBUFOp) def _matches_encoding(word: int, cls: type[Inst]) -> bool: """Check if word matches the encoding pattern of an instruction class.""" @@ -29,7 +26,7 @@ def detect_format(data: bytes) -> type[Inst]: if (word >> 30) == 0b11: for cls in _FORMATS_64: if _matches_encoding(word, cls): - return VOP3SD if cls is VOP3 and ((word >> 16) & 0x3ff) in VOP3SD_OPS else cls + return VOP3SD if cls is VOP3 and ((word >> 16) & 0x3ff) in Inst._VOP3SD_OPS else cls raise ValueError(f"unknown 64-bit format word={word:#010x}") # 32-bit formats for cls in _FORMATS_32: @@ -79,8 +76,6 @@ def waitcnt(vmcnt: int = 0x3f, expcnt: int = 0x7, lgkmcnt: int = 0x3f) -> int: return (expcnt & 0x7) | ((lgkmcnt & 0x3f) << 4) | ((vmcnt & 0x3f) << 10) def _has(op: str, *subs) -> bool: return any(s in op for s in subs) -def _is16(op: str) -> bool: return _has(op, 'f16', 'i16', 'u16', 'b16') and not _has(op, '_f32', '_i32') -def _is64(op: str) -> bool: return _has(op, 'f64', 'i64', 'u64', 'b64') def _omod(v: int) -> str: return {1: " mul:2", 2: " mul:4", 3: " div:2"}.get(v, "") def _src16(inst, v: int) -> str: return _fmt_v16(v) if v >= 256 else inst.lit(v) # format 16-bit src: vgpr.h/l or literal def _mods(*pairs) -> str: return " ".join(m for c, m in pairs if c) @@ -105,50 +100,43 @@ def _opsel_str(opsel: int, n: int, need: bool, is16_d: bool) -> str: # DISASSEMBLER # ═══════════════════════════════════════════════════════════════════════════════ -_VOP1_F64 = {VOP1Op.V_CEIL_F64, VOP1Op.V_FLOOR_F64, VOP1Op.V_FRACT_F64, VOP1Op.V_FREXP_MANT_F64, VOP1Op.V_RCP_F64, VOP1Op.V_RNDNE_F64, VOP1Op.V_RSQ_F64, VOP1Op.V_SQRT_F64, VOP1Op.V_TRUNC_F64} - def _disasm_vop1(inst: VOP1) -> str: - op, name = VOP1Op(inst.op), VOP1Op(inst.op).name.lower() - if op in (VOP1Op.V_NOP, VOP1Op.V_PIPEFLUSH): return name - if op == VOP1Op.V_READFIRSTLANE_B32: return f"v_readfirstlane_b32 {decode_src(inst.vdst)}, v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}" - parts, is_f64_d = name.split('_'), op in _VOP1_F64 or op in (VOP1Op.V_CVT_F64_F32, VOP1Op.V_CVT_F64_I32, VOP1Op.V_CVT_F64_U32) - is_f64_s = op in _VOP1_F64 or op in (VOP1Op.V_CVT_F32_F64, VOP1Op.V_CVT_I32_F64, VOP1Op.V_CVT_U32_F64, VOP1Op.V_FREXP_EXP_I32_F64) + name = inst.op_name.lower() + if inst.op in (VOP1Op.V_NOP, VOP1Op.V_PIPEFLUSH): return name + if inst.op == VOP1Op.V_READFIRSTLANE_B32: return f"v_readfirstlane_b32 {decode_src(inst.vdst)}, v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}" + # 16-bit dst: uses .h/.l suffix (determined by name pattern, not dtype - e.g. sat_pk_u8_i16 outputs 8-bit but uses 16-bit encoding) + parts = name.split('_') is_16d = any(p in ('f16','i16','u16','b16') for p in parts[-2:-1]) or (len(parts) >= 2 and parts[-1] in ('f16','i16','u16','b16') and 'cvt' not in name) - is_16s = parts[-1] in ('f16','i16','u16','b16') and 'sat_pk' not in name - dst = _vreg(inst.vdst, 2) if is_f64_d else _fmt_v16(inst.vdst, 0, 128) if is_16d else f"v{inst.vdst}" - src = _fmt_src(inst.src0, 2) if is_f64_s else _src16(inst, inst.src0) if is_16s else inst.lit(inst.src0) + dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else _fmt_v16(inst.vdst, 0, 128) if is_16d else f"v{inst.vdst}" + src = _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_src_16(0) and 'sat_pk' not in name else inst.lit(inst.src0) return f"{name}_e32 {dst}, {src}" def _disasm_vop2(inst: VOP2) -> str: - op, name = VOP2Op(inst.op), VOP2Op(inst.op).name.lower() - suf, is16 = "" if op == VOP2Op.V_DOT2ACC_F32_F16 else "_e32", _is16(name) and 'pk_' not in name + name = inst.op_name.lower() + suf = "" if inst.op == VOP2Op.V_DOT2ACC_F32_F16 else "_e32" # fmaak: dst = src0 * vsrc1 + K, fmamk: dst = src0 * K + vsrc1 - if op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, 0x{inst._literal:x}" - if op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, 0x{inst._literal:x}, v{inst.vsrc1}" - if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}" - return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}" + (", vcc_lo" if op == VOP2Op.V_CNDMASK_B32 else "") - -VOPC_CLASS = {VOPCOp.V_CMP_CLASS_F16, VOPCOp.V_CMP_CLASS_F32, VOPCOp.V_CMP_CLASS_F64, - VOPCOp.V_CMPX_CLASS_F16, VOPCOp.V_CMPX_CLASS_F32, VOPCOp.V_CMPX_CLASS_F64} + if inst.op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, 0x{inst._literal:x}" + if inst.op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, 0x{inst._literal:x}, v{inst.vsrc1}" + if inst.is_16bit(): return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}" + return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}" + (", vcc_lo" if inst.op == VOP2Op.V_CNDMASK_B32 else "") def _disasm_vopc(inst: VOPC) -> str: - op, name = VOPCOp(inst.op), VOPCOp(inst.op).name.lower() - is64, is16 = _is64(name), _is16(name) - s0 = _fmt_src(inst.src0, 2) if is64 else _src16(inst, inst.src0) if is16 else inst.lit(inst.src0) - s1 = _vreg(inst.vsrc1, 2) if is64 and op not in VOPC_CLASS else _fmt_v16(inst.vsrc1, 0, 128) if is16 else f"v{inst.vsrc1}" - return f"{name}_e32 {s0}, {s1}" if op.value >= 128 else f"{name}_e32 vcc_lo, {s0}, {s1}" + name = inst.op_name.lower() + s0 = _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_16bit() else inst.lit(inst.src0) + s1 = _vreg(inst.vsrc1, inst.src_regs(1)) if inst.src_regs(1) > 1 else _fmt_v16(inst.vsrc1, 0, 128) if inst.is_16bit() else f"v{inst.vsrc1}" + return f"{name}_e32 {s0}, {s1}" if inst.op.value >= 128 else f"{name}_e32 vcc_lo, {s0}, {s1}" NO_ARG_SOPP = {SOPPOp.S_ENDPGM, SOPPOp.S_BARRIER, SOPPOp.S_WAKEUP, SOPPOp.S_ICACHE_INV, SOPPOp.S_WAIT_IDLE, SOPPOp.S_ENDPGM_SAVED, SOPPOp.S_CODE_END, SOPPOp.S_ENDPGM_ORDERED_PS_DONE} def _disasm_sopp(inst: SOPP) -> str: - op, name = SOPPOp(inst.op), SOPPOp(inst.op).name.lower() - if op in NO_ARG_SOPP: return name - if op == SOPPOp.S_WAITCNT: + name = inst.op_name.lower() + if inst.op in NO_ARG_SOPP: return name + if inst.op == SOPPOp.S_WAITCNT: vm, exp, lgkm = (inst.simm16 >> 10) & 0x3f, inst.simm16 & 0xf, (inst.simm16 >> 4) & 0x3f p = [f"vmcnt({vm})" if vm != 0x3f else "", f"expcnt({exp})" if exp != 7 else "", f"lgkmcnt({lgkm})" if lgkm != 0x3f else ""] return f"s_waitcnt {' '.join(x for x in p if x) or '0'}" - if op == SOPPOp.S_DELAY_ALU: + if inst.op == SOPPOp.S_DELAY_ALU: deps, skips = ['VALU_DEP_1','VALU_DEP_2','VALU_DEP_3','VALU_DEP_4','TRANS32_DEP_1','TRANS32_DEP_2','TRANS32_DEP_3','FMA_ACCUM_CYCLE_1','SALU_CYCLE_1','SALU_CYCLE_2','SALU_CYCLE_3'], ['SAME','NEXT','SKIP_1','SKIP_2','SKIP_3','SKIP_4'] id0, skip, id1 = inst.simm16 & 0xf, (inst.simm16 >> 4) & 0x7, (inst.simm16 >> 7) & 0xf dep = lambda v: deps[v-1] if 0 < v <= len(deps) else str(v) @@ -157,25 +145,20 @@ def _disasm_sopp(inst: SOPP) -> str: return f"{name} {inst.simm16}" if name.startswith(('s_cbranch', 's_branch')) else f"{name} 0x{inst.simm16:x}" def _disasm_smem(inst: SMEM) -> str: - op = SMEMOp(inst.op) - name = op.name.lower() - if op in (SMEMOp.S_GL1_INV, SMEMOp.S_DCACHE_INV): return name + name = inst.op_name.lower() + if inst.op in (SMEMOp.S_GL1_INV, SMEMOp.S_DCACHE_INV): return name off_s = f"{decode_src(inst.soffset)} offset:0x{inst.offset:x}" if inst.offset and inst.soffset != 124 else f"0x{inst.offset:x}" if inst.offset else decode_src(inst.soffset) - sbase_idx, sbase_count = inst.sbase * 2, 4 if (8 <= inst.op <= 12 or name == 's_atc_probe_buffer') else 2 + sbase_idx, sbase_count = inst.sbase * 2, 4 if (8 <= inst.op.value <= 12 or name == 's_atc_probe_buffer') else 2 sbase_str = _fmt_src(sbase_idx, sbase_count) if sbase_count == 2 else _sreg(sbase_idx, sbase_count) if sbase_idx <= 105 else _reg("ttmp", sbase_idx - 108, sbase_count) if name in ('s_atc_probe', 's_atc_probe_buffer'): return f"{name} {inst.sdata}, {sbase_str}, {off_s}" - width = {0:1, 1:2, 2:4, 3:8, 4:16, 8:1, 9:2, 10:4, 11:8, 12:16}.get(inst.op, 1) - return f"{name} {_fmt_sdst(inst.sdata, width)}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (inst.dlc, " dlc")) + return f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs())}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (inst.dlc, " dlc")) def _disasm_flat(inst: FLAT) -> str: - name = FLATOp(inst.op).name.lower() + name = inst.op_name.lower() seg = ['flat', 'scratch', 'global'][inst.seg] if inst.seg < 3 else 'flat' instr = f"{seg}_{name.split('_', 1)[1] if '_' in name else name}" off_val = inst.offset if seg == 'flat' else (inst.offset if inst.offset < 4096 else inst.offset - 8192) - suffix = name.split('_')[-1] - w = {'b32':1,'b64':2,'b96':3,'b128':4,'u8':1,'i8':1,'u16':1,'i16':1,'u32':1,'i32':1,'u64':2,'i64':2,'f32':1,'f64':2}.get(suffix, 1) - if 'cmpswap' in name: w *= 2 - if name.endswith('_x2') or 'x2' in suffix: w = max(w, 2) + w = inst.dst_regs() * (2 if 'cmpswap' in name else 1) mods = f"{f' offset:{off_val}' if off_val else ''}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}" # saddr if seg == 'flat' or inst.saddr == 0x7F: saddr_s = "" @@ -195,11 +178,11 @@ def _disasm_flat(inst: FLAT) -> str: return f"{instr} {_vreg(inst.vdst, w)}, {addr_s}{saddr_s}{mods}" def _disasm_ds(inst: DS) -> str: - op, name = DSOp(inst.op), DSOp(inst.op).name.lower() + op, name = inst.op, inst.op_name.lower() gds = " gds" if inst.gds else "" off = f" offset:{inst.offset0 | (inst.offset1 << 8)}" if inst.offset0 or inst.offset1 else "" off2 = f" offset0:{inst.offset0} offset1:{inst.offset1}" if inst.offset0 or inst.offset1 else "" - w = 4 if '128' in name else 3 if '96' in name else 2 if (name.endswith('64') or 'gs_reg' in name) else 1 + w = inst.dst_regs() d0, d1, dst, addr = _vreg(inst.data0, w), _vreg(inst.data1, w), _vreg(inst.vdst, w), f"v{inst.addr}" if op == DSOp.DS_NOP: return name @@ -223,49 +206,34 @@ def _disasm_ds(inst: DS) -> str: return f"{name} {dst}, {addr}, {d0}{off}{gds}" if '_rtn' in name else f"{name} {addr}, {d0}{off}{gds}" def _disasm_vop3(inst: VOP3) -> str: - op = VOP3SDOp(inst.op) if inst.op in VOP3SD_OPS else VOP3Op(inst.op) - name = op.name.lower() + op, name = inst.op, inst.op_name.lower() # VOP3SD (shared encoding) - if inst.op in VOP3SD_OPS: + if isinstance(op, VOP3SDOp): sdst = (inst.clmp << 7) | (inst.opsel << 3) | inst.abs - is64, mad64 = 'f64' in name, _has(name, 'mad_i64_i32', 'mad_u64_u32') - def src(v, neg, ext=False): s = _fmt_src(v, 2) if ext or is64 else inst.lit(v); return f"-{s}" if neg else s - s0, s1, s2 = src(inst.src0, inst.neg & 1), src(inst.src1, inst.neg & 2), src(inst.src2, inst.neg & 4, mad64) - dst = _vreg(inst.vdst, 2) if is64 or mad64 else f"v{inst.vdst}" - if op in (VOP3SDOp.V_ADD_CO_U32, VOP3SDOp.V_SUB_CO_U32, VOP3SDOp.V_SUBREV_CO_U32): return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {s0}, {s1}" - if op in (VOP3SDOp.V_ADD_CO_CI_U32, VOP3SDOp.V_SUB_CO_CI_U32, VOP3SDOp.V_SUBREV_CO_CI_U32): return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {s0}, {s1}, {s2}" - return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {s0}, {s1}, {s2}" + _omod(inst.omod) + def src(v, neg, n): s = _fmt_src(v, n) if n > 1 else inst.lit(v); return f"-{s}" if neg else s + s0, s1, s2 = src(inst.src0, inst.neg & 1, inst.src_regs(0)), src(inst.src1, inst.neg & 2, inst.src_regs(1)), src(inst.src2, inst.neg & 4, inst.src_regs(2)) + dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}" + srcs = f"{s0}, {s1}, {s2}" if inst.num_srcs() == 3 else f"{s0}, {s1}" + return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {srcs}" + _omod(inst.omod) - # Detect operand sizes - is64 = _is64(name) - is64_src, is64_dst = False, False + # Detect 16-bit operand sizes (for .h/.l suffix handling) is16_d = is16_s = is16_s2 = False if 'cvt_pk' in name: is16_s = name.endswith('16') elif m := re.match(r'v_(?:cvt|frexp_exp)_([a-z0-9_]+)_([a-z0-9]+)', name): is16_d, is16_s = _has(m.group(1), 'f16','i16','u16','b16'), _has(m.group(2), 'f16','i16','u16','b16') - is64_src, is64_dst = '64' in m.group(2), '64' in m.group(1) - is16_s2, is64 = is16_s, False + is16_s2 = is16_s elif re.match(r'v_mad_[iu]32_[iu]16', name): is16_s = True elif 'pack_b32' in name: is16_s = is16_s2 = True - else: is16_d = is16_s = is16_s2 = _is16(name) and not _has(name, 'dot2', 'pk_', 'sad', 'msad', 'qsad', 'mqsad') - - # Source counts - shift64 = 'rev' in name and '64' in name and name.startswith('v_') - ldexp64 = op == VOP3Op.V_LDEXP_F64 - trig = op == VOP3Op.V_TRIG_PREOP_F64 - sad64, mqsad = _has(name, 'qsad_pk', 'mqsad_pk'), 'mqsad_u32' in name - s0n = 2 if ((is64 and not shift64) or sad64 or mqsad or is64_src) else 1 - s1n = 2 if (is64 and not _has(name, 'class') and not ldexp64 and not trig) else 1 - s2n = 4 if mqsad else 2 if (is64 or sad64) else 1 + else: is16_d = is16_s = is16_s2 = inst.is_16bit() any_hi = inst.opsel != 0 - s0 = _vop3_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.opsel&1, s0n, is16_s, any_hi) - s1 = _vop3_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.opsel&2, s1n, is16_s, any_hi) - s2 = _vop3_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.opsel&4, s2n, is16_s2, any_hi) + s0 = _vop3_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.opsel&1, inst.src_regs(0), is16_s, any_hi) + s1 = _vop3_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.opsel&2, inst.src_regs(1), is16_s, any_hi) + s2 = _vop3_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.opsel&4, inst.src_regs(2), is16_s2, any_hi) # Destination - dn = 4 if mqsad else 2 if (is64 or sad64 or is64_dst) else 1 + dn = inst.dst_regs() if op == VOP3Op.V_READLANE_B32: dst = _fmt_sdst(inst.vdst, 1) elif dn > 1: dst = _vreg(inst.vdst, dn) elif is16_d: dst = f"v{inst.vdst}.h" if (inst.opsel & 8) else f"v{inst.vdst}.l" if any_hi else f"v{inst.vdst}" @@ -278,24 +246,24 @@ def _disasm_vop3(inst: VOP3) -> str: if inst.op < 256: # VOPC return f"{name}_e64 {s0}, {s1}" if name.startswith('v_cmpx') else f"{name}_e64 {_fmt_sdst(inst.vdst, 1)}, {s0}, {s1}" if inst.op < 384: # VOP2 - os = _opsel_str(inst.opsel, 3, need_opsel, is16_d) if 'cndmask' in name else _opsel_str(inst.opsel, 2, need_opsel, is16_d) - return f"{name}_e64 {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if 'cndmask' in name else f"{name}_e64 {dst}, {s0}, {s1}{os}{cl}{om}" + n = inst.num_srcs() + os = _opsel_str(inst.opsel, n, need_opsel, is16_d) + return f"{name}_e64 {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if n == 3 else f"{name}_e64 {dst}, {s0}, {s1}{os}{cl}{om}" if inst.op < 512: # VOP1 return f"{name}_e64" if op in (VOP3Op.V_NOP, VOP3Op.V_PIPEFLUSH) else f"{name}_e64 {dst}, {s0}{_opsel_str(inst.opsel, 1, need_opsel, is16_d)}{cl}{om}" # Native VOP3 - is3 = _has(name, 'fma', 'mad', 'min3', 'max3', 'med3', 'div_fix', 'div_fmas', 'sad', 'lerp', 'align', 'cube', 'bfe', 'bfi', - 'perm_b32', 'permlane', 'cndmask', 'xor3', 'or3', 'add3', 'lshl_or', 'and_or', 'lshl_add', 'add_lshl', 'xad', 'maxmin', 'minmax', 'dot2', 'cvt_pk_u8', 'mullit') - os = _opsel_str(inst.opsel, 3 if is3 else 2, need_opsel, is16_d) - return f"{name} {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if is3 else f"{name} {dst}, {s0}, {s1}{os}{cl}{om}" + n = inst.num_srcs() + os = _opsel_str(inst.opsel, n, need_opsel, is16_d) + return f"{name} {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if n == 3 else f"{name} {dst}, {s0}, {s1}{os}{cl}{om}" def _disasm_vop3sd(inst: VOP3SD) -> str: - op, name = VOP3SDOp(inst.op), VOP3SDOp(inst.op).name.lower() - is64, mad64 = 'f64' in name, _has(name, 'mad_i64_i32', 'mad_u64_u32') - def src(v, neg, ext=False): s = _fmt_src(v, 2) if ext or is64 else inst.lit(v); return f"-{s}" if neg else s - s0, s1, s2 = src(inst.src0, inst.neg & 1), src(inst.src1, inst.neg & 2), src(inst.src2, inst.neg & 4, mad64) - dst, is2src = _vreg(inst.vdst, 2) if is64 or mad64 else f"v{inst.vdst}", op in (VOP3SDOp.V_ADD_CO_U32, VOP3SDOp.V_SUB_CO_U32, VOP3SDOp.V_SUBREV_CO_U32) + name = inst.op_name.lower() + def src(v, neg, n): s = _fmt_src(v, n) if n > 1 else inst.lit(v); return f"-{s}" if neg else s + s0, s1, s2 = src(inst.src0, inst.neg & 1, inst.src_regs(0)), src(inst.src1, inst.neg & 2, inst.src_regs(1)), src(inst.src2, inst.neg & 4, inst.src_regs(2)) + dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}" + srcs = f"{s0}, {s1}, {s2}" if inst.num_srcs() == 3 else f"{s0}, {s1}" suffix = "_e64" if name.startswith('v_') and 'co_' in name else "" - return f"{name}{suffix} {dst}, {_fmt_sdst(inst.sdst, 1)}, {s0}, {s1}{'' if is2src else f', {s2}'}{' clamp' if inst.clmp else ''}{_omod(inst.omod)}" + return f"{name}{suffix} {dst}, {_fmt_sdst(inst.sdst, 1)}, {srcs}{' clamp' if inst.clmp else ''}{_omod(inst.omod)}" def _disasm_vopd(inst: VOPD) -> str: lit = inst._literal or inst.literal @@ -304,26 +272,25 @@ def _disasm_vopd(inst: VOPD) -> str: return f"{half(nx, inst.vdstx, inst.srcx0, inst.vsrcx1)} :: {half(ny, vdst_y, inst.srcy0, inst.vsrcy1)}" def _disasm_vop3p(inst: VOP3P) -> str: - name = VOP3POp(inst.op).name.lower() - is_wmma, is_3src, is_fma_mix = 'wmma' in name, _has(name, 'fma', 'mad', 'dot', 'wmma'), 'fma_mix' in name + name = inst.op_name.lower() + is_wmma, n, is_fma_mix = 'wmma' in name, inst.num_srcs(), 'fma_mix' in name if is_wmma: sc = 2 if 'iu4' in name else 4 if 'iu8' in name else 8 src0, src1, src2, dst = _fmt_src(inst.src0, sc), _fmt_src(inst.src1, sc), _fmt_src(inst.src2, 8), _vreg(inst.vdst, 8) else: src0, src1, src2, dst = _fmt_src(inst.src0, 1), _fmt_src(inst.src1, 1), _fmt_src(inst.src2, 1), f"v{inst.vdst}" - n, opsel_hi = 3 if is_3src else 2, inst.opsel_hi | (inst.opsel_hi2 << 2) + opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2) if is_fma_mix: def m(s, neg, abs_): return f"-{f'|{s}|' if abs_ else s}" if neg else (f"|{s}|" if abs_ else s) src0, src1, src2 = m(src0, inst.neg & 1, inst.neg_hi & 1), m(src1, inst.neg & 2, inst.neg_hi & 2), m(src2, inst.neg & 4, inst.neg_hi & 4) mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi else []) + (["clamp"] if inst.clmp else []) else: - mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != (7 if is_3src else 3) else []) + \ + mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != (7 if n == 3 else 3) else []) + \ ([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if inst.clmp else []) - return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if is_3src else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}" + return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if n == 3 else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}" def _disasm_buf(inst: MUBUF | MTBUF) -> str: - op = MTBUFOp(inst.op) if isinstance(inst, MTBUF) else MUBUFOp(inst.op) - name = op.name.lower() - if op in (MUBUFOp.BUFFER_GL0_INV, MUBUFOp.BUFFER_GL1_INV): return name + name = inst.op_name.lower() + if inst.op in (MUBUFOp.BUFFER_GL0_INV, MUBUFOp.BUFFER_GL1_INV): return name w = (2 if _has(name, 'xyz', 'xyzw') else 1) if 'd16' in name else \ ((2 if _has(name, 'b64', 'u64', 'i64') else 1) * (2 if 'cmpswap' in name else 1)) if 'atomic' in name else \ {'b32':1,'b64':2,'b96':3,'b128':4,'b16':1,'x':1,'xy':2,'xyz':3,'xyzw':4}.get(name.split('_')[-1], 1) @@ -351,7 +318,7 @@ def _mimg_vaddr_width(name: str, dim: int, a16: bool) -> int: return (base + packed + 1) // 2 + unpacked if a16 else base + packed + unpacked def _disasm_mimg(inst: MIMG) -> str: - name = MIMGOp(inst.op).name.lower() + name = inst.op_name.lower() srsrc_base = inst.srsrc * 4 srsrc_str = _sreg_or_ttmp(srsrc_base, 8) # BVH intersect ray: special case with 4 SGPR srsrc @@ -379,66 +346,38 @@ def _disasm_mimg(inst: MIMG) -> str: ssamp_str = ", " + _sreg_or_ttmp(inst.ssamp * 4, 4) return f"{name} {_vreg(inst.vdata, vdata)}, {vaddr_str}, {srsrc_str}{ssamp_str} {' '.join(mods)}" -def _sop_widths(name: str) -> tuple[int, int, int]: - """Return (dst_width, src0_width, src1_width) in register count for SOP instructions.""" - if name in ('s_bitset0_b64', 's_bitset1_b64', 's_bfm_b64'): return 2, 1, 1 - if name in ('s_lshl_b64', 's_lshr_b64', 's_ashr_i64', 's_bfe_u64', 's_bfe_i64'): return 2, 2, 1 - if name in ('s_bitcmp0_b64', 's_bitcmp1_b64'): return 1, 2, 1 - if m := re.search(r'_(b|i|u)(32|64)_(b|i|u)(32|64)$', name): return 2 if m.group(2) == '64' else 1, 2 if m.group(4) == '64' else 1, 1 - if m := re.search(r'_(b|i|u)(32|64)$', name): sz = 2 if m.group(2) == '64' else 1; return sz, sz, sz - return 1, 1, 1 - def _disasm_sop1(inst: SOP1) -> str: - op, name = SOP1Op(inst.op), SOP1Op(inst.op).name.lower() + op, name = inst.op, inst.op_name.lower() if op == SOP1Op.S_GETPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}" if op in (SOP1Op.S_SETPC_B64, SOP1Op.S_RFE_B64): return f"{name} {_fmt_src(inst.ssrc0, 2)}" if op == SOP1Op.S_SWAPPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}, {_fmt_src(inst.ssrc0, 2)}" - if op in (SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64): return f"{name} {_fmt_sdst(inst.sdst, 2 if 'b64' in name else 1)}, sendmsg({MSG.get(inst.ssrc0, str(inst.ssrc0))})" - dn, s0n, _ = _sop_widths(name) - return f"{name} {_fmt_sdst(inst.sdst, dn)}, {inst.lit(inst.ssrc0) if s0n == 1 else _fmt_src(inst.ssrc0, s0n)}" + if op in (SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64): return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, sendmsg({MSG.get(inst.ssrc0, str(inst.ssrc0))})" + return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {inst.lit(inst.ssrc0) if inst.src_regs(0) == 1 else _fmt_src(inst.ssrc0, inst.src_regs(0))}" def _disasm_sop2(inst: SOP2) -> str: - name = SOP2Op(inst.op).name.lower() - dn, s0n, s1n = _sop_widths(name) - return f"{name} {_fmt_sdst(inst.sdst, dn)}, {inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, s0n)}, {inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, s1n)}" + return f"{inst.op_name.lower()} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0))}, {inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1))}" def _disasm_sopc(inst: SOPC) -> str: - name = SOPCOp(inst.op).name.lower() - _, s0n, s1n = _sop_widths(name) - return f"{name} {_fmt_src(inst.ssrc0, s0n)}, {_fmt_src(inst.ssrc1, s1n)}" + return f"{inst.op_name.lower()} {_fmt_src(inst.ssrc0, inst.src_regs(0))}, {_fmt_src(inst.ssrc1, inst.src_regs(1))}" def _disasm_sopk(inst: SOPK) -> str: - op, name = SOPKOp(inst.op), SOPKOp(inst.op).name.lower() + op, name = inst.op, inst.op_name.lower() if op == SOPKOp.S_VERSION: return f"{name} 0x{inst.simm16:x}" if op in (SOPKOp.S_SETREG_B32, SOPKOp.S_GETREG_B32): hid, hoff, hsz = inst.simm16 & 0x3f, (inst.simm16 >> 6) & 0x1f, ((inst.simm16 >> 11) & 0x1f) + 1 hs = f"0x{inst.simm16:x}" if hid in (16, 17) else f"hwreg({HWREG.get(hid, str(hid))}, {hoff}, {hsz})" return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1)}" if op == SOPKOp.S_SETREG_B32 else f"{name} {_fmt_sdst(inst.sdst, 1)}, {hs}" - dn, _, _ = _sop_widths(name) - return f"{name} {_fmt_sdst(inst.sdst, dn)}, 0x{inst.simm16:x}" + return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, 0x{inst.simm16:x}" def _disasm_vinterp(inst: VINTERP) -> str: - name = VINTERPOp(inst.op).name.lower() - src0 = f"-{inst.lit(inst.src0)}" if inst.neg & 1 else inst.lit(inst.src0) - src1 = f"-{inst.lit(inst.src1)}" if inst.neg & 2 else inst.lit(inst.src1) - src2 = f"-{inst.lit(inst.src2)}" if inst.neg & 4 else inst.lit(inst.src2) mods = _mods((inst.waitexp, f"wait_exp:{inst.waitexp}"), (inst.clmp, "clamp")) - return f"{name} v{inst.vdst}, {src0}, {src1}, {src2}" + (" " + mods if mods else "") - -def _disasm_generic(inst: Inst) -> str: - name = f"op_{inst.op}" - def format_field(field_name, val): - val = unwrap(val) - if field_name in SRC_FIELDS: return inst.lit(val) if val != 255 else "0xff" - return f"{'s' if field_name == 'sdst' else 'v'}{val}" if field_name in ('sdst', 'vdst') else f"v{val}" if field_name == 'vsrc1' else f"0x{val:x}" if field_name == 'simm16' else str(val) - operands = [format_field(field_name, inst._values.get(field_name, 0)) for field_name in inst._fields if field_name not in ('encoding', 'op')] - return f"{name} {', '.join(operands)}" if operands else name + return f"{inst.op_name.lower()} v{inst.vdst}, {inst.lit(inst.src0, inst.neg & 1)}, {inst.lit(inst.src1, inst.neg & 2)}, {inst.lit(inst.src2, inst.neg & 4)}" + (" " + mods if mods else "") DISASM_HANDLERS = {VOP1: _disasm_vop1, VOP2: _disasm_vop2, VOPC: _disasm_vopc, VOP3: _disasm_vop3, VOP3SD: _disasm_vop3sd, VOPD: _disasm_vopd, VOP3P: _disasm_vop3p, VINTERP: _disasm_vinterp, SOPP: _disasm_sopp, SMEM: _disasm_smem, DS: _disasm_ds, FLAT: _disasm_flat, MUBUF: _disasm_buf, MTBUF: _disasm_buf, MIMG: _disasm_mimg, SOP1: _disasm_sop1, SOP2: _disasm_sop2, SOPC: _disasm_sopc, SOPK: _disasm_sopk} -def disasm(inst: Inst) -> str: return DISASM_HANDLERS.get(type(inst), _disasm_generic)(inst) +def disasm(inst: Inst) -> str: return DISASM_HANDLERS[type(inst)](inst) # ═══════════════════════════════════════════════════════════════════════════════ # ASSEMBLER diff --git a/extra/assembly/amd/dsl.py b/extra/assembly/amd/dsl.py index a1bea30722..fdbc7b462c 100644 --- a/extra/assembly/amd/dsl.py +++ b/extra/assembly/amd/dsl.py @@ -4,6 +4,8 @@ from __future__ import annotations import struct, math from enum import IntEnum from typing import overload, Annotated, TypeVar, Generic +from extra.assembly.amd.autogen.rdna3.enum import (VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, VOPDOp, SOP1Op, SOP2Op, + SOPCOp, SOPKOp, SOPPOp, SMEMOp, DSOp, FLATOp, MUBUFOp, MTBUFOp, MIMGOp, VINTERPOp) # Common masks and bit conversion functions MASK32, MASK64 = 0xffffffff, 0xffffffffffffffff @@ -28,6 +30,79 @@ def _i64(f): try: return struct.unpack(" 0 else 0xfff0000000000000 +# Instruction spec - register counts and dtypes derived from instruction names +import re +_REGS = {'B32': 1, 'B64': 2, 'B96': 3, 'B128': 4, 'B256': 8, 'B512': 16, + 'F32': 1, 'I32': 1, 'U32': 1, 'F64': 2, 'I64': 2, 'U64': 2, + 'F16': 1, 'I16': 1, 'U16': 1, 'B16': 1, 'I8': 1, 'U8': 1, 'B8': 1} +def _suffix(name: str) -> tuple[str | None, str | None]: + name = name.upper() + if m := re.search(r'CVT_([FIUB]\d+)_([FIUB]\d+)$', name): return m.group(1), m.group(2) + if m := re.search(r'(?:MAD|MUL)_([IU]\d+)_([IU]\d+)$', name): return m.group(1), m.group(2) + if m := re.search(r'PACK_([FIUB]\d+)_([FIUB]\d+)$', name): return m.group(1), m.group(2) + # Generic dst_src pattern: S_BCNT0_I32_B64, S_BITREPLICATE_B64_B32, V_FREXP_EXP_I32_F64, etc. + if m := re.search(r'_([FIUB]\d+)_([FIUB]\d+)$', name): return m.group(1), m.group(2) + if m := re.search(r'_([FIUB](?:32|64|16|8|96|128|256|512))$', name): return m.group(1), m.group(1) + return None, None +_SPECIAL_REGS = { + 'V_LSHLREV_B64': (2, 1, 2, 1), 'V_LSHRREV_B64': (2, 1, 2, 1), 'V_ASHRREV_I64': (2, 1, 2, 1), + 'S_LSHL_B64': (2, 2, 1, 1), 'S_LSHR_B64': (2, 2, 1, 1), 'S_ASHR_I64': (2, 2, 1, 1), + 'S_BFE_U64': (2, 2, 1, 1), 'S_BFE_I64': (2, 2, 1, 1), 'S_BFM_B64': (2, 1, 1, 1), + 'S_BITSET0_B64': (2, 1, 1, 1), 'S_BITSET1_B64': (2, 1, 1, 1), + 'S_BITCMP0_B64': (1, 2, 1, 1), 'S_BITCMP1_B64': (1, 2, 1, 1), + 'V_LDEXP_F64': (2, 2, 1, 1), 'V_TRIG_PREOP_F64': (2, 2, 1, 1), + 'V_CMP_CLASS_F64': (1, 2, 1, 1), 'V_CMPX_CLASS_F64': (1, 2, 1, 1), + 'V_CMP_CLASS_F32': (1, 1, 1, 1), 'V_CMPX_CLASS_F32': (1, 1, 1, 1), + 'V_CMP_CLASS_F16': (1, 1, 1, 1), 'V_CMPX_CLASS_F16': (1, 1, 1, 1), + 'V_MAD_U64_U32': (2, 1, 1, 2), 'V_MAD_I64_I32': (2, 1, 1, 2), + 'V_QSAD_PK_U16_U8': (2, 2, 1, 2), 'V_MQSAD_PK_U16_U8': (2, 2, 1, 2), 'V_MQSAD_U32_U8': (4, 2, 1, 4), +} +_SPECIAL_DTYPE = { + 'V_LSHLREV_B64': ('B64', 'U32', 'B64', None), 'V_LSHRREV_B64': ('B64', 'U32', 'B64', None), 'V_ASHRREV_I64': ('I64', 'U32', 'I64', None), + 'S_LSHL_B64': ('B64', 'B64', 'U32', None), 'S_LSHR_B64': ('B64', 'B64', 'U32', None), 'S_ASHR_I64': ('I64', 'I64', 'U32', None), + 'S_BFE_U64': ('U64', 'U64', 'U32', None), 'S_BFE_I64': ('I64', 'I64', 'U32', None), + 'S_BFM_B64': ('B64', 'U32', 'U32', None), 'S_BITSET0_B64': ('B64', 'U32', None, None), 'S_BITSET1_B64': ('B64', 'U32', None, None), + 'S_BITCMP0_B64': ('SCC', 'B64', 'U32', None), 'S_BITCMP1_B64': ('SCC', 'B64', 'U32', None), + 'V_LDEXP_F64': ('F64', 'F64', 'I32', None), 'V_TRIG_PREOP_F64': ('F64', 'F64', 'U32', None), + 'V_CMP_CLASS_F64': ('VCC', 'F64', 'U32', None), 'V_CMPX_CLASS_F64': ('EXEC', 'F64', 'U32', None), + 'V_CMP_CLASS_F32': ('VCC', 'F32', 'U32', None), 'V_CMPX_CLASS_F32': ('EXEC', 'F32', 'U32', None), + 'V_CMP_CLASS_F16': ('VCC', 'F16', 'U32', None), 'V_CMPX_CLASS_F16': ('EXEC', 'F16', 'U32', None), + 'V_MAD_U64_U32': ('U64', 'U32', 'U32', 'U64'), 'V_MAD_I64_I32': ('I64', 'I32', 'I32', 'I64'), + 'V_QSAD_PK_U16_U8': ('B64', 'B64', 'B64', 'B64'), 'V_MQSAD_PK_U16_U8': ('B64', 'B64', 'B64', 'B64'), + 'V_MQSAD_U32_U8': ('B128', 'B64', 'B64', 'B128'), +} +def spec_regs(name: str) -> tuple[int, int, int, int]: + name = name.upper() + if name in _SPECIAL_REGS: return _SPECIAL_REGS[name] + if 'SAD' in name and 'U8' in name and 'QSAD' not in name and 'MQSAD' not in name: return 1, 1, 1, 1 + dst_suf, src_suf = _suffix(name) + return _REGS.get(dst_suf, 1), _REGS.get(src_suf, 1), _REGS.get(src_suf, 1), _REGS.get(src_suf, 1) +def spec_dtype(name: str) -> tuple[str | None, str | None, str | None, str | None]: + name = name.upper() + if name in _SPECIAL_DTYPE: return _SPECIAL_DTYPE[name] + if 'SAD' in name and ('U8' in name or 'U16' in name) and 'QSAD' not in name and 'MQSAD' not in name: return 'U32', 'U32', 'U32', 'U32' + if '_CMP_' in name or '_CMPX_' in name: + dst_suf, src_suf = _suffix(name) + return 'EXEC' if '_CMPX_' in name else 'VCC', src_suf, src_suf, None + dst_suf, src_suf = _suffix(name) + return dst_suf, src_suf, src_suf, src_suf +def spec_is_16bit(name: str) -> bool: + name = name.upper() + if 'SAD' in name or 'PACK' in name or '_PK_' in name or 'SAT_PK' in name or 'DOT2' in name: return False + if '_F32' in name or '_I32' in name or '_U32' in name or '_B32' in name: return False # mixed ops like V_DOT2ACC_F32_F16 + return bool(re.search(r'_[FIUB]16(?:_|$)', name)) +def spec_is_64bit(name: str) -> bool: return bool(re.search(r'_[FIUB]64(?:_|$)', name.upper())) +_3SRC = {'FMA', 'MAD', 'MIN3', 'MAX3', 'MED3', 'DIV_FIX', 'DIV_FMAS', 'DIV_SCALE', 'SAD', 'LERP', 'ALIGN', 'CUBE', 'BFE', 'BFI', + 'PERM_B32', 'PERMLANE', 'CNDMASK', 'XOR3', 'OR3', 'ADD3', 'LSHL_OR', 'AND_OR', 'LSHL_ADD', 'ADD_LSHL', 'XAD', 'MAXMIN', + 'MINMAX', 'DOT2', 'DOT4', 'DOT8', 'WMMA', 'CVT_PK_U8', 'MULLIT', 'CO_CI'} +_2SRC = {'FMAC'} # FMAC uses dst as implicit accumulator, so only 2 explicit sources +def spec_num_srcs(name: str) -> int: + name = name.upper() + if any(k in name for k in _2SRC): return 2 + return 3 if any(k in name for k in _3SRC) else 2 +def is_dtype_16(dt: str | None) -> bool: return dt is not None and '16' in dt +def is_dtype_64(dt: str | None) -> bool: return dt is not None and '64' in dt + # Bit field DSL class BitField: def __init__(self, hi: int, lo: int, name: str | None = None): self.hi, self.lo, self.name, self._marker = hi, lo, name, None @@ -54,6 +129,10 @@ class BitField: val = unwrap(obj._values.get(self.name, 0)) # Convert to IntEnum if marker is an IntEnum subclass if self.marker and isinstance(self.marker, type) and issubclass(self.marker, IntEnum): + # VOP3 with VOPC opcodes (0-255) -> VOPCOp, VOP3SD opcodes -> VOP3SDOp + if self.marker is VOP3Op: + if val < 256: return VOPCOp(val) + if val in Inst._VOP3SD_OPS: return VOP3SDOp(val) try: return self.marker(val) except ValueError: pass return val @@ -201,10 +280,13 @@ class Inst: # Track literal value if needed if encoded == 255 and self._literal is None: import struct - is_64 = self._is_64bit_op() + # Check if THIS source uses 64-bit encoding (not just src0) + src_idx = {'src0': 0, 'src1': 1, 'src2': 2, 'ssrc0': 0, 'ssrc1': 1}.get(name, 0) + src_regs = self.src_regs(src_idx) + is_64 = src_regs == 2 if isinstance(val, SrcMod) and not isinstance(val, Reg): lit32 = val.val & MASK32 elif isinstance(val, int) and not isinstance(val, IntEnum): lit32 = val & MASK32 - elif isinstance(val, float): lit32 = _i32(val) + elif isinstance(val, float): lit32 = (_i64(val) >> 32) if is_64 else _i32(val) # f64: high 32 bits of f64 repr else: return self._literal = (lit32 << 32) if is_64 else lit32 @@ -235,11 +317,21 @@ class Inst: raise ValueError(f"SOP1 {orig_args['op'].name} expects {expected} register(s) for {fld}, got {orig_args[fld].count}") def __init__(self, *args, literal: int | None = None, **kwargs): - self._values, self._literal = dict(self._defaults), literal + self._values, self._literal = dict(self._defaults), None field_names = [n for n in self._fields if n != 'encoding'] orig_args = dict(zip(field_names, args)) | kwargs self._values.update(orig_args) self._validate(orig_args) + # Pre-shift literal for 64-bit sources (literal param is always raw 32-bit value from user) + if literal is not None: + # Find which source uses the literal (255) and check its register count + for n, idx in [('src0', 0), ('src1', 1), ('src2', 2), ('ssrc0', 0), ('ssrc1', 1)]: + v = orig_args.get(n) + if (isinstance(v, RawImm) and v.val == 255) or (isinstance(v, int) and v == 255): + self._literal = (literal << 32) if self.src_regs(idx) == 2 else literal + break + else: + self._literal = literal # fallback if no literal source found cls_name = self.__class__.__name__ # Format-specific setup @@ -297,11 +389,9 @@ class Inst: op_name = op.name if hasattr(op, 'name') else None # Look up op name from int if needed (happens in from_bytes path) if op_name is None and self.__class__.__name__ == 'VOP3': - from extra.assembly.amd.autogen.rdna3.ins import VOP3Op try: op_name = VOP3Op(op).name except ValueError: pass if op_name is None and self.__class__.__name__ == 'VOPC': - from extra.assembly.amd.autogen.rdna3.ins import VOPCOp try: op_name = VOPCOp(op).name except ValueError: pass if op_name is None: return False @@ -312,8 +402,16 @@ class Inst: result = self.to_int().to_bytes(self._size(), 'little') lit = self._get_literal() or getattr(self, '_literal', None) if lit is None: return result - # For 64-bit ops, literal is stored in high 32 bits internally, but encoded as 4 bytes - lit32 = (lit >> 32) if self._is_64bit_op() else lit + # For 64-bit sources, literal is stored in high 32 bits internally, but encoded as 4 bytes + # Find which source uses the literal (255) and check its register count + lit_src_is_64 = False + for n, idx in [('src0', 0), ('src1', 1), ('src2', 2), ('ssrc0', 0), ('ssrc1', 1)]: + if n not in self._values: continue + v = self._values[n] + if (isinstance(v, RawImm) and v.val == 255) or (isinstance(v, int) and v == 255): + lit_src_is_64 = self.is_src_64(idx) + break + lit32 = (lit >> 32) if lit_src_is_64 else lit return result + (lit32 & MASK32).to_bytes(4, 'little') @classmethod @@ -343,9 +441,16 @@ class Inst: if has_literal: # For 64-bit ops, the literal is 32 bits placed in the HIGH 32 bits of the 64-bit value # (low 32 bits are zero). This is how AMD hardware interprets 32-bit literals for 64-bit ops. + # Check which source uses the literal and whether THAT source is 64-bit if len(data) >= cls._size() + 4: lit32 = int.from_bytes(data[cls._size():cls._size()+4], 'little') - inst._literal = (lit32 << 32) if inst._is_64bit_op() else lit32 + # Find which source has literal (255) and check its register count + lit_src_is_64 = False + for n, idx in [('src0', 0), ('src1', 1), ('src2', 2)]: + if n in inst._values and isinstance(inst._values[n], RawImm) and inst._values[n].val == 255: + lit_src_is_64 = inst.src_regs(idx) == 2 + break + inst._literal = (lit32 << 32) if lit_src_is_64 else lit32 return inst def __repr__(self): @@ -360,7 +465,9 @@ class Inst: if name.startswith('_'): raise AttributeError(name) return unwrap(self._values.get(name, 0)) - def lit(self, v: int) -> str: return f"0x{self._literal:x}" if v == 255 and self._literal else decode_src(v) + def lit(self, v: int, neg: bool = False) -> str: + s = f"0x{self._literal:x}" if v == 255 and self._literal else decode_src(v) + return f"-{s}" if neg else s def __eq__(self, other): if not isinstance(other, Inst): return NotImplemented @@ -372,5 +479,37 @@ class Inst: from extra.assembly.amd.asm import disasm return disasm(self) + _enum_map = {'VOP1': VOP1Op, 'VOP2': VOP2Op, 'VOP3': VOP3Op, 'VOP3SD': VOP3SDOp, 'VOP3P': VOP3POp, 'VOPC': VOPCOp, + 'SOP1': SOP1Op, 'SOP2': SOP2Op, 'SOPC': SOPCOp, 'SOPK': SOPKOp, 'SOPP': SOPPOp, + 'SMEM': SMEMOp, 'DS': DSOp, 'FLAT': FLATOp, 'MUBUF': MUBUFOp, 'MTBUF': MTBUFOp, 'MIMG': MIMGOp, + 'VOPD': VOPDOp, 'VINTERP': VINTERPOp} + _VOP3SD_OPS = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770} + + @property + def op(self): + """Return the op as an enum (e.g., VOP1Op.V_MOV_B32). VOP3 returns VOPCOp/VOP3SDOp for those op ranges.""" + val = self._values.get('op') + if val is None: return None + if hasattr(val, 'name'): return val # already an enum + cls_name = self.__class__.__name__ + assert cls_name in self._enum_map, f"no enum map for {cls_name}" + return self._enum_map[cls_name](val) + + @property + def op_name(self) -> str: + op = self.op + return op.name if hasattr(op, 'name') else '' + + def dst_regs(self) -> int: return spec_regs(self.op_name)[0] + def src_regs(self, n: int) -> int: return spec_regs(self.op_name)[n + 1] + def num_srcs(self) -> int: return spec_num_srcs(self.op_name) + def dst_dtype(self) -> str | None: return spec_dtype(self.op_name)[0] + def src_dtype(self, n: int) -> str | None: return spec_dtype(self.op_name)[n + 1] + def is_src_16(self, n: int) -> bool: return self.src_regs(n) == 1 and is_dtype_16(self.src_dtype(n)) + def is_src_64(self, n: int) -> bool: return self.src_regs(n) == 2 + def is_16bit(self) -> bool: return spec_is_16bit(self.op_name) + def is_64bit(self) -> bool: return spec_is_64bit(self.op_name) + def is_dst_16(self) -> bool: return self.dst_regs() == 1 and is_dtype_16(self.dst_dtype()) + class Inst32(Inst): pass class Inst64(Inst): pass diff --git a/extra/assembly/amd/emu.py b/extra/assembly/amd/emu.py index b32940dbe4..b24e904264 100644 --- a/extra/assembly/amd/emu.py +++ b/extra/assembly/amd/emu.py @@ -12,34 +12,6 @@ Program = dict[int, Inst] WAVE_SIZE, SGPR_COUNT, VGPR_COUNT = 32, 128, 256 VCC_LO, VCC_HI, NULL, EXEC_LO, EXEC_HI, SCC = SrcEnum.VCC_LO, SrcEnum.VCC_HI, SrcEnum.NULL, SrcEnum.EXEC_LO, SrcEnum.EXEC_HI, SrcEnum.SCC -# Op classification helpers - build sets from op name patterns -def _ops_matching(enum, *patterns, exclude=()): return {op for op in enum if any(p in op.name for p in patterns) and not any(e in op.name for e in exclude)} -def _ops_ending(enum, *suffixes): return {op for op in enum if op.name.endswith(suffixes)} - -# 64-bit ops (for literal handling) -_VOP3_64BIT_OPS = {op.value for op in _ops_ending(VOP3Op, '_F64', '_B64', '_I64', '_U64')} -_VOPC_64BIT_OPS = {op.value for op in _ops_ending(VOPCOp, '_F64', '_B64', '_I64', '_U64')} -_VOP3_64BIT_OPS_32BIT_SRC1 = {VOP3Op.V_LDEXP_F64.value} # src1 is 32-bit exponent - -# 16-bit ops (SAD/MSAD excluded - they use 32-bit packed sources) -_VOP3_16BIT_OPS = _ops_matching(VOP3Op, '_F16', '_B16', '_I16', '_U16', exclude=('SAD',)) -_VOP1_16BIT_OPS = _ops_matching(VOP1Op, '_F16', '_B16', '_I16', '_U16') -_VOP2_16BIT_OPS = _ops_matching(VOP2Op, '_F16', '_B16', '_I16', '_U16') -_VOPC_16BIT_OPS = _ops_matching(VOPCOp, '_F16', '_B16', '_I16', '_U16') - -# CVT ops with 32/64-bit source (despite 16-bit in name) - must end with the type suffix -_CVT_32_64_SRC_OPS = {op for op in _ops_ending(VOP3Op, '_F32', '_I32', '_U32', '_F64', '_I64', '_U64') if op.name.startswith('V_CVT_')} | \ - {op for op in _ops_ending(VOP1Op, '_F32', '_I32', '_U32', '_F64', '_I64', '_U64') if op.name.startswith('V_CVT_')} -# CVT ops with 32-bit destination (FROM 16-bit TO 32-bit) - match patterns like F32_F16 in name -_CVT_32_DST_OPS = _ops_matching(VOP3Op, 'F32_F16', 'I32_I16', 'U32_U16', 'I32_F16', 'U32_F16') | \ - _ops_matching(VOP1Op, 'F32_F16', 'I32_I16', 'U32_U16', 'I32_F16', 'U32_F16') - -# 16-bit dst ops (PACK has 32-bit dst, CVT to 32-bit has 32-bit dst) -_VOP3_16BIT_DST_OPS = {op for op in _VOP3_16BIT_OPS if 'PACK' not in op.name} - _CVT_32_DST_OPS -_VOP1_16BIT_DST_OPS = {op for op in _VOP1_16BIT_OPS if 'PACK' not in op.name} - _CVT_32_DST_OPS -_VOP1_16BIT_SRC_OPS = _VOP1_16BIT_OPS - _CVT_32_64_SRC_OPS - -# Inline constants for src operands 128-254. Build tables for f32, f16, and f64 formats. # Inline constants for src operands 128-254. Build tables for f32, f16, and f64 formats. _FLOAT_CONSTS = {v: k for k, v in FLOAT_ENC.items()} | {248: 0.15915494309189535} # INV_2PI def _build_inline_consts(mask, to_bits): @@ -134,7 +106,7 @@ class WaveState: def rsrc_f16(self, v: int, lane: int) -> int: return self._rsrc_base(v, lane, _INLINE_CONSTS_F16) def rsrc64(self, v: int, lane: int) -> int: if 128 <= v < 255: return _INLINE_CONSTS_F64[v - 128] - if v == 255: return self.literal + if v == 255: return self.literal # literal is already shifted in from_bytes for 64-bit ops return self.rsrc(v, lane) | ((self.rsrc(v+1, lane) if v < VCC_LO or 256 <= v <= 511 else 0) << 32) def pend_sgpr_lane(self, reg: int, lane: int, val: int): @@ -155,20 +127,8 @@ def decode_program(data: bytes) -> Program: base_size = inst_class._size() # Pass enough data for potential 64-bit literal (base + 8 bytes max) inst = inst_class.from_bytes(data[i:i+base_size+8]) - for name, val in inst._values.items(): setattr(inst, name, unwrap(val)) - # from_bytes already handles literal reading - only need fallback for cases it doesn't handle - if inst._literal is None: - has_literal = any(getattr(inst, fld, None) == 255 for fld in ('src0', 'src1', 'src2', 'ssrc0', 'ssrc1', 'srcx0', 'srcy0')) or \ - (inst_class == VOP2 and inst.op in (44, 45, 55, 56)) or \ - (inst_class == VOPD and (inst.opx in (1, 2) or inst.opy in (1, 2))) or \ - (inst_class == SOP2 and inst.op in (69, 70)) - if has_literal: - # For 64-bit ops, the 32-bit literal is placed in HIGH 32 bits (low 32 bits = 0) - op_val = getattr(inst._values.get('op'), 'value', inst._values.get('op')) - is_64bit = ((inst_class is VOP3 and op_val in _VOP3_64BIT_OPS) or (inst_class is VOPC and op_val in _VOPC_64BIT_OPS)) and \ - not (op_val in _VOP3_64BIT_OPS_32BIT_SRC1 and getattr(inst, 'src1', None) == 255) - lit32 = int.from_bytes(data[i+base_size:i+base_size+4], 'little') - inst._literal = (lit32 << 32) if is_64bit else lit32 + for name, val in inst._values.items(): + if name != 'op': setattr(inst, name, unwrap(val)) # skip op to preserve property access inst._words = inst.size() // 4 result[i // 4] = inst i += inst._words * 4 @@ -181,16 +141,14 @@ def decode_program(data: bytes) -> Program: def exec_scalar(st: WaveState, inst: Inst) -> int: """Execute scalar instruction. Returns PC delta or negative for special cases.""" compiled = _get_compiled() - inst_type = type(inst) # SOPP: special cases for control flow that has no pseudocode - if inst_type is SOPP: - op = inst.op - if op == SOPPOp.S_ENDPGM: return -1 - if op == SOPPOp.S_BARRIER: return -2 + if isinstance(inst, SOPP): + if inst.op == SOPPOp.S_ENDPGM: return -1 + if inst.op == SOPPOp.S_BARRIER: return -2 # SMEM: memory loads (not ALU) - if inst_type is SMEM: + if isinstance(inst, SMEM): addr = st.rsgpr64(inst.sbase * 2) + _sext(inst.offset, 21) if inst.soffset not in (NULL, 0x7f): addr += st.rsrc(inst.soffset, 0) if (cnt := SMEM_LOAD.get(inst.op)) is None: raise NotImplementedError(f"SMEM op {inst.op}") @@ -198,34 +156,30 @@ def exec_scalar(st: WaveState, inst: Inst) -> int: return 0 # Get op enum and lookup compiled function - if inst_type is SOP1: op_cls, ssrc0, sdst = SOP1Op, inst.ssrc0, inst.sdst - elif inst_type is SOP2: op_cls, ssrc0, sdst = SOP2Op, inst.ssrc0, inst.sdst - elif inst_type is SOPC: op_cls, ssrc0, sdst = SOPCOp, inst.ssrc0, None - elif inst_type is SOPK: op_cls, ssrc0, sdst = SOPKOp, inst.sdst, inst.sdst # sdst is both src and dst - elif inst_type is SOPP: op_cls, ssrc0, sdst = SOPPOp, None, None - else: raise NotImplementedError(f"Unknown scalar type {inst_type}") + if isinstance(inst, SOP1): ssrc0, sdst = inst.ssrc0, inst.sdst + elif isinstance(inst, SOP2): ssrc0, sdst = inst.ssrc0, inst.sdst + elif isinstance(inst, SOPC): ssrc0, sdst = inst.ssrc0, None + elif isinstance(inst, SOPK): ssrc0, sdst = inst.sdst, inst.sdst # sdst is both src and dst + elif isinstance(inst, SOPP): ssrc0, sdst = None, None + else: raise NotImplementedError(f"Unknown scalar type {type(inst)}") # SOPP has gaps in the opcode enum - treat unknown opcodes as no-ops - try: op = op_cls(inst.op) + try: op = inst.op except ValueError: - if inst_type is SOPP: return 0 + if isinstance(inst, SOPP): return 0 raise - fn = compiled.get(op_cls, {}).get(op) + fn = compiled.get(type(op), {}).get(op) if fn is None: # SOPP instructions without pseudocode (waits, hints, nops) are no-ops - if inst_type is SOPP: return 0 + if isinstance(inst, SOPP): return 0 raise NotImplementedError(f"{op.name} not in pseudocode") - # Build context - handle 64-bit ops that need 64-bit source reads - # 64-bit source ops: name ends with _B64, _I64, _U64 or contains _U64, _I64 before last underscore - is_64bit_s0 = op.name.endswith(('_B64', '_I64', '_U64')) or '_U64_' in op.name or '_I64_' in op.name - is_64bit_s0s1 = op_cls is SOPCOp and op in (SOPCOp.S_CMP_EQ_U64, SOPCOp.S_CMP_LG_U64) - s0 = st.rsrc64(ssrc0, 0) if is_64bit_s0 or is_64bit_s0s1 else (st.rsrc(ssrc0, 0) if inst_type not in (SOPK, SOPP) else (st.rsgpr(inst.sdst) if inst_type is SOPK else 0)) - is_64bit_sop2 = is_64bit_s0 and inst_type is SOP2 - s1 = st.rsrc64(inst.ssrc1, 0) if (is_64bit_sop2 or is_64bit_s0s1) else (st.rsrc(inst.ssrc1, 0) if inst_type in (SOP2, SOPC) else inst.simm16 if inst_type is SOPK else 0) - d0 = st.rsgpr64(sdst) if (is_64bit_s0 or is_64bit_s0s1) and sdst is not None else (st.rsgpr(sdst) if sdst is not None else 0) + # Build context - use inst methods to determine operand sizes + s0 = st.rsrc64(ssrc0, 0) if inst.is_src_64(0) else (st.rsrc(ssrc0, 0) if not isinstance(inst, (SOPK, SOPP)) else (st.rsgpr(inst.sdst) if isinstance(inst, SOPK) else 0)) + s1 = st.rsrc64(inst.ssrc1, 0) if inst.is_src_64(1) else (st.rsrc(inst.ssrc1, 0) if isinstance(inst, (SOP2, SOPC)) else inst.simm16 if isinstance(inst, SOPK) else 0) + d0 = st.rsgpr64(sdst) if inst.dst_regs() == 2 and sdst is not None else (st.rsgpr(sdst) if sdst is not None else 0) exec_mask = st.exec_mask - literal = inst.simm16 if inst_type in (SOPK, SOPP) else st.literal + literal = inst.simm16 if isinstance(inst, (SOPK, SOPP)) else st.literal # Execute compiled function - pass PC in bytes for instructions that need it # For wave32, mask VCC and EXEC to 32 bits since only the lower 32 bits are relevant @@ -248,10 +202,10 @@ def exec_scalar(st: WaveState, inst: Inst) -> int: def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = None) -> None: """Execute vector instruction for one lane.""" compiled = _get_compiled() - inst_type, V = type(inst), st.vgpr[lane] + V = st.vgpr[lane] # Memory ops (not ALU pseudocode) - if inst_type is FLAT: + if isinstance(inst, FLAT): op, addr_reg, data_reg, vdst, offset, saddr = inst.op, inst.addr, inst.data, inst.vdst, _sext(inst.offset, 13), inst.saddr addr = V[addr_reg] | (V[addr_reg+1] << 32) addr = (st.rsgpr64(saddr) + V[addr_reg] + offset) & MASK64 if saddr not in (NULL, 0x7f) else (addr + offset) & MASK64 @@ -272,7 +226,7 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No else: raise NotImplementedError(f"FLAT op {op}") return - if inst_type is DS: + if isinstance(inst, DS): op, addr0, vdst = inst.op, (V[inst.addr] + inst.offset0) & 0xffff, inst.vdst if op in DS_LOAD: cnt, sz, sign = DS_LOAD[op] @@ -302,7 +256,7 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No return # VOPD: dual-issue, execute two ops simultaneously (read all inputs before writes) - if inst_type is VOPD: + if isinstance(inst, VOPD): vdsty = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1) inputs = [(inst.opx, st.rsrc(inst.srcx0, lane), V[inst.vsrcx1], V[inst.vdstx], inst.vdstx), (inst.opy, st.rsrc(inst.srcy0, lane), V[inst.vsrcy1], V[vdsty], vdsty)] @@ -312,18 +266,15 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No return # VOP3SD: has extra scalar dest for carry output - if inst_type is VOP3SD: - op = VOP3SDOp(inst.op) - fn = compiled.get(VOP3SDOp, {}).get(op) - if fn is None: raise NotImplementedError(f"{op.name} not in pseudocode") - # Source sizes vary: DIV_SCALE=all 64-bit, MAD64=32/32/64, others=32-bit - r64 = op == VOP3SDOp.V_DIV_SCALE_F64 - s0, s1 = (st.rsrc64 if r64 else st.rsrc)(inst.src0, lane), (st.rsrc64 if r64 else st.rsrc)(inst.src1, lane) - mad64 = op in (VOP3SDOp.V_MAD_U64_U32, VOP3SDOp.V_MAD_I64_I32) - s2 = st.rsrc64(inst.src2, lane) if r64 else ((V[inst.src2-256]|(V[inst.src2-255]<<32)) if inst.src2>=256 else st.rsgpr64(inst.src2)) if mad64 else st.rsrc(inst.src2, lane) + if isinstance(inst, VOP3SD): + fn = compiled.get(VOP3SDOp, {}).get(inst.op) + if fn is None: raise NotImplementedError(f"{inst.op.name} not in pseudocode") + # Read sources based on register counts from inst properties + def rsrc_n(src, regs): return st.rsrc64(src, lane) if regs == 2 else st.rsrc(src, lane) + s0, s1, s2 = rsrc_n(inst.src0, inst.src_regs(0)), rsrc_n(inst.src1, inst.src_regs(1)), rsrc_n(inst.src2, inst.src_regs(2)) # Carry-in ops use src2 as carry bitmask instead of VCC - carry_ops = (VOP3SDOp.V_ADD_CO_CI_U32, VOP3SDOp.V_SUB_CO_CI_U32, VOP3SDOp.V_SUBREV_CO_CI_U32) - result = fn(s0, s1, s2, V[inst.vdst], st.scc, st.rsgpr64(inst.src2) if op in carry_ops else st.vcc, lane, st.exec_mask, st.literal, None, {}) + vcc = st.rsgpr64(inst.src2) if 'CO_CI' in inst.op_name else st.vcc + result = fn(s0, s1, s2, V[inst.vdst], st.scc, vcc, lane, st.exec_mask, st.literal, None, {}) V[inst.vdst] = result['d0'] & MASK32 if result.get('d0_64'): V[inst.vdst + 1] = (result['d0'] >> 32) & MASK32 if result.get('vcc_lane') is not None: st.pend_sgpr_lane(inst.sdst, lane, result['vcc_lane']) @@ -332,35 +283,31 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No # Get op enum and sources (None means "no source" for that operand) # dst_hi: for VOP1/VOP2 16-bit dst ops, bit 7 of vdst indicates .h (high 16-bit) destination dst_hi = False - if inst_type is VOP1: + if isinstance(inst, VOP1): if inst.op == VOP1Op.V_NOP: return - op_cls, op, src0, src1, src2 = VOP1Op, VOP1Op(inst.op), inst.src0, None, None - dst_hi, vdst = (inst.vdst & 0x80) != 0 and op in _VOP1_16BIT_DST_OPS, inst.vdst & 0x7f if op in _VOP1_16BIT_DST_OPS else inst.vdst - elif inst_type is VOP2: - op_cls, op, src0, src1, src2 = VOP2Op, VOP2Op(inst.op), inst.src0, inst.vsrc1 + 256, None - dst_hi, vdst = (inst.vdst & 0x80) != 0 and op in _VOP2_16BIT_OPS, inst.vdst & 0x7f if op in _VOP2_16BIT_OPS else inst.vdst - elif inst_type is VOP3: - # VOP3 ops 0-255 are VOPC comparisons encoded as VOP3 (use VOPCOp pseudocode) - if inst.op < 256: - op_cls, op, src0, src1, src2, vdst = VOPCOp, VOPCOp(inst.op), inst.src0, inst.src1, None, inst.vdst - else: - op_cls, op, src0, src1, src2, vdst = VOP3Op, VOP3Op(inst.op), inst.src0, inst.src1, inst.src2, inst.vdst - elif inst_type is VOPC: - op = VOPCOp(inst.op) + src0, src1, src2 = inst.src0, None, None + dst_hi = (inst.vdst & 0x80) != 0 and inst.is_dst_16() + vdst = inst.vdst & 0x7f if inst.is_dst_16() else inst.vdst + elif isinstance(inst, VOP2): + src0, src1, src2 = inst.src0, inst.vsrc1 + 256, None + dst_hi = (inst.vdst & 0x80) != 0 and inst.is_dst_16() + vdst = inst.vdst & 0x7f if inst.is_dst_16() else inst.vdst + elif isinstance(inst, VOP3): + # VOP3 ops 0-255 are VOPC comparisons encoded as VOP3 - inst.op returns VOPCOp for these + src0, src1, src2, vdst = inst.src0, inst.src1, (None if inst.op.value < 256 else inst.src2), inst.vdst + elif isinstance(inst, VOPC): # For 16-bit VOPC, vsrc1 uses same encoding as VOP2 16-bit: bit 7 selects hi(1) or lo(0) half # vsrc1 field is 8 bits: [6:0] = VGPR index, [7] = hi flag - src1 = inst.vsrc1 + 256 # convert to standard VGPR encoding (256 + vgpr_idx) - op_cls, src0, src2, vdst = VOPCOp, inst.src0, None, VCC_LO - elif inst_type is VOP3P: + src0, src1, src2, vdst = inst.src0, inst.vsrc1 + 256, None, VCC_LO + elif isinstance(inst, VOP3P): # VOP3P: Packed 16-bit operations using compiled functions - op = VOP3POp(inst.op) # WMMA: wave-level matrix multiply-accumulate (special handling - needs cross-lane access) - if op in (VOP3POp.V_WMMA_F32_16X16X16_F16, VOP3POp.V_WMMA_F32_16X16X16_BF16, VOP3POp.V_WMMA_F16_16X16X16_F16): + if 'WMMA' in inst.op_name: if lane == 0: # Only execute once per wave, write results for all lanes - exec_wmma(st, inst, op) + exec_wmma(st, inst, inst.op) return # V_FMA_MIX: Mixed precision FMA - opsel_hi controls f32(0) vs f16(1), opsel selects which f16 half - if op in (VOP3POp.V_FMA_MIX_F32, VOP3POp.V_FMA_MIXLO_F16, VOP3POp.V_FMA_MIXHI_F16): + if 'FMA_MIX' in inst.op_name: opsel, opsel_hi, opsel_hi2 = getattr(inst, 'opsel', 0), getattr(inst, 'opsel_hi', 0), getattr(inst, 'opsel_hi2', 0) neg, abs_ = getattr(inst, 'neg', 0), getattr(inst, 'neg_hi', 0) # neg_hi reused as abs raws = [st.rsrc(inst.src0, lane), st.rsrc(inst.src1, lane), st.rsrc(inst.src2, lane) if inst.src2 is not None else 0] @@ -371,7 +318,7 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No if neg & (1< int: to_f, to_i = (_f64, _i64) if is64 else (_f32, _i32) if (abs_ >> idx) & 1: val = to_i(abs(to_f(val))) if (neg >> idx) & 1: val = to_i(-to_f(val)) return val - # Determine if sources are 64-bit based on instruction type - # For 64-bit shift ops: src0 is 32-bit (shift amount), src1 is 64-bit (value to shift) - # For most other _B64/_I64/_U64/_F64 ops: all sources are 64-bit - is_64bit_op = op.name.endswith(('_B64', '_I64', '_U64', '_F64')) - # V_LDEXP_F64, V_TRIG_PREOP_F64, V_CMP_CLASS_F64, V_CMPX_CLASS_F64: src0 is 64-bit, src1 is 32-bit - is_ldexp_64 = op in (VOP3Op.V_LDEXP_F64, VOP3Op.V_TRIG_PREOP_F64, VOP3Op.V_CMP_CLASS_F64, VOP3Op.V_CMPX_CLASS_F64, - VOPCOp.V_CMP_CLASS_F64, VOPCOp.V_CMPX_CLASS_F64) - is_shift_64 = op in (VOP3Op.V_LSHLREV_B64, VOP3Op.V_LSHRREV_B64, VOP3Op.V_ASHRREV_I64) - # 16-bit source ops: use precomputed sets instead of string checks - # Note: must check op_cls to avoid cross-enum value collisions - # VOP3-encoded VOPC 16-bit ops also use opsel (not VGPR bit 7 like non-VOP3 VOPC) - is_16bit_src = (op_cls is VOP3Op and op in _VOP3_16BIT_OPS and op not in _CVT_32_64_SRC_OPS) or \ - (inst_type is VOP3 and op_cls is VOPCOp and op in _VOPC_16BIT_OPS) - # VOP2 16-bit ops use f16 inline constants for src0 (vsrc1 is always a VGPR, no inline constants) - is_vop2_16bit = op_cls is VOP2Op and op in _VOP2_16BIT_OPS + # Use inst methods to determine operand sizes (inst.is_src_16, inst.is_src_64, etc.) + is_vop2_16bit = isinstance(inst, VOP2) and inst.is_16bit() - if is_shift_64: - s0, s1 = mod_src(st.rsrc(src0, lane), 0), st.rsrc64(src1, lane) if src1 else 0 - s2 = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0 - elif is_ldexp_64: - s0 = mod_src(st.rsrc64(src0, lane), 0, is64=True) - s1_raw = st.rsrc(src1, lane) if src1 is not None else 0 - is_class_op = op in (VOP3Op.V_CMP_CLASS_F64, VOP3Op.V_CMPX_CLASS_F64, VOPCOp.V_CMP_CLASS_F64, VOPCOp.V_CMPX_CLASS_F64) - s1, s2 = mod_src((s1_raw >> 32) if src1 == 255 and is_class_op else s1_raw, 1), mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0 - elif is_64bit_op: - s0, s1 = mod_src(st.rsrc64(src0, lane), 0, is64=True), mod_src(st.rsrc64(src1, lane), 1, is64=True) if src1 is not None else 0 - s2 = mod_src(st.rsrc64(src2, lane), 2, is64=True) if src2 is not None else 0 - elif is_16bit_src: - # VOP3 16-bit ops: opsel bits select which half, abs/neg as f16 bit ops - def rsrc_16bit(src, idx): - if src is None: return 0 + # Read sources based on register counts and dtypes from inst properties + def read_src(src, idx, regs, is_src_16): + if src is None: return 0 + if regs == 2: return mod_src(st.rsrc64(src, lane), idx, is64=True) + if is_src_16 and isinstance(inst, VOP3): raw = st.rsrc_f16(src, lane) if 128 <= src < 255 else st.rsrc(src, lane) val = _src16(raw, bool(opsel & (1 << idx))) if abs_ & (1 << idx): val &= 0x7fff if neg & (1 << idx): val ^= 0x8000 return val - s0, s1, s2 = rsrc_16bit(src0, 0), rsrc_16bit(src1, 1), rsrc_16bit(src2, 2) - elif is_vop2_16bit or (op_cls is VOP1Op and op in _VOP1_16BIT_SRC_OPS) or (op_cls is VOPCOp and op in _VOPC_16BIT_OPS): - # VOP1/VOP2/VOPC 16-bit ops: VGPRs use bit 7 for hi/lo, non-VGPRs use f16 inline consts - # Special case: VOPC V_CMP_CLASS uses full 32-bit mask for src1 when non-VGPR - def rsrc16_vgpr(src, idx, full32=False): - if src is None: return 0 + if is_src_16 and isinstance(inst, (VOP1, VOP2, VOPC)): if src >= 256: return _src16(mod_src(st.rsrc(_vgpr_masked(src), lane), idx), _vgpr_hi(src)) - return mod_src(st.rsrc(src, lane) if full32 else st.rsrc_f16(src, lane), idx) & (0xffffffff if full32 else 0xffff) - s0, s1 = rsrc16_vgpr(src0, 0), rsrc16_vgpr(src1, 1, full32=op_cls is VOPCOp) - s2 = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0 - else: - s0 = mod_src(st.rsrc(src0, lane), 0) - s1 = mod_src(st.rsrc(src1, lane), 1) if src1 is not None else 0 - s2 = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0 + return mod_src(st.rsrc_f16(src, lane), idx) & 0xffff + return mod_src(st.rsrc(src, lane), idx) + + s0 = read_src(src0, 0, inst.src_regs(0), inst.is_src_16(0)) + s1 = read_src(src1, 1, inst.src_regs(1), inst.is_src_16(1)) if src1 is not None else 0 + s2 = read_src(src2, 2, inst.src_regs(2), inst.is_src_16(2)) if src2 is not None else 0 # Read destination (accumulator for VOP2 f16, 64-bit for 64-bit ops) - d0 = _src16(V[vdst], dst_hi) if is_vop2_16bit else (V[vdst] | (V[vdst + 1] << 32)) if is_64bit_op else V[vdst] + d0 = _src16(V[vdst], dst_hi) if is_vop2_16bit else (V[vdst] | (V[vdst + 1] << 32)) if inst.dst_regs() == 2 else V[vdst] # V_CNDMASK_B32/B16: VOP3 encoding uses src2 as mask (not VCC); VOP2 uses VCC implicitly # Pass the correct mask as vcc to the function so pseudocode VCC.u64[laneId] works correctly - vcc_for_fn = st.rsgpr64(src2) if op in (VOP3Op.V_CNDMASK_B32, VOP3Op.V_CNDMASK_B16) and inst_type is VOP3 and src2 is not None and src2 < 256 else st.vcc + vcc_for_fn = st.rsgpr64(src2) if inst.op in (VOP3Op.V_CNDMASK_B32, VOP3Op.V_CNDMASK_B16) and isinstance(inst, VOP3) and src2 is not None and src2 < 256 else st.vcc # Execute compiled function - pass src0_idx and vdst_idx for lane instructions # For VGPR access: src0 index is the VGPR number (src0 - 256 if VGPR, else src0 for SGPR) @@ -467,17 +386,16 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No st.vgpr[wr_lane][wr_idx] = wr_val if 'vcc_lane' in result: # VOP2 carry ops write to VCC implicitly; VOPC/VOP3 write to vdst - st.pend_sgpr_lane(VCC_LO if op_cls is VOP2Op and 'CO_CI' in op.name else vdst, lane, result['vcc_lane']) + st.pend_sgpr_lane(VCC_LO if isinstance(inst, VOP2) and 'CO_CI' in inst.op_name else vdst, lane, result['vcc_lane']) if 'exec_lane' in result: # V_CMPX instructions write to EXEC per-lane st.pend_sgpr_lane(EXEC_LO, lane, result['exec_lane']) - if 'd0' in result and op_cls not in (VOPCOp,) and 'vgpr_write' not in result: - writes_to_sgpr = op in (VOP1Op.V_READFIRSTLANE_B32,) or (op_cls is VOP3Op and op in (VOP3Op.V_READFIRSTLANE_B32, VOP3Op.V_READLANE_B32)) - is_16bit_dst = (op_cls is VOP3Op and op in _VOP3_16BIT_DST_OPS) or (op_cls is VOP1Op and op in _VOP1_16BIT_DST_OPS) or is_vop2_16bit + if 'd0' in result and op_cls is not VOPCOp and 'vgpr_write' not in result: + writes_to_sgpr = 'READFIRSTLANE' in inst.op_name or 'READLANE' in inst.op_name d0_val = result['d0'] if writes_to_sgpr: st.wsgpr(vdst, d0_val & MASK32) elif result.get('d0_64'): V[vdst], V[vdst + 1] = d0_val & MASK32, (d0_val >> 32) & MASK32 - elif is_16bit_dst: V[vdst] = _dst16(V[vdst], d0_val, bool(opsel & 8) if inst_type is VOP3 else dst_hi) + elif inst.is_dst_16(): V[vdst] = _dst16(V[vdst], d0_val, bool(opsel & 8) if isinstance(inst, VOP3) else dst_hi) else: V[vdst] = d0_val & MASK32 # ═══════════════════════════════════════════════════════════════════════════════ @@ -506,22 +424,19 @@ def exec_wmma(st: WaveState, inst, op: VOP3POp) -> None: # MAIN EXECUTION LOOP # ═══════════════════════════════════════════════════════════════════════════════ -SCALAR_TYPES = {SOP1, SOP2, SOPC, SOPK, SOPP, SMEM} -VECTOR_TYPES = {VOP1, VOP2, VOP3, VOP3SD, VOPC, FLAT, DS, VOPD, VOP3P} - def step_wave(program: Program, st: WaveState, lds: bytearray, n_lanes: int) -> int: inst = program.get(st.pc) if inst is None: return 1 - inst_words, st.literal, inst_type = inst._words, getattr(inst, '_literal', None) or 0, type(inst) + inst_words, st.literal = inst._words, getattr(inst, '_literal', None) or 0 - if inst_type in SCALAR_TYPES: + if isinstance(inst, (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM)): delta = exec_scalar(st, inst) if delta == -1: return -1 # endpgm if delta == -2: st.pc += inst_words; return -2 # barrier st.pc += inst_words + delta else: # V_READFIRSTLANE/V_READLANE write to SGPR, execute once; others execute per-lane with exec_mask - is_readlane = inst_type in (VOP1, VOP3) and hasattr(inst.op, 'name') and 'READLANE' in inst.op.name + is_readlane = isinstance(inst, (VOP1, VOP3)) and ('READFIRSTLANE' in inst.op_name or 'READLANE' in inst.op_name) exec_mask = 1 if is_readlane else st.exec_mask for lane in range(1 if is_readlane else n_lanes): if exec_mask & (1 << lane): exec_vector(st, inst, lane, lds) diff --git a/extra/assembly/amd/test/test_emu.py b/extra/assembly/amd/test/test_emu.py index 9605d15b79..e099aca931 100644 --- a/extra/assembly/amd/test/test_emu.py +++ b/extra/assembly/amd/test/test_emu.py @@ -3982,6 +3982,179 @@ class TestVOP3VOPC16Bit(unittest.TestCase): self.assertEqual(st.sgpr[0] & 1, 1, "hi>hi should be true: 0x9999>0x1234") +class Test64BitLiteralSources(unittest.TestCase): + """Regression tests for 64-bit instruction literal source handling. + + For f64 operations, a 32-bit literal in the instruction stream represents the + HIGH 32 bits of the 64-bit value (low 32 bits are implicitly 0). + + Bug: rsrc64() was returning the 32-bit literal as-is instead of shifting it + left by 32 bits. This caused V_FMA_F64 and V_LDEXP_F64 to use wrong values + when their source is a literal, breaking the f64->i64 conversion sequence. + + The f64->i64 conversion sequence is: + v_trunc_f64 -> v_ldexp_f64 (by -32) -> v_floor_f64 -> v_fma_f64 (by -2^32) + -> v_cvt_u32_f64 (low bits) -> v_cvt_i32_f64 (high bits) + + The V_FMA_F64 uses literal 0xC1F00000 which is the high 32 bits of f64 -2^32. + """ + + def test_v_fma_f64_literal_neg_2pow32(self): + """V_FMA_F64 with literal encoding of -2^32. + + The f64 value -2^32 (-4294967296.0) has bits 0xC1F0000000000000. + The compiler encodes only the high 32 bits (0xC1F00000) as a literal. + The emulator must interpret this as 0xC1F00000_00000000. + """ + # v[0:1] = -41.0 (trunc), v[2:3] = -1.0 (floor of -41/2^32) + # FMA: result = (-2^32) * (-1.0) + (-41.0) = 4294967296 - 41 = 4294967255.0 + val_41 = f2i64(-41.0) + val_m1 = f2i64(-1.0) + # Literal 0xC1F00000 is high 32 bits of f64 -2^32 + lit = 0xC1F00000 + instructions = [ + s_mov_b32(s[0], val_41 & 0xffffffff), + s_mov_b32(s[1], (val_41 >> 32) & 0xffffffff), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + s_mov_b32(s[2], val_m1 & 0xffffffff), + s_mov_b32(s[3], (val_m1 >> 32) & 0xffffffff), + v_mov_b32_e32(v[2], s[2]), + v_mov_b32_e32(v[3], s[3]), + # V_FMA_F64 v[4:5], literal, v[2:3], v[0:1] + # = (-2^32) * (-1.0) + (-41.0) = 4294967255.0 + VOP3(VOP3Op.V_FMA_F64, vdst=v[4], src0=RawImm(255), src1=v[2], src2=v[0], literal=lit), + ] + st = run_program(instructions, n_lanes=1) + result = i642f(st.vgpr[0][4] | (st.vgpr[0][5] << 32)) + expected = 4294967255.0 # 2^32 - 41 + self.assertAlmostEqual(result, expected, places=0, msg=f"Expected {expected}, got {result}") + + def test_v_ldexp_f64_literal_neg32(self): + """V_LDEXP_F64 with literal -32 for exponent. + + V_LDEXP_F64 computes src0 * 2^src1 where src1 is an integer exponent. + The literal 0xFFFFFFE0 represents -32 as a 32-bit signed integer. + For V_LDEXP_F64, src1 is 32-bit (not 64-bit), so this is correct as-is. + """ + val = f2i64(-41.0) + expected = -41.0 * (2.0 ** -32) # -9.5367431640625e-09 + instructions = [ + s_mov_b32(s[0], val & 0xffffffff), + s_mov_b32(s[1], (val >> 32) & 0xffffffff), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + # V_LDEXP_F64 v[2:3], v[0:1], -32 + v_ldexp_f64(v[2:4], v[0:2], 0xFFFFFFE0), + ] + st = run_program(instructions, n_lanes=1) + result = i642f(st.vgpr[0][2] | (st.vgpr[0][3] << 32)) + self.assertAlmostEqual(result, expected, places=15, msg=f"Expected {expected}, got {result}") + + def test_f64_to_i64_full_sequence(self): + """Full f64->i64 conversion sequence with negative value. + + This is the exact sequence generated by the compiler for (long)(-41.0): + v_trunc_f64 v[0:1], v[0:1] + v_ldexp_f64 v[2:3], v[0:1], -32 + v_floor_f64 v[2:3], v[2:3] + v_fma_f64 v[0:1], 0xc1f00000, v[2:3], v[0:1] # -2^32 + v_cvt_u32_f64 v0, v[0:1] + v_cvt_i32_f64 v1, v[2:3] + + Result: v1:v0 = 0xFFFFFFFF:0xFFFFFFD7 = -41 as i64 + """ + val = f2i64(-41.0) + lit = 0xC1F00000 # high 32 bits of f64 -2^32 + instructions = [ + s_mov_b32(s[0], val & 0xffffffff), + s_mov_b32(s[1], (val >> 32) & 0xffffffff), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_trunc_f64_e32(v[0:2], v[0:2]), + v_ldexp_f64(v[2:4], v[0:2], 0xFFFFFFE0), # -32 + v_floor_f64_e32(v[2:4], v[2:4]), + VOP3(VOP3Op.V_FMA_F64, vdst=v[0], src0=RawImm(255), src1=v[2], src2=v[0], literal=lit), + v_cvt_u32_f64_e32(v[4], v[0:2]), + v_cvt_i32_f64_e32(v[5], v[2:4]), + ] + st = run_program(instructions, n_lanes=1) + lo = st.vgpr[0][4] + hi = st.vgpr[0][5] + result = struct.unpack('i64 conversion with larger negative value (-1000000). + + Tests that the conversion sequence works for values that span both + high and low 32-bit parts of the result. + """ + val = f2i64(-1000000.0) + lit = 0xC1F00000 + instructions = [ + s_mov_b32(s[0], val & 0xffffffff), + s_mov_b32(s[1], (val >> 32) & 0xffffffff), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_trunc_f64_e32(v[0:2], v[0:2]), + v_ldexp_f64(v[2:4], v[0:2], 0xFFFFFFE0), + v_floor_f64_e32(v[2:4], v[2:4]), + VOP3(VOP3Op.V_FMA_F64, vdst=v[0], src0=RawImm(255), src1=v[2], src2=v[0], literal=lit), + v_cvt_u32_f64_e32(v[4], v[0:2]), + v_cvt_i32_f64_e32(v[5], v[2:4]), + ] + st = run_program(instructions, n_lanes=1) + lo = st.vgpr[0][4] + hi = st.vgpr[0][5] + result = struct.unpack('i64 conversion with positive value (1000000).""" + val = f2i64(1000000.0) + lit = 0xC1F00000 + instructions = [ + s_mov_b32(s[0], val & 0xffffffff), + s_mov_b32(s[1], (val >> 32) & 0xffffffff), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_trunc_f64_e32(v[0:2], v[0:2]), + v_ldexp_f64(v[2:4], v[0:2], 0xFFFFFFE0), + v_floor_f64_e32(v[2:4], v[2:4]), + VOP3(VOP3Op.V_FMA_F64, vdst=v[0], src0=RawImm(255), src1=v[2], src2=v[0], literal=lit), + v_cvt_u32_f64_e32(v[4], v[0:2]), + v_cvt_i32_f64_e32(v[5], v[2:4]), + ] + st = run_program(instructions, n_lanes=1) + lo = st.vgpr[0][4] + hi = st.vgpr[0][5] + result = struct.unpack('i64 conversion with value > 2^32 (requires 64-bit result).""" + val = f2i64(5000000000.0) # 5 billion, > 2^32 + lit = 0xC1F00000 + instructions = [ + s_mov_b32(s[0], val & 0xffffffff), + s_mov_b32(s[1], (val >> 32) & 0xffffffff), + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_trunc_f64_e32(v[0:2], v[0:2]), + v_ldexp_f64(v[2:4], v[0:2], 0xFFFFFFE0), + v_floor_f64_e32(v[2:4], v[2:4]), + VOP3(VOP3Op.V_FMA_F64, vdst=v[0], src0=RawImm(255), src1=v[2], src2=v[0], literal=lit), + v_cvt_u32_f64_e32(v[4], v[0:2]), + v_cvt_i32_f64_e32(v[5], v[2:4]), + ] + st = run_program(instructions, n_lanes=1) + lo = st.vgpr[0][4] + hi = st.vgpr[0][5] + result = struct.unpack('