mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
minor refactoring for rdna3 (#13873)
* minor refactoring for rdna3 * fix div scale stuff * more bugfixes
This commit is contained in:
@@ -92,7 +92,7 @@ def _SOP1Op_S_CTZ_I32_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG
|
||||
tmp = Reg(-1)
|
||||
for i in range(0, int(31)+1):
|
||||
if S0.u32[i] == 1:
|
||||
tmp = Reg(i)
|
||||
tmp = Reg(i); break
|
||||
D0.i32 = tmp
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
@@ -115,7 +115,7 @@ def _SOP1Op_S_CTZ_I32_B64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG
|
||||
tmp = Reg(-1)
|
||||
for i in range(0, int(63)+1):
|
||||
if S0.u64[i] == 1:
|
||||
tmp = Reg(i)
|
||||
tmp = Reg(i); break
|
||||
D0.i32 = tmp
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
@@ -138,7 +138,7 @@ def _SOP1Op_S_CLZ_I32_U32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG
|
||||
tmp = Reg(-1)
|
||||
for i in range(0, int(31)+1):
|
||||
if S0.u32[31 - i] == 1:
|
||||
tmp = Reg(i)
|
||||
tmp = Reg(i); break
|
||||
D0.i32 = tmp
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
@@ -161,7 +161,7 @@ def _SOP1Op_S_CLZ_I32_U64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG
|
||||
tmp = Reg(-1)
|
||||
for i in range(0, int(63)+1):
|
||||
if S0.u64[63 - i] == 1:
|
||||
tmp = Reg(i)
|
||||
tmp = Reg(i); break
|
||||
D0.i32 = tmp
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
@@ -3746,7 +3746,7 @@ def _VOP1Op_V_CLZ_I32_U32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG
|
||||
D0.i32 = -1
|
||||
for i in range(0, int(31)+1):
|
||||
if S0.u32[31 - i] == 1:
|
||||
D0.i32 = i; break # Stop at first 1 bit found
|
||||
D0.i32 = i; break
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
return result
|
||||
@@ -3766,7 +3766,7 @@ def _VOP1Op_V_CTZ_I32_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG
|
||||
D0.i32 = -1
|
||||
for i in range(0, int(31)+1):
|
||||
if S0.u32[i] == 1:
|
||||
D0.i32 = i; break # Stop at first 1 bit found
|
||||
D0.i32 = i; break
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
return result
|
||||
@@ -5588,7 +5588,7 @@ def _VOP3Op_V_CLZ_I32_U32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG
|
||||
D0.i32 = -1
|
||||
for i in range(0, int(31)+1):
|
||||
if S0.u32[31 - i] == 1:
|
||||
D0.i32 = i; break # Stop at first 1 bit found
|
||||
D0.i32 = i; break
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
return result
|
||||
@@ -5608,7 +5608,7 @@ def _VOP3Op_V_CTZ_I32_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VG
|
||||
D0.i32 = -1
|
||||
for i in range(0, int(31)+1):
|
||||
if S0.u32[i] == 1:
|
||||
D0.i32 = i; break # Stop at first 1 bit found
|
||||
D0.i32 = i; break
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
return result
|
||||
@@ -7207,7 +7207,7 @@ def _VOP3Op_V_DIV_FIXUP_F32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal,
|
||||
elif exponent(S1.f32) == 255:
|
||||
D0.f32 = ((-OVERFLOW_F32) if (sign_out) else (OVERFLOW_F32))
|
||||
else:
|
||||
D0.f32 = ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))
|
||||
D0.f32 = ((-OVERFLOW_F32) if (sign_out) else (OVERFLOW_F32)) if isNAN(S0.f32) else ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
return result
|
||||
@@ -7260,7 +7260,7 @@ def _VOP3Op_V_DIV_FIXUP_F64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal,
|
||||
elif exponent(S1.f64) == 2047:
|
||||
D0.f64 = ((-OVERFLOW_F64) if (sign_out) else (OVERFLOW_F64))
|
||||
else:
|
||||
D0.f64 = ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))
|
||||
D0.f64 = ((-OVERFLOW_F64) if (sign_out) else (OVERFLOW_F64)) if isNAN(S0.f64) else ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
result['d0_64'] = True
|
||||
@@ -7280,7 +7280,7 @@ def _VOP3Op_V_DIV_FMAS_F32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, V
|
||||
laneId = lane
|
||||
# --- compiled pseudocode ---
|
||||
if VCC.u64[laneId]:
|
||||
D0.f32 = 2.0 ** 32 * fma(S0.f32, S1.f32, S2.f32)
|
||||
D0.f32 = (2.0 ** 64 if exponent(S2.f32) > 127 else 2.0 ** -64) * fma(S0.f32, S1.f32, S2.f32)
|
||||
else:
|
||||
D0.f32 = fma(S0.f32, S1.f32, S2.f32)
|
||||
# --- end pseudocode ---
|
||||
@@ -7302,7 +7302,7 @@ def _VOP3Op_V_DIV_FMAS_F64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, V
|
||||
laneId = lane
|
||||
# --- compiled pseudocode ---
|
||||
if VCC.u64[laneId]:
|
||||
D0.f64 = 2.0 ** 64 * fma(S0.f64, S1.f64, S2.f64)
|
||||
D0.f64 = (2.0 ** 128 if exponent(S2.f64) > 1023 else 2.0 ** -128) * fma(S0.f64, S1.f64, S2.f64)
|
||||
else:
|
||||
D0.f64 = fma(S0.f64, S1.f64, S2.f64)
|
||||
# --- end pseudocode ---
|
||||
@@ -8736,13 +8736,13 @@ def _VOP3SDOp_V_DIV_SCALE_F32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal
|
||||
# --- compiled pseudocode ---
|
||||
VCC = Reg(0x0)
|
||||
if ((F(S2.f32) == 0.0) or (F(S1.f32) == 0.0)):
|
||||
D0.f32 = float("nan")
|
||||
VCC = Reg(0x1); D0.f32 = float("nan")
|
||||
elif exponent(S2.f32) - exponent(S1.f32) >= 96:
|
||||
VCC = Reg(0x1)
|
||||
if S0.f32 == S1.f32:
|
||||
D0.f32 = ldexp(S0.f32, 64)
|
||||
elif S1.f32 == DENORM.f32:
|
||||
D0.f32 = ldexp(S0.f32, 64)
|
||||
elif False:
|
||||
pass # denorm check moved to end
|
||||
elif ((1.0 / F(S1.f32) == DENORM.f64) and (S2.f32 / S1.f32 == DENORM.f32)):
|
||||
VCC = Reg(0x1)
|
||||
if S0.f32 == S1.f32:
|
||||
@@ -8751,10 +8751,10 @@ def _VOP3SDOp_V_DIV_SCALE_F32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal
|
||||
D0.f32 = ldexp(S0.f32, -64)
|
||||
elif S2.f32 / S1.f32 == DENORM.f32:
|
||||
VCC = Reg(0x1)
|
||||
if S0.f32 == S2.f32:
|
||||
D0.f32 = ldexp(S0.f32, 64)
|
||||
elif exponent(S2.f32) <= 23:
|
||||
D0.f32 = ldexp(S0.f32, 64)
|
||||
VCC = Reg(0x1); D0.f32 = ldexp(S0.f32, 64)
|
||||
if S1.f32 == DENORM.f32:
|
||||
D0.f32 = float("nan")
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
result['vcc_lane'] = (VCC._val >> lane) & 1
|
||||
@@ -8799,13 +8799,13 @@ def _VOP3SDOp_V_DIV_SCALE_F64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal
|
||||
# --- compiled pseudocode ---
|
||||
VCC = Reg(0x0)
|
||||
if ((S2.f64 == 0.0) or (S1.f64 == 0.0)):
|
||||
D0.f64 = float("nan")
|
||||
VCC = Reg(0x1); D0.f64 = float("nan")
|
||||
elif exponent(S2.f64) - exponent(S1.f64) >= 768:
|
||||
VCC = Reg(0x1)
|
||||
if S0.f64 == S1.f64:
|
||||
D0.f64 = ldexp(S0.f64, 128)
|
||||
elif S1.f64 == DENORM.f64:
|
||||
D0.f64 = ldexp(S0.f64, 128)
|
||||
elif False:
|
||||
pass # denorm check moved to end
|
||||
elif ((1.0 / S1.f64 == DENORM.f64) and (S2.f64 / S1.f64 == DENORM.f64)):
|
||||
VCC = Reg(0x1)
|
||||
if S0.f64 == S1.f64:
|
||||
@@ -8814,10 +8814,10 @@ def _VOP3SDOp_V_DIV_SCALE_F64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal
|
||||
D0.f64 = ldexp(S0.f64, -128)
|
||||
elif S2.f64 / S1.f64 == DENORM.f64:
|
||||
VCC = Reg(0x1)
|
||||
if S0.f64 == S2.f64:
|
||||
D0.f64 = ldexp(S0.f64, 128)
|
||||
elif exponent(S2.f64) <= 53:
|
||||
D0.f64 = ldexp(S0.f64, 128)
|
||||
if S1.f64 == DENORM.f64:
|
||||
D0.f64 = float("nan")
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
result['vcc_lane'] = (VCC._val >> lane) & 1
|
||||
@@ -9258,6 +9258,60 @@ def _VOP3POp_V_DOT2_F32_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal,
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
return result
|
||||
|
||||
def _VOP3POp_V_DOT4_U32_U8(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
|
||||
# tmp = S2.u32;
|
||||
# tmp += u8_to_u32(S0[7 : 0].u8) * u8_to_u32(S1[7 : 0].u8);
|
||||
# tmp += u8_to_u32(S0[15 : 8].u8) * u8_to_u32(S1[15 : 8].u8);
|
||||
# tmp += u8_to_u32(S0[23 : 16].u8) * u8_to_u32(S1[23 : 16].u8);
|
||||
# tmp += u8_to_u32(S0[31 : 24].u8) * u8_to_u32(S1[31 : 24].u8);
|
||||
# D0.u32 = tmp
|
||||
S0 = Reg(s0)
|
||||
S1 = Reg(s1)
|
||||
S2 = Reg(s2)
|
||||
D0 = Reg(d0)
|
||||
tmp = Reg(0)
|
||||
# --- compiled pseudocode ---
|
||||
tmp = Reg(S2.u32)
|
||||
tmp += u8_to_u32(S0[7 : 0].u8) * u8_to_u32(S1[7 : 0].u8)
|
||||
tmp += u8_to_u32(S0[15 : 8].u8) * u8_to_u32(S1[15 : 8].u8)
|
||||
tmp += u8_to_u32(S0[23 : 16].u8) * u8_to_u32(S1[23 : 16].u8)
|
||||
tmp += u8_to_u32(S0[31 : 24].u8) * u8_to_u32(S1[31 : 24].u8)
|
||||
D0.u32 = tmp
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
return result
|
||||
|
||||
def _VOP3POp_V_DOT8_U32_U4(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
|
||||
# tmp = S2.u32;
|
||||
# tmp += u4_to_u32(S0[3 : 0].u4) * u4_to_u32(S1[3 : 0].u4);
|
||||
# tmp += u4_to_u32(S0[7 : 4].u4) * u4_to_u32(S1[7 : 4].u4);
|
||||
# tmp += u4_to_u32(S0[11 : 8].u4) * u4_to_u32(S1[11 : 8].u4);
|
||||
# tmp += u4_to_u32(S0[15 : 12].u4) * u4_to_u32(S1[15 : 12].u4);
|
||||
# tmp += u4_to_u32(S0[19 : 16].u4) * u4_to_u32(S1[19 : 16].u4);
|
||||
# tmp += u4_to_u32(S0[23 : 20].u4) * u4_to_u32(S1[23 : 20].u4);
|
||||
# tmp += u4_to_u32(S0[27 : 24].u4) * u4_to_u32(S1[27 : 24].u4);
|
||||
# tmp += u4_to_u32(S0[31 : 28].u4) * u4_to_u32(S1[31 : 28].u4);
|
||||
# D0.u32 = tmp
|
||||
S0 = Reg(s0)
|
||||
S1 = Reg(s1)
|
||||
S2 = Reg(s2)
|
||||
D0 = Reg(d0)
|
||||
tmp = Reg(0)
|
||||
# --- compiled pseudocode ---
|
||||
tmp = Reg(S2.u32)
|
||||
tmp += u4_to_u32(S0[3 : 0].u4) * u4_to_u32(S1[3 : 0].u4)
|
||||
tmp += u4_to_u32(S0[7 : 4].u4) * u4_to_u32(S1[7 : 4].u4)
|
||||
tmp += u4_to_u32(S0[11 : 8].u4) * u4_to_u32(S1[11 : 8].u4)
|
||||
tmp += u4_to_u32(S0[15 : 12].u4) * u4_to_u32(S1[15 : 12].u4)
|
||||
tmp += u4_to_u32(S0[19 : 16].u4) * u4_to_u32(S1[19 : 16].u4)
|
||||
tmp += u4_to_u32(S0[23 : 20].u4) * u4_to_u32(S1[23 : 20].u4)
|
||||
tmp += u4_to_u32(S0[27 : 24].u4) * u4_to_u32(S1[27 : 24].u4)
|
||||
tmp += u4_to_u32(S0[31 : 28].u4) * u4_to_u32(S1[31 : 28].u4)
|
||||
D0.u32 = tmp
|
||||
# --- end pseudocode ---
|
||||
result = {'d0': D0._val, 'scc': scc & 1}
|
||||
return result
|
||||
|
||||
VOP3POp_FUNCTIONS = {
|
||||
VOP3POp.V_PK_MAD_I16: _VOP3POp_V_PK_MAD_I16,
|
||||
VOP3POp.V_PK_MUL_LO_U16: _VOP3POp_V_PK_MUL_LO_U16,
|
||||
@@ -9279,6 +9333,8 @@ VOP3POp_FUNCTIONS = {
|
||||
VOP3POp.V_PK_MIN_F16: _VOP3POp_V_PK_MIN_F16,
|
||||
VOP3POp.V_PK_MAX_F16: _VOP3POp_V_PK_MAX_F16,
|
||||
VOP3POp.V_DOT2_F32_F16: _VOP3POp_V_DOT2_F32_F16,
|
||||
VOP3POp.V_DOT4_U32_U8: _VOP3POp_V_DOT4_U32_U8,
|
||||
VOP3POp.V_DOT8_U32_U4: _VOP3POp_V_DOT8_U32_U4,
|
||||
}
|
||||
|
||||
def _VOPCOp_V_CMP_F_F16(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
|
||||
|
||||
@@ -21,6 +21,7 @@ _VOP3_64BIT_OPS_32BIT_SRC1 = {VOP3Op.V_LDEXP_F64.value}
|
||||
# Ops with 16-bit types in name (for source/dest handling)
|
||||
_VOP3_16BIT_OPS = {op for op in VOP3Op if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16'))}
|
||||
_VOP1_16BIT_OPS = {op for op in VOP1Op if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16'))}
|
||||
_VOP2_16BIT_OPS = {op for op in VOP2Op if any(s in op.name for s in ('_F16', '_B16', '_I16', '_U16'))}
|
||||
# CVT ops with 32/64-bit source (despite 16-bit in name)
|
||||
_CVT_32_64_SRC_OPS = {op for op in VOP3Op if op.name.startswith('V_CVT_') and op.name.endswith(('_F32', '_I32', '_U32', '_F64', '_I64', '_U64'))} | \
|
||||
{op for op in VOP1Op if op.name.startswith('V_CVT_') and op.name.endswith(('_F32', '_I32', '_U32', '_F64', '_I64', '_U64'))}
|
||||
@@ -28,34 +29,17 @@ _CVT_32_64_SRC_OPS = {op for op in VOP3Op if op.name.startswith('V_CVT_') and op
|
||||
_VOP3_16BIT_DST_OPS = {op for op in _VOP3_16BIT_OPS if 'PACK' not in op.name}
|
||||
_VOP1_16BIT_DST_OPS = {op for op in _VOP1_16BIT_OPS if 'PACK' not in op.name}
|
||||
|
||||
# Inline constants for src operands 128-254 (f32 format for most instructions)
|
||||
_INLINE_CONSTS = [0] * 127
|
||||
for _i in range(65): _INLINE_CONSTS[_i] = _i
|
||||
for _i in range(1, 17): _INLINE_CONSTS[64 + _i] = ((-_i) & 0xffffffff)
|
||||
for _k, _v in {SrcEnum.POS_HALF: 0x3f000000, SrcEnum.NEG_HALF: 0xbf000000, SrcEnum.POS_ONE: 0x3f800000, SrcEnum.NEG_ONE: 0xbf800000,
|
||||
SrcEnum.POS_TWO: 0x40000000, SrcEnum.NEG_TWO: 0xc0000000, SrcEnum.POS_FOUR: 0x40800000, SrcEnum.NEG_FOUR: 0xc0800000,
|
||||
SrcEnum.INV_2PI: 0x3e22f983}.items(): _INLINE_CONSTS[_k - 128] = _v
|
||||
|
||||
# Inline constants for VOP3P packed f16 operations (f16 value in low 16 bits only, high 16 bits are 0)
|
||||
# Hardware does NOT replicate the constant - opsel_hi controls which half is used for the hi result
|
||||
_INLINE_CONSTS_F16 = [0] * 127
|
||||
for _i in range(65): _INLINE_CONSTS_F16[_i] = _i # Integer constants in low 16 bits only
|
||||
for _i in range(1, 17): _INLINE_CONSTS_F16[64 + _i] = (-_i) & 0xffff # Negative integers in low 16 bits
|
||||
for _k, _v in {SrcEnum.POS_HALF: 0x3800, SrcEnum.NEG_HALF: 0xb800, SrcEnum.POS_ONE: 0x3c00, SrcEnum.NEG_ONE: 0xbc00,
|
||||
SrcEnum.POS_TWO: 0x4000, SrcEnum.NEG_TWO: 0xc000, SrcEnum.POS_FOUR: 0x4400, SrcEnum.NEG_FOUR: 0xc400,
|
||||
SrcEnum.INV_2PI: 0x3118}.items(): _INLINE_CONSTS_F16[_k - 128] = _v # f16 values in low 16 bits
|
||||
|
||||
# Inline constants for 64-bit operations (f64 format)
|
||||
# Integer constants 0-64 are zero-extended to 64 bits; -1 to -16 are sign-extended
|
||||
# Float constants are the f64 representation of the value
|
||||
# Inline constants for src operands 128-254. Build tables for f32, f16, and f64 formats.
|
||||
import struct as _struct
|
||||
_INLINE_CONSTS_F64 = [0] * 127
|
||||
for _i in range(65): _INLINE_CONSTS_F64[_i] = _i # Integer constants 0-64 zero-extended
|
||||
for _i in range(1, 17): _INLINE_CONSTS_F64[64 + _i] = ((-_i) & 0xffffffffffffffff) # -1 to -16 sign-extended
|
||||
for _k, _v in {SrcEnum.POS_HALF: 0.5, SrcEnum.NEG_HALF: -0.5, SrcEnum.POS_ONE: 1.0, SrcEnum.NEG_ONE: -1.0,
|
||||
SrcEnum.POS_TWO: 2.0, SrcEnum.NEG_TWO: -2.0, SrcEnum.POS_FOUR: 4.0, SrcEnum.NEG_FOUR: -4.0,
|
||||
SrcEnum.INV_2PI: 0.15915494309189535}.items():
|
||||
_INLINE_CONSTS_F64[_k - 128] = _struct.unpack('<Q', _struct.pack('<d', _v))[0]
|
||||
_FLOAT_CONSTS = {SrcEnum.POS_HALF: 0.5, SrcEnum.NEG_HALF: -0.5, SrcEnum.POS_ONE: 1.0, SrcEnum.NEG_ONE: -1.0,
|
||||
SrcEnum.POS_TWO: 2.0, SrcEnum.NEG_TWO: -2.0, SrcEnum.POS_FOUR: 4.0, SrcEnum.NEG_FOUR: -4.0, SrcEnum.INV_2PI: 0.15915494309189535}
|
||||
def _build_inline_consts(neg_mask, float_to_bits):
|
||||
tbl = list(range(65)) + [((-i) & neg_mask) for i in range(1, 17)] + [0] * (127 - 81)
|
||||
for k, v in _FLOAT_CONSTS.items(): tbl[k - 128] = float_to_bits(v)
|
||||
return tbl
|
||||
_INLINE_CONSTS = _build_inline_consts(0xffffffff, lambda f: _struct.unpack('<I', _struct.pack('<f', f))[0])
|
||||
_INLINE_CONSTS_F16 = _build_inline_consts(0xffff, lambda f: _struct.unpack('<H', _struct.pack('<e', f))[0])
|
||||
_INLINE_CONSTS_F64 = _build_inline_consts(0xffffffffffffffff, lambda f: _struct.unpack('<Q', _struct.pack('<d', f))[0])
|
||||
|
||||
# Memory access
|
||||
_valid_mem_ranges: list[tuple[int, int]] = []
|
||||
@@ -519,8 +503,10 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
||||
is_ldexp_64 = op in (VOP3Op.V_LDEXP_F64,)
|
||||
is_shift_64 = op in (VOP3Op.V_LSHLREV_B64, VOP3Op.V_LSHRREV_B64, VOP3Op.V_ASHRREV_I64)
|
||||
# 16-bit source ops: use precomputed sets instead of string checks
|
||||
has_16bit_type = op in _VOP3_16BIT_OPS or op in _VOP1_16BIT_OPS
|
||||
has_16bit_type = op in _VOP3_16BIT_OPS or op in _VOP1_16BIT_OPS or op in _VOP2_16BIT_OPS
|
||||
is_16bit_src = op_cls is VOP3Op and op in _VOP3_16BIT_OPS and op not in _CVT_32_64_SRC_OPS
|
||||
# VOP2 16-bit ops use f16 inline constants for src0 (vsrc1 is always a VGPR, no inline constants)
|
||||
is_vop2_16bit = op_cls is VOP2Op and op in _VOP2_16BIT_OPS
|
||||
|
||||
if is_shift_64:
|
||||
s0 = mod_src(st.rsrc(src0, lane), 0) # shift amount is 32-bit
|
||||
@@ -544,6 +530,11 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
||||
s0 = ((s0_raw >> 16) & 0xffff) if (opsel & 1) else (s0_raw & 0xffff)
|
||||
s1 = ((s1_raw >> 16) & 0xffff) if (opsel & 2) else (s1_raw & 0xffff)
|
||||
s2 = ((s2_raw >> 16) & 0xffff) if (opsel & 4) else (s2_raw & 0xffff)
|
||||
elif is_vop2_16bit:
|
||||
# VOP2 16-bit ops: src0 can use f16 inline constants, vsrc1 is always a VGPR (no inline constants)
|
||||
s0 = mod_src(st.rsrc_f16(src0, lane), 0)
|
||||
s1 = mod_src(st.rsrc(src1, lane), 1) if src1 is not None else 0
|
||||
s2 = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0
|
||||
else:
|
||||
s0 = mod_src(st.rsrc(src0, lane), 0)
|
||||
s1 = mod_src(st.rsrc(src1, lane), 1) if src1 is not None else 0
|
||||
|
||||
@@ -102,6 +102,8 @@ def i16_to_f16(v): return f32_to_f16(float(_sext(int(v) & 0xffff, 16)))
|
||||
def u16_to_f16(v): return f32_to_f16(float(int(v) & 0xffff))
|
||||
def f16_to_i16(bits): f = _f16_to_f32_bits(bits); return max(-32768, min(32767, int(f))) if not math.isnan(f) else 0
|
||||
def f16_to_u16(bits): f = _f16_to_f32_bits(bits); return max(0, min(65535, int(f))) if not math.isnan(f) else 0
|
||||
def u8_to_u32(v): return int(v) & 0xff
|
||||
def u4_to_u32(v): return int(v) & 0xf
|
||||
def _sign(f): return 1 if math.copysign(1.0, f) < 0 else 0
|
||||
def _mantissa_f32(f): return struct.unpack("<I", struct.pack("<f", f))[0] & 0x7fffff if not (math.isinf(f) or math.isnan(f)) else 0
|
||||
def _ldexp(m, e): return math.ldexp(m, e)
|
||||
@@ -141,9 +143,17 @@ def _ctz64(v):
|
||||
while (v & 1) == 0: v >>= 1; n += 1
|
||||
return n
|
||||
def _exponent(f):
|
||||
# Handle TypedView (f16/f32/f64) to get correct exponent for that type
|
||||
if hasattr(f, '_bits') and hasattr(f, '_float') and f._float:
|
||||
raw = f._val
|
||||
if f._bits == 16: return (raw >> 10) & 0x1f # f16: 5-bit exponent
|
||||
if f._bits == 32: return (raw >> 23) & 0xff # f32: 8-bit exponent
|
||||
if f._bits == 64: return (raw >> 52) & 0x7ff # f64: 11-bit exponent
|
||||
# Fallback: convert to f32 and get exponent
|
||||
f = float(f)
|
||||
if math.isinf(f) or math.isnan(f): return 255
|
||||
if f == 0.0: return 0
|
||||
try: bits = struct.unpack("<I", struct.pack("<f", float(f)))[0]; return (bits >> 23) & 0xff
|
||||
try: bits = struct.unpack("<I", struct.pack("<f", f))[0]; return (bits >> 23) & 0xff
|
||||
except: return 0
|
||||
def _is_denorm_f32(f):
|
||||
if not isinstance(f, float): f = _f32(int(f) & 0xffffffff)
|
||||
@@ -229,7 +239,7 @@ __all__ = [
|
||||
'f32_to_i32', 'f32_to_u32', 'f64_to_i32', 'f64_to_u32', 'f32_to_f16', 'f16_to_f32',
|
||||
'i16_to_f16', 'u16_to_f16', 'f16_to_i16', 'f16_to_u16', 'u32_to_u16', 'i32_to_i16',
|
||||
'f16_to_snorm', 'f16_to_unorm', 'f32_to_snorm', 'f32_to_unorm', 'v_cvt_i16_f32', 'v_cvt_u16_f32',
|
||||
'SAT8', 'f32_to_u8',
|
||||
'SAT8', 'f32_to_u8', 'u8_to_u32', 'u4_to_u32',
|
||||
# Math functions
|
||||
'trunc', 'floor', 'ceil', 'sqrt', 'log2', 'sin', 'cos', 'pow', 'fract', 'isEven', 'mantissa',
|
||||
# Min/max functions
|
||||
@@ -698,7 +708,7 @@ INST_PATTERN = re.compile(r'^([SV]_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
|
||||
# Patterns that can't be handled by the DSL (require special handling in emu.py)
|
||||
UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'BYTE_PERMUTE', 'FATAL_HALT', 'HW_REGISTERS',
|
||||
'PC =', 'PC=', 'PC+', '= PC', 'v_sad', '+:', 'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt',
|
||||
'CVT_OFF_TABLE', '.bf16', 'ThreadMask', 'u8_to_u32', 'u4_to_u32',
|
||||
'CVT_OFF_TABLE', '.bf16', 'ThreadMask',
|
||||
'S1[i', 'C.i32', 'v_msad_u8', 'S[i]', 'in[', '2.0 / PI',
|
||||
'if n.', 'DST.u32', 'addrd = DST', 'addr = DST'] # Malformed pseudocode from PDF
|
||||
|
||||
@@ -797,9 +807,72 @@ from extra.assembly.rdna3.pcode import *
|
||||
try:
|
||||
code = compile_pseudocode(pc)
|
||||
# CLZ/CTZ: The PDF pseudocode searches for the first 1 bit but doesn't break.
|
||||
# Hardware stops at first match, so we need to add break after D0.i32 = i
|
||||
# Hardware stops at first match. SOP1 uses tmp=i, VOP1/VOP3 use D0.i32=i
|
||||
if 'CLZ' in op.name or 'CTZ' in op.name:
|
||||
code = code.replace('D0.i32 = i', 'D0.i32 = i; break # Stop at first 1 bit found')
|
||||
code = code.replace('tmp = Reg(i)', 'tmp = Reg(i); break')
|
||||
code = code.replace('D0.i32 = i', 'D0.i32 = i; break')
|
||||
# V_DIV_FMAS_F32/F64: PDF page 449 says 2^32/2^64 but hardware behavior is more complex.
|
||||
# The scale direction depends on S2 (the addend): if exponent(S2) > 127 (i.e., S2 >= 2.0),
|
||||
# scale by 2^+64 (to unscale a numerator that was scaled). Otherwise scale by 2^-64
|
||||
# (to unscale a denominator that was scaled).
|
||||
if op.name == 'V_DIV_FMAS_F32':
|
||||
code = code.replace(
|
||||
'D0.f32 = 2.0 ** 32 * fma(S0.f32, S1.f32, S2.f32)',
|
||||
'D0.f32 = (2.0 ** 64 if exponent(S2.f32) > 127 else 2.0 ** -64) * fma(S0.f32, S1.f32, S2.f32)')
|
||||
if op.name == 'V_DIV_FMAS_F64':
|
||||
code = code.replace(
|
||||
'D0.f64 = 2.0 ** 64 * fma(S0.f64, S1.f64, S2.f64)',
|
||||
'D0.f64 = (2.0 ** 128 if exponent(S2.f64) > 1023 else 2.0 ** -128) * fma(S0.f64, S1.f64, S2.f64)')
|
||||
# V_DIV_SCALE_F32/F64: PDF page 463-464 has several bugs vs hardware behavior:
|
||||
# 1. Zero case: hardware sets VCC=1 (PDF doesn't)
|
||||
# 2. Denorm denom: hardware returns NaN (PDF says scale). VCC is set independently by exp diff check.
|
||||
# 3. Tiny numer (exp<=23): hardware sets VCC=1 (PDF doesn't)
|
||||
# 4. Result would be denorm: hardware doesn't scale, just sets VCC=1
|
||||
if op.name == 'V_DIV_SCALE_F32':
|
||||
# Fix 1: Set VCC=1 when zero operands produce NaN
|
||||
code = code.replace(
|
||||
'D0.f32 = float("nan")',
|
||||
'VCC = Reg(0x1); D0.f32 = float("nan")')
|
||||
# Fix 2: Denorm denom returns NaN. Must check this AFTER all VCC-setting logic runs.
|
||||
# Insert at end of all branches, before the final result is used
|
||||
code = code.replace(
|
||||
'elif S1.f32 == DENORM.f32:\n D0.f32 = ldexp(S0.f32, 64)',
|
||||
'elif False:\n pass # denorm check moved to end')
|
||||
# Add denorm check at the very end - this overrides D0 but preserves VCC
|
||||
code += '\nif S1.f32 == DENORM.f32:\n D0.f32 = float("nan")'
|
||||
# Fix 3: Tiny numer should set VCC=1
|
||||
code = code.replace(
|
||||
'elif exponent(S2.f32) <= 23:\n D0.f32 = ldexp(S0.f32, 64)',
|
||||
'elif exponent(S2.f32) <= 23:\n VCC = Reg(0x1); D0.f32 = ldexp(S0.f32, 64)')
|
||||
# Fix 4: S2/S1 would be denorm - don't scale, just set VCC
|
||||
code = code.replace(
|
||||
'elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)\n if S0.f32 == S2.f32:\n D0.f32 = ldexp(S0.f32, 64)',
|
||||
'elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)')
|
||||
if op.name == 'V_DIV_SCALE_F64':
|
||||
# Same fixes for f64 version
|
||||
code = code.replace(
|
||||
'D0.f64 = float("nan")',
|
||||
'VCC = Reg(0x1); D0.f64 = float("nan")')
|
||||
code = code.replace(
|
||||
'elif S1.f64 == DENORM.f64:\n D0.f64 = ldexp(S0.f64, 128)',
|
||||
'elif False:\n pass # denorm check moved to end')
|
||||
code += '\nif S1.f64 == DENORM.f64:\n D0.f64 = float("nan")'
|
||||
code = code.replace(
|
||||
'elif exponent(S2.f64) <= 52:\n D0.f64 = ldexp(S0.f64, 128)',
|
||||
'elif exponent(S2.f64) <= 52:\n VCC = Reg(0x1); D0.f64 = ldexp(S0.f64, 128)')
|
||||
code = code.replace(
|
||||
'elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)\n if S0.f64 == S2.f64:\n D0.f64 = ldexp(S0.f64, 128)',
|
||||
'elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)')
|
||||
# V_DIV_FIXUP_F32/F64: PDF doesn't check isNAN(S0), but hardware returns OVERFLOW if S0 is NaN.
|
||||
# When division fails (e.g., due to denorm denom), S0 becomes NaN, and fixup should return ±inf.
|
||||
if op.name == 'V_DIV_FIXUP_F32':
|
||||
code = code.replace(
|
||||
'D0.f32 = ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))',
|
||||
'D0.f32 = ((-OVERFLOW_F32) if (sign_out) else (OVERFLOW_F32)) if isNAN(S0.f32) else ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))')
|
||||
if op.name == 'V_DIV_FIXUP_F64':
|
||||
code = code.replace(
|
||||
'D0.f64 = ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))',
|
||||
'D0.f64 = ((-OVERFLOW_F64) if (sign_out) else (OVERFLOW_F64)) if isNAN(S0.f64) else ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))')
|
||||
# Detect flags for result handling
|
||||
is_64 = any(p in pc for p in ['D0.u64', 'D0.b64', 'D0.f64', 'D0.i64', 'D1.u64', 'D1.b64', 'D1.f64', 'D1.i64'])
|
||||
has_d1 = '{ D1' in pc
|
||||
|
||||
@@ -225,7 +225,21 @@ def run_program(instructions: list, n_lanes: int = 1) -> WaveState:
|
||||
|
||||
|
||||
class TestVDivScale(unittest.TestCase):
|
||||
"""Tests for V_DIV_SCALE_F32 VCC handling."""
|
||||
"""Tests for V_DIV_SCALE_F32 edge cases.
|
||||
|
||||
V_DIV_SCALE_F32 is used in the Newton-Raphson division sequence to handle
|
||||
denormals and near-overflow cases. It scales operands and sets VCC when
|
||||
the final result needs to be unscaled.
|
||||
|
||||
Pseudocode cases:
|
||||
1. Zero operands -> NaN
|
||||
2. exp(S2) - exp(S1) >= 96 -> scale denom, VCC=1
|
||||
3. S1 is denorm -> scale by 2^64
|
||||
4. 1/S1 is f64 denorm AND S2/S1 is f32 denorm -> scale denom, VCC=1
|
||||
5. 1/S1 is f64 denorm -> scale by 2^-64
|
||||
6. S2/S1 is f32 denorm -> scale numer, VCC=1
|
||||
7. exp(S2) <= 23 -> scale by 2^64 (tiny numerator)
|
||||
"""
|
||||
|
||||
def test_div_scale_f32_vcc_zero_single_lane(self):
|
||||
"""V_DIV_SCALE_F32 sets VCC=0 when no scaling needed."""
|
||||
@@ -257,6 +271,376 @@ class TestVDivScale(unittest.TestCase):
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 2.0, places=5)
|
||||
|
||||
def test_div_scale_f32_zero_denom_gives_nan(self):
|
||||
"""V_DIV_SCALE_F32: zero denominator -> NaN, VCC=1."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 1.0), # numerator
|
||||
v_mov_b32_e32(v[1], 0.0), # denominator = 0
|
||||
v_div_scale_f32(v[2], VCC, v[0], v[1], v[0]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
import math
|
||||
self.assertTrue(math.isnan(i2f(st.vgpr[0][2])), "Should be NaN for zero denom")
|
||||
self.assertEqual(st.vcc & 1, 1, "VCC should be 1 for zero denom")
|
||||
|
||||
def test_div_scale_f32_zero_numer_gives_nan(self):
|
||||
"""V_DIV_SCALE_F32: zero numerator -> NaN, VCC=1."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0.0), # numerator = 0
|
||||
v_mov_b32_e32(v[1], 1.0), # denominator
|
||||
v_div_scale_f32(v[2], VCC, v[0], v[1], v[0]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
import math
|
||||
self.assertTrue(math.isnan(i2f(st.vgpr[0][2])), "Should be NaN for zero numer")
|
||||
self.assertEqual(st.vcc & 1, 1, "VCC should be 1 for zero numer")
|
||||
|
||||
def test_div_scale_f32_large_exp_diff_scales_denom(self):
|
||||
"""V_DIV_SCALE_F32: exp(numer) - exp(denom) >= 96 -> scale denom, VCC=1."""
|
||||
# Need exp difference >= 96. Use MAX_FLOAT / tiny_normal
|
||||
# MAX_FLOAT exp=254, tiny_normal with exp <= 254-96=158
|
||||
# Let's use exp=127 (1.0) for denom, exp=254 for numer -> diff = 127 (>96)
|
||||
max_float = 0x7f7fffff # 3.4028235e+38, exp=254
|
||||
instructions = [
|
||||
s_mov_b32(s[0], max_float),
|
||||
v_mov_b32_e32(v[0], s[0]), # numer = MAX_FLOAT (S2)
|
||||
v_mov_b32_e32(v[1], 1.0), # denom = 1.0 (S1), exp=127. diff = 254-127 = 127 >= 96
|
||||
# S0=denom (what we're scaling), S1=denom, S2=numer
|
||||
v_div_scale_f32(v[2], VCC, v[1], v[1], v[0]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(st.vcc & 1, 1, "VCC should be 1 when scaling denom for large exp diff")
|
||||
# Result should be denom * 2^64
|
||||
expected = 1.0 * (2.0 ** 64)
|
||||
self.assertAlmostEqual(i2f(st.vgpr[0][2]), expected, delta=expected * 1e-6)
|
||||
|
||||
def test_div_scale_f32_denorm_denom(self):
|
||||
"""V_DIV_SCALE_F32: denormalized denominator -> NaN, VCC=1.
|
||||
|
||||
Hardware returns NaN when denominator is denormalized (different from PDF pseudocode).
|
||||
"""
|
||||
# Smallest positive denorm: 0x00000001 = 1.4e-45
|
||||
denorm = 0x00000001
|
||||
instructions = [
|
||||
s_mov_b32(s[0], denorm),
|
||||
v_mov_b32_e32(v[0], 1.0), # numer = 1.0 (S2)
|
||||
v_mov_b32_e32(v[1], s[0]), # denom = denorm (S1)
|
||||
# S0=denom, S1=denom, S2=numer -> scale denom
|
||||
v_div_scale_f32(v[2], VCC, v[1], v[1], v[0]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
import math
|
||||
self.assertTrue(math.isnan(i2f(st.vgpr[0][2])), "Hardware returns NaN for denorm denom")
|
||||
self.assertEqual(st.vcc & 1, 1, "VCC should be 1 for denorm denom")
|
||||
|
||||
def test_div_scale_f32_tiny_numer_exp_le_23(self):
|
||||
"""V_DIV_SCALE_F32: exponent(numer) <= 23 -> scale by 2^64, VCC=1."""
|
||||
# exp <= 23 means exponent field is 0..23
|
||||
# exp=23 corresponds to float value around 2^(23-127) = 2^-104 ≈ 4.9e-32
|
||||
# Use exp=1 (smallest normal), which is 2^(1-127) = 2^-126 ≈ 1.18e-38
|
||||
smallest_normal = 0x00800000 # exp=1, mantissa=0
|
||||
instructions = [
|
||||
s_mov_b32(s[0], smallest_normal),
|
||||
v_mov_b32_e32(v[0], s[0]), # numer = smallest_normal (S2), exp=1 <= 23
|
||||
v_mov_b32_e32(v[1], 1.0), # denom = 1.0 (S1)
|
||||
# S0=numer, S1=denom, S2=numer -> scale numer
|
||||
v_div_scale_f32(v[2], VCC, v[0], v[1], v[0]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
# Numer scaled by 2^64, VCC=1 to indicate scaling was done
|
||||
numer_f = i2f(smallest_normal)
|
||||
expected = numer_f * (2.0 ** 64)
|
||||
self.assertAlmostEqual(i2f(st.vgpr[0][2]), expected, delta=abs(expected) * 1e-5)
|
||||
self.assertEqual(st.vcc & 1, 1, "VCC should be 1 when scaling tiny numer")
|
||||
|
||||
def test_div_scale_f32_result_would_be_denorm(self):
|
||||
"""V_DIV_SCALE_F32: result would be denorm -> no scaling applied, VCC=1.
|
||||
|
||||
When the result of numer/denom would be denormalized, hardware sets VCC=1
|
||||
but does NOT scale the input (returns it unchanged). The scaling happens
|
||||
elsewhere in the division sequence.
|
||||
"""
|
||||
# If S2/S1 would be denorm, set VCC but don't scale
|
||||
# Denorm result: exp < 1, i.e., |result| < 2^-126
|
||||
# Use 1.0 / 2^127 ≈ 5.9e-39 (result would be denorm)
|
||||
large_denom = 0x7f000000 # 2^127
|
||||
instructions = [
|
||||
s_mov_b32(s[0], large_denom),
|
||||
v_mov_b32_e32(v[0], 1.0), # numer = 1.0 (S2)
|
||||
v_mov_b32_e32(v[1], s[0]), # denom = 2^127 (S1)
|
||||
# S0=numer, S1=denom, S2=numer -> check if we need to scale numer
|
||||
v_div_scale_f32(v[2], VCC, v[0], v[1], v[0]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
# Hardware returns input unchanged but sets VCC=1
|
||||
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 1.0, places=5)
|
||||
self.assertEqual(st.vcc & 1, 1, "VCC should be 1 when result would be denorm")
|
||||
|
||||
|
||||
class TestVDivFmas(unittest.TestCase):
|
||||
"""Tests for V_DIV_FMAS_F32 edge cases.
|
||||
|
||||
V_DIV_FMAS_F32 performs FMA with optional scaling based on VCC.
|
||||
The scale direction depends on S2's exponent (the addend):
|
||||
- If exponent(S2) > 127 (i.e., S2 >= 2.0): scale by 2^+64
|
||||
- Otherwise: scale by 2^-64
|
||||
|
||||
NOTE: The PDF (page 449) incorrectly says just 2^32.
|
||||
"""
|
||||
|
||||
def test_div_fmas_f32_no_scale(self):
|
||||
"""V_DIV_FMAS_F32: VCC=0 -> normal FMA."""
|
||||
instructions = [
|
||||
s_mov_b32(s[SrcEnum.VCC_LO - 128], 0), # VCC = 0
|
||||
v_mov_b32_e32(v[0], 2.0), # S0
|
||||
v_mov_b32_e32(v[1], 3.0), # S1
|
||||
v_mov_b32_e32(v[2], 1.0), # S2
|
||||
v_div_fmas_f32(v[3], v[0], v[1], v[2]), # 2*3+1 = 7
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertAlmostEqual(i2f(st.vgpr[0][3]), 7.0, places=5)
|
||||
|
||||
def test_div_fmas_f32_scale_up(self):
|
||||
"""V_DIV_FMAS_F32: VCC=1 with S2 >= 2.0 -> scale by 2^+64."""
|
||||
instructions = [
|
||||
s_mov_b32(s[SrcEnum.VCC_LO - 128], 1), # VCC = 1
|
||||
v_mov_b32_e32(v[0], 1.0), # S0
|
||||
v_mov_b32_e32(v[1], 1.0), # S1
|
||||
v_mov_b32_e32(v[2], 2.0), # S2 >= 2.0, so scale UP
|
||||
v_div_fmas_f32(v[3], v[0], v[1], v[2]), # 2^+64 * (1*1+2) = 2^+64 * 3
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
expected = 3.0 * (2.0 ** 64)
|
||||
self.assertAlmostEqual(i2f(st.vgpr[0][3]), expected, delta=abs(expected) * 1e-6)
|
||||
|
||||
def test_div_fmas_f32_scale_down(self):
|
||||
"""V_DIV_FMAS_F32: VCC=1 with S2 < 2.0 -> scale by 2^-64."""
|
||||
instructions = [
|
||||
s_mov_b32(s[SrcEnum.VCC_LO - 128], 1), # VCC = 1
|
||||
v_mov_b32_e32(v[0], 2.0), # S0
|
||||
v_mov_b32_e32(v[1], 3.0), # S1
|
||||
v_mov_b32_e32(v[2], 1.0), # S2 < 2.0, so scale DOWN
|
||||
v_div_fmas_f32(v[3], v[0], v[1], v[2]), # 2^-64 * (2*3+1) = 2^-64 * 7
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
expected = 7.0 * (2.0 ** -64)
|
||||
self.assertAlmostEqual(i2f(st.vgpr[0][3]), expected, delta=abs(expected) * 1e-6)
|
||||
|
||||
def test_div_fmas_f32_per_lane_vcc(self):
|
||||
"""V_DIV_FMAS_F32: different VCC per lane with S2 < 2.0."""
|
||||
instructions = [
|
||||
s_mov_b32(s[SrcEnum.VCC_LO - 128], 0b0101), # VCC: lanes 0,2 set
|
||||
v_mov_b32_e32(v[0], 1.0),
|
||||
v_mov_b32_e32(v[1], 1.0),
|
||||
v_mov_b32_e32(v[2], 1.0), # S2 < 2.0, so scale DOWN
|
||||
v_div_fmas_f32(v[3], v[0], v[1], v[2]), # fma(1,1,1) = 2, scaled = 2^-64 * 2
|
||||
]
|
||||
st = run_program(instructions, n_lanes=4)
|
||||
scaled = 2.0 * (2.0 ** -64)
|
||||
unscaled = 2.0
|
||||
self.assertAlmostEqual(i2f(st.vgpr[0][3]), scaled, delta=abs(scaled) * 1e-6) # lane 0: VCC=1
|
||||
self.assertAlmostEqual(i2f(st.vgpr[1][3]), unscaled, places=5) # lane 1: VCC=0
|
||||
self.assertAlmostEqual(i2f(st.vgpr[2][3]), scaled, delta=abs(scaled) * 1e-6) # lane 2: VCC=1
|
||||
self.assertAlmostEqual(i2f(st.vgpr[3][3]), unscaled, places=5) # lane 3: VCC=0
|
||||
|
||||
|
||||
class TestVDivFixup(unittest.TestCase):
|
||||
"""Tests for V_DIV_FIXUP_F32 edge cases.
|
||||
|
||||
V_DIV_FIXUP_F32 is the final step of Newton-Raphson division.
|
||||
It handles special cases: NaN, Inf, zero, overflow, underflow.
|
||||
|
||||
Args: S0=quotient from NR iteration, S1=denominator, S2=numerator
|
||||
"""
|
||||
|
||||
def test_div_fixup_f32_normal(self):
|
||||
"""V_DIV_FIXUP_F32: normal division passes through quotient."""
|
||||
# 6.0 / 2.0 = 3.0
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 3.0), # S0 = quotient
|
||||
v_mov_b32_e32(v[1], 2.0), # S1 = denominator
|
||||
v_mov_b32_e32(v[2], 6.0), # S2 = numerator
|
||||
v_div_fixup_f32(v[3], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertAlmostEqual(i2f(st.vgpr[0][3]), 3.0, places=5)
|
||||
|
||||
def test_div_fixup_f32_nan_numer(self):
|
||||
"""V_DIV_FIXUP_F32: NaN numerator -> quiet NaN."""
|
||||
nan = 0x7fc00000 # quiet NaN
|
||||
instructions = [
|
||||
s_mov_b32(s[0], nan),
|
||||
v_mov_b32_e32(v[0], 1.0), # S0 = quotient
|
||||
v_mov_b32_e32(v[1], 1.0), # S1 = denominator
|
||||
v_mov_b32_e32(v[2], s[0]), # S2 = numerator = NaN
|
||||
v_div_fixup_f32(v[3], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
import math
|
||||
self.assertTrue(math.isnan(i2f(st.vgpr[0][3])), "Should be NaN")
|
||||
|
||||
def test_div_fixup_f32_nan_denom(self):
|
||||
"""V_DIV_FIXUP_F32: NaN denominator -> quiet NaN."""
|
||||
nan = 0x7fc00000 # quiet NaN
|
||||
instructions = [
|
||||
s_mov_b32(s[0], nan),
|
||||
v_mov_b32_e32(v[0], 1.0), # S0 = quotient
|
||||
v_mov_b32_e32(v[1], s[0]), # S1 = denominator = NaN
|
||||
v_mov_b32_e32(v[2], 1.0), # S2 = numerator
|
||||
v_div_fixup_f32(v[3], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
import math
|
||||
self.assertTrue(math.isnan(i2f(st.vgpr[0][3])), "Should be NaN")
|
||||
|
||||
def test_div_fixup_f32_zero_div_zero(self):
|
||||
"""V_DIV_FIXUP_F32: 0/0 -> NaN (0xffc00000)."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 1.0), # S0 = quotient (doesn't matter)
|
||||
v_mov_b32_e32(v[1], 0.0), # S1 = denominator = 0
|
||||
v_mov_b32_e32(v[2], 0.0), # S2 = numerator = 0
|
||||
v_div_fixup_f32(v[3], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
import math
|
||||
self.assertTrue(math.isnan(i2f(st.vgpr[0][3])), "0/0 should be NaN")
|
||||
|
||||
def test_div_fixup_f32_inf_div_inf(self):
|
||||
"""V_DIV_FIXUP_F32: inf/inf -> NaN."""
|
||||
pos_inf = 0x7f800000
|
||||
instructions = [
|
||||
s_mov_b32(s[0], pos_inf),
|
||||
v_mov_b32_e32(v[0], 1.0), # S0 = quotient
|
||||
v_mov_b32_e32(v[1], s[0]), # S1 = denominator = +inf
|
||||
v_mov_b32_e32(v[2], s[0]), # S2 = numerator = +inf
|
||||
v_div_fixup_f32(v[3], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
import math
|
||||
self.assertTrue(math.isnan(i2f(st.vgpr[0][3])), "inf/inf should be NaN")
|
||||
|
||||
def test_div_fixup_f32_x_div_zero(self):
|
||||
"""V_DIV_FIXUP_F32: x/0 -> +/-inf based on sign."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 1.0), # S0 = quotient
|
||||
v_mov_b32_e32(v[1], 0.0), # S1 = denominator = 0
|
||||
v_mov_b32_e32(v[2], 1.0), # S2 = numerator = 1.0
|
||||
v_div_fixup_f32(v[3], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
import math
|
||||
self.assertTrue(math.isinf(i2f(st.vgpr[0][3])), "x/0 should be inf")
|
||||
self.assertGreater(i2f(st.vgpr[0][3]), 0, "1/0 should be +inf")
|
||||
|
||||
def test_div_fixup_f32_neg_x_div_zero(self):
|
||||
"""V_DIV_FIXUP_F32: -x/0 -> -inf."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 1.0), # S0 = quotient
|
||||
v_mov_b32_e32(v[1], 0.0), # S1 = denominator = 0
|
||||
v_mov_b32_e32(v[2], -1.0), # S2 = numerator = -1.0
|
||||
v_div_fixup_f32(v[3], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
import math
|
||||
self.assertTrue(math.isinf(i2f(st.vgpr[0][3])), "-x/0 should be inf")
|
||||
self.assertLess(i2f(st.vgpr[0][3]), 0, "-1/0 should be -inf")
|
||||
|
||||
def test_div_fixup_f32_zero_div_x(self):
|
||||
"""V_DIV_FIXUP_F32: 0/x -> 0."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 1.0), # S0 = quotient
|
||||
v_mov_b32_e32(v[1], 2.0), # S1 = denominator = 2.0
|
||||
v_mov_b32_e32(v[2], 0.0), # S2 = numerator = 0
|
||||
v_div_fixup_f32(v[3], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(i2f(st.vgpr[0][3]), 0.0, "0/x should be 0")
|
||||
|
||||
def test_div_fixup_f32_x_div_inf(self):
|
||||
"""V_DIV_FIXUP_F32: x/inf -> 0."""
|
||||
pos_inf = 0x7f800000
|
||||
instructions = [
|
||||
s_mov_b32(s[0], pos_inf),
|
||||
v_mov_b32_e32(v[0], 1.0), # S0 = quotient
|
||||
v_mov_b32_e32(v[1], s[0]), # S1 = denominator = +inf
|
||||
v_mov_b32_e32(v[2], 1.0), # S2 = numerator = 1.0
|
||||
v_div_fixup_f32(v[3], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
self.assertEqual(i2f(st.vgpr[0][3]), 0.0, "x/inf should be 0")
|
||||
|
||||
def test_div_fixup_f32_inf_div_x(self):
|
||||
"""V_DIV_FIXUP_F32: inf/x -> inf."""
|
||||
pos_inf = 0x7f800000
|
||||
instructions = [
|
||||
s_mov_b32(s[0], pos_inf),
|
||||
v_mov_b32_e32(v[0], 1.0), # S0 = quotient
|
||||
v_mov_b32_e32(v[1], 1.0), # S1 = denominator = 1.0
|
||||
v_mov_b32_e32(v[2], s[0]), # S2 = numerator = +inf
|
||||
v_div_fixup_f32(v[3], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
import math
|
||||
self.assertTrue(math.isinf(i2f(st.vgpr[0][3])), "inf/x should be inf")
|
||||
|
||||
def test_div_fixup_f32_sign_propagation(self):
|
||||
"""V_DIV_FIXUP_F32: sign is XOR of numer and denom signs."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 3.0), # S0 = |quotient|
|
||||
v_mov_b32_e32(v[1], -2.0), # S1 = denominator (negative)
|
||||
v_mov_b32_e32(v[2], 6.0), # S2 = numerator (positive)
|
||||
v_div_fixup_f32(v[3], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
# pos / neg = neg
|
||||
self.assertAlmostEqual(i2f(st.vgpr[0][3]), -3.0, places=5)
|
||||
|
||||
def test_div_fixup_f32_neg_neg(self):
|
||||
"""V_DIV_FIXUP_F32: neg/neg -> positive."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 3.0), # S0 = |quotient|
|
||||
v_mov_b32_e32(v[1], -2.0), # S1 = denominator (negative)
|
||||
v_mov_b32_e32(v[2], -6.0), # S2 = numerator (negative)
|
||||
v_div_fixup_f32(v[3], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
# neg / neg = pos
|
||||
self.assertAlmostEqual(i2f(st.vgpr[0][3]), 3.0, places=5)
|
||||
|
||||
def test_div_fixup_f32_nan_estimate_overflow(self):
|
||||
"""V_DIV_FIXUP_F32: NaN estimate returns overflow (inf).
|
||||
|
||||
PDF doesn't check isNAN(S0), but hardware returns OVERFLOW if S0 is NaN.
|
||||
This happens when division fails (e.g., denorm denominator in V_DIV_SCALE).
|
||||
"""
|
||||
quiet_nan = 0x7fc00000
|
||||
instructions = [
|
||||
s_mov_b32(s[0], quiet_nan),
|
||||
v_mov_b32_e32(v[0], s[0]), # S0 = NaN (failed estimate)
|
||||
v_mov_b32_e32(v[1], 1.0), # S1 = denominator = 1.0
|
||||
v_mov_b32_e32(v[2], 1.0), # S2 = numerator = 1.0
|
||||
v_div_fixup_f32(v[3], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
import math
|
||||
self.assertTrue(math.isinf(i2f(st.vgpr[0][3])), "NaN estimate should return inf")
|
||||
self.assertEqual(st.vgpr[0][3], 0x7f800000, "Should be +inf (pos/pos)")
|
||||
|
||||
def test_div_fixup_f32_nan_estimate_sign(self):
|
||||
"""V_DIV_FIXUP_F32: NaN estimate with negative sign returns -inf."""
|
||||
quiet_nan = 0x7fc00000
|
||||
instructions = [
|
||||
s_mov_b32(s[0], quiet_nan),
|
||||
v_mov_b32_e32(v[0], s[0]), # S0 = NaN (failed estimate)
|
||||
v_mov_b32_e32(v[1], -1.0), # S1 = denominator = -1.0
|
||||
v_mov_b32_e32(v[2], 1.0), # S2 = numerator = 1.0
|
||||
v_div_fixup_f32(v[3], v[0], v[1], v[2]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
import math
|
||||
self.assertTrue(math.isinf(i2f(st.vgpr[0][3])), "NaN estimate should return inf")
|
||||
self.assertEqual(st.vgpr[0][3], 0xff800000, "Should be -inf (pos/neg)")
|
||||
|
||||
|
||||
class TestVCmpClass(unittest.TestCase):
|
||||
"""Tests for V_CMP_CLASS_F32 float classification."""
|
||||
|
||||
Reference in New Issue
Block a user