mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
assembly/amd: fix saturation in python remu (#14557)
* PYTHONREMU: failing test for V_SUB_NC_U32_E64 clamp * fix saturation in PYTHON_REMU * simpler * more tests, less lines --------- Co-authored-by: Christopher Milan <chrismilan@ucla.edu>
This commit is contained in:
@@ -135,14 +135,6 @@ def _val_to_u32(val: UOp) -> UOp:
|
||||
if val.dtype in (dtypes.uint16, dtypes.int16): return val.cast(dtypes.uint32)
|
||||
return val.cast(dtypes.uint32)
|
||||
|
||||
def _apply_clamp(val: UOp, clmp: int | UOp) -> UOp:
|
||||
"""Apply VOP3 clamp modifier: clamp float results to [0.0, 1.0] range."""
|
||||
if isinstance(clmp, int) and clmp == 0: return val
|
||||
if val.dtype not in (dtypes.float32, dtypes.half, dtypes.float64): return val
|
||||
zero, one = UOp.const(val.dtype, 0.0), UOp.const(val.dtype, 1.0)
|
||||
clamped = val.maximum(zero).minimum(one)
|
||||
return clmp.ne(_c(0)).where(clamped, val) if isinstance(clmp, UOp) else clamped
|
||||
|
||||
_pcode_fixes = {
|
||||
'V_DIV_FMAS_F32': ('D0.f32 = 2.0F ** 32 * fma(S0.f32, S1.f32, S2.f32)',
|
||||
'D0.f32 = (exponent(S2.f32) > 127) ? (2.0F ** 64 * fma(S0.f32, S1.f32, S2.f32)) : (2.0F ** -64 * fma(S0.f32, S1.f32, S2.f32))'),
|
||||
@@ -445,7 +437,7 @@ class _Ctx:
|
||||
return UOp.sink(*stores, *self.inc_pc())
|
||||
|
||||
def compile_vop_pcode(self, op, srcs: dict[str, UOp], lane: UOp, vdst_reg: UOp, exec_mask: UOp,
|
||||
opsel_dst_hi: bool | UOp = False, sdst_reg: int | None = None, clmp: int | UOp = 0) -> UOp:
|
||||
opsel_dst_hi: bool | UOp = False, sdst_reg: int | None = None, clmp: int = 0) -> UOp:
|
||||
"""Compile VOP instruction. Returns sink with stores and inc_pc."""
|
||||
pcode = get_pcode(op)
|
||||
vcc_reg = sdst_reg if sdst_reg is not None else VCC_LO.offset
|
||||
@@ -454,6 +446,24 @@ class _Ctx:
|
||||
'ROUND_MODE': _c(0), 'ROUND_TOWARD_ZERO': _c(0)}) # rounding mode: 0=RNE, RTZ constant
|
||||
_, assigns = parse_pcode(pcode, srcs)
|
||||
|
||||
# For integer ops with clamp, compute overflow using wide arithmetic
|
||||
# NOTE: MUL_LO ops don't saturate - they always return the low bits
|
||||
int_saturate = None
|
||||
if clmp and any(p in op.name for p in ('_NC_U', '_MAD_U', '_NC_I', '_MAD_I')):
|
||||
is_signed, is_16bit = '_I' in op.name and '_U' not in op.name, '16' in op.name
|
||||
if not (is_16bit and is_signed): # Skip 16-bit signed ops due to codegen issues
|
||||
s0, s1, s2 = srcs.get('S0'), srcs.get('S1'), srcs.get('S2')
|
||||
if s0 is not None and s1 is not None:
|
||||
narrow_dt = dtypes.uint16 if is_16bit else (dtypes.int32 if is_signed else dtypes.uint32)
|
||||
wide_dt = dtypes.int32 if is_16bit else dtypes.int64
|
||||
narrow_max, narrow_min = (0xFFFF, 0) if is_16bit else ((0x7FFFFFFF, -0x80000000) if is_signed else (0xFFFFFFFF, 0))
|
||||
def to_wide(x): return (x.bitcast(narrow_dt) if x.dtype.itemsize == narrow_dt.itemsize else x.cast(narrow_dt)).cast(wide_dt)
|
||||
is_sub, is_mad = 'SUB' in op.name, 'MAD' in op.name
|
||||
full = (to_wide(s0) * to_wide(s1) + to_wide(s2)) if is_mad and s2 is not None else \
|
||||
(to_wide(s1) - to_wide(s0)) if is_sub and 'SUBREV' in op.name else \
|
||||
(to_wide(s0) - to_wide(s1)) if is_sub else (to_wide(s0) + to_wide(s1))
|
||||
int_saturate = full.clamp(narrow_min, narrow_max).cast(narrow_dt)
|
||||
|
||||
raw_stores: list = []
|
||||
vcc_val, exec_val = None, None
|
||||
for dest, val in assigns:
|
||||
@@ -468,7 +478,10 @@ class _Ctx:
|
||||
val.cast(dtypes.uint32) if val.dtype in (dtypes.uint16, dtypes.int16) else val.cast(dtypes.uint32) & UOp.const(dtypes.uint32, slice_mask)
|
||||
raw_stores.append(('vgpr_slice', (lo_bit, width, val_bits)))
|
||||
continue
|
||||
val = _apply_clamp(val, clmp)
|
||||
# For integer ops with clamp, use pre-computed saturated value; for floats, clamp to [0,1]
|
||||
if int_saturate is not None: val = int_saturate
|
||||
elif clmp and val.dtype in (dtypes.float32, dtypes.half, dtypes.float64):
|
||||
val = val.maximum(UOp.const(val.dtype, 0.0)).minimum(UOp.const(val.dtype, 1.0))
|
||||
if val.dtype in (dtypes.uint64, dtypes.int64, dtypes.float64):
|
||||
lo, hi = _split64(val)
|
||||
raw_stores.extend([('vgpr', self.wvgpr_dyn(vdst_reg, lane, lo, exec_mask)), ('vgpr', self.wvgpr_dyn(vdst_reg + _c(1), lane, hi, exec_mask))])
|
||||
@@ -732,6 +745,7 @@ def _compile_vop3sd(inst: ir3.VOP3SD | ir4.VOP3SD, ctx: _Ctx) -> UOp:
|
||||
_, assigns = parse_pcode(pcode, srcs)
|
||||
|
||||
has_per_lane_vcc = any('[laneId]' in dest for dest, _ in assigns if dest.startswith('VCC') or dest.startswith('D0.u64'))
|
||||
clmp = getattr(inst, 'clmp', 0)
|
||||
if has_per_lane_vcc:
|
||||
# VCC computation: RANGE+REDUCE gets axis ID first (lower ID = runs first)
|
||||
# This ensures VCC reads source values BEFORE VGPR stores modify them
|
||||
@@ -743,11 +757,17 @@ def _compile_vop3sd(inst: ir3.VOP3SD | ir4.VOP3SD, ctx: _Ctx) -> UOp:
|
||||
final_vcc = ctx.unroll_lanes(get_vcc_bit, exec_mask)
|
||||
# VGPR stores: RANGE gets axis ID second (higher ID = runs after VCC loop)
|
||||
lane3 = ctx.range()
|
||||
d0_val = None
|
||||
d0_val, vcc_per_lane = None, None
|
||||
for dest, val in parse_pcode(pcode, load_srcs(lane3))[1]:
|
||||
if dest.startswith('D0') and '[laneId]' not in dest: d0_val = val
|
||||
if dest.startswith('VCC') or (dest.startswith('D0.u64') and '[laneId]' in dest): vcc_per_lane = val
|
||||
vgpr_stores = []
|
||||
if d0_val is not None:
|
||||
# Apply clamp using carry/borrow bit: ADD overflow->0xFFFFFFFF, SUB underflow->0
|
||||
if clmp and vcc_per_lane is not None:
|
||||
is_sub = 'SUB' in inst.op.name
|
||||
sat_val = _c(0) if is_sub else _c(0xFFFFFFFF)
|
||||
d0_val = vcc_per_lane.cast(dtypes.bool).where(sat_val, d0_val.cast(dtypes.uint32))
|
||||
if d0_val.dtype in (dtypes.uint64, dtypes.int64, dtypes.float64):
|
||||
lo, hi = _split64(d0_val)
|
||||
vgpr_stores.extend([ctx.wvgpr_dyn(vdst_reg, lane3, lo, exec_mask), ctx.wvgpr_dyn(vdst_reg + _c(1), lane3, hi, exec_mask)])
|
||||
|
||||
@@ -2949,6 +2949,306 @@ class TestVOP3Clamp(unittest.TestCase):
|
||||
self.assertAlmostEqual(i2f(st.vgpr[3][1]), 1.0, places=5, msg="lane 3: 2.5 should clamp to 1.0")
|
||||
|
||||
|
||||
class TestVOP3ClampUint32(unittest.TestCase):
|
||||
"""Tests for VOP3 clamp modifier on unsigned 32-bit integer operations."""
|
||||
|
||||
def test_v_sub_nc_u32_e64_clamp_underflow(self):
|
||||
"""V_SUB_NC_U32_E64 with clamp: 0 - 1 should saturate to 0."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0),
|
||||
v_mov_b32_e32(v[1], 1),
|
||||
v_sub_nc_u32_e64(v[2], v[0], v[1], clmp=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][2], 0, f"expected 0, got 0x{st.vgpr[0][2]:08x}")
|
||||
|
||||
def test_v_sub_nc_u32_e64_clamp_no_underflow(self):
|
||||
"""V_SUB_NC_U32_E64 with clamp: 100 - 50 = 50 (no saturation needed)."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 100),
|
||||
v_mov_b32_e32(v[1], 50),
|
||||
v_sub_nc_u32_e64(v[2], v[0], v[1], clmp=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][2], 50, f"expected 50, got {st.vgpr[0][2]}")
|
||||
|
||||
def test_v_add_nc_u32_e64_clamp_overflow(self):
|
||||
"""V_ADD_NC_U32_E64 with clamp: 0xFFFFFFFF + 1 should saturate to 0xFFFFFFFF."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0xFFFFFFFF),
|
||||
v_mov_b32_e32(v[1], 1),
|
||||
v_add_nc_u32_e64(v[2], v[0], v[1], clmp=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][2], 0xFFFFFFFF, f"expected 0xFFFFFFFF, got 0x{st.vgpr[0][2]:08x}")
|
||||
|
||||
def test_v_add_nc_u32_e64_clamp_no_overflow(self):
|
||||
"""V_ADD_NC_U32_E64 with clamp: 100 + 50 = 150 (no saturation needed)."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 100),
|
||||
v_mov_b32_e32(v[1], 50),
|
||||
v_add_nc_u32_e64(v[2], v[0], v[1], clmp=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][2], 150, f"expected 150, got {st.vgpr[0][2]}")
|
||||
|
||||
|
||||
class TestVOP3ClampUint16(unittest.TestCase):
|
||||
"""Tests for VOP3 clamp modifier on unsigned 16-bit integer operations."""
|
||||
|
||||
def test_v_sub_nc_u16_clamp_underflow(self):
|
||||
"""V_SUB_NC_U16 with clamp: 0 - 1 should saturate to 0."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0),
|
||||
v_mov_b32_e32(v[1], 1),
|
||||
v_sub_nc_u16(v[2], v[0], v[1], clmp=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][2] & 0xFFFF, 0, f"expected 0, got 0x{st.vgpr[0][2] & 0xFFFF:04x}")
|
||||
|
||||
def test_v_sub_nc_u16_clamp_no_underflow(self):
|
||||
"""V_SUB_NC_U16 with clamp: 100 - 50 = 50 (no saturation needed)."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 100),
|
||||
v_mov_b32_e32(v[1], 50),
|
||||
v_sub_nc_u16(v[2], v[0], v[1], clmp=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][2] & 0xFFFF, 50, f"expected 50, got {st.vgpr[0][2] & 0xFFFF}")
|
||||
|
||||
def test_v_add_nc_u16_clamp_overflow(self):
|
||||
"""V_ADD_NC_U16 with clamp: 0xFFFF + 1 should saturate to 0xFFFF."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0xFFFF),
|
||||
v_mov_b32_e32(v[1], 1),
|
||||
v_add_nc_u16(v[2], v[0], v[1], clmp=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][2] & 0xFFFF, 0xFFFF, f"expected 0xFFFF, got 0x{st.vgpr[0][2] & 0xFFFF:04x}")
|
||||
|
||||
def test_v_add_nc_u16_clamp_no_overflow(self):
|
||||
"""V_ADD_NC_U16 with clamp: 100 + 50 = 150 (no saturation needed)."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 100),
|
||||
v_mov_b32_e32(v[1], 50),
|
||||
v_add_nc_u16(v[2], v[0], v[1], clmp=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][2] & 0xFFFF, 150, f"expected 150, got {st.vgpr[0][2] & 0xFFFF}")
|
||||
|
||||
|
||||
class TestVOP3ClampInt32(unittest.TestCase):
|
||||
"""Tests for VOP3 clamp modifier on signed 32-bit integer operations."""
|
||||
|
||||
def test_v_add_nc_i32_clamp_overflow(self):
|
||||
"""V_ADD_NC_I32 with clamp: INT_MAX + 1 should saturate to INT_MAX."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0x7FFFFFFF), # S0 = INT_MAX
|
||||
v_mov_b32_e32(v[1], 1), # S1 = 1
|
||||
v_add_nc_i32(v[2], v[0], v[1], clmp=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][2], 0x7FFFFFFF, f"expected 0x7FFFFFFF, got 0x{st.vgpr[0][2]:08x}")
|
||||
|
||||
def test_v_add_nc_i32_clamp_underflow(self):
|
||||
"""V_ADD_NC_I32 with clamp: INT_MIN + (-1) should saturate to INT_MIN."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0x80000000), # S0 = INT_MIN
|
||||
v_mov_b32_e32(v[1], 0xFFFFFFFF), # S1 = -1
|
||||
v_add_nc_i32(v[2], v[0], v[1], clmp=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][2], 0x80000000, f"expected 0x80000000, got 0x{st.vgpr[0][2]:08x}")
|
||||
|
||||
def test_v_sub_nc_i32_clamp_underflow(self):
|
||||
"""V_SUB_NC_I32 with clamp: INT_MIN - 1 should saturate to INT_MIN."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0x80000000), # S0 = INT_MIN
|
||||
v_mov_b32_e32(v[1], 1), # S1 = 1
|
||||
v_sub_nc_i32(v[2], v[0], v[1], clmp=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][2], 0x80000000, f"expected 0x80000000, got 0x{st.vgpr[0][2]:08x}")
|
||||
|
||||
def test_v_sub_nc_i32_clamp_overflow(self):
|
||||
"""V_SUB_NC_I32 with clamp: INT_MAX - (-1) should saturate to INT_MAX."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0x7FFFFFFF), # S0 = INT_MAX
|
||||
v_mov_b32_e32(v[1], 0xFFFFFFFF), # S1 = -1
|
||||
v_sub_nc_i32(v[2], v[0], v[1], clmp=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][2], 0x7FFFFFFF, f"expected 0x7FFFFFFF, got 0x{st.vgpr[0][2]:08x}")
|
||||
|
||||
def test_v_add_nc_i32_no_saturation_positive(self):
|
||||
"""V_ADD_NC_I32 with clamp: 100 + 200 = 300 (no saturation needed)."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 100),
|
||||
v_mov_b32_e32(v[1], 200),
|
||||
v_add_nc_i32(v[2], v[0], v[1], clmp=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][2], 300, f"expected 300, got {st.vgpr[0][2]}")
|
||||
|
||||
def test_v_add_nc_i32_no_saturation_negative(self):
|
||||
"""V_ADD_NC_I32 with clamp: -100 + -200 = -300 (no saturation needed)."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0xFFFFFF9C), # -100
|
||||
v_mov_b32_e32(v[1], 0xFFFFFF38), # -200
|
||||
v_add_nc_i32(v[2], v[0], v[1], clmp=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
expected = 0xFFFFFED4 # -300
|
||||
self.assertEqual(st.vgpr[0][2], expected, f"expected 0x{expected:08x}, got 0x{st.vgpr[0][2]:08x}")
|
||||
|
||||
|
||||
class TestVOP3ClampCarry(unittest.TestCase):
|
||||
"""Tests for VOP3 clamp modifier on carry operations (VOP3SD)."""
|
||||
|
||||
def test_v_add_co_u32_clamp_overflow(self):
|
||||
"""V_ADD_CO_U32 with clamp: 0xFFFFFFFF + 1 should saturate to 0xFFFFFFFF."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0xFFFFFFFF),
|
||||
v_mov_b32_e32(v[1], 1),
|
||||
v_add_co_u32(v[2], VCC, v[0], v[1], clmp=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][2], 0xFFFFFFFF, f"expected 0xFFFFFFFF, got 0x{st.vgpr[0][2]:08x}")
|
||||
|
||||
def test_v_add_co_u32_clamp_no_overflow(self):
|
||||
"""V_ADD_CO_U32 with clamp: 100 + 200 = 300 (no saturation)."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 100),
|
||||
v_mov_b32_e32(v[1], 200),
|
||||
v_add_co_u32(v[2], VCC, v[0], v[1], clmp=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][2], 300, f"expected 300, got {st.vgpr[0][2]}")
|
||||
|
||||
def test_v_sub_co_u32_clamp_underflow(self):
|
||||
"""V_SUB_CO_U32 with clamp: 0 - 1 should saturate to 0."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0),
|
||||
v_mov_b32_e32(v[1], 1),
|
||||
v_sub_co_u32(v[2], VCC, v[0], v[1], clmp=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][2], 0, f"expected 0, got 0x{st.vgpr[0][2]:08x}")
|
||||
|
||||
def test_v_sub_co_u32_clamp_no_underflow(self):
|
||||
"""V_SUB_CO_U32 with clamp: 300 - 100 = 200 (no saturation)."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 300),
|
||||
v_mov_b32_e32(v[1], 100),
|
||||
v_sub_co_u32(v[2], VCC, v[0], v[1], clmp=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][2], 200, f"expected 200, got {st.vgpr[0][2]}")
|
||||
|
||||
def test_v_subrev_co_u32_clamp_underflow(self):
|
||||
"""V_SUBREV_CO_U32 with clamp: 1 - 0 reversed = 0 - 1 should saturate to 0."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 1), # This becomes the subtrahend
|
||||
v_mov_b32_e32(v[1], 0), # This becomes the minuend (0 - 1)
|
||||
v_subrev_co_u32(v[2], VCC, v[0], v[1], clmp=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][2], 0, f"expected 0, got 0x{st.vgpr[0][2]:08x}")
|
||||
|
||||
def test_v_add_co_ci_u32_clamp_overflow(self):
|
||||
"""V_ADD_CO_CI_U32 with clamp: 0xFFFFFFFF + 1 + 0 should saturate to 0xFFFFFFFF."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0xFFFFFFFF),
|
||||
v_mov_b32_e32(v[1], 1),
|
||||
s_mov_b64(VCC, 0), # No carry in
|
||||
v_add_co_ci_u32(v[2], VCC, v[0], v[1], VCC, clmp=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][2], 0xFFFFFFFF, f"expected 0xFFFFFFFF, got 0x{st.vgpr[0][2]:08x}")
|
||||
|
||||
def test_v_add_co_ci_u32_clamp_overflow_with_carry(self):
|
||||
"""V_ADD_CO_CI_U32 with clamp: 0xFFFFFFFE + 1 + 1 should saturate to 0xFFFFFFFF."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0xFFFFFFFE),
|
||||
v_mov_b32_e32(v[1], 1),
|
||||
s_mov_b64(VCC, 1), # Carry in = 1
|
||||
v_add_co_ci_u32(v[2], VCC, v[0], v[1], VCC, clmp=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][2], 0xFFFFFFFF, f"expected 0xFFFFFFFF, got 0x{st.vgpr[0][2]:08x}")
|
||||
|
||||
def test_v_sub_co_ci_u32_clamp_underflow(self):
|
||||
"""V_SUB_CO_CI_U32 with clamp: 0 - 1 - 0 should saturate to 0."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0),
|
||||
v_mov_b32_e32(v[1], 1),
|
||||
s_mov_b64(VCC, 0), # No borrow in
|
||||
v_sub_co_ci_u32(v[2], VCC, v[0], v[1], VCC, clmp=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][2], 0, f"expected 0, got 0x{st.vgpr[0][2]:08x}")
|
||||
|
||||
def test_v_subrev_co_ci_u32_clamp_underflow(self):
|
||||
"""V_SUBREV_CO_CI_U32 with clamp: reversed 1 - 0 - 0 = 0 - 1 should saturate to 0."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 1),
|
||||
v_mov_b32_e32(v[1], 0),
|
||||
s_mov_b64(VCC, 0),
|
||||
v_subrev_co_ci_u32(v[2], VCC, v[0], v[1], VCC, clmp=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][2], 0, f"expected 0, got 0x{st.vgpr[0][2]:08x}")
|
||||
|
||||
|
||||
class TestVOP3ClampMAD(unittest.TestCase):
|
||||
"""Tests for VOP3 clamp modifier on MAD (multiply-add) operations."""
|
||||
|
||||
def test_v_mad_u16_clamp_overflow(self):
|
||||
"""V_MAD_U16 with clamp: 0xFFFF * 2 + 0 should saturate to 0xFFFF."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0xFFFF),
|
||||
v_mov_b32_e32(v[1], 2),
|
||||
v_mov_b32_e32(v[2], 0),
|
||||
v_mad_u16(v[3], v[0], v[1], v[2], clmp=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][3] & 0xFFFF, 0xFFFF, f"expected 0xFFFF, got 0x{st.vgpr[0][3] & 0xFFFF:04x}")
|
||||
|
||||
def test_v_mad_u16_clamp_overflow_with_add(self):
|
||||
"""V_MAD_U16 with clamp: 0x8000 * 2 + 0x1000 should saturate to 0xFFFF."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0x8000), # 32768
|
||||
v_mov_b32_e32(v[1], 2), # * 2 = 65536
|
||||
v_mov_b32_e32(v[2], 0x1000), # + 4096 = 69632 > 0xFFFF
|
||||
v_mad_u16(v[3], v[0], v[1], v[2], clmp=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][3] & 0xFFFF, 0xFFFF, f"expected 0xFFFF, got 0x{st.vgpr[0][3] & 0xFFFF:04x}")
|
||||
|
||||
def test_v_mad_u16_no_overflow(self):
|
||||
"""V_MAD_U16 with clamp: 100 * 100 + 50 = 10050 (no saturation)."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 100),
|
||||
v_mov_b32_e32(v[1], 100),
|
||||
v_mov_b32_e32(v[2], 50),
|
||||
v_mad_u16(v[3], v[0], v[1], v[2], clmp=1),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vgpr[0][3] & 0xFFFF, 10050, f"expected 10050, got {st.vgpr[0][3] & 0xFFFF}")
|
||||
|
||||
def test_v_mad_u16_no_clamp(self):
|
||||
"""V_MAD_U16 without clamp: 0xFFFF * 2 + 0 should wrap to 0xFFFE."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0xFFFF),
|
||||
v_mov_b32_e32(v[1], 2),
|
||||
v_mov_b32_e32(v[2], 0),
|
||||
v_mad_u16(v[3], v[0], v[1], v[2], clmp=0),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
# 0xFFFF * 2 = 0x1FFFE, low 16 bits = 0xFFFE
|
||||
self.assertEqual(st.vgpr[0][3] & 0xFFFF, 0xFFFE, f"expected 0xFFFE, got 0x{st.vgpr[0][3] & 0xFFFF:04x}")
|
||||
|
||||
|
||||
class TestCvtPkF16(unittest.TestCase):
|
||||
"""Tests for V_CVT_PK_RTZ_F16_F32 - pack two f32 to f16 with round toward zero."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user