diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b77fd7068e..657a85a077 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -677,6 +677,8 @@ jobs: run: cloc --by-file extra/assembly/amd/*.py - name: Run RDNA3 emulator tests run: python -m pytest -n=auto extra/assembly/amd/ --durations 20 + - name: Run RDNA3 emulator tests (AMD_LLVM=1) + run: AMD_LLVM=1 python -m pytest -n=auto extra/assembly/amd/ --durations 20 - name: Install pdfplumber run: pip install pdfplumber - name: Verify AMD autogen is up to date diff --git a/extra/assembly/amd/asm.py b/extra/assembly/amd/asm.py index a8a5f55f54..3496795dc9 100644 --- a/extra/assembly/amd/asm.py +++ b/extra/assembly/amd/asm.py @@ -219,9 +219,12 @@ def disasm(inst: Inst) -> str: src2_str = fmt_sd_src(src2, neg & 4, is_mad64) dst_str = _vreg(vdst, 2) if (is_f64 or is_mad64) else f"v{vdst}" sdst_str = _fmt_sdst(sdst, 1) - # v_add_co_u32, v_sub_co_u32, v_subrev_co_u32, v_add_co_ci_u32, etc. only use 2 sources - if op_name in ('v_add_co_u32', 'v_sub_co_u32', 'v_subrev_co_u32', 'v_add_co_ci_u32', 'v_sub_co_ci_u32', 'v_subrev_co_ci_u32'): + # v_add_co_u32, v_sub_co_u32, v_subrev_co_u32 only use 2 sources + if op_name in ('v_add_co_u32', 'v_sub_co_u32', 'v_subrev_co_u32'): return f"{op_name} {dst_str}, {sdst_str}, {src0_str}, {src1_str}" + # v_add_co_ci_u32, v_sub_co_ci_u32, v_subrev_co_ci_u32 use 3 sources (src2 is carry-in) + if op_name in ('v_add_co_ci_u32', 'v_sub_co_ci_u32', 'v_subrev_co_ci_u32'): + return f"{op_name} {dst_str}, {sdst_str}, {src0_str}, {src1_str}, {src2_str}" # v_div_scale uses 3 sources return f"{op_name} {dst_str}, {sdst_str}, {src0_str}, {src1_str}, {src2_str}" + omod_str @@ -351,12 +354,17 @@ def disasm(inst: Inst) -> str: from extra.assembly.amd.autogen import rdna3 as autogen opx, opy, vdstx, vdsty_enc = [unwrap(inst._values.get(f, 0)) for f in ('opx', 'opy', 'vdstx', 'vdsty')] srcx0, vsrcx1, srcy0, vsrcy1 = [unwrap(inst._values.get(f, 0)) for f in ('srcx0', 'vsrcx1', 'srcy0', 'vsrcy1')] + literal = inst._literal if hasattr(inst, '_literal') and inst._literal else unwrap(inst._values.get('literal', None)) vdsty = (vdsty_enc << 1) | ((vdstx & 1) ^ 1) # Decode vdsty - def fmt_vopd(op, vdst, src0, vsrc1): + def fmt_vopd(op, vdst, src0, vsrc1, include_lit): try: name = autogen.VOPDOp(op).name.lower() except (ValueError, KeyError): name = f"op_{op}" - return f"{name} v{vdst}, {fmt_src(src0)}" if 'mov' in name else f"{name} v{vdst}, {fmt_src(src0)}, v{vsrc1}" - return f"{fmt_vopd(opx, vdstx, srcx0, vsrcx1)} :: {fmt_vopd(opy, vdsty, srcy0, vsrcy1)}" + lit_str = f", 0x{literal:x}" if include_lit and literal is not None and ('fmaak' in name or 'fmamk' in name) else "" + return f"{name} v{vdst}, {fmt_src(src0)}{lit_str}" if 'mov' in name else f"{name} v{vdst}, {fmt_src(src0)}, v{vsrc1}{lit_str}" + # fmaak/fmamk: both X and Y can use the shared literal + x_needs_lit = 'fmaak' in autogen.VOPDOp(opx).name.lower() or 'fmamk' in autogen.VOPDOp(opx).name.lower() + y_needs_lit = 'fmaak' in autogen.VOPDOp(opy).name.lower() or 'fmamk' in autogen.VOPDOp(opy).name.lower() + return f"{fmt_vopd(opx, vdstx, srcx0, vsrcx1, x_needs_lit)} :: {fmt_vopd(opy, vdsty, srcy0, vsrcy1, y_needs_lit)}" # VOP3P: packed vector ops if cls_name == 'VOP3P': @@ -721,6 +729,9 @@ def get_dsl(text: str) -> str: if mnemonic.replace('_e32', '') in vcc_ops and len(dsl_args) >= 5: mnemonic = mnemonic.replace('_e32', '') + '_e32' # Ensure _e32 suffix for VOP2 encoding dsl_args = [dsl_args[0], dsl_args[2], dsl_args[3]] + # Handle v_add_co_ci_u32_e64 etc - strip _e64 suffix (function name doesn't have it, returns VOP3SD) + if mnemonic.replace('_e64', '') in vcc_ops and mnemonic.endswith('_e64'): + mnemonic = mnemonic.replace('_e64', '') # v_cmp_*_e32: strip implicit vcc_lo dest if mnemonic.startswith('v_cmp') and not mnemonic.endswith('_e64') and len(dsl_args) >= 3 and operands[0].strip().lower() in ('vcc_lo', 'vcc_hi', 'vcc'): dsl_args = dsl_args[1:] diff --git a/extra/assembly/amd/dsl.py b/extra/assembly/amd/dsl.py index ae62c2fcee..615597e81b 100644 --- a/extra/assembly/amd/dsl.py +++ b/extra/assembly/amd/dsl.py @@ -315,6 +315,9 @@ class Inst: op_val = inst._values.get('op', 0) has_literal = cls.__name__ == 'VOP2' and op_val in (44, 45, 55, 56) has_literal = has_literal or (cls.__name__ == 'SOP2' and op_val in (69, 70)) + # VOPD fmaak/fmamk always have a literal (opx/opy value 1 or 2) + opx, opy = inst._values.get('opx', 0), inst._values.get('opy', 0) + has_literal = has_literal or (cls.__name__ == 'VOPD' and (opx in (1, 2) or opy in (1, 2))) for n in SRC_FIELDS: if n in inst._values and isinstance(inst._values[n], RawImm) and inst._values[n].val == 255: has_literal = True if has_literal: diff --git a/extra/assembly/amd/emu.py b/extra/assembly/amd/emu.py index cce704c55f..dbbd33b820 100644 --- a/extra/assembly/amd/emu.py +++ b/extra/assembly/amd/emu.py @@ -24,12 +24,18 @@ _VOP3_64BIT_OPS_32BIT_SRC1 = {VOP3Op.V_LDEXP_F64.value} _VOP3_16BIT_OPS = {op for op in VOP3Op if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16')) and 'SAD' not in op.name} _VOP1_16BIT_OPS = {op for op in VOP1Op if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16'))} _VOP2_16BIT_OPS = {op for op in VOP2Op if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16'))} +_VOPC_16BIT_OPS = {op for op in VOPCOp if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16'))} # CVT ops with 32/64-bit source (despite 16-bit in name) _CVT_32_64_SRC_OPS = {op for op in VOP3Op if op.name.startswith('V_CVT_') and op.name.endswith(('_F32', '_I32', '_U32', '_F64', '_I64', '_U64'))} | \ {op for op in VOP1Op if op.name.startswith('V_CVT_') and op.name.endswith(('_F32', '_I32', '_U32', '_F64', '_I64', '_U64'))} -# 16-bit dst ops (PACK has 32-bit dst despite F16 in name) -_VOP3_16BIT_DST_OPS = {op for op in _VOP3_16BIT_OPS if 'PACK' not in op.name} -_VOP1_16BIT_DST_OPS = {op for op in _VOP1_16BIT_OPS if 'PACK' not in op.name} +# CVT ops with 32-bit destination (convert FROM 16-bit TO 32-bit): V_CVT_F32_F16, V_CVT_I32_I16, V_CVT_U32_U16 +_CVT_32_DST_OPS = {op for op in VOP3Op if op.name.startswith('V_CVT_') and any(s in op.name for s in ('F32_F16', 'I32_I16', 'U32_U16', 'I32_F16', 'U32_F16'))} | \ + {op for op in VOP1Op if op.name.startswith('V_CVT_') and any(s in op.name for s in ('F32_F16', 'I32_I16', 'U32_U16', 'I32_F16', 'U32_F16'))} +# 16-bit dst ops (PACK has 32-bit dst despite F16 in name, CVT to 32-bit has 32-bit dst) +_VOP3_16BIT_DST_OPS = {op for op in _VOP3_16BIT_OPS if 'PACK' not in op.name} - _CVT_32_DST_OPS +_VOP1_16BIT_DST_OPS = {op for op in _VOP1_16BIT_OPS if 'PACK' not in op.name} - _CVT_32_DST_OPS +# VOP1 16-bit source ops (excluding CVT ops with 32/64-bit source) - for VOP1 e32, .h encoded in register index +_VOP1_16BIT_SRC_OPS = _VOP1_16BIT_OPS - _CVT_32_64_SRC_OPS # Inline constants for src operands 128-254. Build tables for f32, f16, and f64 formats. import struct as _struct @@ -371,11 +377,25 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No # Get op enum and sources (None means "no source" for that operand) + # vop1_dst_hi/vop2_dst_hi: for VOP1/VOP2 16-bit dst ops, bit 7 of vdst indicates .h (high 16-bit) destination + vop1_dst_hi, vop2_dst_hi = False, False if inst_type is VOP1: if inst.op == VOP1Op.V_NOP: return - op_cls, op, src0, src1, src2, vdst = VOP1Op, VOP1Op(inst.op), inst.src0, None, None, inst.vdst + op_cls, op, src0, src1, src2 = VOP1Op, VOP1Op(inst.op), inst.src0, None, None + # For 16-bit dst ops, vdst encodes .h in bit 7 + if op in _VOP1_16BIT_DST_OPS: + vop1_dst_hi = (inst.vdst & 0x80) != 0 + vdst = inst.vdst & 0x7f + else: + vdst = inst.vdst elif inst_type is VOP2: - op_cls, op, src0, src1, src2, vdst = VOP2Op, VOP2Op(inst.op), inst.src0, inst.vsrc1 + 256, None, inst.vdst + op_cls, op, src0, src1, src2 = VOP2Op, VOP2Op(inst.op), inst.src0, inst.vsrc1 + 256, None + # For 16-bit dst ops, vdst encodes .h in bit 7 + if op in _VOP2_16BIT_OPS: + vop2_dst_hi = (inst.vdst & 0x80) != 0 + vdst = inst.vdst & 0x7f + else: + vdst = inst.vdst elif inst_type is VOP3: # VOP3 ops 0-255 are VOPC comparisons encoded as VOP3 (use VOPCOp pseudocode) if inst.op < 256: @@ -397,7 +417,11 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No V[vdst] = result & 0xffffffff return elif inst_type is VOPC: - op_cls, op, src0, src1, src2, vdst = VOPCOp, VOPCOp(inst.op), inst.src0, inst.vsrc1 + 256, None, VCC_LO + op = VOPCOp(inst.op) + # For 16-bit VOPC, vsrc1 uses same encoding as VOP2 16-bit: bit 7 selects hi(1) or lo(0) half + # vsrc1 field is 8 bits: [6:0] = VGPR index, [7] = hi flag + src1 = inst.vsrc1 + 256 # convert to standard VGPR encoding (256 + vgpr_idx) + op_cls, src0, src2, vdst = VOPCOp, inst.src0, None, VCC_LO elif inst_type is VOP3P: # VOP3P: Packed 16-bit operations using compiled functions op = VOP3POp(inst.op) @@ -406,26 +430,44 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No if lane == 0: # Only execute once per wave, write results for all lanes exec_wmma(st, inst, op) return - # V_FMA_MIX: Mixed precision FMA - inputs can be f16 or f32 controlled by opsel + # V_FMA_MIX: Mixed precision FMA - inputs can be f16 or f32 controlled by opsel_hi/opsel_hi2 + # opsel_hi[0]: src0 is f32 (0) or f16 from hi bits (1) + # opsel_hi[1]: src1 is f32 (0) or f16 from hi bits (1) + # opsel_hi2: src2 is f32 (0) or f16 from hi bits (1) + # opsel[i]: when source is f16, use lo (0) or hi (1) 16 bits - BUT for V_FMA_MIX, opsel selects lo/hi when opsel_hi=1 + # neg_hi[i]: abs modifier for source i (reuses neg_hi field for abs in V_FMA_MIX) if op in (VOP3POp.V_FMA_MIX_F32, VOP3POp.V_FMA_MIXLO_F16, VOP3POp.V_FMA_MIXHI_F16): opsel = getattr(inst, 'opsel', 0) opsel_hi = getattr(inst, 'opsel_hi', 0) + opsel_hi2 = getattr(inst, 'opsel_hi2', 0) neg = getattr(inst, 'neg', 0) - neg_hi = getattr(inst, 'neg_hi', 0) + abs_ = getattr(inst, 'neg_hi', 0) # neg_hi field is reused as abs for V_FMA_MIX vdst = inst.vdst - # Read raw 32-bit values - for V_FMA_MIX, sources can be either f32 or f16 + # Read raw 32-bit values s0_raw = st.rsrc(inst.src0, lane) s1_raw = st.rsrc(inst.src1, lane) s2_raw = st.rsrc(inst.src2, lane) if inst.src2 is not None else 0 - # opsel[i]=0: use as f32, opsel[i]=1: use hi f16 as f32 - # For src0: opsel[0], for src1: opsel[1], for src2: opsel[2] - if opsel & 1: s0 = _f16((s0_raw >> 16) & 0xffff) # hi f16 -> f32 - else: s0 = _f32(s0_raw) # use as f32 - if opsel & 2: s1 = _f16((s1_raw >> 16) & 0xffff) - else: s1 = _f32(s1_raw) - if opsel & 4: s2 = _f16((s2_raw >> 16) & 0xffff) - else: s2 = _f32(s2_raw) - # Apply neg modifiers (for f32 values) + # Decode sources based on opsel_hi (controls f32 vs f16) and opsel (controls which half for f16) + # src0: opsel_hi[0]=1 means f16, opsel[0] selects hi(1) or lo(0) half + if opsel_hi & 1: + s0 = _f16((s0_raw >> 16) & 0xffff) if (opsel & 1) else _f16(s0_raw & 0xffff) + else: + s0 = _f32(s0_raw) + # src1: opsel_hi[1]=1 means f16, opsel[1] selects hi(1) or lo(0) half + if opsel_hi & 2: + s1 = _f16((s1_raw >> 16) & 0xffff) if (opsel & 2) else _f16(s1_raw & 0xffff) + else: + s1 = _f32(s1_raw) + # src2: opsel_hi2=1 means f16, opsel[2] selects hi(1) or lo(0) half + if opsel_hi2: + s2 = _f16((s2_raw >> 16) & 0xffff) if (opsel & 4) else _f16(s2_raw & 0xffff) + else: + s2 = _f32(s2_raw) + # Apply abs modifiers (abs_ field reuses neg_hi position) + if abs_ & 1: s0 = abs(s0) + if abs_ & 2: s1 = abs(s1) + if abs_ & 4: s2 = abs(s2) + # Apply neg modifiers if neg & 1: s0 = -s0 if neg & 2: s1 = -s1 if neg & 4: s2 = -s2 @@ -505,7 +547,7 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No is_ldexp_64 = op in (VOP3Op.V_LDEXP_F64,) is_shift_64 = op in (VOP3Op.V_LSHLREV_B64, VOP3Op.V_LSHRREV_B64, VOP3Op.V_ASHRREV_I64) # 16-bit source ops: use precomputed sets instead of string checks - has_16bit_type = op in _VOP3_16BIT_OPS or op in _VOP1_16BIT_OPS or op in _VOP2_16BIT_OPS + # Note: must check op_cls to avoid cross-enum value collisions is_16bit_src = op_cls is VOP3Op and op in _VOP3_16BIT_OPS and op not in _CVT_32_64_SRC_OPS # VOP2 16-bit ops use f16 inline constants for src0 (vsrc1 is always a VGPR, no inline constants) is_vop2_16bit = op_cls is VOP2Op and op in _VOP2_16BIT_OPS @@ -525,27 +567,88 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No s2 = mod_src64(st.rsrc64(src2, lane), 2) if src2 is not None else 0 elif is_16bit_src: # For 16-bit source ops, opsel bits select which half to use - s0_raw = mod_src(st.rsrc(src0, lane), 0) - s1_raw = mod_src(st.rsrc(src1, lane), 1) if src1 is not None else 0 - s2_raw = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0 + # Inline constants (128-254) must use f16 encoding, not f32 + def rsrc_16bit(src, lane): return st.rsrc_f16(src, lane) if 128 <= src < 255 else st.rsrc(src, lane) + s0_raw = rsrc_16bit(src0, lane) + s1_raw = rsrc_16bit(src1, lane) if src1 is not None else 0 + s2_raw = rsrc_16bit(src2, lane) if src2 is not None else 0 # opsel[0] selects hi(1) or lo(0) for src0, opsel[1] for src1, opsel[2] for src2 s0 = ((s0_raw >> 16) & 0xffff) if (opsel & 1) else (s0_raw & 0xffff) s1 = ((s1_raw >> 16) & 0xffff) if (opsel & 2) else (s1_raw & 0xffff) s2 = ((s2_raw >> 16) & 0xffff) if (opsel & 4) else (s2_raw & 0xffff) + # Apply abs/neg modifiers as f16 operations (toggle sign bit 15) + if abs_ & 1: s0 &= 0x7fff + if abs_ & 2: s1 &= 0x7fff + if abs_ & 4: s2 &= 0x7fff + if neg & 1: s0 ^= 0x8000 + if neg & 2: s1 ^= 0x8000 + if neg & 4: s2 ^= 0x8000 elif is_vop2_16bit: - # VOP2 16-bit ops: src0 can use f16 inline constants, vsrc1 is always a VGPR (no inline constants) - s0 = mod_src(st.rsrc_f16(src0, lane), 0) - s1 = mod_src(st.rsrc(src1, lane), 1) if src1 is not None else 0 + # VOP2 16-bit ops: src0 uses f16 inline constants, or VGPR where v128+ = hi half of v0-v127 + # RDNA3 encoding: for VGPRs, bit 7 of VGPR index (src0-256) selects hi(1) or lo(0) half + if src0 >= 256: # VGPR + src0_hi = (src0 - 256) & 0x80 != 0 + src0_masked = ((src0 - 256) & 0x7f) + 256 # mask out hi bit to get actual VGPR + s0_raw = mod_src(st.rsrc(src0_masked, lane), 0) + s0 = ((s0_raw >> 16) & 0xffff) if src0_hi else (s0_raw & 0xffff) + else: # SGPR or inline constant + s0_raw = mod_src(st.rsrc_f16(src0, lane), 0) + s0 = s0_raw & 0xffff + # vsrc1: .h suffix encoded in bit 7 of VGPR index (src1 = 256 + vgpr_idx + 0x80 if hi) + if src1 is not None: + src1_hi = (src1 - 256) & 0x80 != 0 + src1_masked = ((src1 - 256) & 0x7f) + 256 + s1_raw = mod_src(st.rsrc(src1_masked, lane), 1) + s1 = ((s1_raw >> 16) & 0xffff) if src1_hi else (s1_raw & 0xffff) + else: + s1 = 0 s2 = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0 + elif op_cls is VOP1Op and op in _VOP1_16BIT_SRC_OPS: + # VOP1 16-bit source ops: .h encoded in bit 7 of VGPR index (src0 >= 384 means hi half) + # For VGPRs: src0 = 256 + vgpr_idx + (0x80 if hi else 0), so bit 7 of (src0-256) is the hi flag + src0_hi = src0 >= 256 and ((src0 - 256) & 0x80) != 0 + src0_masked = ((src0 - 256) & 0x7f) + 256 if src0 >= 256 else src0 # mask out hi bit for VGPR + s0_raw = mod_src(st.rsrc(src0_masked, lane), 0) + s0 = ((s0_raw >> 16) & 0xffff) if src0_hi else (s0_raw & 0xffff) + s1, s2 = 0, 0 + elif op_cls is VOPCOp and op in _VOPC_16BIT_OPS: + # VOPC 16-bit ops: src0 and vsrc1 use same encoding as VOP2 16-bit + # For VGPRs, bit 7 of VGPR index selects hi(1) or lo(0) half + if src0 >= 256: # VGPR + src0_hi = (src0 - 256) & 0x80 != 0 + src0_masked = ((src0 - 256) & 0x7f) + 256 + s0_raw = mod_src(st.rsrc(src0_masked, lane), 0) + s0 = ((s0_raw >> 16) & 0xffff) if src0_hi else (s0_raw & 0xffff) + else: # SGPR or inline constant + s0_raw = mod_src(st.rsrc_f16(src0, lane), 0) + s0 = s0_raw & 0xffff + # vsrc1: bit 7 of VGPR index selects hi(1) or lo(0) half + if src1 is not None: + if src1 >= 256: # VGPR - use hi/lo encoding + src1_hi = (src1 - 256) & 0x80 != 0 + src1_masked = ((src1 - 256) & 0x7f) + 256 + s1_raw = mod_src(st.rsrc(src1_masked, lane), 1) + s1 = ((s1_raw >> 16) & 0xffff) if src1_hi else (s1_raw & 0xffff) + else: # SGPR or inline constant - read as 32-bit, use low 16 bits + s1_raw = mod_src(st.rsrc(src1, lane), 1) + s1 = s1_raw & 0xffffffff # V_CMP_CLASS uses full 32-bit mask + else: + s1 = 0 + s2 = 0 else: s0 = mod_src(st.rsrc(src0, lane), 0) s1 = mod_src(st.rsrc(src1, lane), 1) if src1 is not None else 0 s2 = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0 - d0 = V[vdst] if not is_64bit_op else (V[vdst] | (V[vdst + 1] << 32)) + # For VOP2 16-bit ops (like V_FMAC_F16), the destination is used as an accumulator. + # The pseudocode reads D0.f16 from low 16 bits, so we need to shift hi->lo when vop2_dst_hi is True. + if is_vop2_16bit: + d0 = ((V[vdst] >> 16) & 0xffff) if vop2_dst_hi else (V[vdst] & 0xffff) + else: + d0 = V[vdst] if not is_64bit_op else (V[vdst] | (V[vdst + 1] << 32)) - # V_CNDMASK_B32: VOP3 encoding uses src2 as mask (not VCC); VOP2 uses VCC implicitly + # V_CNDMASK_B32/B16: VOP3 encoding uses src2 as mask (not VCC); VOP2 uses VCC implicitly # Pass the correct mask as vcc to the function so pseudocode VCC.u64[laneId] works correctly - vcc_for_fn = st.rsgpr64(src2) if op in (VOP3Op.V_CNDMASK_B32,) and inst_type is VOP3 and src2 is not None and src2 < 256 else st.vcc + vcc_for_fn = st.rsgpr64(src2) if op in (VOP3Op.V_CNDMASK_B32, VOP3Op.V_CNDMASK_B16) and inst_type is VOP3 and src2 is not None and src2 < 256 else st.vcc # Execute compiled function - pass src0_idx and vdst_idx for lane instructions # For VGPR access: src0 index is the VGPR number (src0 - 256 if VGPR, else src0 for SGPR) @@ -571,7 +674,8 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No writes_to_sgpr = op in (VOP1Op.V_READFIRSTLANE_B32,) or \ (op_cls is VOP3Op and op in (VOP3Op.V_READFIRSTLANE_B32, VOP3Op.V_READLANE_B32)) # Check for 16-bit destination ops (opsel[3] controls hi/lo write) - is_16bit_dst = op in _VOP3_16BIT_DST_OPS or op in _VOP1_16BIT_DST_OPS + # Must check op_cls to avoid cross-enum value collisions (e.g., VOP1Op.V_MOV_B32=1 vs VOP3Op.V_CMP_LT_F16=1) + is_16bit_dst = (op_cls is VOP3Op and op in _VOP3_16BIT_DST_OPS) or (op_cls is VOP1Op and op in _VOP1_16BIT_DST_OPS) if writes_to_sgpr: st.wsgpr(vdst, result['d0'] & 0xffffffff) elif result.get('d0_64') or is_64bit_op: @@ -583,6 +687,18 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No V[vdst] = (V[vdst] & 0x0000ffff) | ((result['d0'] & 0xffff) << 16) else: # opsel[3] = 0: write to low 16 bits V[vdst] = (V[vdst] & 0xffff0000) | (result['d0'] & 0xffff) + elif is_16bit_dst and inst_type is VOP1: + # VOP1 16-bit ops: .h suffix encoded in bit 7 of vdst (extracted as vop1_dst_hi) + if vop1_dst_hi: # .h: write to high 16 bits + V[vdst] = (V[vdst] & 0x0000ffff) | ((result['d0'] & 0xffff) << 16) + else: # .l: write to low 16 bits + V[vdst] = (V[vdst] & 0xffff0000) | (result['d0'] & 0xffff) + elif is_vop2_16bit: + # VOP2 16-bit ops: .h suffix encoded in bit 7 of vdst (extracted as vop2_dst_hi) + if vop2_dst_hi: # .h: write to high 16 bits + V[vdst] = (V[vdst] & 0x0000ffff) | ((result['d0'] & 0xffff) << 16) + else: # .l: write to low 16 bits + V[vdst] = (V[vdst] & 0xffff0000) | (result['d0'] & 0xffff) else: V[vdst] = result['d0'] & 0xffffffff diff --git a/extra/assembly/amd/pcode.py b/extra/assembly/amd/pcode.py index dcf96a5a99..05b23d7528 100644 --- a/extra/assembly/amd/pcode.py +++ b/extra/assembly/amd/pcode.py @@ -35,12 +35,18 @@ def _isnan(x): try: return math.isnan(float(x)) except (TypeError, ValueError): return False def _isquietnan(x): - """Check if x is a quiet NaN. For f32: exponent=255, bit22=1, mantissa!=0""" + """Check if x is a quiet NaN. + f16: exponent=31, bit9=1, mantissa!=0 + f32: exponent=255, bit22=1, mantissa!=0 + f64: exponent=2047, bit51=1, mantissa!=0 + """ try: if not math.isnan(float(x)): return False # Get raw bits from TypedView or similar object with _reg attribute if hasattr(x, '_reg') and hasattr(x, '_bits'): bits = x._reg._val & ((1 << x._bits) - 1) + if x._bits == 16: + return ((bits >> 10) & 0x1f) == 31 and ((bits >> 9) & 1) == 1 and (bits & 0x3ff) != 0 if x._bits == 32: return ((bits >> 23) & 0xff) == 255 and ((bits >> 22) & 1) == 1 and (bits & 0x7fffff) != 0 if x._bits == 64: @@ -48,12 +54,18 @@ def _isquietnan(x): return True # Default to quiet NaN if we can't determine bit pattern except (TypeError, ValueError): return False def _issignalnan(x): - """Check if x is a signaling NaN. For f32: exponent=255, bit22=0, mantissa!=0""" + """Check if x is a signaling NaN. + f16: exponent=31, bit9=0, mantissa!=0 + f32: exponent=255, bit22=0, mantissa!=0 + f64: exponent=2047, bit51=0, mantissa!=0 + """ try: if not math.isnan(float(x)): return False # Get raw bits from TypedView or similar object with _reg attribute if hasattr(x, '_reg') and hasattr(x, '_bits'): bits = x._reg._val & ((1 << x._bits) - 1) + if x._bits == 16: + return ((bits >> 10) & 0x1f) == 31 and ((bits >> 9) & 1) == 0 and (bits & 0x3ff) != 0 if x._bits == 32: return ((bits >> 23) & 0xff) == 255 and ((bits >> 22) & 1) == 0 and (bits & 0x7fffff) != 0 if x._bits == 64: @@ -73,7 +85,11 @@ def floor(x): def ceil(x): x = float(x) return x if math.isnan(x) or math.isinf(x) else float(math.ceil(x)) -def sqrt(x): return math.sqrt(x) if x >= 0 else float("nan") +class _SafeFloat(float): + """Float subclass that uses _div for division to handle 0/inf correctly.""" + def __truediv__(self, o): return _div(float(self), float(o)) + def __rtruediv__(self, o): return _div(float(o), float(self)) +def sqrt(x): return _SafeFloat(math.sqrt(x)) if x >= 0 else _SafeFloat(float("nan")) def log2(x): return math.log2(x) if x > 0 else (float("-inf") if x == 0 else float("nan")) i32_to_f32 = u32_to_f32 = i32_to_f64 = u32_to_f64 = f32_to_f64 = f64_to_f32 = float def f32_to_i32(f): @@ -107,7 +123,10 @@ def u4_to_u32(v): return int(v) & 0xf def _sign(f): return 1 if math.copysign(1.0, f) < 0 else 0 def _mantissa_f32(f): return struct.unpack("> 16) & 0xffff + self.assertAlmostEqual(lo, 7.0, places=1, msg=f"lo: 2*3+1=7, got {lo}") + self.assertEqual(hi, 0xdead, f"hi should be preserved, got 0x{hi:04x}") + + class TestF64Conversions(unittest.TestCase): """Tests for 64-bit float operations and conversions.""" @@ -2598,5 +2696,933 @@ class TestQuadmaskWqm(unittest.TestCase): self.assertEqual(st.scc, 0, "SCC should be 0 (result == 0)") +class TestVOP2_16bit_HiHalf(unittest.TestCase): + """Regression tests for VOP2 16-bit ops reading from high half of VGPR (v128+ encoding). + + Bug: VOP2 16-bit ops like v_add_f16 with src0 as v128+ should read the HIGH 16 bits + of the corresponding VGPR (v128 = v0.hi, v129 = v1.hi, etc). The emulator was + incorrectly reading from VGPR v128+ instead of the high half of v0+. + + Example: v_add_f16 v0, v128, v0 means v0.lo = v0.hi + v0.lo (fold packed result) + """ + + def test_v_add_f16_src0_hi_fold(self): + """v_add_f16 with src0=v128 (v0.hi) - fold packed f16 values. + + This pattern is generated by LLVM for summing packed f16 results: + v_pk_mul_f16 produces [hi, lo] in v0, then v_add_f16 v0, v128, v0 sums them. + """ + instructions = [ + # v0 = packed f16: high=2.0 (0x4000), low=1.0 (0x3c00) + s_mov_b32(s[0], 0x40003c00), + v_mov_b32_e32(v[0], s[0]), + # v_add_f16 v1, v128, v0 means: v1.lo = v0.hi + v0.lo = 2.0 + 1.0 = 3.0 + # v128 in src0 means "read high 16 bits of v0" + v_add_f16_e32(v[1], v[0].h, v[0]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][1] & 0xffff + self.assertEqual(result, 0x4200, f"Expected 3.0 (0x4200), got 0x{result:04x}") + + def test_v_add_f16_src0_hi_different_reg(self): + """v_add_f16 with src0=v129 (v1.hi) reads high half of v1.""" + instructions = [ + s_mov_b32(s[0], 0x44004200), # v1: high=4.0, low=3.0 + v_mov_b32_e32(v[1], s[0]), + s_mov_b32(s[1], 0x3c00), # v0: low=1.0 + v_mov_b32_e32(v[0], s[1]), + # v_add_f16 v2, v129, v0 means: v2.lo = v1.hi + v0.lo = 4.0 + 1.0 = 5.0 + v_add_f16_e32(v[2], v[1].h, v[0]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][2] & 0xffff + self.assertEqual(result, 0x4500, f"Expected 5.0 (0x4500), got 0x{result:04x}") + + def test_v_mul_f16_src0_hi(self): + """v_mul_f16 with src0 from high half.""" + instructions = [ + s_mov_b32(s[0], 0x40003c00), # v0: high=2.0, low=1.0 + v_mov_b32_e32(v[0], s[0]), + s_mov_b32(s[1], 0x4200), # v1: low=3.0 + v_mov_b32_e32(v[1], s[1]), + # v_mul_f16 v2, v128, v1 means: v2.lo = v0.hi * v1.lo = 2.0 * 3.0 = 6.0 + v_mul_f16_e32(v[2], v[0].h, v[1]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][2] & 0xffff + self.assertEqual(result, 0x4600, f"Expected 6.0 (0x4600), got 0x{result:04x}") + + def test_v_add_f16_multilane(self): + """v_add_f16 with src0=v128 across multiple lanes.""" + instructions = [ + # Set up different packed values per lane using v_mov with lane-dependent values + # Lane 0: v0 = 0x40003c00 (hi=2.0, lo=1.0) -> sum = 3.0 + # Lane 1: v0 = 0x44004200 (hi=4.0, lo=3.0) -> sum = 7.0 + v_mov_b32_e32(v[0], 0x40003c00), # default for all lanes + # Use v_cmp to select lane 1 (v255 = lane_id from prologue) + v_cmp_eq_u32_e32(1, v[255]), # vcc = (lane == 1) + v_cndmask_b32_e64(v[0], v[0], 0x44004200, SrcEnum.VCC_LO), + # Now fold: v1.lo = v0.hi + v0.lo + v_add_f16_e32(v[1], v[0].h, v[0]), + ] + st = run_program(instructions, n_lanes=2) + # Lane 0: 2.0 + 1.0 = 3.0 (0x4200) + self.assertEqual(st.vgpr[0][1] & 0xffff, 0x4200, "Lane 0: expected 3.0") + # Lane 1: 4.0 + 3.0 = 7.0 (0x4700) + self.assertEqual(st.vgpr[1][1] & 0xffff, 0x4700, "Lane 1: expected 7.0") + + +class TestVOPC_16bit_HiHalf(unittest.TestCase): + """Regression tests for VOPC 16-bit ops reading from high half of VGPR (v128+ encoding). + + Bug: VOPC 16-bit ops like v_cmp_lt_f16 with vsrc1 as v128+ should read the HIGH 16 bits + of the corresponding VGPR. The emulator was incorrectly reading from VGPR v128+. + + Example: v_cmp_nge_f16 vcc, v0, v128 compares v0.lo with v0.hi + """ + + def test_v_cmp_lt_f16_vsrc1_hi(self): + """v_cmp_lt_f16 comparing low half with high half of same register.""" + instructions = [ + # v0: high=2.0 (0x4000), low=1.0 (0x3c00) + s_mov_b32(s[0], 0x40003c00), + v_mov_b32_e32(v[0], s[0]), + # v_cmp_lt_f16 vcc, v0, v128 means: vcc = (v0.lo < v0.hi) = (1.0 < 2.0) = true + v_cmp_lt_f16_e32(v[0], v[0].h), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vcc & 1, 1, "Expected vcc=1 (1.0 < 2.0)") + + def test_v_cmp_gt_f16_vsrc1_hi(self): + """v_cmp_gt_f16 with vsrc1 from high half.""" + instructions = [ + # v0: high=1.0 (0x3c00), low=2.0 (0x4000) + s_mov_b32(s[0], 0x3c004000), + v_mov_b32_e32(v[0], s[0]), + # v_cmp_gt_f16 vcc, v0, v128 means: vcc = (v0.lo > v0.hi) = (2.0 > 1.0) = true + v_cmp_gt_f16_e32(v[0], v[0].h), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vcc & 1, 1, "Expected vcc=1 (2.0 > 1.0)") + + def test_v_cmp_eq_f16_vsrc1_hi_equal(self): + """v_cmp_eq_f16 with equal low and high halves.""" + instructions = [ + # v0: high=3.0 (0x4200), low=3.0 (0x4200) + s_mov_b32(s[0], 0x42004200), + v_mov_b32_e32(v[0], s[0]), + # v_cmp_eq_f16 vcc, v0, v128 means: vcc = (v0.lo == v0.hi) = (3.0 == 3.0) = true + v_cmp_eq_f16_e32(v[0], v[0].h), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vcc & 1, 1, "Expected vcc=1 (3.0 == 3.0)") + + def test_v_cmp_neq_f16_vsrc1_hi(self): + """v_cmp_neq_f16 with different low and high halves.""" + instructions = [ + # v0: high=2.0 (0x4000), low=1.0 (0x3c00) + s_mov_b32(s[0], 0x40003c00), + v_mov_b32_e32(v[0], s[0]), + # v_cmp_neq_f16 vcc, v0, v128 means: vcc = (v0.lo != v0.hi) = (1.0 != 2.0) = true + v_cmp_lg_f16_e32(v[0], v[0].h), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vcc & 1, 1, "Expected vcc=1 (1.0 != 2.0)") + + def test_v_cmp_nge_f16_inf_self(self): + """v_cmp_nge_f16 comparing -inf with itself (unordered less than). + + Regression test: -inf < -inf should be false (IEEE 754). + The bug was VOPC 16-bit not handling v128+ encoding for vsrc1. + """ + instructions = [ + # v0: both halves = -inf (0xFC00) + s_mov_b32(s[0], 0xFC00FC00), + v_mov_b32_e32(v[0], s[0]), + # v_cmp_nge_f16 is "not greater or equal" which is equivalent to "unordered less than" + # -inf nge -inf should be false (since -inf >= -inf is true) + v_cmp_nge_f16_e32(v[0], v[0].h), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vcc & 1, 0, "Expected vcc=0 (-inf >= -inf)") + + def test_v_cmp_f16_multilane(self): + """v_cmp_lt_f16 with vsrc1=v128 across multiple lanes.""" + instructions = [ + # Lane 0: v0 = 0x40003c00 (hi=2.0, lo=1.0) -> 1.0 < 2.0 = true + # Lane 1: v0 = 0x3c004000 (hi=1.0, lo=2.0) -> 2.0 < 1.0 = false + v_mov_b32_e32(v[0], 0x40003c00), # default + # Use v_cmp to select lane 1 (v255 = lane_id from prologue) + v_cmp_eq_u32_e32(1, v[255]), # vcc = (lane == 1) + v_cndmask_b32_e64(v[0], v[0], 0x3c004000, SrcEnum.VCC_LO), + v_cmp_lt_f16_e32(v[0], v[0].h), + ] + st = run_program(instructions, n_lanes=2) + self.assertEqual(st.vcc & 1, 1, "Lane 0: expected vcc=1 (1.0 < 2.0)") + self.assertEqual((st.vcc >> 1) & 1, 0, "Lane 1: expected vcc=0 (2.0 < 1.0)") + + +class TestF16SinKernelOps(unittest.TestCase): + """Tests for F16 instructions used in the sin kernel. Run with USE_HW=1 to compare emulator vs hardware.""" + + def test_v_cvt_i16_f16_zero(self): + """v_cvt_i16_f16: Convert f16 0.0 to i16 0.""" + instructions = [ + s_mov_b32(s[0], 0x00000000), # f16 0.0 in low bits + v_mov_b32_e32(v[0], s[0]), + v_cvt_i16_f16_e32(v[1], v[0]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][1] & 0xFFFF + self.assertEqual(result, 0, f"Expected 0, got {result}") + + def test_v_cvt_i16_f16_one(self): + """v_cvt_i16_f16: Convert f16 1.0 (0x3c00) to i16 1.""" + instructions = [ + s_mov_b32(s[0], 0x00003c00), # f16 1.0 in low bits + v_mov_b32_e32(v[0], s[0]), + v_cvt_i16_f16_e32(v[1], v[0]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][1] & 0xFFFF + self.assertEqual(result, 1, f"Expected 1, got {result}") + + def test_v_cvt_i16_f16_negative(self): + """v_cvt_i16_f16: Convert f16 -2.0 (0xc000) to i16 -2.""" + instructions = [ + s_mov_b32(s[0], 0x0000c000), # f16 -2.0 in low bits + v_mov_b32_e32(v[0], s[0]), + v_cvt_i16_f16_e32(v[1], v[0]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][1] & 0xFFFF + # -2 as signed 16-bit = 0xFFFE + self.assertEqual(result, 0xFFFE, f"Expected 0xFFFE (-2), got 0x{result:04x}") + + def test_v_cvt_i16_f16_from_hi(self): + """v_cvt_i16_f16: Convert f16 from high half of register.""" + instructions = [ + s_mov_b32(s[0], 0x3c000000), # f16 1.0 in HIGH bits, 0.0 in low + v_mov_b32_e32(v[0], s[0]), + v_cvt_i16_f16_e32(v[1], v[0].h), # Read from high half + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][1] & 0xFFFF + self.assertEqual(result, 1, f"Expected 1, got {result}") + + def test_v_bfe_i32_sign_extend(self): + """v_bfe_i32: Extract 16 bits with sign extension.""" + instructions = [ + s_mov_b32(s[0], 0x80000001), # low 16 bits = 0x0001 + v_mov_b32_e32(v[0], s[0]), + v_bfe_i32(v[1], v[0], 0, 16), # Extract bits 0-15 with sign extend + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][1] + self.assertEqual(result, 1, f"Expected 1, got {result}") + + def test_v_bfe_i32_sign_extend_negative(self): + """v_bfe_i32: Extract 16 bits with sign extension (negative value).""" + instructions = [ + s_mov_b32(s[0], 0x0000FFFE), # low 16 bits = 0xFFFE = -2 as i16 + v_mov_b32_e32(v[0], s[0]), + v_bfe_i32(v[1], v[0], 0, 16), # Extract bits 0-15 with sign extend + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][1] + # -2 sign-extended to 32 bits = 0xFFFFFFFE + self.assertEqual(result, 0xFFFFFFFE, f"Expected 0xFFFFFFFE (-2), got 0x{result:08x}") + + def test_v_cndmask_b16_select_src0(self): + """v_cndmask_b16: Select src0 when vcc=0.""" + instructions = [ + s_mov_b32(s[0], 0x3c003800), # src0.h=1.0, src0.l=0.5 + v_mov_b32_e32(v[0], s[0]), + s_mov_b32(s[1], 0x4000c000), # src1.h=2.0, src1.l=-2.0 + v_mov_b32_e32(v[1], s[1]), + s_mov_b32(s[SrcEnum.VCC_LO - 128], 0), # vcc = 0 + v_cndmask_b16(v[2], v[0], v[1], SrcEnum.VCC_LO), # Should select v0.l = 0.5 + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][2] & 0xFFFF + self.assertEqual(result, 0x3800, f"Expected 0x3800 (0.5), got 0x{result:04x}") + + def test_v_cndmask_b16_select_src1(self): + """v_cndmask_b16: Select src1 when vcc=1.""" + instructions = [ + s_mov_b32(s[0], 0x3c003800), # src0.h=1.0, src0.l=0.5 + v_mov_b32_e32(v[0], s[0]), + s_mov_b32(s[1], 0x4000c000), # src1.h=2.0, src1.l=-2.0 + v_mov_b32_e32(v[1], s[1]), + s_mov_b32(s[SrcEnum.VCC_LO - 128], 1), # vcc = 1 for lane 0 + v_cndmask_b16(v[2], v[0], v[1], SrcEnum.VCC_LO), # Should select v1.l = -2.0 + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][2] & 0xFFFF + self.assertEqual(result, 0xc000, f"Expected 0xc000 (-2.0), got 0x{result:04x}") + + def test_v_cndmask_b16_write_hi(self): + """v_cndmask_b16: Write to high half with opsel.""" + instructions = [ + s_mov_b32(s[0], 0x3c003800), # src0: hi=1.0, lo=0.5 + v_mov_b32_e32(v[0], s[0]), + s_mov_b32(s[1], 0x4000c000), # src1: hi=2.0, lo=-2.0 + v_mov_b32_e32(v[1], s[1]), + s_mov_b32(s[2], 0xDEAD0000), # v2 initial: hi=0xDEAD, lo=0 + v_mov_b32_e32(v[2], s[2]), + s_mov_b32(s[SrcEnum.VCC_LO - 128], 0), # vcc = 0 + # opsel=8 means write to high half (bit 3 = dst hi) + # opsel=1 means read src0 from hi, opsel=2 means read src1 from hi + # v_cndmask_b16 v2.h, v0.h, v1.h, vcc -> select v0.h = 1.0 + VOP3(VOP3Op.V_CNDMASK_B16, vdst=v[2], src0=v[0], src1=v[1], src2=SrcEnum.VCC_LO, opsel=0b1011), + ] + st = run_program(instructions, n_lanes=1) + result_hi = (st.vgpr[0][2] >> 16) & 0xFFFF + result_lo = st.vgpr[0][2] & 0xFFFF + self.assertEqual(result_hi, 0x3c00, f"Expected hi=0x3c00 (1.0), got 0x{result_hi:04x}") + self.assertEqual(result_lo, 0x0000, f"Expected lo preserved as 0, got 0x{result_lo:04x}") + + def test_v_mul_f16_basic(self): + """v_mul_f16: 2.0 * 3.0 = 6.0.""" + instructions = [ + s_mov_b32(s[0], 0x00004000), # f16 2.0 in low bits + v_mov_b32_e32(v[0], s[0]), + s_mov_b32(s[1], 0x00004200), # f16 3.0 in low bits + v_mov_b32_e32(v[1], s[1]), + v_mul_f16_e32(v[2], v[0], v[1]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][2] & 0xFFFF + self.assertEqual(result, 0x4600, f"Expected 0x4600 (6.0), got 0x{result:04x}") + + def test_v_mul_f16_by_zero(self): + """v_mul_f16: x * 0.0 = 0.0.""" + instructions = [ + s_mov_b32(s[0], 0x00003c00), # f16 1.0 + v_mov_b32_e32(v[0], s[0]), + s_mov_b32(s[1], 0x00000000), # f16 0.0 + v_mov_b32_e32(v[1], s[1]), + v_mul_f16_e32(v[2], v[0], v[1]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][2] & 0xFFFF + self.assertEqual(result, 0x0000, f"Expected 0x0000 (0.0), got 0x{result:04x}") + + def test_v_mul_f16_hi_half(self): + """v_mul_f16: Multiply using high halves.""" + instructions = [ + s_mov_b32(s[0], 0x40000000), # hi=2.0, lo=0.0 + v_mov_b32_e32(v[0], s[0]), + s_mov_b32(s[1], 0x42000000), # hi=3.0, lo=0.0 + v_mov_b32_e32(v[1], s[1]), + v_mul_f16_e32(v[2].h, v[0].h, v[1].h), # 2.0 * 3.0 = 6.0 in hi + ] + st = run_program(instructions, n_lanes=1) + result_hi = (st.vgpr[0][2] >> 16) & 0xFFFF + self.assertEqual(result_hi, 0x4600, f"Expected hi=0x4600 (6.0), got 0x{result_hi:04x}") + + def test_v_fmac_f16_basic(self): + """v_fmac_f16: dst = src0 * src1 + dst = 2.0 * 3.0 + 1.0 = 7.0.""" + instructions = [ + s_mov_b32(s[0], 0x00004000), # f16 2.0 + v_mov_b32_e32(v[0], s[0]), + s_mov_b32(s[1], 0x00004200), # f16 3.0 + v_mov_b32_e32(v[1], s[1]), + s_mov_b32(s[2], 0x00003c00), # f16 1.0 (accumulator) + v_mov_b32_e32(v[2], s[2]), + v_fmac_f16_e32(v[2], v[0], v[1]), # v2 = v0 * v1 + v2 + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][2] & 0xFFFF + self.assertEqual(result, 0x4700, f"Expected 0x4700 (7.0), got 0x{result:04x}") + + def test_v_fmac_f16_hi_dest(self): + """v_fmac_f16 with .h destination: dst.h = src0 * src1 + dst.h. + + This tests the case from AMD_LLVM sin(0) where V_FMAC_F16 writes to v0.h. + The accumulator D should be read from v0.h, not v0.l. + """ + from extra.assembly.amd.pcode import f32_to_f16, _f16 + # Set up: v0 = {hi=0.5, lo=1.0}, src0 = 0.0 (literal), src1 = v1.l (any value) + # Expected: v0.h = 0.0 * v1.l + 0.5 = 0.5 (unchanged) + instructions = [ + s_mov_b32(s[0], 0x38003c00), # v0 = {hi=0.5, lo=1.0} + v_mov_b32_e32(v[0], s[0]), + s_mov_b32(s[1], 0x38000000), # v1 = {hi=0.5, lo=0.0} + v_mov_b32_e32(v[1], s[1]), + # v_fmac_f16 v0.h, literal(0.318...), v1.l (vdst=128 for .h) + # D = D + S0 * S1 = v0.h + 0.318 * 0.0 = 0.5 + 0 = 0.5 + VOP2(VOP2Op.V_FMAC_F16, vdst=RawImm(128), src0=RawImm(255), vsrc1=RawImm(1), literal=0x3518), # 0.318... * 0.0 + 0.5 + ] + st = run_program(instructions, n_lanes=1) + v0 = st.vgpr[0][0] + result_hi = _f16((v0 >> 16) & 0xffff) + result_lo = _f16(v0 & 0xffff) + self.assertAlmostEqual(result_hi, 0.5, delta=0.01, msg=f"Expected v0.h=0.5, got {result_hi}") + self.assertAlmostEqual(result_lo, 1.0, delta=0.01, msg=f"Expected v0.l=1.0, got {result_lo}") + + def test_v_add_f16_basic(self): + """v_add_f16: 1.0 + 2.0 = 3.0.""" + instructions = [ + s_mov_b32(s[0], 0x00003c00), # f16 1.0 + v_mov_b32_e32(v[0], s[0]), + s_mov_b32(s[1], 0x00004000), # f16 2.0 + v_mov_b32_e32(v[1], s[1]), + v_add_f16_e32(v[2], v[0], v[1]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][2] & 0xFFFF + self.assertEqual(result, 0x4200, f"Expected 0x4200 (3.0), got 0x{result:04x}") + + def test_v_add_f16_negative(self): + """v_add_f16: 1.0 + (-1.5703125) = -0.5703125.""" + # 0xbe48 is approximately -1.5703125 in f16 + instructions = [ + s_mov_b32(s[0], 0x00003c00), # f16 1.0 + v_mov_b32_e32(v[0], s[0]), + s_mov_b32(s[1], 0x0000be48), # f16 -1.5703125 + v_mov_b32_e32(v[1], s[1]), + v_add_f16_e32(v[2], v[0], v[1]), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][2] & 0xFFFF + # 1.0 + (-1.5703125) = -0.5703125 which is approximately 0xb890 + # Allow some tolerance - just check it's negative and close + from extra.assembly.amd.pcode import _f16 + result_f = _f16(result) + expected = 1.0 - 1.5703125 + self.assertAlmostEqual(result_f, expected, places=2, msg=f"Expected ~{expected}, got {result_f}") + + def test_v_fmaak_f16_basic(self): + """v_fmaak_f16: dst = src0 * vsrc1 + K.""" + # v_fmaak_f16 computes: D = S0 * S1 + K + # 2.0 * 3.0 + 1.0 = 7.0 + instructions = [ + s_mov_b32(s[0], 0x00004000), # f16 2.0 + v_mov_b32_e32(v[0], s[0]), + s_mov_b32(s[1], 0x00004200), # f16 3.0 + v_mov_b32_e32(v[1], s[1]), + v_fmaak_f16_e32(v[2], v[0], v[1], 0x3c00), # v2 = v0 * v1 + 1.0 + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][2] & 0xFFFF + self.assertEqual(result, 0x4700, f"Expected 0x4700 (7.0), got 0x{result:04x}") + + def test_v_fmamk_f32_basic(self): + """v_fmamk_f32: dst = src0 * K + vsrc1.""" + # v_fmamk_f32 computes: D = S0 * K + S1 + # 2.0 * 3.0 + 1.0 = 7.0 + instructions = [ + s_mov_b32(s[0], f2i(2.0)), + v_mov_b32_e32(v[0], s[0]), + s_mov_b32(s[1], f2i(1.0)), # accumulator + v_mov_b32_e32(v[1], s[1]), + v_fmamk_f32_e32(v[2], v[0], f2i(3.0), v[1]), # v2 = v0 * 3.0 + v1 + ] + st = run_program(instructions, n_lanes=1) + result = i2f(st.vgpr[0][2]) + self.assertAlmostEqual(result, 7.0, places=5, msg=f"Expected 7.0, got {result}") + + def test_v_fmamk_f32_small_constant(self): + """v_fmamk_f32: Test with small constant like in sin kernel.""" + # This mimics part of the sin kernel: 1.0 * (-1.13e-4) + (-3.1414795) ≈ -3.1415926 + k_val = 0xb8ed5000 # approximately -0.0001131594 as f32 + s1_val = f2i(-3.1414794921875) + instructions = [ + s_mov_b32(s[0], f2i(1.0)), + v_mov_b32_e32(v[0], s[0]), + s_mov_b32(s[1], s1_val), + v_mov_b32_e32(v[1], s[1]), + v_fmamk_f32_e32(v[2], v[0], k_val, v[1]), # v2 = 1.0 * K + v1 + ] + st = run_program(instructions, n_lanes=1) + result = i2f(st.vgpr[0][2]) + k_f32 = i2f(k_val) + expected = 1.0 * k_f32 + (-3.1414794921875) + self.assertAlmostEqual(result, expected, places=5, msg=f"Expected {expected}, got {result}") + + def test_v_mov_b16_to_hi(self): + """v_mov_b16: Move immediate to high half, preserving low.""" + instructions = [ + s_mov_b32(s[0], 0x0000DEAD), # initial: lo=0xDEAD, hi=0 + v_mov_b32_e32(v[0], s[0]), + v_mov_b16_e32(v[0].h, 0x3800), # Move 0.5 to high half + ] + st = run_program(instructions, n_lanes=1) + result_hi = (st.vgpr[0][0] >> 16) & 0xFFFF + result_lo = st.vgpr[0][0] & 0xFFFF + self.assertEqual(result_hi, 0x3800, f"Expected hi=0x3800, got 0x{result_hi:04x}") + self.assertEqual(result_lo, 0xDEAD, f"Expected lo=0xDEAD (preserved), got 0x{result_lo:04x}") + + def test_v_mov_b16_to_lo(self): + """v_mov_b16: Move immediate to low half, preserving high.""" + instructions = [ + s_mov_b32(s[0], 0xBEEF0000), # initial: hi=0xBEEF, lo=0 + v_mov_b32_e32(v[0], s[0]), + v_mov_b16_e32(v[0], 0x3c00), # Move 1.0 to low half + ] + st = run_program(instructions, n_lanes=1) + result_hi = (st.vgpr[0][0] >> 16) & 0xFFFF + result_lo = st.vgpr[0][0] & 0xFFFF + self.assertEqual(result_lo, 0x3c00, f"Expected lo=0x3c00, got 0x{result_lo:04x}") + self.assertEqual(result_hi, 0xBEEF, f"Expected hi=0xBEEF (preserved), got 0x{result_hi:04x}") + + def test_v_xor_b32_sign_flip(self): + """v_xor_b32: XOR with 0x8000 flips sign of f16 in low bits.""" + # 0x4246 is approximately 3.13671875 in f16 + # XOR with 0x8000 gives 0xC246 which is -3.13671875 + instructions = [ + s_mov_b32(s[0], 0x00004246), # f16 3.13671875 + v_mov_b32_e32(v[0], s[0]), + v_xor_b32_e32(v[1], 0x8000, v[0]), # Flip sign bit of low half + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][1] & 0xFFFF + self.assertEqual(result, 0xC246, f"Expected 0xC246 (-3.137), got 0x{result:04x}") + + def test_v_fma_mix_f32_all_f32_sources(self): + """v_fma_mix_f32: All sources as f32 (opsel_hi=0).""" + instructions = [ + s_mov_b32(s[0], f2i(2.0)), + v_mov_b32_e32(v[0], s[0]), + s_mov_b32(s[1], f2i(3.0)), + v_mov_b32_e32(v[1], s[1]), + s_mov_b32(s[2], f2i(1.0)), + v_mov_b32_e32(v[2], s[2]), + # opsel_hi=0,0,0 means all sources are f32 + VOP3P(VOP3POp.V_FMA_MIX_F32, vdst=v[3], src0=v[0], src1=v[1], src2=v[2], opsel=0, opsel_hi=0, opsel_hi2=0), + ] + st = run_program(instructions, n_lanes=1) + result = i2f(st.vgpr[0][3]) + self.assertAlmostEqual(result, 7.0, places=5, msg=f"2*3+1=7, got {result}") + + def test_v_fma_mixlo_f16_all_f32_sources(self): + """v_fma_mixlo_f16: All sources as f32, result to low f16.""" + instructions = [ + s_mov_b32(s[0], f2i(1.0)), + v_mov_b32_e32(v[0], s[0]), + s_mov_b32(s[1], f2i(-1.22e-10)), # Very small + v_mov_b32_e32(v[1], s[1]), + s_mov_b32(s[2], f2i(-3.1415927)), # -pi + v_mov_b32_e32(v[2], s[2]), + s_mov_b32(s[3], 0xDEAD0000), # Garbage in hi + v_mov_b32_e32(v[3], s[3]), + # 1.0 * (-1.22e-10) + (-3.1415927) ≈ -3.1415927 + VOP3P(VOP3POp.V_FMA_MIXLO_F16, vdst=v[3], src0=v[0], src1=v[1], src2=v[2], opsel=0, opsel_hi=0, opsel_hi2=0), + ] + st = run_program(instructions, n_lanes=1) + from extra.assembly.amd.pcode import _f16 + result_lo = _f16(st.vgpr[0][3] & 0xFFFF) + result_hi = (st.vgpr[0][3] >> 16) & 0xFFFF + # Result should be approximately -pi + self.assertAlmostEqual(result_lo, -3.14, delta=0.01, msg=f"Expected ~-3.14, got {result_lo}") + self.assertEqual(result_hi, 0xDEAD, f"Expected hi preserved as 0xDEAD, got 0x{result_hi:04x}") + + +class TestVCmpClassF16(unittest.TestCase): + """Tests for V_CMP_CLASS_F16 - critical for f16 sin/cos classification. + + Class bit mapping: + bit 0 = signaling NaN + bit 1 = quiet NaN + bit 2 = -infinity + bit 3 = -normal + bit 4 = -denormal + bit 5 = -zero + bit 6 = +zero + bit 7 = +denormal + bit 8 = +normal + bit 9 = +infinity + + This is crucial for the f16 sin kernel which uses v_cmp_class_f16 to detect + special values like +-0, +-inf, NaN and select appropriate outputs. + """ + + def test_cmp_class_f16_positive_zero(self): + """V_CMP_CLASS_F16: +zero should match bit 6.""" + # f16 +0.0 = 0x0000 + instructions = [ + v_mov_b32_e32(v[0], 0), # f16 +0.0 in low 16 bits + v_mov_b32_e32(v[1], 0x40), # bit 6 only (+zero) + v_cmp_class_f16_e32(v[0], v[1]), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vcc & 1, 1, "VCC should be 1 for +zero with mask 0x40") + + def test_cmp_class_f16_negative_zero(self): + """V_CMP_CLASS_F16: -zero should match bit 5.""" + # f16 -0.0 = 0x8000 + instructions = [ + s_mov_b32(s[0], 0x8000), # f16 -0.0 + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], 0x20), # bit 5 only (-zero) + v_cmp_class_f16_e32(v[0], v[1]), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vcc & 1, 1, "VCC should be 1 for -zero with mask 0x20") + + def test_cmp_class_f16_positive_normal(self): + """V_CMP_CLASS_F16: +1.0 (normal) should match bit 8.""" + # f16 1.0 = 0x3c00 + instructions = [ + s_mov_b32(s[0], 0x3c00), # f16 +1.0 + s_mov_b32(s[1], 0x100), # bit 8 (+normal) + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_cmp_class_f16_e32(v[0], v[1]), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vcc & 1, 1, "VCC should be 1 for +1.0 with mask 0x100 (+normal)") + + def test_cmp_class_f16_negative_normal(self): + """V_CMP_CLASS_F16: -1.0 (normal) should match bit 3.""" + # f16 -1.0 = 0xbc00 + instructions = [ + s_mov_b32(s[0], 0xbc00), # f16 -1.0 + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], 0x08), # bit 3 (-normal) + v_cmp_class_f16_e32(v[0], v[1]), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vcc & 1, 1, "VCC should be 1 for -1.0 with mask 0x08 (-normal)") + + def test_cmp_class_f16_positive_infinity(self): + """V_CMP_CLASS_F16: +inf should match bit 9.""" + # f16 +inf = 0x7c00 + instructions = [ + s_mov_b32(s[0], 0x7c00), # f16 +inf + s_mov_b32(s[1], 0x200), # bit 9 (+inf) + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], s[1]), + v_cmp_class_f16_e32(v[0], v[1]), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vcc & 1, 1, "VCC should be 1 for +inf with mask 0x200") + + def test_cmp_class_f16_negative_infinity(self): + """V_CMP_CLASS_F16: -inf should match bit 2.""" + # f16 -inf = 0xfc00 + instructions = [ + s_mov_b32(s[0], 0xfc00), # f16 -inf + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], 0x04), # bit 2 (-inf) + v_cmp_class_f16_e32(v[0], v[1]), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vcc & 1, 1, "VCC should be 1 for -inf with mask 0x04") + + def test_cmp_class_f16_quiet_nan(self): + """V_CMP_CLASS_F16: quiet NaN should match bit 1.""" + # f16 quiet NaN = 0x7e00 (exponent all 1s, mantissa MSB set) + instructions = [ + s_mov_b32(s[0], 0x7e00), # f16 quiet NaN + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], 0x02), # bit 1 (quiet NaN) + v_cmp_class_f16_e32(v[0], v[1]), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vcc & 1, 1, "VCC should be 1 for quiet NaN with mask 0x02") + + def test_cmp_class_f16_signaling_nan(self): + """V_CMP_CLASS_F16: signaling NaN should match bit 0.""" + # f16 signaling NaN = 0x7c01 (exponent all 1s, mantissa MSB clear, other mantissa bits set) + instructions = [ + s_mov_b32(s[0], 0x7c01), # f16 signaling NaN + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], 0x01), # bit 0 (signaling NaN) + v_cmp_class_f16_e32(v[0], v[1]), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vcc & 1, 1, "VCC should be 1 for signaling NaN with mask 0x01") + + def test_cmp_class_f16_positive_denormal(self): + """V_CMP_CLASS_F16: positive denormal should match bit 7.""" + # f16 smallest positive denormal = 0x0001 + instructions = [ + v_mov_b32_e32(v[0], 1), # f16 +denormal (0x0001) + v_mov_b32_e32(v[1], 0x80), # bit 7 (+denormal) + v_cmp_class_f16_e32(v[0], v[1]), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vcc & 1, 1, "VCC should be 1 for +denormal with mask 0x80") + + def test_cmp_class_f16_negative_denormal(self): + """V_CMP_CLASS_F16: negative denormal should match bit 4.""" + # f16 smallest negative denormal = 0x8001 + instructions = [ + s_mov_b32(s[0], 0x8001), # f16 -denormal + v_mov_b32_e32(v[0], s[0]), + v_mov_b32_e32(v[1], 0x10), # bit 4 (-denormal) + v_cmp_class_f16_e32(v[0], v[1]), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vcc & 1, 1, "VCC should be 1 for -denormal with mask 0x10") + + def test_cmp_class_f16_combined_mask_zeros(self): + """V_CMP_CLASS_F16: mask 0x60 covers both +zero and -zero.""" + # Test with +0.0 + instructions = [ + v_mov_b32_e32(v[0], 0), # f16 +0.0 + v_mov_b32_e32(v[1], 0x60), # bits 5 and 6 (+-zero) + v_cmp_class_f16_e32(v[0], v[1]), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vcc & 1, 1, "VCC should be 1 for +zero with mask 0x60") + + def test_cmp_class_f16_combined_mask_1f8(self): + """V_CMP_CLASS_F16: mask 0x1f8 covers -normal,-denorm,-zero,+zero,+denorm,+normal. + + This is the exact mask used in the f16 sin kernel at PC=46: + v_cmp_class_f16_e64 vcc_lo, v1, 0x1f8 + + The kernel uses this to detect if the input is a "normal" finite value + (not NaN, not infinity). If the check fails (vcc=0), it selects NaN output. + """ + # Test with +0.0 - should match via bit 6 + instructions = [ + v_mov_b32_e32(v[0], 0), # f16 +0.0 + s_mov_b32(s[0], 0x1f8), + v_mov_b32_e32(v[1], s[0]), # mask 0x1f8 + v_cmp_class_f16_e32(v[0], v[1]), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vcc & 1, 1, "VCC should be 1 for +zero with mask 0x1f8") + + def test_cmp_class_f16_vop3_encoding(self): + """V_CMP_CLASS_F16 in VOP3 encoding (v_cmp_class_f16_e64). + + This tests the exact instruction encoding used in the f16 sin kernel. + VOP3 encoding allows the result to go to any SGPR pair, not just VCC. + """ + # v_cmp_class_f16_e64 vcc_lo, v0, 0x1f8 + # Use SGPR to hold the mask since literals require special handling + instructions = [ + v_mov_b32_e32(v[0], 0), # f16 +0.0 + s_mov_b32(s[0], 0x1f8), # class mask + VOP3(VOP3Op.V_CMP_CLASS_F16, vdst=RawImm(VCC), src0=v[0], src1=s[0]), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vcc & 1, 1, "VCC should be 1 for +zero with VOP3 encoding") + + def test_cmp_class_f16_vop3_normal_positive(self): + """V_CMP_CLASS_F16 VOP3 encoding with +1.0 (normal).""" + # f16 1.0 = 0x3c00, should match bit 8 (+normal) in mask 0x1f8 + instructions = [ + s_mov_b32(s[0], 0x3c00), # f16 +1.0 + v_mov_b32_e32(v[0], s[0]), + s_mov_b32(s[1], 0x1f8), # class mask + VOP3(VOP3Op.V_CMP_CLASS_F16, vdst=RawImm(VCC), src0=v[0], src1=s[1]), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vcc & 1, 1, "VCC should be 1 for +1.0 (normal) with mask 0x1f8") + + def test_cmp_class_f16_vop3_nan_fails_mask(self): + """V_CMP_CLASS_F16 VOP3: NaN should NOT match mask 0x1f8 (no NaN bits set).""" + # f16 quiet NaN = 0x7e00, should NOT match mask 0x1f8 (bits 3-8 only) + instructions = [ + s_mov_b32(s[0], 0x7e00), # f16 quiet NaN + v_mov_b32_e32(v[0], s[0]), + s_mov_b32(s[1], 0x1f8), # class mask + VOP3(VOP3Op.V_CMP_CLASS_F16, vdst=RawImm(VCC), src0=v[0], src1=s[1]), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vcc & 1, 0, "VCC should be 0 for NaN with mask 0x1f8 (no NaN bits)") + + def test_cmp_class_f16_vop3_inf_fails_mask(self): + """V_CMP_CLASS_F16 VOP3: +inf should NOT match mask 0x1f8 (no inf bits set).""" + # f16 +inf = 0x7c00, should NOT match mask 0x1f8 (bits 3-8 only) + instructions = [ + s_mov_b32(s[0], 0x7c00), # f16 +inf + v_mov_b32_e32(v[0], s[0]), + s_mov_b32(s[1], 0x1f8), # class mask + VOP3(VOP3Op.V_CMP_CLASS_F16, vdst=RawImm(VCC), src0=v[0], src1=s[1]), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vcc & 1, 0, "VCC should be 0 for +inf with mask 0x1f8 (no inf bits)") + + +class TestVOP3F16Modifiers(unittest.TestCase): + """Tests for VOP3 16-bit ops with abs/neg modifiers and inline constants. + + VOP3 16-bit ops must: + 1. Use f16 inline constants (not f32) + 2. Apply abs/neg modifiers as f16 operations (toggle bit 15) + + This is critical for sin/cos kernels that use v_cvt_f32_f16 with |abs| + and v_fma_f16 with inline constants. + """ + + def test_v_cvt_f32_f16_abs_negative(self): + """V_CVT_F32_F16 with |abs| on negative value.""" + from extra.assembly.amd.pcode import f32_to_f16 + f16_neg1 = f32_to_f16(-1.0) # 0xbc00 + instructions = [ + s_mov_b32(s[0], f16_neg1), + v_mov_b32_e32(v[1], s[0]), + v_cvt_f32_f16_e64(v[0], abs(v[1])), # |(-1.0)| = 1.0 + ] + st = run_program(instructions, n_lanes=1) + result = i2f(st.vgpr[0][0]) + self.assertAlmostEqual(result, 1.0, places=5, msg=f"Expected 1.0, got {result}") + + def test_v_cvt_f32_f16_abs_positive(self): + """V_CVT_F32_F16 with |abs| on positive value (should stay positive).""" + from extra.assembly.amd.pcode import f32_to_f16 + f16_2 = f32_to_f16(2.0) # 0x4000 + instructions = [ + s_mov_b32(s[0], f16_2), + v_mov_b32_e32(v[1], s[0]), + v_cvt_f32_f16_e64(v[0], abs(v[1])), # |2.0| = 2.0 + ] + st = run_program(instructions, n_lanes=1) + result = i2f(st.vgpr[0][0]) + self.assertAlmostEqual(result, 2.0, places=5, msg=f"Expected 2.0, got {result}") + + def test_v_cvt_f32_f16_neg_positive(self): + """V_CVT_F32_F16 with neg on positive value.""" + from extra.assembly.amd.pcode import f32_to_f16 + f16_2 = f32_to_f16(2.0) # 0x4000 + instructions = [ + s_mov_b32(s[0], f16_2), + v_mov_b32_e32(v[1], s[0]), + v_cvt_f32_f16_e64(v[0], -v[1]), # -(2.0) = -2.0 + ] + st = run_program(instructions, n_lanes=1) + result = i2f(st.vgpr[0][0]) + self.assertAlmostEqual(result, -2.0, places=5, msg=f"Expected -2.0, got {result}") + + def test_v_cvt_f32_f16_neg_negative(self): + """V_CVT_F32_F16 with neg on negative value (double negative).""" + from extra.assembly.amd.pcode import f32_to_f16 + f16_neg2 = f32_to_f16(-2.0) # 0xc000 + instructions = [ + s_mov_b32(s[0], f16_neg2), + v_mov_b32_e32(v[1], s[0]), + v_cvt_f32_f16_e64(v[0], -v[1]), # -(-2.0) = 2.0 + ] + st = run_program(instructions, n_lanes=1) + result = i2f(st.vgpr[0][0]) + self.assertAlmostEqual(result, 2.0, places=5, msg=f"Expected 2.0, got {result}") + + def test_v_fma_f16_inline_const_1_0(self): + """V_FMA_F16: a*b + 1.0 should use f16 inline constant.""" + from extra.assembly.amd.pcode import f32_to_f16, _f16 + # v4 = 0.3259 (f16), v6 = -0.4866 (f16), src2 = 1.0 inline + # Result: 0.3259 * (-0.4866) + 1.0 = 0.8413... + f16_a = f32_to_f16(0.325928) # 0x3537 + f16_b = f32_to_f16(-0.486572) # 0xb7c9 + instructions = [ + s_mov_b32(s[0], f16_a), + v_mov_b32_e32(v[4], s[0]), + s_mov_b32(s[1], f16_b), + v_mov_b32_e32(v[6], s[1]), + v_fma_f16(v[4], v[4], v[6], 1.0), # 1.0 is inline constant + ] + st = run_program(instructions, n_lanes=1) + result = _f16(st.vgpr[0][4] & 0xffff) + expected = 0.325928 * (-0.486572) + 1.0 + self.assertAlmostEqual(result, expected, delta=0.01, msg=f"Expected ~{expected:.4f}, got {result}") + + def test_v_fma_f16_inline_const_0_5(self): + """V_FMA_F16: a*b + 0.5 should use f16 inline constant.""" + from extra.assembly.amd.pcode import f32_to_f16, _f16 + f16_a = f32_to_f16(2.0) + f16_b = f32_to_f16(3.0) + instructions = [ + s_mov_b32(s[0], f16_a), + v_mov_b32_e32(v[0], s[0]), + s_mov_b32(s[1], f16_b), + v_mov_b32_e32(v[1], s[1]), + v_fma_f16(v[2], v[0], v[1], 0.5), # 0.5 is inline constant + ] + st = run_program(instructions, n_lanes=1) + result = _f16(st.vgpr[0][2] & 0xffff) + expected = 2.0 * 3.0 + 0.5 + self.assertAlmostEqual(result, expected, delta=0.01, msg=f"Expected {expected}, got {result}") + + def test_v_fma_f16_inline_const_neg_1_0(self): + """V_FMA_F16: a*b + (-1.0) should use f16 inline constant.""" + from extra.assembly.amd.pcode import f32_to_f16, _f16 + f16_a = f32_to_f16(2.0) + f16_b = f32_to_f16(3.0) + instructions = [ + s_mov_b32(s[0], f16_a), + v_mov_b32_e32(v[0], s[0]), + s_mov_b32(s[1], f16_b), + v_mov_b32_e32(v[1], s[1]), + v_fma_f16(v[2], v[0], v[1], -1.0), # -1.0 is inline constant + ] + st = run_program(instructions, n_lanes=1) + result = _f16(st.vgpr[0][2] & 0xffff) + expected = 2.0 * 3.0 + (-1.0) + self.assertAlmostEqual(result, expected, delta=0.01, msg=f"Expected {expected}, got {result}") + + def test_v_add_f16_abs_both(self): + """V_ADD_F16 with abs on both operands.""" + from extra.assembly.amd.pcode import f32_to_f16, _f16 + f16_neg2 = f32_to_f16(-2.0) + f16_neg3 = f32_to_f16(-3.0) + instructions = [ + s_mov_b32(s[0], f16_neg2), + v_mov_b32_e32(v[0], s[0]), + s_mov_b32(s[1], f16_neg3), + v_mov_b32_e32(v[1], s[1]), + v_add_f16_e64(v[2], abs(v[0]), abs(v[1])), # |-2| + |-3| = 5 + ] + st = run_program(instructions, n_lanes=1) + result = _f16(st.vgpr[0][2] & 0xffff) + self.assertAlmostEqual(result, 5.0, delta=0.01, msg=f"Expected 5.0, got {result}") + + def test_v_mul_f16_neg_abs(self): + """V_MUL_F16 with neg on one operand and abs on another.""" + from extra.assembly.amd.pcode import f32_to_f16, _f16 + f16_2 = f32_to_f16(2.0) + f16_neg3 = f32_to_f16(-3.0) + instructions = [ + s_mov_b32(s[0], f16_2), + v_mov_b32_e32(v[0], s[0]), + s_mov_b32(s[1], f16_neg3), + v_mov_b32_e32(v[1], s[1]), + v_mul_f16_e64(v[2], -v[0], abs(v[1])), # -(2) * |-3| = -6 + ] + st = run_program(instructions, n_lanes=1) + result = _f16(st.vgpr[0][2] & 0xffff) + self.assertAlmostEqual(result, -6.0, delta=0.01, msg=f"Expected -6.0, got {result}") + + if __name__ == '__main__': unittest.main() + + +class TestVFmaMixSinCase(unittest.TestCase): + """Tests for the specific V_FMA_MIXLO_F16 case that fails in AMD_LLVM sin(0) kernel.""" + + def test_v_fma_mixlo_f16_sin_case(self): + """V_FMA_MIXLO_F16 case from sin kernel at pc=0x14e. + + This tests the specific operands that produce the wrong result: + - src0 = v3 = 0x3f800000 (f32 1.0) + - src1 = s6 = 0xaf05a309 (f32 tiny negative) + - src2 = v5 = 0xc0490fdb (f32 -π) + - Result should be approximately -π (tiny * 1.0 + -π ≈ -π) + """ + from extra.assembly.amd.pcode import _f16 + instructions = [ + # Set up operands as in the sin kernel + s_mov_b32(s[0], 0x3f800000), # f32 1.0 + v_mov_b32_e32(v[3], s[0]), + s_mov_b32(s[1], 0xaf05a309), # f32 tiny negative + s_mov_b32(s[6], s[1]), + s_mov_b32(s[2], 0xc0490fdb), # f32 -π + v_mov_b32_e32(v[5], s[2]), + # Pre-fill v3 with expected hi bits + s_mov_b32(s[3], 0x3f800000), # hi = f32 1.0 encoding (will be overwritten by opsel behavior) + v_mov_b32_e32(v[3], s[3]), + # V_FMA_MIXLO_F16: src0=v3 (259), src1=s6, src2=v5 (261), opsel=0, opsel_hi=0, opsel_hi2=0 + VOP3P(VOP3POp.V_FMA_MIXLO_F16, vdst=v[3], src0=v[3], src1=s[6], src2=v[5], opsel=0, opsel_hi=0, opsel_hi2=0), + ] + st = run_program(instructions, n_lanes=1) + lo = _f16(st.vgpr[0][3] & 0xffff) + # Result should be approximately -π = -3.14... + # f16 -π ≈ 0xc248 = -3.140625 + self.assertAlmostEqual(lo, -3.14159, delta=0.01, msg=f"Expected ~-π, got {lo}") diff --git a/extra/assembly/amd/test/test_roundtrip.py b/extra/assembly/amd/test/test_roundtrip.py index 7fe12d9855..bf9b68d869 100644 --- a/extra/assembly/amd/test/test_roundtrip.py +++ b/extra/assembly/amd/test/test_roundtrip.py @@ -31,7 +31,12 @@ def detect_format(data: bytes) -> type[Inst] | None: # Check 64-bit formats if len(data) >= 8: - if enc_8bit in (0xD4, 0xD5, 0xD7): return VOP3 + if enc_8bit in (0xD4, 0xD5, 0xD7): + # VOP3 and VOP3SD share encoding - check opcode to determine which + # VOP3SD opcodes: 288-290 (v_*_co_ci_*), 764-770 (v_div_scale_*, v_mad_*, v_*_co_u32) + op = (int.from_bytes(data[:8], 'little') >> 16) & 0x3FF + if op in {288, 289, 290, 764, 765, 766, 767, 768, 769, 770}: return VOP3SD + return VOP3 if enc_8bit == 0xD6: return VOP3SD if enc_8bit == 0xCC: return VOP3P if enc_8bit == 0xCD: return VINTERP