From ff856a74cbcb44308b04c0bd5e0eed2be6d3c271 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 29 Dec 2025 13:20:00 -0500 Subject: [PATCH] minor refactoring for rdna3 (#13873) * minor refactoring for rdna3 * fix div scale stuff * more bugfixes --- extra/assembly/rdna3/autogen/gen_pcode.py | 102 ++++-- extra/assembly/rdna3/emu.py | 47 ++- extra/assembly/rdna3/pcode.py | 83 ++++- extra/assembly/rdna3/test/test_emu.py | 386 +++++++++++++++++++++- 4 files changed, 561 insertions(+), 57 deletions(-) diff --git a/extra/assembly/rdna3/autogen/gen_pcode.py b/extra/assembly/rdna3/autogen/gen_pcode.py index 2d47062d09..c3fda6de19 100644 --- a/extra/assembly/rdna3/autogen/gen_pcode.py +++ b/extra/assembly/rdna3/autogen/gen_pcode.py @@ -92,7 +92,7 @@ def _SOP1Op_S_CTZ_I32_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG tmp = Reg(-1) for i in range(0, int(31)+1): if S0.u32[i] == 1: - tmp = Reg(i) + tmp = Reg(i); break D0.i32 = tmp # --- end pseudocode --- result = {'d0': D0._val, 'scc': scc & 1} @@ -115,7 +115,7 @@ def _SOP1Op_S_CTZ_I32_B64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG tmp = Reg(-1) for i in range(0, int(63)+1): if S0.u64[i] == 1: - tmp = Reg(i) + tmp = Reg(i); break D0.i32 = tmp # --- end pseudocode --- result = {'d0': D0._val, 'scc': scc & 1} @@ -138,7 +138,7 @@ def _SOP1Op_S_CLZ_I32_U32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG tmp = Reg(-1) for i in range(0, int(31)+1): if S0.u32[31 - i] == 1: - tmp = Reg(i) + tmp = Reg(i); break D0.i32 = tmp # --- end pseudocode --- result = {'d0': D0._val, 'scc': scc & 1} @@ -161,7 +161,7 @@ def _SOP1Op_S_CLZ_I32_U64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG tmp = Reg(-1) for i in range(0, int(63)+1): if S0.u64[63 - i] == 1: - tmp = Reg(i) + tmp = Reg(i); break D0.i32 = tmp # --- end pseudocode --- result = {'d0': D0._val, 'scc': scc & 1} @@ -3746,7 +3746,7 @@ def _VOP1Op_V_CLZ_I32_U32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG D0.i32 = -1 for i in range(0, int(31)+1): if S0.u32[31 - i] == 1: - D0.i32 = i; break # Stop at first 1 bit found + D0.i32 = i; break # --- end pseudocode --- result = {'d0': D0._val, 'scc': scc & 1} return result @@ -3766,7 +3766,7 @@ def _VOP1Op_V_CTZ_I32_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG D0.i32 = -1 for i in range(0, int(31)+1): if S0.u32[i] == 1: - D0.i32 = i; break # Stop at first 1 bit found + D0.i32 = i; break # --- end pseudocode --- result = {'d0': D0._val, 'scc': scc & 1} return result @@ -5588,7 +5588,7 @@ def _VOP3Op_V_CLZ_I32_U32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG D0.i32 = -1 for i in range(0, int(31)+1): if S0.u32[31 - i] == 1: - D0.i32 = i; break # Stop at first 1 bit found + D0.i32 = i; break # --- end pseudocode --- result = {'d0': D0._val, 'scc': scc & 1} return result @@ -5608,7 +5608,7 @@ def _VOP3Op_V_CTZ_I32_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG D0.i32 = -1 for i in range(0, int(31)+1): if S0.u32[i] == 1: - D0.i32 = i; break # Stop at first 1 bit found + D0.i32 = i; break # --- end pseudocode --- result = {'d0': D0._val, 'scc': scc & 1} return result @@ -7207,7 +7207,7 @@ def _VOP3Op_V_DIV_FIXUP_F32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, elif exponent(S1.f32) == 255: D0.f32 = ((-OVERFLOW_F32) if (sign_out) else (OVERFLOW_F32)) else: - D0.f32 = ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32))) + D0.f32 = ((-OVERFLOW_F32) if (sign_out) else (OVERFLOW_F32)) if isNAN(S0.f32) else ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32))) # --- end pseudocode --- result = {'d0': D0._val, 'scc': scc & 1} return result @@ -7260,7 +7260,7 @@ def _VOP3Op_V_DIV_FIXUP_F64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, elif exponent(S1.f64) == 2047: D0.f64 = ((-OVERFLOW_F64) if (sign_out) else (OVERFLOW_F64)) else: - D0.f64 = ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64))) + D0.f64 = ((-OVERFLOW_F64) if (sign_out) else (OVERFLOW_F64)) if isNAN(S0.f64) else ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64))) # --- end pseudocode --- result = {'d0': D0._val, 'scc': scc & 1} result['d0_64'] = True @@ -7280,7 +7280,7 @@ def _VOP3Op_V_DIV_FMAS_F32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, V laneId = lane # --- compiled pseudocode --- if VCC.u64[laneId]: - D0.f32 = 2.0 ** 32 * fma(S0.f32, S1.f32, S2.f32) + D0.f32 = (2.0 ** 64 if exponent(S2.f32) > 127 else 2.0 ** -64) * fma(S0.f32, S1.f32, S2.f32) else: D0.f32 = fma(S0.f32, S1.f32, S2.f32) # --- end pseudocode --- @@ -7302,7 +7302,7 @@ def _VOP3Op_V_DIV_FMAS_F64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, V laneId = lane # --- compiled pseudocode --- if VCC.u64[laneId]: - D0.f64 = 2.0 ** 64 * fma(S0.f64, S1.f64, S2.f64) + D0.f64 = (2.0 ** 128 if exponent(S2.f64) > 1023 else 2.0 ** -128) * fma(S0.f64, S1.f64, S2.f64) else: D0.f64 = fma(S0.f64, S1.f64, S2.f64) # --- end pseudocode --- @@ -8736,13 +8736,13 @@ def _VOP3SDOp_V_DIV_SCALE_F32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal # --- compiled pseudocode --- VCC = Reg(0x0) if ((F(S2.f32) == 0.0) or (F(S1.f32) == 0.0)): - D0.f32 = float("nan") + VCC = Reg(0x1); D0.f32 = float("nan") elif exponent(S2.f32) - exponent(S1.f32) >= 96: VCC = Reg(0x1) if S0.f32 == S1.f32: D0.f32 = ldexp(S0.f32, 64) - elif S1.f32 == DENORM.f32: - D0.f32 = ldexp(S0.f32, 64) + elif False: + pass # denorm check moved to end elif ((1.0 / F(S1.f32) == DENORM.f64) and (S2.f32 / S1.f32 == DENORM.f32)): VCC = Reg(0x1) if S0.f32 == S1.f32: @@ -8751,10 +8751,10 @@ def _VOP3SDOp_V_DIV_SCALE_F32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal D0.f32 = ldexp(S0.f32, -64) elif S2.f32 / S1.f32 == DENORM.f32: VCC = Reg(0x1) - if S0.f32 == S2.f32: - D0.f32 = ldexp(S0.f32, 64) elif exponent(S2.f32) <= 23: - D0.f32 = ldexp(S0.f32, 64) + VCC = Reg(0x1); D0.f32 = ldexp(S0.f32, 64) + if S1.f32 == DENORM.f32: + D0.f32 = float("nan") # --- end pseudocode --- result = {'d0': D0._val, 'scc': scc & 1} result['vcc_lane'] = (VCC._val >> lane) & 1 @@ -8799,13 +8799,13 @@ def _VOP3SDOp_V_DIV_SCALE_F64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal # --- compiled pseudocode --- VCC = Reg(0x0) if ((S2.f64 == 0.0) or (S1.f64 == 0.0)): - D0.f64 = float("nan") + VCC = Reg(0x1); D0.f64 = float("nan") elif exponent(S2.f64) - exponent(S1.f64) >= 768: VCC = Reg(0x1) if S0.f64 == S1.f64: D0.f64 = ldexp(S0.f64, 128) - elif S1.f64 == DENORM.f64: - D0.f64 = ldexp(S0.f64, 128) + elif False: + pass # denorm check moved to end elif ((1.0 / S1.f64 == DENORM.f64) and (S2.f64 / S1.f64 == DENORM.f64)): VCC = Reg(0x1) if S0.f64 == S1.f64: @@ -8814,10 +8814,10 @@ def _VOP3SDOp_V_DIV_SCALE_F64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal D0.f64 = ldexp(S0.f64, -128) elif S2.f64 / S1.f64 == DENORM.f64: VCC = Reg(0x1) - if S0.f64 == S2.f64: - D0.f64 = ldexp(S0.f64, 128) elif exponent(S2.f64) <= 53: D0.f64 = ldexp(S0.f64, 128) + if S1.f64 == DENORM.f64: + D0.f64 = float("nan") # --- end pseudocode --- result = {'d0': D0._val, 'scc': scc & 1} result['vcc_lane'] = (VCC._val >> lane) & 1 @@ -9258,6 +9258,60 @@ def _VOP3POp_V_DOT2_F32_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, result = {'d0': D0._val, 'scc': scc & 1} return result +def _VOP3POp_V_DOT4_U32_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # tmp = S2.u32; + # tmp += u8_to_u32(S0[7 : 0].u8) * u8_to_u32(S1[7 : 0].u8); + # tmp += u8_to_u32(S0[15 : 8].u8) * u8_to_u32(S1[15 : 8].u8); + # tmp += u8_to_u32(S0[23 : 16].u8) * u8_to_u32(S1[23 : 16].u8); + # tmp += u8_to_u32(S0[31 : 24].u8) * u8_to_u32(S1[31 : 24].u8); + # D0.u32 = tmp + S0 = Reg(s0) + S1 = Reg(s1) + S2 = Reg(s2) + D0 = Reg(d0) + tmp = Reg(0) + # --- compiled pseudocode --- + tmp = Reg(S2.u32) + tmp += u8_to_u32(S0[7 : 0].u8) * u8_to_u32(S1[7 : 0].u8) + tmp += u8_to_u32(S0[15 : 8].u8) * u8_to_u32(S1[15 : 8].u8) + tmp += u8_to_u32(S0[23 : 16].u8) * u8_to_u32(S1[23 : 16].u8) + tmp += u8_to_u32(S0[31 : 24].u8) * u8_to_u32(S1[31 : 24].u8) + D0.u32 = tmp + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': scc & 1} + return result + +def _VOP3POp_V_DOT8_U32_U4(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): + # tmp = S2.u32; + # tmp += u4_to_u32(S0[3 : 0].u4) * u4_to_u32(S1[3 : 0].u4); + # tmp += u4_to_u32(S0[7 : 4].u4) * u4_to_u32(S1[7 : 4].u4); + # tmp += u4_to_u32(S0[11 : 8].u4) * u4_to_u32(S1[11 : 8].u4); + # tmp += u4_to_u32(S0[15 : 12].u4) * u4_to_u32(S1[15 : 12].u4); + # tmp += u4_to_u32(S0[19 : 16].u4) * u4_to_u32(S1[19 : 16].u4); + # tmp += u4_to_u32(S0[23 : 20].u4) * u4_to_u32(S1[23 : 20].u4); + # tmp += u4_to_u32(S0[27 : 24].u4) * u4_to_u32(S1[27 : 24].u4); + # tmp += u4_to_u32(S0[31 : 28].u4) * u4_to_u32(S1[31 : 28].u4); + # D0.u32 = tmp + S0 = Reg(s0) + S1 = Reg(s1) + S2 = Reg(s2) + D0 = Reg(d0) + tmp = Reg(0) + # --- compiled pseudocode --- + tmp = Reg(S2.u32) + tmp += u4_to_u32(S0[3 : 0].u4) * u4_to_u32(S1[3 : 0].u4) + tmp += u4_to_u32(S0[7 : 4].u4) * u4_to_u32(S1[7 : 4].u4) + tmp += u4_to_u32(S0[11 : 8].u4) * u4_to_u32(S1[11 : 8].u4) + tmp += u4_to_u32(S0[15 : 12].u4) * u4_to_u32(S1[15 : 12].u4) + tmp += u4_to_u32(S0[19 : 16].u4) * u4_to_u32(S1[19 : 16].u4) + tmp += u4_to_u32(S0[23 : 20].u4) * u4_to_u32(S1[23 : 20].u4) + tmp += u4_to_u32(S0[27 : 24].u4) * u4_to_u32(S1[27 : 24].u4) + tmp += u4_to_u32(S0[31 : 28].u4) * u4_to_u32(S1[31 : 28].u4) + D0.u32 = tmp + # --- end pseudocode --- + result = {'d0': D0._val, 'scc': scc & 1} + return result + VOP3POp_FUNCTIONS = { VOP3POp.V_PK_MAD_I16: _VOP3POp_V_PK_MAD_I16, VOP3POp.V_PK_MUL_LO_U16: _VOP3POp_V_PK_MUL_LO_U16, @@ -9279,6 +9333,8 @@ VOP3POp_FUNCTIONS = { VOP3POp.V_PK_MIN_F16: _VOP3POp_V_PK_MIN_F16, VOP3POp.V_PK_MAX_F16: _VOP3POp_V_PK_MAX_F16, VOP3POp.V_DOT2_F32_F16: _VOP3POp_V_DOT2_F32_F16, + VOP3POp.V_DOT4_U32_U8: _VOP3POp_V_DOT4_U32_U8, + VOP3POp.V_DOT8_U32_U4: _VOP3POp_V_DOT8_U32_U4, } def _VOPCOp_V_CMP_F_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): diff --git a/extra/assembly/rdna3/emu.py b/extra/assembly/rdna3/emu.py index ef58d8c431..75d62da802 100644 --- a/extra/assembly/rdna3/emu.py +++ b/extra/assembly/rdna3/emu.py @@ -21,6 +21,7 @@ _VOP3_64BIT_OPS_32BIT_SRC1 = {VOP3Op.V_LDEXP_F64.value} # Ops with 16-bit types in name (for source/dest handling) _VOP3_16BIT_OPS = {op for op in VOP3Op if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16'))} _VOP1_16BIT_OPS = {op for op in VOP1Op if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16'))} +_VOP2_16BIT_OPS = {op for op in VOP2Op if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16'))} # CVT ops with 32/64-bit source (despite 16-bit in name) _CVT_32_64_SRC_OPS = {op for op in VOP3Op if op.name.startswith('V_CVT_') and op.name.endswith(('_F32', '_I32', '_U32', '_F64', '_I64', '_U64'))} | \ {op for op in VOP1Op if op.name.startswith('V_CVT_') and op.name.endswith(('_F32', '_I32', '_U32', '_F64', '_I64', '_U64'))} @@ -28,34 +29,17 @@ _CVT_32_64_SRC_OPS = {op for op in VOP3Op if op.name.startswith('V_CVT_') and op _VOP3_16BIT_DST_OPS = {op for op in _VOP3_16BIT_OPS if 'PACK' not in op.name} _VOP1_16BIT_DST_OPS = {op for op in _VOP1_16BIT_OPS if 'PACK' not in op.name} -# Inline constants for src operands 128-254 (f32 format for most instructions) -_INLINE_CONSTS = [0] * 127 -for _i in range(65): _INLINE_CONSTS[_i] = _i -for _i in range(1, 17): _INLINE_CONSTS[64 + _i] = ((-_i) & 0xffffffff) -for _k, _v in {SrcEnum.POS_HALF: 0x3f000000, SrcEnum.NEG_HALF: 0xbf000000, SrcEnum.POS_ONE: 0x3f800000, SrcEnum.NEG_ONE: 0xbf800000, - SrcEnum.POS_TWO: 0x40000000, SrcEnum.NEG_TWO: 0xc0000000, SrcEnum.POS_FOUR: 0x40800000, SrcEnum.NEG_FOUR: 0xc0800000, - SrcEnum.INV_2PI: 0x3e22f983}.items(): _INLINE_CONSTS[_k - 128] = _v - -# Inline constants for VOP3P packed f16 operations (f16 value in low 16 bits only, high 16 bits are 0) -# Hardware does NOT replicate the constant - opsel_hi controls which half is used for the hi result -_INLINE_CONSTS_F16 = [0] * 127 -for _i in range(65): _INLINE_CONSTS_F16[_i] = _i # Integer constants in low 16 bits only -for _i in range(1, 17): _INLINE_CONSTS_F16[64 + _i] = (-_i) & 0xffff # Negative integers in low 16 bits -for _k, _v in {SrcEnum.POS_HALF: 0x3800, SrcEnum.NEG_HALF: 0xb800, SrcEnum.POS_ONE: 0x3c00, SrcEnum.NEG_ONE: 0xbc00, - SrcEnum.POS_TWO: 0x4000, SrcEnum.NEG_TWO: 0xc000, SrcEnum.POS_FOUR: 0x4400, SrcEnum.NEG_FOUR: 0xc400, - SrcEnum.INV_2PI: 0x3118}.items(): _INLINE_CONSTS_F16[_k - 128] = _v # f16 values in low 16 bits - -# Inline constants for 64-bit operations (f64 format) -# Integer constants 0-64 are zero-extended to 64 bits; -1 to -16 are sign-extended -# Float constants are the f64 representation of the value +# Inline constants for src operands 128-254. Build tables for f32, f16, and f64 formats. import struct as _struct -_INLINE_CONSTS_F64 = [0] * 127 -for _i in range(65): _INLINE_CONSTS_F64[_i] = _i # Integer constants 0-64 zero-extended -for _i in range(1, 17): _INLINE_CONSTS_F64[64 + _i] = ((-_i) & 0xffffffffffffffff) # -1 to -16 sign-extended -for _k, _v in {SrcEnum.POS_HALF: 0.5, SrcEnum.NEG_HALF: -0.5, SrcEnum.POS_ONE: 1.0, SrcEnum.NEG_ONE: -1.0, - SrcEnum.POS_TWO: 2.0, SrcEnum.NEG_TWO: -2.0, SrcEnum.POS_FOUR: 4.0, SrcEnum.NEG_FOUR: -4.0, - SrcEnum.INV_2PI: 0.15915494309189535}.items(): - _INLINE_CONSTS_F64[_k - 128] = _struct.unpack('> 16) & 0xffff) if (opsel & 1) else (s0_raw & 0xffff) s1 = ((s1_raw >> 16) & 0xffff) if (opsel & 2) else (s1_raw & 0xffff) s2 = ((s2_raw >> 16) & 0xffff) if (opsel & 4) else (s2_raw & 0xffff) + elif is_vop2_16bit: + # VOP2 16-bit ops: src0 can use f16 inline constants, vsrc1 is always a VGPR (no inline constants) + s0 = mod_src(st.rsrc_f16(src0, lane), 0) + s1 = mod_src(st.rsrc(src1, lane), 1) if src1 is not None else 0 + s2 = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0 else: s0 = mod_src(st.rsrc(src0, lane), 0) s1 = mod_src(st.rsrc(src1, lane), 1) if src1 is not None else 0 diff --git a/extra/assembly/rdna3/pcode.py b/extra/assembly/rdna3/pcode.py index 5bbd971f18..6877a36ccf 100644 --- a/extra/assembly/rdna3/pcode.py +++ b/extra/assembly/rdna3/pcode.py @@ -102,6 +102,8 @@ def i16_to_f16(v): return f32_to_f16(float(_sext(int(v) & 0xffff, 16))) def u16_to_f16(v): return f32_to_f16(float(int(v) & 0xffff)) def f16_to_i16(bits): f = _f16_to_f32_bits(bits); return max(-32768, min(32767, int(f))) if not math.isnan(f) else 0 def f16_to_u16(bits): f = _f16_to_f32_bits(bits); return max(0, min(65535, int(f))) if not math.isnan(f) else 0 +def u8_to_u32(v): return int(v) & 0xff +def u4_to_u32(v): return int(v) & 0xf def _sign(f): return 1 if math.copysign(1.0, f) < 0 else 0 def _mantissa_f32(f): return struct.unpack(">= 1; n += 1 return n def _exponent(f): + # Handle TypedView (f16/f32/f64) to get correct exponent for that type + if hasattr(f, '_bits') and hasattr(f, '_float') and f._float: + raw = f._val + if f._bits == 16: return (raw >> 10) & 0x1f # f16: 5-bit exponent + if f._bits == 32: return (raw >> 23) & 0xff # f32: 8-bit exponent + if f._bits == 64: return (raw >> 52) & 0x7ff # f64: 11-bit exponent + # Fallback: convert to f32 and get exponent + f = float(f) if math.isinf(f) or math.isnan(f): return 255 if f == 0.0: return 0 - try: bits = struct.unpack("> 23) & 0xff + try: bits = struct.unpack("> 23) & 0xff except: return 0 def _is_denorm_f32(f): if not isinstance(f, float): f = _f32(int(f) & 0xffffffff) @@ -229,7 +239,7 @@ __all__ = [ 'f32_to_i32', 'f32_to_u32', 'f64_to_i32', 'f64_to_u32', 'f32_to_f16', 'f16_to_f32', 'i16_to_f16', 'u16_to_f16', 'f16_to_i16', 'f16_to_u16', 'u32_to_u16', 'i32_to_i16', 'f16_to_snorm', 'f16_to_unorm', 'f32_to_snorm', 'f32_to_unorm', 'v_cvt_i16_f32', 'v_cvt_u16_f32', - 'SAT8', 'f32_to_u8', + 'SAT8', 'f32_to_u8', 'u8_to_u32', 'u4_to_u32', # Math functions 'trunc', 'floor', 'ceil', 'sqrt', 'log2', 'sin', 'cos', 'pow', 'fract', 'isEven', 'mantissa', # Min/max functions @@ -698,7 +708,7 @@ INST_PATTERN = re.compile(r'^([SV]_[A-Z0-9_]+)\s+(\d+)\s*$', re.M) # Patterns that can't be handled by the DSL (require special handling in emu.py) UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'BYTE_PERMUTE', 'FATAL_HALT', 'HW_REGISTERS', 'PC =', 'PC=', 'PC+', '= PC', 'v_sad', '+:', 'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt', - 'CVT_OFF_TABLE', '.bf16', 'ThreadMask', 'u8_to_u32', 'u4_to_u32', + 'CVT_OFF_TABLE', '.bf16', 'ThreadMask', 'S1[i', 'C.i32', 'v_msad_u8', 'S[i]', 'in[', '2.0 / PI', 'if n.', 'DST.u32', 'addrd = DST', 'addr = DST'] # Malformed pseudocode from PDF @@ -797,9 +807,72 @@ from extra.assembly.rdna3.pcode import * try: code = compile_pseudocode(pc) # CLZ/CTZ: The PDF pseudocode searches for the first 1 bit but doesn't break. - # Hardware stops at first match, so we need to add break after D0.i32 = i + # Hardware stops at first match. SOP1 uses tmp=i, VOP1/VOP3 use D0.i32=i if 'CLZ' in op.name or 'CTZ' in op.name: - code = code.replace('D0.i32 = i', 'D0.i32 = i; break # Stop at first 1 bit found') + code = code.replace('tmp = Reg(i)', 'tmp = Reg(i); break') + code = code.replace('D0.i32 = i', 'D0.i32 = i; break') + # V_DIV_FMAS_F32/F64: PDF page 449 says 2^32/2^64 but hardware behavior is more complex. + # The scale direction depends on S2 (the addend): if exponent(S2) > 127 (i.e., S2 >= 2.0), + # scale by 2^+64 (to unscale a numerator that was scaled). Otherwise scale by 2^-64 + # (to unscale a denominator that was scaled). + if op.name == 'V_DIV_FMAS_F32': + code = code.replace( + 'D0.f32 = 2.0 ** 32 * fma(S0.f32, S1.f32, S2.f32)', + 'D0.f32 = (2.0 ** 64 if exponent(S2.f32) > 127 else 2.0 ** -64) * fma(S0.f32, S1.f32, S2.f32)') + if op.name == 'V_DIV_FMAS_F64': + code = code.replace( + 'D0.f64 = 2.0 ** 64 * fma(S0.f64, S1.f64, S2.f64)', + 'D0.f64 = (2.0 ** 128 if exponent(S2.f64) > 1023 else 2.0 ** -128) * fma(S0.f64, S1.f64, S2.f64)') + # V_DIV_SCALE_F32/F64: PDF page 463-464 has several bugs vs hardware behavior: + # 1. Zero case: hardware sets VCC=1 (PDF doesn't) + # 2. Denorm denom: hardware returns NaN (PDF says scale). VCC is set independently by exp diff check. + # 3. Tiny numer (exp<=23): hardware sets VCC=1 (PDF doesn't) + # 4. Result would be denorm: hardware doesn't scale, just sets VCC=1 + if op.name == 'V_DIV_SCALE_F32': + # Fix 1: Set VCC=1 when zero operands produce NaN + code = code.replace( + 'D0.f32 = float("nan")', + 'VCC = Reg(0x1); D0.f32 = float("nan")') + # Fix 2: Denorm denom returns NaN. Must check this AFTER all VCC-setting logic runs. + # Insert at end of all branches, before the final result is used + code = code.replace( + 'elif S1.f32 == DENORM.f32:\n D0.f32 = ldexp(S0.f32, 64)', + 'elif False:\n pass # denorm check moved to end') + # Add denorm check at the very end - this overrides D0 but preserves VCC + code += '\nif S1.f32 == DENORM.f32:\n D0.f32 = float("nan")' + # Fix 3: Tiny numer should set VCC=1 + code = code.replace( + 'elif exponent(S2.f32) <= 23:\n D0.f32 = ldexp(S0.f32, 64)', + 'elif exponent(S2.f32) <= 23:\n VCC = Reg(0x1); D0.f32 = ldexp(S0.f32, 64)') + # Fix 4: S2/S1 would be denorm - don't scale, just set VCC + code = code.replace( + 'elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)\n if S0.f32 == S2.f32:\n D0.f32 = ldexp(S0.f32, 64)', + 'elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)') + if op.name == 'V_DIV_SCALE_F64': + # Same fixes for f64 version + code = code.replace( + 'D0.f64 = float("nan")', + 'VCC = Reg(0x1); D0.f64 = float("nan")') + code = code.replace( + 'elif S1.f64 == DENORM.f64:\n D0.f64 = ldexp(S0.f64, 128)', + 'elif False:\n pass # denorm check moved to end') + code += '\nif S1.f64 == DENORM.f64:\n D0.f64 = float("nan")' + code = code.replace( + 'elif exponent(S2.f64) <= 52:\n D0.f64 = ldexp(S0.f64, 128)', + 'elif exponent(S2.f64) <= 52:\n VCC = Reg(0x1); D0.f64 = ldexp(S0.f64, 128)') + code = code.replace( + 'elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)\n if S0.f64 == S2.f64:\n D0.f64 = ldexp(S0.f64, 128)', + 'elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)') + # V_DIV_FIXUP_F32/F64: PDF doesn't check isNAN(S0), but hardware returns OVERFLOW if S0 is NaN. + # When division fails (e.g., due to denorm denom), S0 becomes NaN, and fixup should return ±inf. + if op.name == 'V_DIV_FIXUP_F32': + code = code.replace( + 'D0.f32 = ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))', + 'D0.f32 = ((-OVERFLOW_F32) if (sign_out) else (OVERFLOW_F32)) if isNAN(S0.f32) else ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))') + if op.name == 'V_DIV_FIXUP_F64': + code = code.replace( + 'D0.f64 = ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))', + 'D0.f64 = ((-OVERFLOW_F64) if (sign_out) else (OVERFLOW_F64)) if isNAN(S0.f64) else ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))') # Detect flags for result handling is_64 = any(p in pc for p in ['D0.u64', 'D0.b64', 'D0.f64', 'D0.i64', 'D1.u64', 'D1.b64', 'D1.f64', 'D1.i64']) has_d1 = '{ D1' in pc diff --git a/extra/assembly/rdna3/test/test_emu.py b/extra/assembly/rdna3/test/test_emu.py index d7b3c45e24..98095d0fde 100644 --- a/extra/assembly/rdna3/test/test_emu.py +++ b/extra/assembly/rdna3/test/test_emu.py @@ -225,7 +225,21 @@ def run_program(instructions: list, n_lanes: int = 1) -> WaveState: class TestVDivScale(unittest.TestCase): - """Tests for V_DIV_SCALE_F32 VCC handling.""" + """Tests for V_DIV_SCALE_F32 edge cases. + + V_DIV_SCALE_F32 is used in the Newton-Raphson division sequence to handle + denormals and near-overflow cases. It scales operands and sets VCC when + the final result needs to be unscaled. + + Pseudocode cases: + 1. Zero operands -> NaN + 2. exp(S2) - exp(S1) >= 96 -> scale denom, VCC=1 + 3. S1 is denorm -> scale by 2^64 + 4. 1/S1 is f64 denorm AND S2/S1 is f32 denorm -> scale denom, VCC=1 + 5. 1/S1 is f64 denorm -> scale by 2^-64 + 6. S2/S1 is f32 denorm -> scale numer, VCC=1 + 7. exp(S2) <= 23 -> scale by 2^64 (tiny numerator) + """ def test_div_scale_f32_vcc_zero_single_lane(self): """V_DIV_SCALE_F32 sets VCC=0 when no scaling needed.""" @@ -257,6 +271,376 @@ class TestVDivScale(unittest.TestCase): st = run_program(instructions, n_lanes=1) self.assertAlmostEqual(i2f(st.vgpr[0][2]), 2.0, places=5) + def test_div_scale_f32_zero_denom_gives_nan(self): + """V_DIV_SCALE_F32: zero denominator -> NaN, VCC=1.""" + instructions = [ + v_mov_b32_e32(v[0], 1.0), # numerator + v_mov_b32_e32(v[1], 0.0), # denominator = 0 + v_div_scale_f32(v[2], VCC, v[0], v[1], v[0]), + ] + st = run_program(instructions, n_lanes=1) + import math + self.assertTrue(math.isnan(i2f(st.vgpr[0][2])), "Should be NaN for zero denom") + self.assertEqual(st.vcc & 1, 1, "VCC should be 1 for zero denom") + + def test_div_scale_f32_zero_numer_gives_nan(self): + """V_DIV_SCALE_F32: zero numerator -> NaN, VCC=1.""" + instructions = [ + v_mov_b32_e32(v[0], 0.0), # numerator = 0 + v_mov_b32_e32(v[1], 1.0), # denominator + v_div_scale_f32(v[2], VCC, v[0], v[1], v[0]), + ] + st = run_program(instructions, n_lanes=1) + import math + self.assertTrue(math.isnan(i2f(st.vgpr[0][2])), "Should be NaN for zero numer") + self.assertEqual(st.vcc & 1, 1, "VCC should be 1 for zero numer") + + def test_div_scale_f32_large_exp_diff_scales_denom(self): + """V_DIV_SCALE_F32: exp(numer) - exp(denom) >= 96 -> scale denom, VCC=1.""" + # Need exp difference >= 96. Use MAX_FLOAT / tiny_normal + # MAX_FLOAT exp=254, tiny_normal with exp <= 254-96=158 + # Let's use exp=127 (1.0) for denom, exp=254 for numer -> diff = 127 (>96) + max_float = 0x7f7fffff # 3.4028235e+38, exp=254 + instructions = [ + s_mov_b32(s[0], max_float), + v_mov_b32_e32(v[0], s[0]), # numer = MAX_FLOAT (S2) + v_mov_b32_e32(v[1], 1.0), # denom = 1.0 (S1), exp=127. diff = 254-127 = 127 >= 96 + # S0=denom (what we're scaling), S1=denom, S2=numer + v_div_scale_f32(v[2], VCC, v[1], v[1], v[0]), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(st.vcc & 1, 1, "VCC should be 1 when scaling denom for large exp diff") + # Result should be denom * 2^64 + expected = 1.0 * (2.0 ** 64) + self.assertAlmostEqual(i2f(st.vgpr[0][2]), expected, delta=expected * 1e-6) + + def test_div_scale_f32_denorm_denom(self): + """V_DIV_SCALE_F32: denormalized denominator -> NaN, VCC=1. + + Hardware returns NaN when denominator is denormalized (different from PDF pseudocode). + """ + # Smallest positive denorm: 0x00000001 = 1.4e-45 + denorm = 0x00000001 + instructions = [ + s_mov_b32(s[0], denorm), + v_mov_b32_e32(v[0], 1.0), # numer = 1.0 (S2) + v_mov_b32_e32(v[1], s[0]), # denom = denorm (S1) + # S0=denom, S1=denom, S2=numer -> scale denom + v_div_scale_f32(v[2], VCC, v[1], v[1], v[0]), + ] + st = run_program(instructions, n_lanes=1) + import math + self.assertTrue(math.isnan(i2f(st.vgpr[0][2])), "Hardware returns NaN for denorm denom") + self.assertEqual(st.vcc & 1, 1, "VCC should be 1 for denorm denom") + + def test_div_scale_f32_tiny_numer_exp_le_23(self): + """V_DIV_SCALE_F32: exponent(numer) <= 23 -> scale by 2^64, VCC=1.""" + # exp <= 23 means exponent field is 0..23 + # exp=23 corresponds to float value around 2^(23-127) = 2^-104 ≈ 4.9e-32 + # Use exp=1 (smallest normal), which is 2^(1-127) = 2^-126 ≈ 1.18e-38 + smallest_normal = 0x00800000 # exp=1, mantissa=0 + instructions = [ + s_mov_b32(s[0], smallest_normal), + v_mov_b32_e32(v[0], s[0]), # numer = smallest_normal (S2), exp=1 <= 23 + v_mov_b32_e32(v[1], 1.0), # denom = 1.0 (S1) + # S0=numer, S1=denom, S2=numer -> scale numer + v_div_scale_f32(v[2], VCC, v[0], v[1], v[0]), + ] + st = run_program(instructions, n_lanes=1) + # Numer scaled by 2^64, VCC=1 to indicate scaling was done + numer_f = i2f(smallest_normal) + expected = numer_f * (2.0 ** 64) + self.assertAlmostEqual(i2f(st.vgpr[0][2]), expected, delta=abs(expected) * 1e-5) + self.assertEqual(st.vcc & 1, 1, "VCC should be 1 when scaling tiny numer") + + def test_div_scale_f32_result_would_be_denorm(self): + """V_DIV_SCALE_F32: result would be denorm -> no scaling applied, VCC=1. + + When the result of numer/denom would be denormalized, hardware sets VCC=1 + but does NOT scale the input (returns it unchanged). The scaling happens + elsewhere in the division sequence. + """ + # If S2/S1 would be denorm, set VCC but don't scale + # Denorm result: exp < 1, i.e., |result| < 2^-126 + # Use 1.0 / 2^127 ≈ 5.9e-39 (result would be denorm) + large_denom = 0x7f000000 # 2^127 + instructions = [ + s_mov_b32(s[0], large_denom), + v_mov_b32_e32(v[0], 1.0), # numer = 1.0 (S2) + v_mov_b32_e32(v[1], s[0]), # denom = 2^127 (S1) + # S0=numer, S1=denom, S2=numer -> check if we need to scale numer + v_div_scale_f32(v[2], VCC, v[0], v[1], v[0]), + ] + st = run_program(instructions, n_lanes=1) + # Hardware returns input unchanged but sets VCC=1 + self.assertAlmostEqual(i2f(st.vgpr[0][2]), 1.0, places=5) + self.assertEqual(st.vcc & 1, 1, "VCC should be 1 when result would be denorm") + + +class TestVDivFmas(unittest.TestCase): + """Tests for V_DIV_FMAS_F32 edge cases. + + V_DIV_FMAS_F32 performs FMA with optional scaling based on VCC. + The scale direction depends on S2's exponent (the addend): + - If exponent(S2) > 127 (i.e., S2 >= 2.0): scale by 2^+64 + - Otherwise: scale by 2^-64 + + NOTE: The PDF (page 449) incorrectly says just 2^32. + """ + + def test_div_fmas_f32_no_scale(self): + """V_DIV_FMAS_F32: VCC=0 -> normal FMA.""" + instructions = [ + s_mov_b32(s[SrcEnum.VCC_LO - 128], 0), # VCC = 0 + v_mov_b32_e32(v[0], 2.0), # S0 + v_mov_b32_e32(v[1], 3.0), # S1 + v_mov_b32_e32(v[2], 1.0), # S2 + v_div_fmas_f32(v[3], v[0], v[1], v[2]), # 2*3+1 = 7 + ] + st = run_program(instructions, n_lanes=1) + self.assertAlmostEqual(i2f(st.vgpr[0][3]), 7.0, places=5) + + def test_div_fmas_f32_scale_up(self): + """V_DIV_FMAS_F32: VCC=1 with S2 >= 2.0 -> scale by 2^+64.""" + instructions = [ + s_mov_b32(s[SrcEnum.VCC_LO - 128], 1), # VCC = 1 + v_mov_b32_e32(v[0], 1.0), # S0 + v_mov_b32_e32(v[1], 1.0), # S1 + v_mov_b32_e32(v[2], 2.0), # S2 >= 2.0, so scale UP + v_div_fmas_f32(v[3], v[0], v[1], v[2]), # 2^+64 * (1*1+2) = 2^+64 * 3 + ] + st = run_program(instructions, n_lanes=1) + expected = 3.0 * (2.0 ** 64) + self.assertAlmostEqual(i2f(st.vgpr[0][3]), expected, delta=abs(expected) * 1e-6) + + def test_div_fmas_f32_scale_down(self): + """V_DIV_FMAS_F32: VCC=1 with S2 < 2.0 -> scale by 2^-64.""" + instructions = [ + s_mov_b32(s[SrcEnum.VCC_LO - 128], 1), # VCC = 1 + v_mov_b32_e32(v[0], 2.0), # S0 + v_mov_b32_e32(v[1], 3.0), # S1 + v_mov_b32_e32(v[2], 1.0), # S2 < 2.0, so scale DOWN + v_div_fmas_f32(v[3], v[0], v[1], v[2]), # 2^-64 * (2*3+1) = 2^-64 * 7 + ] + st = run_program(instructions, n_lanes=1) + expected = 7.0 * (2.0 ** -64) + self.assertAlmostEqual(i2f(st.vgpr[0][3]), expected, delta=abs(expected) * 1e-6) + + def test_div_fmas_f32_per_lane_vcc(self): + """V_DIV_FMAS_F32: different VCC per lane with S2 < 2.0.""" + instructions = [ + s_mov_b32(s[SrcEnum.VCC_LO - 128], 0b0101), # VCC: lanes 0,2 set + v_mov_b32_e32(v[0], 1.0), + v_mov_b32_e32(v[1], 1.0), + v_mov_b32_e32(v[2], 1.0), # S2 < 2.0, so scale DOWN + v_div_fmas_f32(v[3], v[0], v[1], v[2]), # fma(1,1,1) = 2, scaled = 2^-64 * 2 + ] + st = run_program(instructions, n_lanes=4) + scaled = 2.0 * (2.0 ** -64) + unscaled = 2.0 + self.assertAlmostEqual(i2f(st.vgpr[0][3]), scaled, delta=abs(scaled) * 1e-6) # lane 0: VCC=1 + self.assertAlmostEqual(i2f(st.vgpr[1][3]), unscaled, places=5) # lane 1: VCC=0 + self.assertAlmostEqual(i2f(st.vgpr[2][3]), scaled, delta=abs(scaled) * 1e-6) # lane 2: VCC=1 + self.assertAlmostEqual(i2f(st.vgpr[3][3]), unscaled, places=5) # lane 3: VCC=0 + + +class TestVDivFixup(unittest.TestCase): + """Tests for V_DIV_FIXUP_F32 edge cases. + + V_DIV_FIXUP_F32 is the final step of Newton-Raphson division. + It handles special cases: NaN, Inf, zero, overflow, underflow. + + Args: S0=quotient from NR iteration, S1=denominator, S2=numerator + """ + + def test_div_fixup_f32_normal(self): + """V_DIV_FIXUP_F32: normal division passes through quotient.""" + # 6.0 / 2.0 = 3.0 + instructions = [ + v_mov_b32_e32(v[0], 3.0), # S0 = quotient + v_mov_b32_e32(v[1], 2.0), # S1 = denominator + v_mov_b32_e32(v[2], 6.0), # S2 = numerator + v_div_fixup_f32(v[3], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + self.assertAlmostEqual(i2f(st.vgpr[0][3]), 3.0, places=5) + + def test_div_fixup_f32_nan_numer(self): + """V_DIV_FIXUP_F32: NaN numerator -> quiet NaN.""" + nan = 0x7fc00000 # quiet NaN + instructions = [ + s_mov_b32(s[0], nan), + v_mov_b32_e32(v[0], 1.0), # S0 = quotient + v_mov_b32_e32(v[1], 1.0), # S1 = denominator + v_mov_b32_e32(v[2], s[0]), # S2 = numerator = NaN + v_div_fixup_f32(v[3], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + import math + self.assertTrue(math.isnan(i2f(st.vgpr[0][3])), "Should be NaN") + + def test_div_fixup_f32_nan_denom(self): + """V_DIV_FIXUP_F32: NaN denominator -> quiet NaN.""" + nan = 0x7fc00000 # quiet NaN + instructions = [ + s_mov_b32(s[0], nan), + v_mov_b32_e32(v[0], 1.0), # S0 = quotient + v_mov_b32_e32(v[1], s[0]), # S1 = denominator = NaN + v_mov_b32_e32(v[2], 1.0), # S2 = numerator + v_div_fixup_f32(v[3], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + import math + self.assertTrue(math.isnan(i2f(st.vgpr[0][3])), "Should be NaN") + + def test_div_fixup_f32_zero_div_zero(self): + """V_DIV_FIXUP_F32: 0/0 -> NaN (0xffc00000).""" + instructions = [ + v_mov_b32_e32(v[0], 1.0), # S0 = quotient (doesn't matter) + v_mov_b32_e32(v[1], 0.0), # S1 = denominator = 0 + v_mov_b32_e32(v[2], 0.0), # S2 = numerator = 0 + v_div_fixup_f32(v[3], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + import math + self.assertTrue(math.isnan(i2f(st.vgpr[0][3])), "0/0 should be NaN") + + def test_div_fixup_f32_inf_div_inf(self): + """V_DIV_FIXUP_F32: inf/inf -> NaN.""" + pos_inf = 0x7f800000 + instructions = [ + s_mov_b32(s[0], pos_inf), + v_mov_b32_e32(v[0], 1.0), # S0 = quotient + v_mov_b32_e32(v[1], s[0]), # S1 = denominator = +inf + v_mov_b32_e32(v[2], s[0]), # S2 = numerator = +inf + v_div_fixup_f32(v[3], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + import math + self.assertTrue(math.isnan(i2f(st.vgpr[0][3])), "inf/inf should be NaN") + + def test_div_fixup_f32_x_div_zero(self): + """V_DIV_FIXUP_F32: x/0 -> +/-inf based on sign.""" + instructions = [ + v_mov_b32_e32(v[0], 1.0), # S0 = quotient + v_mov_b32_e32(v[1], 0.0), # S1 = denominator = 0 + v_mov_b32_e32(v[2], 1.0), # S2 = numerator = 1.0 + v_div_fixup_f32(v[3], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + import math + self.assertTrue(math.isinf(i2f(st.vgpr[0][3])), "x/0 should be inf") + self.assertGreater(i2f(st.vgpr[0][3]), 0, "1/0 should be +inf") + + def test_div_fixup_f32_neg_x_div_zero(self): + """V_DIV_FIXUP_F32: -x/0 -> -inf.""" + instructions = [ + v_mov_b32_e32(v[0], 1.0), # S0 = quotient + v_mov_b32_e32(v[1], 0.0), # S1 = denominator = 0 + v_mov_b32_e32(v[2], -1.0), # S2 = numerator = -1.0 + v_div_fixup_f32(v[3], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + import math + self.assertTrue(math.isinf(i2f(st.vgpr[0][3])), "-x/0 should be inf") + self.assertLess(i2f(st.vgpr[0][3]), 0, "-1/0 should be -inf") + + def test_div_fixup_f32_zero_div_x(self): + """V_DIV_FIXUP_F32: 0/x -> 0.""" + instructions = [ + v_mov_b32_e32(v[0], 1.0), # S0 = quotient + v_mov_b32_e32(v[1], 2.0), # S1 = denominator = 2.0 + v_mov_b32_e32(v[2], 0.0), # S2 = numerator = 0 + v_div_fixup_f32(v[3], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(i2f(st.vgpr[0][3]), 0.0, "0/x should be 0") + + def test_div_fixup_f32_x_div_inf(self): + """V_DIV_FIXUP_F32: x/inf -> 0.""" + pos_inf = 0x7f800000 + instructions = [ + s_mov_b32(s[0], pos_inf), + v_mov_b32_e32(v[0], 1.0), # S0 = quotient + v_mov_b32_e32(v[1], s[0]), # S1 = denominator = +inf + v_mov_b32_e32(v[2], 1.0), # S2 = numerator = 1.0 + v_div_fixup_f32(v[3], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + self.assertEqual(i2f(st.vgpr[0][3]), 0.0, "x/inf should be 0") + + def test_div_fixup_f32_inf_div_x(self): + """V_DIV_FIXUP_F32: inf/x -> inf.""" + pos_inf = 0x7f800000 + instructions = [ + s_mov_b32(s[0], pos_inf), + v_mov_b32_e32(v[0], 1.0), # S0 = quotient + v_mov_b32_e32(v[1], 1.0), # S1 = denominator = 1.0 + v_mov_b32_e32(v[2], s[0]), # S2 = numerator = +inf + v_div_fixup_f32(v[3], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + import math + self.assertTrue(math.isinf(i2f(st.vgpr[0][3])), "inf/x should be inf") + + def test_div_fixup_f32_sign_propagation(self): + """V_DIV_FIXUP_F32: sign is XOR of numer and denom signs.""" + instructions = [ + v_mov_b32_e32(v[0], 3.0), # S0 = |quotient| + v_mov_b32_e32(v[1], -2.0), # S1 = denominator (negative) + v_mov_b32_e32(v[2], 6.0), # S2 = numerator (positive) + v_div_fixup_f32(v[3], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + # pos / neg = neg + self.assertAlmostEqual(i2f(st.vgpr[0][3]), -3.0, places=5) + + def test_div_fixup_f32_neg_neg(self): + """V_DIV_FIXUP_F32: neg/neg -> positive.""" + instructions = [ + v_mov_b32_e32(v[0], 3.0), # S0 = |quotient| + v_mov_b32_e32(v[1], -2.0), # S1 = denominator (negative) + v_mov_b32_e32(v[2], -6.0), # S2 = numerator (negative) + v_div_fixup_f32(v[3], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + # neg / neg = pos + self.assertAlmostEqual(i2f(st.vgpr[0][3]), 3.0, places=5) + + def test_div_fixup_f32_nan_estimate_overflow(self): + """V_DIV_FIXUP_F32: NaN estimate returns overflow (inf). + + PDF doesn't check isNAN(S0), but hardware returns OVERFLOW if S0 is NaN. + This happens when division fails (e.g., denorm denominator in V_DIV_SCALE). + """ + quiet_nan = 0x7fc00000 + instructions = [ + s_mov_b32(s[0], quiet_nan), + v_mov_b32_e32(v[0], s[0]), # S0 = NaN (failed estimate) + v_mov_b32_e32(v[1], 1.0), # S1 = denominator = 1.0 + v_mov_b32_e32(v[2], 1.0), # S2 = numerator = 1.0 + v_div_fixup_f32(v[3], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + import math + self.assertTrue(math.isinf(i2f(st.vgpr[0][3])), "NaN estimate should return inf") + self.assertEqual(st.vgpr[0][3], 0x7f800000, "Should be +inf (pos/pos)") + + def test_div_fixup_f32_nan_estimate_sign(self): + """V_DIV_FIXUP_F32: NaN estimate with negative sign returns -inf.""" + quiet_nan = 0x7fc00000 + instructions = [ + s_mov_b32(s[0], quiet_nan), + v_mov_b32_e32(v[0], s[0]), # S0 = NaN (failed estimate) + v_mov_b32_e32(v[1], -1.0), # S1 = denominator = -1.0 + v_mov_b32_e32(v[2], 1.0), # S2 = numerator = 1.0 + v_div_fixup_f32(v[3], v[0], v[1], v[2]), + ] + st = run_program(instructions, n_lanes=1) + import math + self.assertTrue(math.isinf(i2f(st.vgpr[0][3])), "NaN estimate should return inf") + self.assertEqual(st.vgpr[0][3], 0xff800000, "Should be -inf (pos/neg)") + class TestVCmpClass(unittest.TestCase): """Tests for V_CMP_CLASS_F32 float classification."""