minor refactoring for rdna3 (#13873)

* minor refactoring for rdna3

* fix div scale stuff

* more bugfixes
This commit is contained in:
George Hotz
2025-12-29 13:20:00 -05:00
committed by GitHub
parent 39923203ba
commit ff856a74cb
4 changed files with 561 additions and 57 deletions

View File

@@ -92,7 +92,7 @@ def _SOP1Op_S_CTZ_I32_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG
tmp = Reg(-1)
for i in range(0, int(31)+1):
if S0.u32[i] == 1:
tmp = Reg(i)
tmp = Reg(i); break
D0.i32 = tmp
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
@@ -115,7 +115,7 @@ def _SOP1Op_S_CTZ_I32_B64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG
tmp = Reg(-1)
for i in range(0, int(63)+1):
if S0.u64[i] == 1:
tmp = Reg(i)
tmp = Reg(i); break
D0.i32 = tmp
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
@@ -138,7 +138,7 @@ def _SOP1Op_S_CLZ_I32_U32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG
tmp = Reg(-1)
for i in range(0, int(31)+1):
if S0.u32[31 - i] == 1:
tmp = Reg(i)
tmp = Reg(i); break
D0.i32 = tmp
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
@@ -161,7 +161,7 @@ def _SOP1Op_S_CLZ_I32_U64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG
tmp = Reg(-1)
for i in range(0, int(63)+1):
if S0.u64[63 - i] == 1:
tmp = Reg(i)
tmp = Reg(i); break
D0.i32 = tmp
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
@@ -3746,7 +3746,7 @@ def _VOP1Op_V_CLZ_I32_U32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG
D0.i32 = -1
for i in range(0, int(31)+1):
if S0.u32[31 - i] == 1:
D0.i32 = i; break # Stop at first 1 bit found
D0.i32 = i; break
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
@@ -3766,7 +3766,7 @@ def _VOP1Op_V_CTZ_I32_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG
D0.i32 = -1
for i in range(0, int(31)+1):
if S0.u32[i] == 1:
D0.i32 = i; break # Stop at first 1 bit found
D0.i32 = i; break
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
@@ -5588,7 +5588,7 @@ def _VOP3Op_V_CLZ_I32_U32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG
D0.i32 = -1
for i in range(0, int(31)+1):
if S0.u32[31 - i] == 1:
D0.i32 = i; break # Stop at first 1 bit found
D0.i32 = i; break
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
@@ -5608,7 +5608,7 @@ def _VOP3Op_V_CTZ_I32_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG
D0.i32 = -1
for i in range(0, int(31)+1):
if S0.u32[i] == 1:
D0.i32 = i; break # Stop at first 1 bit found
D0.i32 = i; break
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
@@ -7207,7 +7207,7 @@ def _VOP3Op_V_DIV_FIXUP_F32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal,
elif exponent(S1.f32) == 255:
D0.f32 = ((-OVERFLOW_F32) if (sign_out) else (OVERFLOW_F32))
else:
D0.f32 = ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))
D0.f32 = ((-OVERFLOW_F32) if (sign_out) else (OVERFLOW_F32)) if isNAN(S0.f32) else ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
@@ -7260,7 +7260,7 @@ def _VOP3Op_V_DIV_FIXUP_F64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal,
elif exponent(S1.f64) == 2047:
D0.f64 = ((-OVERFLOW_F64) if (sign_out) else (OVERFLOW_F64))
else:
D0.f64 = ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))
D0.f64 = ((-OVERFLOW_F64) if (sign_out) else (OVERFLOW_F64)) if isNAN(S0.f64) else ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
result['d0_64'] = True
@@ -7280,7 +7280,7 @@ def _VOP3Op_V_DIV_FMAS_F32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, V
laneId = lane
# --- compiled pseudocode ---
if VCC.u64[laneId]:
D0.f32 = 2.0 ** 32 * fma(S0.f32, S1.f32, S2.f32)
D0.f32 = (2.0 ** 64 if exponent(S2.f32) > 127 else 2.0 ** -64) * fma(S0.f32, S1.f32, S2.f32)
else:
D0.f32 = fma(S0.f32, S1.f32, S2.f32)
# --- end pseudocode ---
@@ -7302,7 +7302,7 @@ def _VOP3Op_V_DIV_FMAS_F64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, V
laneId = lane
# --- compiled pseudocode ---
if VCC.u64[laneId]:
D0.f64 = 2.0 ** 64 * fma(S0.f64, S1.f64, S2.f64)
D0.f64 = (2.0 ** 128 if exponent(S2.f64) > 1023 else 2.0 ** -128) * fma(S0.f64, S1.f64, S2.f64)
else:
D0.f64 = fma(S0.f64, S1.f64, S2.f64)
# --- end pseudocode ---
@@ -8736,13 +8736,13 @@ def _VOP3SDOp_V_DIV_SCALE_F32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal
# --- compiled pseudocode ---
VCC = Reg(0x0)
if ((F(S2.f32) == 0.0) or (F(S1.f32) == 0.0)):
D0.f32 = float("nan")
VCC = Reg(0x1); D0.f32 = float("nan")
elif exponent(S2.f32) - exponent(S1.f32) >= 96:
VCC = Reg(0x1)
if S0.f32 == S1.f32:
D0.f32 = ldexp(S0.f32, 64)
elif S1.f32 == DENORM.f32:
D0.f32 = ldexp(S0.f32, 64)
elif False:
pass # denorm check moved to end
elif ((1.0 / F(S1.f32) == DENORM.f64) and (S2.f32 / S1.f32 == DENORM.f32)):
VCC = Reg(0x1)
if S0.f32 == S1.f32:
@@ -8751,10 +8751,10 @@ def _VOP3SDOp_V_DIV_SCALE_F32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal
D0.f32 = ldexp(S0.f32, -64)
elif S2.f32 / S1.f32 == DENORM.f32:
VCC = Reg(0x1)
if S0.f32 == S2.f32:
D0.f32 = ldexp(S0.f32, 64)
elif exponent(S2.f32) <= 23:
D0.f32 = ldexp(S0.f32, 64)
VCC = Reg(0x1); D0.f32 = ldexp(S0.f32, 64)
if S1.f32 == DENORM.f32:
D0.f32 = float("nan")
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
result['vcc_lane'] = (VCC._val >> lane) & 1
@@ -8799,13 +8799,13 @@ def _VOP3SDOp_V_DIV_SCALE_F64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal
# --- compiled pseudocode ---
VCC = Reg(0x0)
if ((S2.f64 == 0.0) or (S1.f64 == 0.0)):
D0.f64 = float("nan")
VCC = Reg(0x1); D0.f64 = float("nan")
elif exponent(S2.f64) - exponent(S1.f64) >= 768:
VCC = Reg(0x1)
if S0.f64 == S1.f64:
D0.f64 = ldexp(S0.f64, 128)
elif S1.f64 == DENORM.f64:
D0.f64 = ldexp(S0.f64, 128)
elif False:
pass # denorm check moved to end
elif ((1.0 / S1.f64 == DENORM.f64) and (S2.f64 / S1.f64 == DENORM.f64)):
VCC = Reg(0x1)
if S0.f64 == S1.f64:
@@ -8814,10 +8814,10 @@ def _VOP3SDOp_V_DIV_SCALE_F64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal
D0.f64 = ldexp(S0.f64, -128)
elif S2.f64 / S1.f64 == DENORM.f64:
VCC = Reg(0x1)
if S0.f64 == S2.f64:
D0.f64 = ldexp(S0.f64, 128)
elif exponent(S2.f64) <= 53:
D0.f64 = ldexp(S0.f64, 128)
if S1.f64 == DENORM.f64:
D0.f64 = float("nan")
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
result['vcc_lane'] = (VCC._val >> lane) & 1
@@ -9258,6 +9258,60 @@ def _VOP3POp_V_DOT2_F32_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal,
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3POp_V_DOT4_U32_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# tmp = S2.u32;
# tmp += u8_to_u32(S0[7 : 0].u8) * u8_to_u32(S1[7 : 0].u8);
# tmp += u8_to_u32(S0[15 : 8].u8) * u8_to_u32(S1[15 : 8].u8);
# tmp += u8_to_u32(S0[23 : 16].u8) * u8_to_u32(S1[23 : 16].u8);
# tmp += u8_to_u32(S0[31 : 24].u8) * u8_to_u32(S1[31 : 24].u8);
# D0.u32 = tmp
S0 = Reg(s0)
S1 = Reg(s1)
S2 = Reg(s2)
D0 = Reg(d0)
tmp = Reg(0)
# --- compiled pseudocode ---
tmp = Reg(S2.u32)
tmp += u8_to_u32(S0[7 : 0].u8) * u8_to_u32(S1[7 : 0].u8)
tmp += u8_to_u32(S0[15 : 8].u8) * u8_to_u32(S1[15 : 8].u8)
tmp += u8_to_u32(S0[23 : 16].u8) * u8_to_u32(S1[23 : 16].u8)
tmp += u8_to_u32(S0[31 : 24].u8) * u8_to_u32(S1[31 : 24].u8)
D0.u32 = tmp
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
def _VOP3POp_V_DOT8_U32_U4(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
# tmp = S2.u32;
# tmp += u4_to_u32(S0[3 : 0].u4) * u4_to_u32(S1[3 : 0].u4);
# tmp += u4_to_u32(S0[7 : 4].u4) * u4_to_u32(S1[7 : 4].u4);
# tmp += u4_to_u32(S0[11 : 8].u4) * u4_to_u32(S1[11 : 8].u4);
# tmp += u4_to_u32(S0[15 : 12].u4) * u4_to_u32(S1[15 : 12].u4);
# tmp += u4_to_u32(S0[19 : 16].u4) * u4_to_u32(S1[19 : 16].u4);
# tmp += u4_to_u32(S0[23 : 20].u4) * u4_to_u32(S1[23 : 20].u4);
# tmp += u4_to_u32(S0[27 : 24].u4) * u4_to_u32(S1[27 : 24].u4);
# tmp += u4_to_u32(S0[31 : 28].u4) * u4_to_u32(S1[31 : 28].u4);
# D0.u32 = tmp
S0 = Reg(s0)
S1 = Reg(s1)
S2 = Reg(s2)
D0 = Reg(d0)
tmp = Reg(0)
# --- compiled pseudocode ---
tmp = Reg(S2.u32)
tmp += u4_to_u32(S0[3 : 0].u4) * u4_to_u32(S1[3 : 0].u4)
tmp += u4_to_u32(S0[7 : 4].u4) * u4_to_u32(S1[7 : 4].u4)
tmp += u4_to_u32(S0[11 : 8].u4) * u4_to_u32(S1[11 : 8].u4)
tmp += u4_to_u32(S0[15 : 12].u4) * u4_to_u32(S1[15 : 12].u4)
tmp += u4_to_u32(S0[19 : 16].u4) * u4_to_u32(S1[19 : 16].u4)
tmp += u4_to_u32(S0[23 : 20].u4) * u4_to_u32(S1[23 : 20].u4)
tmp += u4_to_u32(S0[27 : 24].u4) * u4_to_u32(S1[27 : 24].u4)
tmp += u4_to_u32(S0[31 : 28].u4) * u4_to_u32(S1[31 : 28].u4)
D0.u32 = tmp
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
return result
VOP3POp_FUNCTIONS = {
VOP3POp.V_PK_MAD_I16: _VOP3POp_V_PK_MAD_I16,
VOP3POp.V_PK_MUL_LO_U16: _VOP3POp_V_PK_MUL_LO_U16,
@@ -9279,6 +9333,8 @@ VOP3POp_FUNCTIONS = {
VOP3POp.V_PK_MIN_F16: _VOP3POp_V_PK_MIN_F16,
VOP3POp.V_PK_MAX_F16: _VOP3POp_V_PK_MAX_F16,
VOP3POp.V_DOT2_F32_F16: _VOP3POp_V_DOT2_F32_F16,
VOP3POp.V_DOT4_U32_U8: _VOP3POp_V_DOT4_U32_U8,
VOP3POp.V_DOT8_U32_U4: _VOP3POp_V_DOT8_U32_U4,
}
def _VOPCOp_V_CMP_F_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):

View File

@@ -21,6 +21,7 @@ _VOP3_64BIT_OPS_32BIT_SRC1 = {VOP3Op.V_LDEXP_F64.value}
# Ops with 16-bit types in name (for source/dest handling)
_VOP3_16BIT_OPS = {op for op in VOP3Op if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16'))}
_VOP1_16BIT_OPS = {op for op in VOP1Op if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16'))}
_VOP2_16BIT_OPS = {op for op in VOP2Op if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16'))}
# CVT ops with 32/64-bit source (despite 16-bit in name)
_CVT_32_64_SRC_OPS = {op for op in VOP3Op if op.name.startswith('V_CVT_') and op.name.endswith(('_F32', '_I32', '_U32', '_F64', '_I64', '_U64'))} | \
{op for op in VOP1Op if op.name.startswith('V_CVT_') and op.name.endswith(('_F32', '_I32', '_U32', '_F64', '_I64', '_U64'))}
@@ -28,34 +29,17 @@ _CVT_32_64_SRC_OPS = {op for op in VOP3Op if op.name.startswith('V_CVT_') and op
_VOP3_16BIT_DST_OPS = {op for op in _VOP3_16BIT_OPS if 'PACK' not in op.name}
_VOP1_16BIT_DST_OPS = {op for op in _VOP1_16BIT_OPS if 'PACK' not in op.name}
# Inline constants for src operands 128-254 (f32 format for most instructions)
_INLINE_CONSTS = [0] * 127
for _i in range(65): _INLINE_CONSTS[_i] = _i
for _i in range(1, 17): _INLINE_CONSTS[64 + _i] = ((-_i) & 0xffffffff)
for _k, _v in {SrcEnum.POS_HALF: 0x3f000000, SrcEnum.NEG_HALF: 0xbf000000, SrcEnum.POS_ONE: 0x3f800000, SrcEnum.NEG_ONE: 0xbf800000,
SrcEnum.POS_TWO: 0x40000000, SrcEnum.NEG_TWO: 0xc0000000, SrcEnum.POS_FOUR: 0x40800000, SrcEnum.NEG_FOUR: 0xc0800000,
SrcEnum.INV_2PI: 0x3e22f983}.items(): _INLINE_CONSTS[_k - 128] = _v
# Inline constants for VOP3P packed f16 operations (f16 value in low 16 bits only, high 16 bits are 0)
# Hardware does NOT replicate the constant - opsel_hi controls which half is used for the hi result
_INLINE_CONSTS_F16 = [0] * 127
for _i in range(65): _INLINE_CONSTS_F16[_i] = _i # Integer constants in low 16 bits only
for _i in range(1, 17): _INLINE_CONSTS_F16[64 + _i] = (-_i) & 0xffff # Negative integers in low 16 bits
for _k, _v in {SrcEnum.POS_HALF: 0x3800, SrcEnum.NEG_HALF: 0xb800, SrcEnum.POS_ONE: 0x3c00, SrcEnum.NEG_ONE: 0xbc00,
SrcEnum.POS_TWO: 0x4000, SrcEnum.NEG_TWO: 0xc000, SrcEnum.POS_FOUR: 0x4400, SrcEnum.NEG_FOUR: 0xc400,
SrcEnum.INV_2PI: 0x3118}.items(): _INLINE_CONSTS_F16[_k - 128] = _v # f16 values in low 16 bits
# Inline constants for 64-bit operations (f64 format)
# Integer constants 0-64 are zero-extended to 64 bits; -1 to -16 are sign-extended
# Float constants are the f64 representation of the value
# Inline constants for src operands 128-254. Build tables for f32, f16, and f64 formats.
import struct as _struct
_INLINE_CONSTS_F64 = [0] * 127
for _i in range(65): _INLINE_CONSTS_F64[_i] = _i # Integer constants 0-64 zero-extended
for _i in range(1, 17): _INLINE_CONSTS_F64[64 + _i] = ((-_i) & 0xffffffffffffffff) # -1 to -16 sign-extended
for _k, _v in {SrcEnum.POS_HALF: 0.5, SrcEnum.NEG_HALF: -0.5, SrcEnum.POS_ONE: 1.0, SrcEnum.NEG_ONE: -1.0,
SrcEnum.POS_TWO: 2.0, SrcEnum.NEG_TWO: -2.0, SrcEnum.POS_FOUR: 4.0, SrcEnum.NEG_FOUR: -4.0,
SrcEnum.INV_2PI: 0.15915494309189535}.items():
_INLINE_CONSTS_F64[_k - 128] = _struct.unpack('<Q', _struct.pack('<d', _v))[0]
_FLOAT_CONSTS = {SrcEnum.POS_HALF: 0.5, SrcEnum.NEG_HALF: -0.5, SrcEnum.POS_ONE: 1.0, SrcEnum.NEG_ONE: -1.0,
SrcEnum.POS_TWO: 2.0, SrcEnum.NEG_TWO: -2.0, SrcEnum.POS_FOUR: 4.0, SrcEnum.NEG_FOUR: -4.0, SrcEnum.INV_2PI: 0.15915494309189535}
def _build_inline_consts(neg_mask, float_to_bits):
tbl = list(range(65)) + [((-i) & neg_mask) for i in range(1, 17)] + [0] * (127 - 81)
for k, v in _FLOAT_CONSTS.items(): tbl[k - 128] = float_to_bits(v)
return tbl
_INLINE_CONSTS = _build_inline_consts(0xffffffff, lambda f: _struct.unpack('<I', _struct.pack('<f', f))[0])
_INLINE_CONSTS_F16 = _build_inline_consts(0xffff, lambda f: _struct.unpack('<H', _struct.pack('<e', f))[0])
_INLINE_CONSTS_F64 = _build_inline_consts(0xffffffffffffffff, lambda f: _struct.unpack('<Q', _struct.pack('<d', f))[0])
# Memory access
_valid_mem_ranges: list[tuple[int, int]] = []
@@ -519,8 +503,10 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
is_ldexp_64 = op in (VOP3Op.V_LDEXP_F64,)
is_shift_64 = op in (VOP3Op.V_LSHLREV_B64, VOP3Op.V_LSHRREV_B64, VOP3Op.V_ASHRREV_I64)
# 16-bit source ops: use precomputed sets instead of string checks
has_16bit_type = op in _VOP3_16BIT_OPS or op in _VOP1_16BIT_OPS
has_16bit_type = op in _VOP3_16BIT_OPS or op in _VOP1_16BIT_OPS or op in _VOP2_16BIT_OPS
is_16bit_src = op_cls is VOP3Op and op in _VOP3_16BIT_OPS and op not in _CVT_32_64_SRC_OPS
# VOP2 16-bit ops use f16 inline constants for src0 (vsrc1 is always a VGPR, no inline constants)
is_vop2_16bit = op_cls is VOP2Op and op in _VOP2_16BIT_OPS
if is_shift_64:
s0 = mod_src(st.rsrc(src0, lane), 0) # shift amount is 32-bit
@@ -544,6 +530,11 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
s0 = ((s0_raw >> 16) & 0xffff) if (opsel & 1) else (s0_raw & 0xffff)
s1 = ((s1_raw >> 16) & 0xffff) if (opsel & 2) else (s1_raw & 0xffff)
s2 = ((s2_raw >> 16) & 0xffff) if (opsel & 4) else (s2_raw & 0xffff)
elif is_vop2_16bit:
# VOP2 16-bit ops: src0 can use f16 inline constants, vsrc1 is always a VGPR (no inline constants)
s0 = mod_src(st.rsrc_f16(src0, lane), 0)
s1 = mod_src(st.rsrc(src1, lane), 1) if src1 is not None else 0
s2 = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0
else:
s0 = mod_src(st.rsrc(src0, lane), 0)
s1 = mod_src(st.rsrc(src1, lane), 1) if src1 is not None else 0

View File

@@ -102,6 +102,8 @@ def i16_to_f16(v): return f32_to_f16(float(_sext(int(v) & 0xffff, 16)))
def u16_to_f16(v): return f32_to_f16(float(int(v) & 0xffff))
def f16_to_i16(bits): f = _f16_to_f32_bits(bits); return max(-32768, min(32767, int(f))) if not math.isnan(f) else 0
def f16_to_u16(bits): f = _f16_to_f32_bits(bits); return max(0, min(65535, int(f))) if not math.isnan(f) else 0
def u8_to_u32(v): return int(v) & 0xff
def u4_to_u32(v): return int(v) & 0xf
def _sign(f): return 1 if math.copysign(1.0, f) < 0 else 0
def _mantissa_f32(f): return struct.unpack("<I", struct.pack("<f", f))[0] & 0x7fffff if not (math.isinf(f) or math.isnan(f)) else 0
def _ldexp(m, e): return math.ldexp(m, e)
@@ -141,9 +143,17 @@ def _ctz64(v):
while (v & 1) == 0: v >>= 1; n += 1
return n
def _exponent(f):
# Handle TypedView (f16/f32/f64) to get correct exponent for that type
if hasattr(f, '_bits') and hasattr(f, '_float') and f._float:
raw = f._val
if f._bits == 16: return (raw >> 10) & 0x1f # f16: 5-bit exponent
if f._bits == 32: return (raw >> 23) & 0xff # f32: 8-bit exponent
if f._bits == 64: return (raw >> 52) & 0x7ff # f64: 11-bit exponent
# Fallback: convert to f32 and get exponent
f = float(f)
if math.isinf(f) or math.isnan(f): return 255
if f == 0.0: return 0
try: bits = struct.unpack("<I", struct.pack("<f", float(f)))[0]; return (bits >> 23) & 0xff
try: bits = struct.unpack("<I", struct.pack("<f", f))[0]; return (bits >> 23) & 0xff
except: return 0
def _is_denorm_f32(f):
if not isinstance(f, float): f = _f32(int(f) & 0xffffffff)
@@ -229,7 +239,7 @@ __all__ = [
'f32_to_i32', 'f32_to_u32', 'f64_to_i32', 'f64_to_u32', 'f32_to_f16', 'f16_to_f32',
'i16_to_f16', 'u16_to_f16', 'f16_to_i16', 'f16_to_u16', 'u32_to_u16', 'i32_to_i16',
'f16_to_snorm', 'f16_to_unorm', 'f32_to_snorm', 'f32_to_unorm', 'v_cvt_i16_f32', 'v_cvt_u16_f32',
'SAT8', 'f32_to_u8',
'SAT8', 'f32_to_u8', 'u8_to_u32', 'u4_to_u32',
# Math functions
'trunc', 'floor', 'ceil', 'sqrt', 'log2', 'sin', 'cos', 'pow', 'fract', 'isEven', 'mantissa',
# Min/max functions
@@ -698,7 +708,7 @@ INST_PATTERN = re.compile(r'^([SV]_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
# Patterns that can't be handled by the DSL (require special handling in emu.py)
UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'BYTE_PERMUTE', 'FATAL_HALT', 'HW_REGISTERS',
'PC =', 'PC=', 'PC+', '= PC', 'v_sad', '+:', 'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt',
'CVT_OFF_TABLE', '.bf16', 'ThreadMask', 'u8_to_u32', 'u4_to_u32',
'CVT_OFF_TABLE', '.bf16', 'ThreadMask',
'S1[i', 'C.i32', 'v_msad_u8', 'S[i]', 'in[', '2.0 / PI',
'if n.', 'DST.u32', 'addrd = DST', 'addr = DST'] # Malformed pseudocode from PDF
@@ -797,9 +807,72 @@ from extra.assembly.rdna3.pcode import *
try:
code = compile_pseudocode(pc)
# CLZ/CTZ: The PDF pseudocode searches for the first 1 bit but doesn't break.
# Hardware stops at first match, so we need to add break after D0.i32 = i
# Hardware stops at first match. SOP1 uses tmp=i, VOP1/VOP3 use D0.i32=i
if 'CLZ' in op.name or 'CTZ' in op.name:
code = code.replace('D0.i32 = i', 'D0.i32 = i; break # Stop at first 1 bit found')
code = code.replace('tmp = Reg(i)', 'tmp = Reg(i); break')
code = code.replace('D0.i32 = i', 'D0.i32 = i; break')
# V_DIV_FMAS_F32/F64: PDF page 449 says 2^32/2^64 but hardware behavior is more complex.
# The scale direction depends on S2 (the addend): if exponent(S2) > 127 (i.e., S2 >= 2.0),
# scale by 2^+64 (to unscale a numerator that was scaled). Otherwise scale by 2^-64
# (to unscale a denominator that was scaled).
if op.name == 'V_DIV_FMAS_F32':
code = code.replace(
'D0.f32 = 2.0 ** 32 * fma(S0.f32, S1.f32, S2.f32)',
'D0.f32 = (2.0 ** 64 if exponent(S2.f32) > 127 else 2.0 ** -64) * fma(S0.f32, S1.f32, S2.f32)')
if op.name == 'V_DIV_FMAS_F64':
code = code.replace(
'D0.f64 = 2.0 ** 64 * fma(S0.f64, S1.f64, S2.f64)',
'D0.f64 = (2.0 ** 128 if exponent(S2.f64) > 1023 else 2.0 ** -128) * fma(S0.f64, S1.f64, S2.f64)')
# V_DIV_SCALE_F32/F64: PDF page 463-464 has several bugs vs hardware behavior:
# 1. Zero case: hardware sets VCC=1 (PDF doesn't)
# 2. Denorm denom: hardware returns NaN (PDF says scale). VCC is set independently by exp diff check.
# 3. Tiny numer (exp<=23): hardware sets VCC=1 (PDF doesn't)
# 4. Result would be denorm: hardware doesn't scale, just sets VCC=1
if op.name == 'V_DIV_SCALE_F32':
# Fix 1: Set VCC=1 when zero operands produce NaN
code = code.replace(
'D0.f32 = float("nan")',
'VCC = Reg(0x1); D0.f32 = float("nan")')
# Fix 2: Denorm denom returns NaN. Must check this AFTER all VCC-setting logic runs.
# Insert at end of all branches, before the final result is used
code = code.replace(
'elif S1.f32 == DENORM.f32:\n D0.f32 = ldexp(S0.f32, 64)',
'elif False:\n pass # denorm check moved to end')
# Add denorm check at the very end - this overrides D0 but preserves VCC
code += '\nif S1.f32 == DENORM.f32:\n D0.f32 = float("nan")'
# Fix 3: Tiny numer should set VCC=1
code = code.replace(
'elif exponent(S2.f32) <= 23:\n D0.f32 = ldexp(S0.f32, 64)',
'elif exponent(S2.f32) <= 23:\n VCC = Reg(0x1); D0.f32 = ldexp(S0.f32, 64)')
# Fix 4: S2/S1 would be denorm - don't scale, just set VCC
code = code.replace(
'elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)\n if S0.f32 == S2.f32:\n D0.f32 = ldexp(S0.f32, 64)',
'elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)')
if op.name == 'V_DIV_SCALE_F64':
# Same fixes for f64 version
code = code.replace(
'D0.f64 = float("nan")',
'VCC = Reg(0x1); D0.f64 = float("nan")')
code = code.replace(
'elif S1.f64 == DENORM.f64:\n D0.f64 = ldexp(S0.f64, 128)',
'elif False:\n pass # denorm check moved to end')
code += '\nif S1.f64 == DENORM.f64:\n D0.f64 = float("nan")'
code = code.replace(
'elif exponent(S2.f64) <= 52:\n D0.f64 = ldexp(S0.f64, 128)',
'elif exponent(S2.f64) <= 52:\n VCC = Reg(0x1); D0.f64 = ldexp(S0.f64, 128)')
code = code.replace(
'elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)\n if S0.f64 == S2.f64:\n D0.f64 = ldexp(S0.f64, 128)',
'elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)')
# V_DIV_FIXUP_F32/F64: PDF doesn't check isNAN(S0), but hardware returns OVERFLOW if S0 is NaN.
# When division fails (e.g., due to denorm denom), S0 becomes NaN, and fixup should return ±inf.
if op.name == 'V_DIV_FIXUP_F32':
code = code.replace(
'D0.f32 = ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))',
'D0.f32 = ((-OVERFLOW_F32) if (sign_out) else (OVERFLOW_F32)) if isNAN(S0.f32) else ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))')
if op.name == 'V_DIV_FIXUP_F64':
code = code.replace(
'D0.f64 = ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))',
'D0.f64 = ((-OVERFLOW_F64) if (sign_out) else (OVERFLOW_F64)) if isNAN(S0.f64) else ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))')
# Detect flags for result handling
is_64 = any(p in pc for p in ['D0.u64', 'D0.b64', 'D0.f64', 'D0.i64', 'D1.u64', 'D1.b64', 'D1.f64', 'D1.i64'])
has_d1 = '{ D1' in pc

View File

@@ -225,7 +225,21 @@ def run_program(instructions: list, n_lanes: int = 1) -> WaveState:
class TestVDivScale(unittest.TestCase):
"""Tests for V_DIV_SCALE_F32 VCC handling."""
"""Tests for V_DIV_SCALE_F32 edge cases.
V_DIV_SCALE_F32 is used in the Newton-Raphson division sequence to handle
denormals and near-overflow cases. It scales operands and sets VCC when
the final result needs to be unscaled.
Pseudocode cases:
1. Zero operands -> NaN
2. exp(S2) - exp(S1) >= 96 -> scale denom, VCC=1
3. S1 is denorm -> scale by 2^64
4. 1/S1 is f64 denorm AND S2/S1 is f32 denorm -> scale denom, VCC=1
5. 1/S1 is f64 denorm -> scale by 2^-64
6. S2/S1 is f32 denorm -> scale numer, VCC=1
7. exp(S2) <= 23 -> scale by 2^64 (tiny numerator)
"""
def test_div_scale_f32_vcc_zero_single_lane(self):
"""V_DIV_SCALE_F32 sets VCC=0 when no scaling needed."""
@@ -257,6 +271,376 @@ class TestVDivScale(unittest.TestCase):
st = run_program(instructions, n_lanes=1)
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 2.0, places=5)
def test_div_scale_f32_zero_denom_gives_nan(self):
"""V_DIV_SCALE_F32: zero denominator -> NaN, VCC=1."""
instructions = [
v_mov_b32_e32(v[0], 1.0), # numerator
v_mov_b32_e32(v[1], 0.0), # denominator = 0
v_div_scale_f32(v[2], VCC, v[0], v[1], v[0]),
]
st = run_program(instructions, n_lanes=1)
import math
self.assertTrue(math.isnan(i2f(st.vgpr[0][2])), "Should be NaN for zero denom")
self.assertEqual(st.vcc & 1, 1, "VCC should be 1 for zero denom")
def test_div_scale_f32_zero_numer_gives_nan(self):
"""V_DIV_SCALE_F32: zero numerator -> NaN, VCC=1."""
instructions = [
v_mov_b32_e32(v[0], 0.0), # numerator = 0
v_mov_b32_e32(v[1], 1.0), # denominator
v_div_scale_f32(v[2], VCC, v[0], v[1], v[0]),
]
st = run_program(instructions, n_lanes=1)
import math
self.assertTrue(math.isnan(i2f(st.vgpr[0][2])), "Should be NaN for zero numer")
self.assertEqual(st.vcc & 1, 1, "VCC should be 1 for zero numer")
def test_div_scale_f32_large_exp_diff_scales_denom(self):
"""V_DIV_SCALE_F32: exp(numer) - exp(denom) >= 96 -> scale denom, VCC=1."""
# Need exp difference >= 96. Use MAX_FLOAT / tiny_normal
# MAX_FLOAT exp=254, tiny_normal with exp <= 254-96=158
# Let's use exp=127 (1.0) for denom, exp=254 for numer -> diff = 127 (>96)
max_float = 0x7f7fffff # 3.4028235e+38, exp=254
instructions = [
s_mov_b32(s[0], max_float),
v_mov_b32_e32(v[0], s[0]), # numer = MAX_FLOAT (S2)
v_mov_b32_e32(v[1], 1.0), # denom = 1.0 (S1), exp=127. diff = 254-127 = 127 >= 96
# S0=denom (what we're scaling), S1=denom, S2=numer
v_div_scale_f32(v[2], VCC, v[1], v[1], v[0]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vcc & 1, 1, "VCC should be 1 when scaling denom for large exp diff")
# Result should be denom * 2^64
expected = 1.0 * (2.0 ** 64)
self.assertAlmostEqual(i2f(st.vgpr[0][2]), expected, delta=expected * 1e-6)
def test_div_scale_f32_denorm_denom(self):
"""V_DIV_SCALE_F32: denormalized denominator -> NaN, VCC=1.
Hardware returns NaN when denominator is denormalized (different from PDF pseudocode).
"""
# Smallest positive denorm: 0x00000001 = 1.4e-45
denorm = 0x00000001
instructions = [
s_mov_b32(s[0], denorm),
v_mov_b32_e32(v[0], 1.0), # numer = 1.0 (S2)
v_mov_b32_e32(v[1], s[0]), # denom = denorm (S1)
# S0=denom, S1=denom, S2=numer -> scale denom
v_div_scale_f32(v[2], VCC, v[1], v[1], v[0]),
]
st = run_program(instructions, n_lanes=1)
import math
self.assertTrue(math.isnan(i2f(st.vgpr[0][2])), "Hardware returns NaN for denorm denom")
self.assertEqual(st.vcc & 1, 1, "VCC should be 1 for denorm denom")
def test_div_scale_f32_tiny_numer_exp_le_23(self):
"""V_DIV_SCALE_F32: exponent(numer) <= 23 -> scale by 2^64, VCC=1."""
# exp <= 23 means exponent field is 0..23
# exp=23 corresponds to float value around 2^(23-127) = 2^-104 ≈ 4.9e-32
# Use exp=1 (smallest normal), which is 2^(1-127) = 2^-126 ≈ 1.18e-38
smallest_normal = 0x00800000 # exp=1, mantissa=0
instructions = [
s_mov_b32(s[0], smallest_normal),
v_mov_b32_e32(v[0], s[0]), # numer = smallest_normal (S2), exp=1 <= 23
v_mov_b32_e32(v[1], 1.0), # denom = 1.0 (S1)
# S0=numer, S1=denom, S2=numer -> scale numer
v_div_scale_f32(v[2], VCC, v[0], v[1], v[0]),
]
st = run_program(instructions, n_lanes=1)
# Numer scaled by 2^64, VCC=1 to indicate scaling was done
numer_f = i2f(smallest_normal)
expected = numer_f * (2.0 ** 64)
self.assertAlmostEqual(i2f(st.vgpr[0][2]), expected, delta=abs(expected) * 1e-5)
self.assertEqual(st.vcc & 1, 1, "VCC should be 1 when scaling tiny numer")
def test_div_scale_f32_result_would_be_denorm(self):
"""V_DIV_SCALE_F32: result would be denorm -> no scaling applied, VCC=1.
When the result of numer/denom would be denormalized, hardware sets VCC=1
but does NOT scale the input (returns it unchanged). The scaling happens
elsewhere in the division sequence.
"""
# If S2/S1 would be denorm, set VCC but don't scale
# Denorm result: exp < 1, i.e., |result| < 2^-126
# Use 1.0 / 2^127 ≈ 5.9e-39 (result would be denorm)
large_denom = 0x7f000000 # 2^127
instructions = [
s_mov_b32(s[0], large_denom),
v_mov_b32_e32(v[0], 1.0), # numer = 1.0 (S2)
v_mov_b32_e32(v[1], s[0]), # denom = 2^127 (S1)
# S0=numer, S1=denom, S2=numer -> check if we need to scale numer
v_div_scale_f32(v[2], VCC, v[0], v[1], v[0]),
]
st = run_program(instructions, n_lanes=1)
# Hardware returns input unchanged but sets VCC=1
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 1.0, places=5)
self.assertEqual(st.vcc & 1, 1, "VCC should be 1 when result would be denorm")
class TestVDivFmas(unittest.TestCase):
"""Tests for V_DIV_FMAS_F32 edge cases.
V_DIV_FMAS_F32 performs FMA with optional scaling based on VCC.
The scale direction depends on S2's exponent (the addend):
- If exponent(S2) > 127 (i.e., S2 >= 2.0): scale by 2^+64
- Otherwise: scale by 2^-64
NOTE: The PDF (page 449) incorrectly says just 2^32.
"""
def test_div_fmas_f32_no_scale(self):
"""V_DIV_FMAS_F32: VCC=0 -> normal FMA."""
instructions = [
s_mov_b32(s[SrcEnum.VCC_LO - 128], 0), # VCC = 0
v_mov_b32_e32(v[0], 2.0), # S0
v_mov_b32_e32(v[1], 3.0), # S1
v_mov_b32_e32(v[2], 1.0), # S2
v_div_fmas_f32(v[3], v[0], v[1], v[2]), # 2*3+1 = 7
]
st = run_program(instructions, n_lanes=1)
self.assertAlmostEqual(i2f(st.vgpr[0][3]), 7.0, places=5)
def test_div_fmas_f32_scale_up(self):
"""V_DIV_FMAS_F32: VCC=1 with S2 >= 2.0 -> scale by 2^+64."""
instructions = [
s_mov_b32(s[SrcEnum.VCC_LO - 128], 1), # VCC = 1
v_mov_b32_e32(v[0], 1.0), # S0
v_mov_b32_e32(v[1], 1.0), # S1
v_mov_b32_e32(v[2], 2.0), # S2 >= 2.0, so scale UP
v_div_fmas_f32(v[3], v[0], v[1], v[2]), # 2^+64 * (1*1+2) = 2^+64 * 3
]
st = run_program(instructions, n_lanes=1)
expected = 3.0 * (2.0 ** 64)
self.assertAlmostEqual(i2f(st.vgpr[0][3]), expected, delta=abs(expected) * 1e-6)
def test_div_fmas_f32_scale_down(self):
"""V_DIV_FMAS_F32: VCC=1 with S2 < 2.0 -> scale by 2^-64."""
instructions = [
s_mov_b32(s[SrcEnum.VCC_LO - 128], 1), # VCC = 1
v_mov_b32_e32(v[0], 2.0), # S0
v_mov_b32_e32(v[1], 3.0), # S1
v_mov_b32_e32(v[2], 1.0), # S2 < 2.0, so scale DOWN
v_div_fmas_f32(v[3], v[0], v[1], v[2]), # 2^-64 * (2*3+1) = 2^-64 * 7
]
st = run_program(instructions, n_lanes=1)
expected = 7.0 * (2.0 ** -64)
self.assertAlmostEqual(i2f(st.vgpr[0][3]), expected, delta=abs(expected) * 1e-6)
def test_div_fmas_f32_per_lane_vcc(self):
"""V_DIV_FMAS_F32: different VCC per lane with S2 < 2.0."""
instructions = [
s_mov_b32(s[SrcEnum.VCC_LO - 128], 0b0101), # VCC: lanes 0,2 set
v_mov_b32_e32(v[0], 1.0),
v_mov_b32_e32(v[1], 1.0),
v_mov_b32_e32(v[2], 1.0), # S2 < 2.0, so scale DOWN
v_div_fmas_f32(v[3], v[0], v[1], v[2]), # fma(1,1,1) = 2, scaled = 2^-64 * 2
]
st = run_program(instructions, n_lanes=4)
scaled = 2.0 * (2.0 ** -64)
unscaled = 2.0
self.assertAlmostEqual(i2f(st.vgpr[0][3]), scaled, delta=abs(scaled) * 1e-6) # lane 0: VCC=1
self.assertAlmostEqual(i2f(st.vgpr[1][3]), unscaled, places=5) # lane 1: VCC=0
self.assertAlmostEqual(i2f(st.vgpr[2][3]), scaled, delta=abs(scaled) * 1e-6) # lane 2: VCC=1
self.assertAlmostEqual(i2f(st.vgpr[3][3]), unscaled, places=5) # lane 3: VCC=0
class TestVDivFixup(unittest.TestCase):
"""Tests for V_DIV_FIXUP_F32 edge cases.
V_DIV_FIXUP_F32 is the final step of Newton-Raphson division.
It handles special cases: NaN, Inf, zero, overflow, underflow.
Args: S0=quotient from NR iteration, S1=denominator, S2=numerator
"""
def test_div_fixup_f32_normal(self):
"""V_DIV_FIXUP_F32: normal division passes through quotient."""
# 6.0 / 2.0 = 3.0
instructions = [
v_mov_b32_e32(v[0], 3.0), # S0 = quotient
v_mov_b32_e32(v[1], 2.0), # S1 = denominator
v_mov_b32_e32(v[2], 6.0), # S2 = numerator
v_div_fixup_f32(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
self.assertAlmostEqual(i2f(st.vgpr[0][3]), 3.0, places=5)
def test_div_fixup_f32_nan_numer(self):
"""V_DIV_FIXUP_F32: NaN numerator -> quiet NaN."""
nan = 0x7fc00000 # quiet NaN
instructions = [
s_mov_b32(s[0], nan),
v_mov_b32_e32(v[0], 1.0), # S0 = quotient
v_mov_b32_e32(v[1], 1.0), # S1 = denominator
v_mov_b32_e32(v[2], s[0]), # S2 = numerator = NaN
v_div_fixup_f32(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
import math
self.assertTrue(math.isnan(i2f(st.vgpr[0][3])), "Should be NaN")
def test_div_fixup_f32_nan_denom(self):
"""V_DIV_FIXUP_F32: NaN denominator -> quiet NaN."""
nan = 0x7fc00000 # quiet NaN
instructions = [
s_mov_b32(s[0], nan),
v_mov_b32_e32(v[0], 1.0), # S0 = quotient
v_mov_b32_e32(v[1], s[0]), # S1 = denominator = NaN
v_mov_b32_e32(v[2], 1.0), # S2 = numerator
v_div_fixup_f32(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
import math
self.assertTrue(math.isnan(i2f(st.vgpr[0][3])), "Should be NaN")
def test_div_fixup_f32_zero_div_zero(self):
"""V_DIV_FIXUP_F32: 0/0 -> NaN (0xffc00000)."""
instructions = [
v_mov_b32_e32(v[0], 1.0), # S0 = quotient (doesn't matter)
v_mov_b32_e32(v[1], 0.0), # S1 = denominator = 0
v_mov_b32_e32(v[2], 0.0), # S2 = numerator = 0
v_div_fixup_f32(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
import math
self.assertTrue(math.isnan(i2f(st.vgpr[0][3])), "0/0 should be NaN")
def test_div_fixup_f32_inf_div_inf(self):
"""V_DIV_FIXUP_F32: inf/inf -> NaN."""
pos_inf = 0x7f800000
instructions = [
s_mov_b32(s[0], pos_inf),
v_mov_b32_e32(v[0], 1.0), # S0 = quotient
v_mov_b32_e32(v[1], s[0]), # S1 = denominator = +inf
v_mov_b32_e32(v[2], s[0]), # S2 = numerator = +inf
v_div_fixup_f32(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
import math
self.assertTrue(math.isnan(i2f(st.vgpr[0][3])), "inf/inf should be NaN")
def test_div_fixup_f32_x_div_zero(self):
"""V_DIV_FIXUP_F32: x/0 -> +/-inf based on sign."""
instructions = [
v_mov_b32_e32(v[0], 1.0), # S0 = quotient
v_mov_b32_e32(v[1], 0.0), # S1 = denominator = 0
v_mov_b32_e32(v[2], 1.0), # S2 = numerator = 1.0
v_div_fixup_f32(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
import math
self.assertTrue(math.isinf(i2f(st.vgpr[0][3])), "x/0 should be inf")
self.assertGreater(i2f(st.vgpr[0][3]), 0, "1/0 should be +inf")
def test_div_fixup_f32_neg_x_div_zero(self):
"""V_DIV_FIXUP_F32: -x/0 -> -inf."""
instructions = [
v_mov_b32_e32(v[0], 1.0), # S0 = quotient
v_mov_b32_e32(v[1], 0.0), # S1 = denominator = 0
v_mov_b32_e32(v[2], -1.0), # S2 = numerator = -1.0
v_div_fixup_f32(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
import math
self.assertTrue(math.isinf(i2f(st.vgpr[0][3])), "-x/0 should be inf")
self.assertLess(i2f(st.vgpr[0][3]), 0, "-1/0 should be -inf")
def test_div_fixup_f32_zero_div_x(self):
"""V_DIV_FIXUP_F32: 0/x -> 0."""
instructions = [
v_mov_b32_e32(v[0], 1.0), # S0 = quotient
v_mov_b32_e32(v[1], 2.0), # S1 = denominator = 2.0
v_mov_b32_e32(v[2], 0.0), # S2 = numerator = 0
v_div_fixup_f32(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(i2f(st.vgpr[0][3]), 0.0, "0/x should be 0")
def test_div_fixup_f32_x_div_inf(self):
"""V_DIV_FIXUP_F32: x/inf -> 0."""
pos_inf = 0x7f800000
instructions = [
s_mov_b32(s[0], pos_inf),
v_mov_b32_e32(v[0], 1.0), # S0 = quotient
v_mov_b32_e32(v[1], s[0]), # S1 = denominator = +inf
v_mov_b32_e32(v[2], 1.0), # S2 = numerator = 1.0
v_div_fixup_f32(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(i2f(st.vgpr[0][3]), 0.0, "x/inf should be 0")
def test_div_fixup_f32_inf_div_x(self):
"""V_DIV_FIXUP_F32: inf/x -> inf."""
pos_inf = 0x7f800000
instructions = [
s_mov_b32(s[0], pos_inf),
v_mov_b32_e32(v[0], 1.0), # S0 = quotient
v_mov_b32_e32(v[1], 1.0), # S1 = denominator = 1.0
v_mov_b32_e32(v[2], s[0]), # S2 = numerator = +inf
v_div_fixup_f32(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
import math
self.assertTrue(math.isinf(i2f(st.vgpr[0][3])), "inf/x should be inf")
def test_div_fixup_f32_sign_propagation(self):
"""V_DIV_FIXUP_F32: sign is XOR of numer and denom signs."""
instructions = [
v_mov_b32_e32(v[0], 3.0), # S0 = |quotient|
v_mov_b32_e32(v[1], -2.0), # S1 = denominator (negative)
v_mov_b32_e32(v[2], 6.0), # S2 = numerator (positive)
v_div_fixup_f32(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
# pos / neg = neg
self.assertAlmostEqual(i2f(st.vgpr[0][3]), -3.0, places=5)
def test_div_fixup_f32_neg_neg(self):
"""V_DIV_FIXUP_F32: neg/neg -> positive."""
instructions = [
v_mov_b32_e32(v[0], 3.0), # S0 = |quotient|
v_mov_b32_e32(v[1], -2.0), # S1 = denominator (negative)
v_mov_b32_e32(v[2], -6.0), # S2 = numerator (negative)
v_div_fixup_f32(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
# neg / neg = pos
self.assertAlmostEqual(i2f(st.vgpr[0][3]), 3.0, places=5)
def test_div_fixup_f32_nan_estimate_overflow(self):
"""V_DIV_FIXUP_F32: NaN estimate returns overflow (inf).
PDF doesn't check isNAN(S0), but hardware returns OVERFLOW if S0 is NaN.
This happens when division fails (e.g., denorm denominator in V_DIV_SCALE).
"""
quiet_nan = 0x7fc00000
instructions = [
s_mov_b32(s[0], quiet_nan),
v_mov_b32_e32(v[0], s[0]), # S0 = NaN (failed estimate)
v_mov_b32_e32(v[1], 1.0), # S1 = denominator = 1.0
v_mov_b32_e32(v[2], 1.0), # S2 = numerator = 1.0
v_div_fixup_f32(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
import math
self.assertTrue(math.isinf(i2f(st.vgpr[0][3])), "NaN estimate should return inf")
self.assertEqual(st.vgpr[0][3], 0x7f800000, "Should be +inf (pos/pos)")
def test_div_fixup_f32_nan_estimate_sign(self):
"""V_DIV_FIXUP_F32: NaN estimate with negative sign returns -inf."""
quiet_nan = 0x7fc00000
instructions = [
s_mov_b32(s[0], quiet_nan),
v_mov_b32_e32(v[0], s[0]), # S0 = NaN (failed estimate)
v_mov_b32_e32(v[1], -1.0), # S1 = denominator = -1.0
v_mov_b32_e32(v[2], 1.0), # S2 = numerator = 1.0
v_div_fixup_f32(v[3], v[0], v[1], v[2]),
]
st = run_program(instructions, n_lanes=1)
import math
self.assertTrue(math.isinf(i2f(st.vgpr[0][3])), "NaN estimate should return inf")
self.assertEqual(st.vgpr[0][3], 0xff800000, "Should be -inf (pos/neg)")
class TestVCmpClass(unittest.TestCase):
"""Tests for V_CMP_CLASS_F32 float classification."""