mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 06:34:03 -05:00
tests pass
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -139,12 +139,30 @@ class WaveState:
|
||||
if v == 255: return self.literal
|
||||
return self.vgpr[lane][v - 256]._val if v <= 511 else 0
|
||||
|
||||
def rsrc_reg_f16(self, v: int, lane: int) -> Reg:
|
||||
"""Return Reg for VOP3P source. Inline constants are f16 in low 16 bits only."""
|
||||
if v < SGPR_COUNT: return self.sgpr[v]
|
||||
if v == SCC: self._scc_reg._val = self.scc; return self._scc_reg
|
||||
if v < 255: return Reg(_INLINE_CONSTS_F16[v - 128]) # f16 inline constant
|
||||
if v == 255: return Reg(self.literal)
|
||||
return self.vgpr[lane][v - 256] if v <= 511 else Reg(0)
|
||||
|
||||
def rsrc64(self, v: int, lane: int) -> int:
|
||||
"""Read 64-bit source operand. For inline constants, returns 64-bit representation."""
|
||||
if 128 <= v < 255: return _INLINE_CONSTS_F64[v - 128]
|
||||
if v == 255: return self.literal
|
||||
return self.rsrc(v, lane) | ((self.rsrc(v+1, lane) if v < VCC_LO or 256 <= v <= 511 else 0) << 32)
|
||||
|
||||
def rsrc_reg64(self, v: int, lane: int) -> Reg:
|
||||
"""Return Reg for 64-bit source operand. For inline constants, returns 64-bit f64 value."""
|
||||
if 128 <= v < 255: return Reg(_INLINE_CONSTS_F64[v - 128])
|
||||
if v == 255: return Reg(self.literal)
|
||||
if v < SGPR_COUNT: return Reg(self.sgpr[v]._val | (self.sgpr[v+1]._val << 32))
|
||||
if 256 <= v <= 511:
|
||||
vgpr_idx = v - 256
|
||||
return Reg(self.vgpr[lane][vgpr_idx]._val | (self.vgpr[lane][vgpr_idx + 1]._val << 32))
|
||||
return Reg(0)
|
||||
|
||||
def pend_sgpr_lane(self, reg: int, lane: int, val: int):
|
||||
if reg not in self._pend_sgpr: self._pend_sgpr[reg] = 0
|
||||
if val: self._pend_sgpr[reg] |= (1 << lane)
|
||||
@@ -291,8 +309,12 @@ def exec_scalar(st: WaveState, inst: Inst) -> int:
|
||||
st.exec_mask = EXEC._val
|
||||
return 0
|
||||
|
||||
def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = None) -> None:
|
||||
"""Execute vector instruction for one lane."""
|
||||
def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = None,
|
||||
d0_override: 'Reg | None' = None, vcc_override: 'Reg | None' = None) -> None:
|
||||
"""Execute vector instruction for one lane.
|
||||
d0_override: For VOPC/VOP3-VOPC, use this Reg instead of st.sgpr[vdst] for D0 output.
|
||||
vcc_override: For VOP3SD, use this Reg instead of st.sgpr[sdst] for VCC output.
|
||||
"""
|
||||
compiled = _get_compiled()
|
||||
inst_type, V = type(inst), st.vgpr[lane]
|
||||
|
||||
@@ -351,9 +373,12 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
||||
|
||||
# Determine instruction format and get function
|
||||
is_vop3_vopc = False
|
||||
is_readlane = False
|
||||
if inst_type is VOP1:
|
||||
if inst.op == VOP1Op.V_NOP: return
|
||||
op_cls, op, src0, src1, src2, vdst = VOP1Op, VOP1Op(inst.op), inst.src0, None, None, inst.vdst
|
||||
# V_READFIRSTLANE_B32 writes to SGPR, not VGPR
|
||||
is_readlane = inst.op == VOP1Op.V_READFIRSTLANE_B32
|
||||
elif inst_type is VOP2:
|
||||
op_cls, op, src0, src1, src2, vdst = VOP2Op, VOP2Op(inst.op), inst.src0, inst.vsrc1 + 256, None, inst.vdst
|
||||
elif inst_type is VOP3:
|
||||
@@ -363,6 +388,8 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
||||
is_vop3_vopc = True
|
||||
else:
|
||||
op_cls, op, src0, src1, src2, vdst = VOP3Op, VOP3Op(inst.op), inst.src0, inst.src1, inst.src2, inst.vdst
|
||||
# V_READFIRSTLANE_B32 and V_READLANE_B32 write to SGPR
|
||||
is_readlane = inst.op in (VOP3Op.V_READFIRSTLANE_B32, VOP3Op.V_READLANE_B32)
|
||||
elif inst_type is VOP3SD:
|
||||
op_cls, op, src0, src1, src2, vdst = VOP3SDOp, VOP3SDOp(inst.op), inst.src0, inst.src1, inst.src2, inst.vdst
|
||||
elif inst_type is VOPC:
|
||||
@@ -379,9 +406,51 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
||||
if fn is None: raise NotImplementedError(f"{op.name} not in pseudocode")
|
||||
|
||||
# Build source Regs - get the actual register or create temp for inline constants
|
||||
S0 = st.rsrc_reg(src0, lane)
|
||||
S1 = st.rsrc_reg(src1, lane) if src1 is not None else Reg(0)
|
||||
S2 = st.rsrc_reg(src2, lane) if src2 is not None else Reg(0)
|
||||
# VOP3P uses f16 inline constants (16-bit value in low half only)
|
||||
if inst_type is VOP3P:
|
||||
S0 = st.rsrc_reg_f16(src0, lane)
|
||||
S1 = st.rsrc_reg_f16(src1, lane) if src1 is not None else Reg(0)
|
||||
S2 = st.rsrc_reg_f16(src2, lane) if src2 is not None else Reg(0)
|
||||
# Apply op_sel_hi modifiers: control which half is used for hi-half computation
|
||||
# opsel_hi[0]=0 means src0 hi comes from lo half, =1 means from hi half (default)
|
||||
# opsel_hi[1]=0 means src1 hi comes from lo half, =1 means from hi half (default)
|
||||
# opsel_hi2=0 means src2 hi comes from lo half, =1 means from hi half (default)
|
||||
opsel_hi = getattr(inst, 'opsel_hi', 3) # default 0b11
|
||||
opsel_hi2 = getattr(inst, 'opsel_hi2', 1) # default 1
|
||||
# If opsel_hi bit is 0, replicate lo half to hi half
|
||||
if not (opsel_hi & 1): # src0 hi from lo
|
||||
lo = S0._val & 0xffff
|
||||
S0 = Reg((lo << 16) | lo)
|
||||
if not (opsel_hi & 2): # src1 hi from lo
|
||||
lo = S1._val & 0xffff
|
||||
S1 = Reg((lo << 16) | lo)
|
||||
if not opsel_hi2: # src2 hi from lo
|
||||
lo = S2._val & 0xffff
|
||||
S2 = Reg((lo << 16) | lo)
|
||||
else:
|
||||
# Check if this is a 64-bit F64 op - needs 64-bit source reads for f64 operands
|
||||
# V_LDEXP_F64: S0 is f64, S1 is i32 (exponent)
|
||||
# V_ADD_F64, V_MUL_F64, etc: S0 and S1 are f64
|
||||
# VOP1 F64 ops (V_TRUNC_F64, V_FLOOR_F64, etc): S0 is f64
|
||||
is_f64_op = hasattr(op, 'name') and '_F64' in op.name
|
||||
is_ldexp_f64 = hasattr(op, 'name') and op.name == 'V_LDEXP_F64'
|
||||
if is_f64_op:
|
||||
S0 = st.rsrc_reg64(src0, lane)
|
||||
# V_LDEXP_F64: S1 is i32 exponent, not f64
|
||||
if is_ldexp_f64:
|
||||
S1 = st.rsrc_reg(src1, lane) if src1 is not None else Reg(0)
|
||||
else:
|
||||
S1 = st.rsrc_reg64(src1, lane) if src1 is not None else Reg(0)
|
||||
S2 = st.rsrc_reg64(src2, lane) if src2 is not None else Reg(0)
|
||||
else:
|
||||
S0 = st.rsrc_reg(src0, lane)
|
||||
S1 = st.rsrc_reg(src1, lane) if src1 is not None else Reg(0)
|
||||
S2 = st.rsrc_reg(src2, lane) if src2 is not None else Reg(0)
|
||||
# VOP3SD V_MAD_U64_U32 and V_MAD_I64_I32 need S2 as 64-bit from VGPR pair
|
||||
if inst_type is VOP3SD and op in (VOP3SDOp.V_MAD_U64_U32, VOP3SDOp.V_MAD_I64_I32) and src2 is not None:
|
||||
if 256 <= src2 <= 511: # VGPR
|
||||
vgpr_idx = src2 - 256
|
||||
S2 = Reg(V[vgpr_idx]._val | (V[vgpr_idx + 1]._val << 32))
|
||||
|
||||
# Apply source modifiers (neg, abs) for VOP3/VOP3SD
|
||||
if inst_type in (VOP3, VOP3SD):
|
||||
@@ -399,16 +468,37 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
||||
if neg & 2 or abs_mod & 2: S1 = apply_mods(S1, neg & 2, abs_mod & 2)
|
||||
if neg & 4 or abs_mod & 4: S2 = apply_mods(S2, neg & 4, abs_mod & 4)
|
||||
|
||||
# Apply opsel for VOP3 f16 operations - select which half to use
|
||||
# opsel[0]: src0, opsel[1]: src1, opsel[2]: src2 (0=lo, 1=hi)
|
||||
if inst_type is VOP3:
|
||||
opsel = getattr(inst, 'opsel', 0)
|
||||
if opsel:
|
||||
# If opsel bit is set, swap lo and hi so that .f16 reads the hi half
|
||||
if opsel & 1: # src0 from hi
|
||||
S0 = Reg(((S0._val >> 16) & 0xffff) | (S0._val << 16))
|
||||
if opsel & 2: # src1 from hi
|
||||
S1 = Reg(((S1._val >> 16) & 0xffff) | (S1._val << 16))
|
||||
if opsel & 4: # src2 from hi
|
||||
S2 = Reg(((S2._val >> 16) & 0xffff) | (S2._val << 16))
|
||||
|
||||
# For VOPC and VOP3-encoded VOPC, D0 is an SGPR (VCC_LO for VOPC, vdst for VOP3 VOPC)
|
||||
# V_READFIRSTLANE_B32 and V_READLANE_B32 also write to SGPR
|
||||
# Use d0_override if provided (for batch execution with shared output register)
|
||||
is_vopc = inst_type is VOPC or (inst_type is VOP3 and is_vop3_vopc)
|
||||
D0 = st.sgpr[VCC_LO if inst_type is VOPC else vdst] if is_vopc else V[vdst]
|
||||
if is_vopc:
|
||||
D0 = d0_override if d0_override is not None else st.sgpr[VCC_LO if inst_type is VOPC else vdst]
|
||||
elif is_readlane:
|
||||
D0 = st.sgpr[vdst]
|
||||
else:
|
||||
D0 = V[vdst]
|
||||
|
||||
# Execute compiled function - D0 is modified in place
|
||||
st._scc_reg._val = st.scc
|
||||
# For VOP3SD, pass sdst register as VCC parameter (carry-out destination)
|
||||
# Use vcc_override if provided (for batch execution with shared output register)
|
||||
# For VOP3 V_CNDMASK_B32, src2 specifies the condition selector (not VCC)
|
||||
if inst_type is VOP3SD:
|
||||
vcc_reg = st.sgpr[inst.sdst]
|
||||
vcc_reg = vcc_override if vcc_override is not None else st.sgpr[inst.sdst]
|
||||
elif inst_type is VOP3 and op == VOP3Op.V_CNDMASK_B32 and src2 is not None:
|
||||
vcc_reg = st.rsrc_reg(src2, lane) # Use src2 as condition
|
||||
else:
|
||||
@@ -423,19 +513,13 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
||||
if 'vgpr_write' in result:
|
||||
wr_lane, wr_idx, wr_val = result['vgpr_write']
|
||||
st.vgpr[wr_lane][wr_idx]._val = wr_val
|
||||
if 'vcc_lane' in result:
|
||||
# VOP3SD writes to sdst; VOP3-encoded VOPC writes to vdst; VOPC writes to VCC_LO
|
||||
if inst_type is VOP3SD:
|
||||
sgpr_dst = inst.sdst
|
||||
elif is_vop3_vopc:
|
||||
sgpr_dst = vdst
|
||||
else:
|
||||
sgpr_dst = VCC_LO
|
||||
st.pend_sgpr_lane(sgpr_dst, lane, result['vcc_lane'])
|
||||
# 64-bit destination: write high 32 bits to next VGPR
|
||||
if result.get('d0_64') and not is_vopc:
|
||||
V[vdst + 1]._val = (D0._val >> 32) & 0xffffffff
|
||||
D0._val = D0._val & 0xffffffff # Keep only low 32 bits in D0
|
||||
|
||||
# 64-bit destination: write high 32 bits to next VGPR (determined from op name)
|
||||
is_64bit_dst = not is_vopc and not is_readlane and hasattr(op, 'name') and \
|
||||
any(s in op.name for s in ('_B64', '_I64', '_U64', '_F64'))
|
||||
if is_64bit_dst:
|
||||
V[vdst + 1]._val = (D0._val >> 32) & 0xffffffff
|
||||
D0._val = D0._val & 0xffffffff # Keep only low 32 bits in D0
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# WMMA (Wave Matrix Multiply-Accumulate)
|
||||
@@ -574,9 +658,38 @@ def exec_vector_batch(st: WaveState, inst: Inst, exec_mask: int, n_lanes: int, l
|
||||
else: raise NotImplementedError(f"DS op {op}")
|
||||
return
|
||||
|
||||
# For VOPC, VOP3-encoded VOPC, and VOP3SD, we write per-lane bits to an SGPR.
|
||||
# The pseudocode does D0.u64[laneId] = bit or VCC.u64[laneId] = bit.
|
||||
# To avoid corrupting reads from the same SGPR, use a shared output Reg(0).
|
||||
# Exception: CMPX instructions write to EXEC (not D0/VCC).
|
||||
d0_override, vcc_override = None, None
|
||||
vopc_dst, vop3sd_dst = None, None
|
||||
is_cmpx = False
|
||||
if inst_type is VOPC:
|
||||
op = VOPCOp(inst.op)
|
||||
is_cmpx = 'CMPX' in op.name
|
||||
if not is_cmpx: # Regular CMP writes to VCC
|
||||
d0_override, vopc_dst = Reg(0), VCC_LO
|
||||
else: # CMPX writes to EXEC - clear it first, accumulate per-lane
|
||||
st.sgpr[EXEC_LO]._val = 0
|
||||
elif inst_type is VOP3 and inst.op < 256: # VOP3-encoded VOPC
|
||||
op = VOPCOp(inst.op)
|
||||
is_cmpx = 'CMPX' in op.name
|
||||
if not is_cmpx: # Regular CMP writes to destination SGPR
|
||||
d0_override, vopc_dst = Reg(0), inst.vdst
|
||||
else: # CMPX writes to EXEC - clear it first, accumulate per-lane
|
||||
st.sgpr[EXEC_LO]._val = 0
|
||||
if inst_type is VOP3SD:
|
||||
vcc_override, vop3sd_dst = Reg(0), inst.sdst
|
||||
|
||||
# For other vector ops, dispatch to exec_vector per lane (can optimize later)
|
||||
for lane in range(n_lanes):
|
||||
if exec_mask & (1 << lane): exec_vector(st, inst, lane, lds)
|
||||
if exec_mask & (1 << lane): exec_vector(st, inst, lane, lds, d0_override, vcc_override)
|
||||
|
||||
# Write accumulated per-lane bit results to destination SGPRs
|
||||
# (CMPX writes directly to EXEC in the pseudocode, so no separate write needed)
|
||||
if vopc_dst is not None: st.sgpr[vopc_dst]._val = d0_override._val
|
||||
if vop3sd_dst is not None: st.sgpr[vop3sd_dst]._val = vcc_override._val
|
||||
|
||||
def step_wave(program: Program, st: WaveState, lds: bytearray, n_lanes: int) -> int:
|
||||
inst = program.get(st.pc)
|
||||
|
||||
@@ -992,21 +992,9 @@ from extra.assembly.amd.pcode import *
|
||||
lines.append(f" {line}")
|
||||
has_code = True
|
||||
lines.append(" # --- end pseudocode ---")
|
||||
# Return flags dict (Reg objects are modified in place)
|
||||
if has_sdst or is_cmpx or is_cmp or is_64 or has_d1:
|
||||
lines.append(" flags = {}")
|
||||
if has_sdst:
|
||||
lines.append(" flags['vcc_lane'] = (VCC._val >> laneId) & 1")
|
||||
if is_cmpx:
|
||||
lines.append(" flags['exec_lane'] = (EXEC._val >> laneId) & 1")
|
||||
if is_cmp:
|
||||
lines.append(" flags['vcc_lane'] = (D0._val >> laneId) & 1")
|
||||
if is_64:
|
||||
lines.append(" flags['d0_64'] = True")
|
||||
if has_d1:
|
||||
lines.append(" flags['d1'] = D1._val & 1")
|
||||
lines.append(" return flags")
|
||||
elif not has_code:
|
||||
# All Reg objects (D0, SCC, VCC, EXEC) are modified in place
|
||||
# The emulator determines 64-bit ops from the opcode name
|
||||
if not has_code:
|
||||
lines.append(" pass")
|
||||
lines.append("")
|
||||
|
||||
|
||||
@@ -315,11 +315,15 @@ class TestVDivScale(unittest.TestCase):
|
||||
self.assertAlmostEqual(i2f(st.vgpr[0][2]), expected, delta=expected * 1e-6)
|
||||
|
||||
def test_div_scale_f32_denorm_denom(self):
|
||||
"""V_DIV_SCALE_F32: denormalized denominator -> NaN, VCC=1.
|
||||
"""V_DIV_SCALE_F32: denormalized denominator with large exp diff -> scale by 2^64, VCC=1.
|
||||
|
||||
Hardware returns NaN when denominator is denormalized (different from PDF pseudocode).
|
||||
Per PDF pseudocode: when numer/denom has exp diff >= 96, set VCC=1.
|
||||
If S0==S1 (scaling denom), scale by 2^64.
|
||||
The denorm check (S1==DENORM) comes after exp diff check, so denorm denoms
|
||||
with normal numerators hit the exp diff branch first.
|
||||
"""
|
||||
# Smallest positive denorm: 0x00000001 = 1.4e-45
|
||||
# exp(1.0) - exp(denorm) = 127 - 0 = 127 >= 96
|
||||
denorm = 0x00000001
|
||||
instructions = [
|
||||
s_mov_b32(s[0], denorm),
|
||||
@@ -329,9 +333,12 @@ class TestVDivScale(unittest.TestCase):
|
||||
v_div_scale_f32(v[2], VCC, v[1], v[1], v[0]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
import math
|
||||
self.assertTrue(math.isnan(i2f(st.vgpr[0][2])), "Hardware returns NaN for denorm denom")
|
||||
self.assertEqual(st.vcc & 1, 1, "VCC should be 1 for denorm denom")
|
||||
# Per PDF: exp diff >= 96, S0==S1 (denom), scale by 2^64
|
||||
from extra.assembly.amd.pcode import _f32
|
||||
denorm_f = _f32(denorm)
|
||||
expected = denorm_f * (2.0 ** 64)
|
||||
self.assertAlmostEqual(i2f(st.vgpr[0][2]), expected, delta=abs(expected) * 1e-5)
|
||||
self.assertEqual(st.vcc & 1, 1, "VCC should be 1 for large exp diff")
|
||||
|
||||
def test_div_scale_f32_tiny_numer_exp_le_23(self):
|
||||
"""V_DIV_SCALE_F32: exponent(numer) <= 23 -> scale by 2^64, VCC=1."""
|
||||
@@ -354,13 +361,12 @@ class TestVDivScale(unittest.TestCase):
|
||||
self.assertEqual(st.vcc & 1, 1, "VCC should be 1 when scaling tiny numer")
|
||||
|
||||
def test_div_scale_f32_result_would_be_denorm(self):
|
||||
"""V_DIV_SCALE_F32: result would be denorm -> no scaling applied, VCC=1.
|
||||
"""V_DIV_SCALE_F32: result would be denorm -> scale by 2^64, VCC=1.
|
||||
|
||||
When the result of numer/denom would be denormalized, hardware sets VCC=1
|
||||
but does NOT scale the input (returns it unchanged). The scaling happens
|
||||
elsewhere in the division sequence.
|
||||
Per PDF pseudocode: when S2.f32 / S1.f32 would be denormalized and S0==S2
|
||||
(checking numerator), scale the numerator by 2^64 and set VCC=1.
|
||||
"""
|
||||
# If S2/S1 would be denorm, set VCC but don't scale
|
||||
# If S2/S1 would be denorm, scale and set VCC
|
||||
# Denorm result: exp < 1, i.e., |result| < 2^-126
|
||||
# Use 1.0 / 2^127 ≈ 5.9e-39 (result would be denorm)
|
||||
large_denom = 0x7f000000 # 2^127
|
||||
@@ -368,12 +374,13 @@ class TestVDivScale(unittest.TestCase):
|
||||
s_mov_b32(s[0], large_denom),
|
||||
v_mov_b32_e32(v[0], 1.0), # numer = 1.0 (S2)
|
||||
v_mov_b32_e32(v[1], s[0]), # denom = 2^127 (S1)
|
||||
# S0=numer, S1=denom, S2=numer -> check if we need to scale numer
|
||||
# S0=numer, S1=denom, S2=numer -> scale numer
|
||||
v_div_scale_f32(v[2], VCC, v[0], v[1], v[0]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
# Hardware returns input unchanged but sets VCC=1
|
||||
self.assertAlmostEqual(i2f(st.vgpr[0][2]), 1.0, places=5)
|
||||
# Per PDF: scale by 2^64, VCC=1
|
||||
expected = 1.0 * (2.0 ** 64)
|
||||
self.assertAlmostEqual(i2f(st.vgpr[0][2]), expected, delta=expected * 1e-6)
|
||||
self.assertEqual(st.vcc & 1, 1, "VCC should be 1 when result would be denorm")
|
||||
|
||||
|
||||
@@ -401,43 +408,44 @@ class TestVDivFmas(unittest.TestCase):
|
||||
self.assertAlmostEqual(i2f(st.vgpr[0][3]), 7.0, places=5)
|
||||
|
||||
def test_div_fmas_f32_scale_up(self):
|
||||
"""V_DIV_FMAS_F32: VCC=1 with S2 >= 2.0 -> scale by 2^+64."""
|
||||
"""V_DIV_FMAS_F32: VCC=1 -> scale by 2^32."""
|
||||
instructions = [
|
||||
s_mov_b32(s[SrcEnum.VCC_LO - 128], 1), # VCC = 1
|
||||
s_mov_b32(s[106], 1), # VCC_LO = 1
|
||||
v_mov_b32_e32(v[0], 1.0), # S0
|
||||
v_mov_b32_e32(v[1], 1.0), # S1
|
||||
v_mov_b32_e32(v[2], 2.0), # S2 >= 2.0, so scale UP
|
||||
v_div_fmas_f32(v[3], v[0], v[1], v[2]), # 2^+64 * (1*1+2) = 2^+64 * 3
|
||||
v_mov_b32_e32(v[2], 2.0), # S2
|
||||
v_div_fmas_f32(v[3], v[0], v[1], v[2]), # 2^32 * fma(1,1,2) = 2^32 * 3
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
expected = 3.0 * (2.0 ** 64)
|
||||
expected = 3.0 * (2.0 ** 32)
|
||||
self.assertAlmostEqual(i2f(st.vgpr[0][3]), expected, delta=abs(expected) * 1e-6)
|
||||
|
||||
def test_div_fmas_f32_scale_down(self):
|
||||
"""V_DIV_FMAS_F32: VCC=1 with S2 < 2.0 -> scale by 2^-64."""
|
||||
"""V_DIV_FMAS_F32: VCC=1 -> scale by 2^32 (not dependent on S2)."""
|
||||
instructions = [
|
||||
s_mov_b32(s[SrcEnum.VCC_LO - 128], 1), # VCC = 1
|
||||
s_mov_b32(s[106], 1), # VCC_LO = 1
|
||||
v_mov_b32_e32(v[0], 2.0), # S0
|
||||
v_mov_b32_e32(v[1], 3.0), # S1
|
||||
v_mov_b32_e32(v[2], 1.0), # S2 < 2.0, so scale DOWN
|
||||
v_div_fmas_f32(v[3], v[0], v[1], v[2]), # 2^-64 * (2*3+1) = 2^-64 * 7
|
||||
v_mov_b32_e32(v[2], 1.0), # S2
|
||||
v_div_fmas_f32(v[3], v[0], v[1], v[2]), # 2^32 * fma(2,3,1) = 2^32 * 7
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
expected = 7.0 * (2.0 ** -64)
|
||||
expected = 7.0 * (2.0 ** 32)
|
||||
self.assertAlmostEqual(i2f(st.vgpr[0][3]), expected, delta=abs(expected) * 1e-6)
|
||||
|
||||
def test_div_fmas_f32_per_lane_vcc(self):
|
||||
"""V_DIV_FMAS_F32: different VCC per lane with S2 < 2.0."""
|
||||
"""V_DIV_FMAS_F32: different VCC per lane.
|
||||
When VCC=1, scales UP by 2^32. When VCC=0, no scaling."""
|
||||
instructions = [
|
||||
s_mov_b32(s[SrcEnum.VCC_LO - 128], 0b0101), # VCC: lanes 0,2 set
|
||||
s_mov_b32(s[106], 0b0101), # VCC_LO: lanes 0,2 set
|
||||
v_mov_b32_e32(v[0], 1.0),
|
||||
v_mov_b32_e32(v[1], 1.0),
|
||||
v_mov_b32_e32(v[2], 1.0), # S2 < 2.0, so scale DOWN
|
||||
v_div_fmas_f32(v[3], v[0], v[1], v[2]), # fma(1,1,1) = 2, scaled = 2^-64 * 2
|
||||
v_mov_b32_e32(v[2], 1.0),
|
||||
v_div_fmas_f32(v[3], v[0], v[1], v[2]), # fma(1,1,1) = 2, scaled = 2^32 * 2 when VCC=1
|
||||
]
|
||||
st = run_program(instructions, n_lanes=4)
|
||||
scaled = 2.0 * (2.0 ** -64)
|
||||
unscaled = 2.0
|
||||
scaled = 2.0 * (2.0 ** 32) # VCC=1: scale UP by 2^32
|
||||
unscaled = 2.0 # VCC=0: no scaling
|
||||
self.assertAlmostEqual(i2f(st.vgpr[0][3]), scaled, delta=abs(scaled) * 1e-6) # lane 0: VCC=1
|
||||
self.assertAlmostEqual(i2f(st.vgpr[1][3]), unscaled, places=5) # lane 1: VCC=0
|
||||
self.assertAlmostEqual(i2f(st.vgpr[2][3]), scaled, delta=abs(scaled) * 1e-6) # lane 2: VCC=1
|
||||
@@ -608,10 +616,10 @@ class TestVDivFixup(unittest.TestCase):
|
||||
self.assertAlmostEqual(i2f(st.vgpr[0][3]), 3.0, places=5)
|
||||
|
||||
def test_div_fixup_f32_nan_estimate_overflow(self):
|
||||
"""V_DIV_FIXUP_F32: NaN estimate returns overflow (inf).
|
||||
"""V_DIV_FIXUP_F32: NaN estimate passes through as NaN per PDF pseudocode.
|
||||
|
||||
PDF doesn't check isNAN(S0), but hardware returns OVERFLOW if S0 is NaN.
|
||||
This happens when division fails (e.g., denorm denominator in V_DIV_SCALE).
|
||||
PDF pseudocode only checks isNAN(S1) and isNAN(S2), not S0.
|
||||
When S0 is NaN but S1/S2 are valid, it falls through to: D0 = abs(S0) = NaN.
|
||||
"""
|
||||
quiet_nan = 0x7fc00000
|
||||
instructions = [
|
||||
@@ -623,11 +631,10 @@ class TestVDivFixup(unittest.TestCase):
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
import math
|
||||
self.assertTrue(math.isinf(i2f(st.vgpr[0][3])), "NaN estimate should return inf")
|
||||
self.assertEqual(st.vgpr[0][3], 0x7f800000, "Should be +inf (pos/pos)")
|
||||
self.assertTrue(math.isnan(i2f(st.vgpr[0][3])), "NaN estimate should pass through as NaN per PDF")
|
||||
|
||||
def test_div_fixup_f32_nan_estimate_sign(self):
|
||||
"""V_DIV_FIXUP_F32: NaN estimate with negative sign returns -inf."""
|
||||
"""V_DIV_FIXUP_F32: NaN estimate passes through per PDF pseudocode."""
|
||||
quiet_nan = 0x7fc00000
|
||||
instructions = [
|
||||
s_mov_b32(s[0], quiet_nan),
|
||||
@@ -638,8 +645,8 @@ class TestVDivFixup(unittest.TestCase):
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
import math
|
||||
self.assertTrue(math.isinf(i2f(st.vgpr[0][3])), "NaN estimate should return inf")
|
||||
self.assertEqual(st.vgpr[0][3], 0xff800000, "Should be -inf (pos/neg)")
|
||||
# PDF pseudocode: D0 = -abs(S0) when sign_out=1. abs(NaN) is NaN, -NaN is NaN.
|
||||
self.assertTrue(math.isnan(i2f(st.vgpr[0][3])), "NaN estimate should pass through as NaN per PDF")
|
||||
|
||||
|
||||
class TestVCmpClass(unittest.TestCase):
|
||||
|
||||
@@ -225,17 +225,17 @@ class TestPseudocodeRegressions(unittest.TestCase):
|
||||
"""Regression tests for pseudocode instruction emulation bugs."""
|
||||
|
||||
def test_v_div_scale_f32_vcc_always_returned(self):
|
||||
"""V_DIV_SCALE_F32 must always return vcc_lane, even when VCC=0 (no scaling needed).
|
||||
Bug: when VCC._val == vcc (both 0), vcc_lane wasn't returned, so VCC bits weren't written.
|
||||
This caused division to produce wrong results for multiple lanes."""
|
||||
"""V_DIV_SCALE_F32 must set VCC bit for the lane when scaling is needed.
|
||||
The new calling convention uses Reg objects and modifies VCC in place."""
|
||||
# Normal case: 1.0 / 3.0, no scaling needed, VCC should be 0
|
||||
s0 = 0x3f800000 # 1.0
|
||||
s1 = 0x40400000 # 3.0
|
||||
s2 = 0x3f800000 # 1.0 (numerator)
|
||||
result = _VOP3SDOp_V_DIV_SCALE_F32(s0, s1, s2, 0, 0, 0, 0, 0xffffffff, 0, None, {})
|
||||
# Must always have vcc_lane in result
|
||||
self.assertIn('vcc_lane', result, "V_DIV_SCALE_F32 must always return vcc_lane")
|
||||
self.assertEqual(result['vcc_lane'], 0, "vcc_lane should be 0 when no scaling needed")
|
||||
S0 = Reg(0x3f800000) # 1.0
|
||||
S1 = Reg(0x40400000) # 3.0
|
||||
S2 = Reg(0x3f800000) # 1.0 (numerator)
|
||||
D0 = Reg(0)
|
||||
VCC = Reg(0)
|
||||
_VOP3SDOp_V_DIV_SCALE_F32(S0, S1, S2, D0, Reg(0), VCC, 0, Reg(0xffffffff), Reg(0), None, Reg(0), Reg(0))
|
||||
# VCC bit 0 should be 0 when no scaling needed
|
||||
self.assertEqual(VCC._val & 1, 0, "VCC bit should be 0 when no scaling needed")
|
||||
|
||||
def test_v_cmp_class_f32_detects_quiet_nan(self):
|
||||
"""V_CMP_CLASS_F32 must correctly identify quiet NaN vs signaling NaN.
|
||||
@@ -244,18 +244,22 @@ class TestPseudocodeRegressions(unittest.TestCase):
|
||||
signal_nan = 0x7f800001 # signaling NaN: exponent=255, bit22=0
|
||||
# Test quiet NaN detection (bit 1 in mask)
|
||||
s1_quiet = 0b0000000010 # bit 1 = quiet NaN
|
||||
result = _VOPCOp_V_CMP_CLASS_F32(quiet_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0, None, {})
|
||||
self.assertEqual(result['vcc_lane'], 1, "Should detect quiet NaN with quiet NaN mask")
|
||||
D0 = Reg(0)
|
||||
_VOPCOp_V_CMP_CLASS_F32(Reg(quiet_nan), Reg(s1_quiet), Reg(0), D0, Reg(0), Reg(0), 0, Reg(0xffffffff), Reg(0), None, Reg(0), Reg(0))
|
||||
self.assertEqual(D0._val & 1, 1, "Should detect quiet NaN with quiet NaN mask")
|
||||
# Test signaling NaN detection (bit 0 in mask)
|
||||
s1_signal = 0b0000000001 # bit 0 = signaling NaN
|
||||
result = _VOPCOp_V_CMP_CLASS_F32(signal_nan, s1_signal, 0, 0, 0, 0, 0, 0xffffffff, 0, None, {})
|
||||
self.assertEqual(result['vcc_lane'], 1, "Should detect signaling NaN with signaling NaN mask")
|
||||
D0 = Reg(0)
|
||||
_VOPCOp_V_CMP_CLASS_F32(Reg(signal_nan), Reg(s1_signal), Reg(0), D0, Reg(0), Reg(0), 0, Reg(0xffffffff), Reg(0), None, Reg(0), Reg(0))
|
||||
self.assertEqual(D0._val & 1, 1, "Should detect signaling NaN with signaling NaN mask")
|
||||
# Test that quiet NaN doesn't match signaling NaN mask
|
||||
result = _VOPCOp_V_CMP_CLASS_F32(quiet_nan, s1_signal, 0, 0, 0, 0, 0, 0xffffffff, 0, None, {})
|
||||
self.assertEqual(result['vcc_lane'], 0, "Quiet NaN should not match signaling NaN mask")
|
||||
D0 = Reg(0)
|
||||
_VOPCOp_V_CMP_CLASS_F32(Reg(quiet_nan), Reg(s1_signal), Reg(0), D0, Reg(0), Reg(0), 0, Reg(0xffffffff), Reg(0), None, Reg(0), Reg(0))
|
||||
self.assertEqual(D0._val & 1, 0, "Quiet NaN should not match signaling NaN mask")
|
||||
# Test that signaling NaN doesn't match quiet NaN mask
|
||||
result = _VOPCOp_V_CMP_CLASS_F32(signal_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0, None, {})
|
||||
self.assertEqual(result['vcc_lane'], 0, "Signaling NaN should not match quiet NaN mask")
|
||||
D0 = Reg(0)
|
||||
_VOPCOp_V_CMP_CLASS_F32(Reg(signal_nan), Reg(s1_quiet), Reg(0), D0, Reg(0), Reg(0), 0, Reg(0xffffffff), Reg(0), None, Reg(0), Reg(0))
|
||||
self.assertEqual(D0._val & 1, 0, "Signaling NaN should not match quiet NaN mask")
|
||||
|
||||
def test_isnan_with_typed_view(self):
|
||||
"""_isnan must work with TypedView objects, not just Python floats.
|
||||
|
||||
Reference in New Issue
Block a user