From 232848d0863a80d5b25285cef1949f6ed49cba4e Mon Sep 17 00:00:00 2001 From: Christopher Milan Date: Wed, 4 Feb 2026 17:10:59 -0800 Subject: [PATCH] PYTHONREMU: VOP3P integer operations with constants don't cast to fp16 (#14546) * PYTHONREMU: VOP3P integer operations with constants don't cast to fp16 * put that back * cleaner * do that once --- extra/assembly/amd/emu.py | 11 ++++++----- extra/assembly/amd/test/hw/test_vop3p.py | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/extra/assembly/amd/emu.py b/extra/assembly/amd/emu.py index 03633d6def..b5abd9b0ec 100644 --- a/extra/assembly/amd/emu.py +++ b/extra/assembly/amd/emu.py @@ -354,7 +354,7 @@ class _Ctx: offset = reg.cast(dtypes.int) * _c(32, dtypes.int) + lane.cast(dtypes.int) return buf.index(offset, _lane_active(exec_mask, lane)).store(val.cast(dtypes.uint32)) - def rsrc_dyn(self, off: UOp, lane: UOp | None, bits: int = 32, literal: UOp | None = None, is_f64: bool = False) -> UOp: + def rsrc_dyn(self, off: UOp, lane: UOp | None, bits: int = 32, literal: UOp | None = None, is_f64: bool = False, do_cast: bool = True) -> UOp: """Read source operand with dynamic offset. Handles SGPR/inline constants (<256), VGPR (>=256). If lane is None, only scalar access is supported (off must be < 256). is_f64: True for F64 operations where 64-bit literals go in high 32 bits.""" @@ -385,7 +385,7 @@ class _Ctx: else: scalar_val = sgpr_lo if literal is not None: scalar_val = off.eq(_c(255)).where(literal, scalar_val) - if bits == 16: # Float constants: cast F32 to F16 + if bits == 16 and do_cast: # Float constants: cast F32 to F16 scalar_val = is_float_const.where(scalar_val.bitcast(dtypes.float32).cast(dtypes.half).bitcast(dtypes.uint16).cast(dtypes.uint32), scalar_val) return is_vgpr.where(vgpr_val, scalar_val) if lane is not None else scalar_val @@ -800,9 +800,10 @@ def _compile_vop3p(inst: ir3.VOP3P | ir4.VOP3P, ctx: _Ctx) -> UOp: lane = ctx.range() exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset)) vdst_reg = ctx.inst_field(type(inst).vdst) - src0 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src0), lane, 16) - src1 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src1), lane, 16) - src2 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src2), lane, 16) + do_cast = any(x in op_name for x in ('F16', 'F32', 'BF16')) and 'IU' not in op_name + src0 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src0), lane, 16, do_cast=do_cast) + src1 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src1), lane, 16, do_cast=do_cast) + src2 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src2), lane, 16, do_cast=do_cast) opsel, opsel_hi = getattr(inst, 'opsel', 0) or 0, getattr(inst, 'opsel_hi', 3) if getattr(inst, 'opsel_hi', 3) is not None else 3 opsel_hi2 = getattr(inst, 'opsel_hi2', 1) if getattr(inst, 'opsel_hi2', 1) is not None else 1 neg, neg_hi = getattr(inst, 'neg', 0) or 0, getattr(inst, 'neg_hi', 0) or 0 diff --git a/extra/assembly/amd/test/hw/test_vop3p.py b/extra/assembly/amd/test/hw/test_vop3p.py index 185c61f7bb..a2995b0ba4 100644 --- a/extra/assembly/amd/test/hw/test_vop3p.py +++ b/extra/assembly/amd/test/hw/test_vop3p.py @@ -390,6 +390,24 @@ class TestVOP3P(unittest.TestCase): self.assertAlmostEqual(lo, 6.0, places=1) self.assertAlmostEqual(hi, 0.0, places=1) + def test_v_pk_add_u16_float_inline_const_opsel(self): + """V_PK_ADD_U16 with float inline constant 2.0 + Regression test: for integer packed ops, do not perform the f32->f16 conversion. + """ + # src1 = inline float constant 2.0 + instructions = [ + s_mov_b32(s[0], 0x00030005), # packed u16: hi=3, lo=5 + v_mov_b32_e32(v[0], s[0]), + v_pk_add_u16(v[1], v[0], SrcEnum.POS_TWO, opsel_hi=3, opsel_hi2=1), + ] + st = run_program(instructions, n_lanes=1) + result = st.vgpr[0][1] + lo = result & 0xffff + hi = (result >> 16) & 0xffff + # lo = 5 + 0x0000 = 0x0005, hi = 3 + 0x4000 = 0x4003 + self.assertEqual(lo, 0x0005, f"lo: expected 0x0005, got 0x{lo:04x}") + self.assertEqual(hi, 0x4003, f"hi: expected 0x4003, got 0x{hi:04x}") + class TestWMMAF16(unittest.TestCase): """Tests for WMMA F16 output variant (V_WMMA_F16_16X16X16_F16).