mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
PYTHONREMU: VOP3P integer operations with constants don't cast to fp16 (#14546)
* PYTHONREMU: VOP3P integer operations with constants don't cast to fp16 * put that back * cleaner * do that once
This commit is contained in:
committed by
GitHub
parent
2966619834
commit
232848d086
@@ -354,7 +354,7 @@ class _Ctx:
|
||||
offset = reg.cast(dtypes.int) * _c(32, dtypes.int) + lane.cast(dtypes.int)
|
||||
return buf.index(offset, _lane_active(exec_mask, lane)).store(val.cast(dtypes.uint32))
|
||||
|
||||
def rsrc_dyn(self, off: UOp, lane: UOp | None, bits: int = 32, literal: UOp | None = None, is_f64: bool = False) -> UOp:
|
||||
def rsrc_dyn(self, off: UOp, lane: UOp | None, bits: int = 32, literal: UOp | None = None, is_f64: bool = False, do_cast: bool = True) -> UOp:
|
||||
"""Read source operand with dynamic offset. Handles SGPR/inline constants (<256), VGPR (>=256).
|
||||
If lane is None, only scalar access is supported (off must be < 256).
|
||||
is_f64: True for F64 operations where 64-bit literals go in high 32 bits."""
|
||||
@@ -385,7 +385,7 @@ class _Ctx:
|
||||
else:
|
||||
scalar_val = sgpr_lo
|
||||
if literal is not None: scalar_val = off.eq(_c(255)).where(literal, scalar_val)
|
||||
if bits == 16: # Float constants: cast F32 to F16
|
||||
if bits == 16 and do_cast: # Float constants: cast F32 to F16
|
||||
scalar_val = is_float_const.where(scalar_val.bitcast(dtypes.float32).cast(dtypes.half).bitcast(dtypes.uint16).cast(dtypes.uint32), scalar_val)
|
||||
|
||||
return is_vgpr.where(vgpr_val, scalar_val) if lane is not None else scalar_val
|
||||
@@ -800,9 +800,10 @@ def _compile_vop3p(inst: ir3.VOP3P | ir4.VOP3P, ctx: _Ctx) -> UOp:
|
||||
lane = ctx.range()
|
||||
exec_mask = ctx.rsgpr_dyn(_c(EXEC_LO.offset))
|
||||
vdst_reg = ctx.inst_field(type(inst).vdst)
|
||||
src0 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src0), lane, 16)
|
||||
src1 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src1), lane, 16)
|
||||
src2 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src2), lane, 16)
|
||||
do_cast = any(x in op_name for x in ('F16', 'F32', 'BF16')) and 'IU' not in op_name
|
||||
src0 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src0), lane, 16, do_cast=do_cast)
|
||||
src1 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src1), lane, 16, do_cast=do_cast)
|
||||
src2 = ctx.rsrc_dyn(ctx.inst_field(type(inst).src2), lane, 16, do_cast=do_cast)
|
||||
opsel, opsel_hi = getattr(inst, 'opsel', 0) or 0, getattr(inst, 'opsel_hi', 3) if getattr(inst, 'opsel_hi', 3) is not None else 3
|
||||
opsel_hi2 = getattr(inst, 'opsel_hi2', 1) if getattr(inst, 'opsel_hi2', 1) is not None else 1
|
||||
neg, neg_hi = getattr(inst, 'neg', 0) or 0, getattr(inst, 'neg_hi', 0) or 0
|
||||
|
||||
@@ -390,6 +390,24 @@ class TestVOP3P(unittest.TestCase):
|
||||
self.assertAlmostEqual(lo, 6.0, places=1)
|
||||
self.assertAlmostEqual(hi, 0.0, places=1)
|
||||
|
||||
def test_v_pk_add_u16_float_inline_const_opsel(self):
|
||||
"""V_PK_ADD_U16 with float inline constant 2.0
|
||||
Regression test: for integer packed ops, do not perform the f32->f16 conversion.
|
||||
"""
|
||||
# src1 = inline float constant 2.0
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0x00030005), # packed u16: hi=3, lo=5
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_pk_add_u16(v[1], v[0], SrcEnum.POS_TWO, opsel_hi=3, opsel_hi2=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = st.vgpr[0][1]
|
||||
lo = result & 0xffff
|
||||
hi = (result >> 16) & 0xffff
|
||||
# lo = 5 + 0x0000 = 0x0005, hi = 3 + 0x4000 = 0x4003
|
||||
self.assertEqual(lo, 0x0005, f"lo: expected 0x0005, got 0x{lo:04x}")
|
||||
self.assertEqual(hi, 0x4003, f"hi: expected 0x4003, got 0x{hi:04x}")
|
||||
|
||||
|
||||
class TestWMMAF16(unittest.TestCase):
|
||||
"""Tests for WMMA F16 output variant (V_WMMA_F16_16X16X16_F16).
|
||||
|
||||
Reference in New Issue
Block a user