PYTHONREMU: VOP3P integer operations with constants don't cast to fp16 (#14546)

* PYTHONREMU: VOP3P integer operations with constants don't cast to fp16

* put that back

* cleaner

* do that once
This commit is contained in:
Christopher Milan
2026-02-04 17:10:59 -08:00
committed by GitHub
parent 2966619834
commit 232848d086
2 changed files with 24 additions and 5 deletions

View File

@@ -354,7 +354,7 @@ class _Ctx:
offset = reg.cast(dtypes.int) * _c(32, dtypes.int) + lane.cast(dtypes.int)
return buf.index(offset, _lane_active(exec_mask, lane)).store(val.cast(dtypes.uint32))
def rsrc_dyn(self, off: UOp, lane: UOp | None, bits: int = 32, literal: UOp | None = None, is_f64: bool = False) -> UOp:
def rsrc_dyn(self, off: UOp, lane: UOp | None, bits: int = 32, literal: UOp | None = None, is_f64: bool = False, do_cast: bool = True) -> UOp:
"""Read source operand with dynamic offset. Handles SGPR/inline constants (<256), VGPR (>=256).
If lane is None, only scalar access is supported (off must be < 256).
is_f64: True for F64 operations where 64-bit literals go in high 32 bits."""
@@ -385,7 +385,7 @@ class _Ctx:
else:
scalar_val = sgpr_lo
if literal is not None: scalar_val = off.eq(_c(255)).where(literal, scalar_val)
if bits == 16: # Float constants: cast F32 to F16
if bits == 16 and do_cast: # Float constants: cast F32 to F16
scalar_val = is_float_const.where(scalar_val.bitcast(dtypes.float32).cast(dtypes.half).bitcast(dtypes.uint16).cast(dtypes.uint32), scalar_val)
return is_vgpr.where(vgpr_val, scalar_val) if lane is not None else scalar_val
@@ -800,9 +800,10 @@ def _compile_vop3p(inst: ir3.VOP3P | ir4.VOP3P, ctx: _Ctx) -> UOp:
lane = ctx.range()
exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset))
vdst_reg = ctx.inst_field(type(inst).vdst)
src0 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src0), lane, 16)
src1 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src1), lane, 16)
src2 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src2), lane, 16)
do_cast = any(x in op_name for x in ('F16', 'F32', 'BF16')) and 'IU' not in op_name
src0 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src0), lane, 16, do_cast=do_cast)
src1 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src1), lane, 16, do_cast=do_cast)
src2 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src2), lane, 16, do_cast=do_cast)
opsel, opsel_hi = getattr(inst, 'opsel', 0) or 0, getattr(inst, 'opsel_hi', 3) if getattr(inst, 'opsel_hi', 3) is not None else 3
opsel_hi2 = getattr(inst, 'opsel_hi2', 1) if getattr(inst, 'opsel_hi2', 1) is not None else 1
neg, neg_hi = getattr(inst, 'neg', 0) or 0, getattr(inst, 'neg_hi', 0) or 0

View File

@@ -390,6 +390,24 @@ class TestVOP3P(unittest.TestCase):
self.assertAlmostEqual(lo, 6.0, places=1)
self.assertAlmostEqual(hi, 0.0, places=1)
def test_v_pk_add_u16_float_inline_const_opsel(self):
"""V_PK_ADD_U16 with float inline constant 2.0
Regression test: for integer packed ops, do not perform the f32->f16 conversion.
"""
# src1 = inline float constant 2.0
instructions = [
s_mov_b32(s[0], 0x00030005), # packed u16: hi=3, lo=5
v_mov_b32_e32(v[0], s[0]),
v_pk_add_u16(v[1], v[0], SrcEnum.POS_TWO, opsel_hi=3, opsel_hi2=1),
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][1]
lo = result & 0xffff
hi = (result >> 16) & 0xffff
# lo = 5 + 0x0000 = 0x0005, hi = 3 + 0x4000 = 0x4003
self.assertEqual(lo, 0x0005, f"lo: expected 0x0005, got 0x{lo:04x}")
self.assertEqual(hi, 0x4003, f"hi: expected 0x4003, got 0x{hi:04x}")
class TestWMMAF16(unittest.TestCase):
"""Tests for WMMA F16 output variant (V_WMMA_F16_16X16X16_F16).