assembly/amd: add dtype tests to AMD IDE CI (#13899)

* add dtype tests to AMD IDE CI

* more tests

* add trig preop

* regen done

* split to amd autogen

* simpler
This commit is contained in:
George Hotz
2025-12-30 11:09:51 -05:00
committed by GitHub
parent 9c89be5235
commit 69cdc8066d
8 changed files with 701 additions and 191 deletions

View File

@@ -654,7 +654,7 @@ jobs:
- name: Run process replay tests
uses: ./.github/actions/process-replay
testrdna3:
testamdasm:
name: AMD ASM IDE
runs-on: ubuntu-24.04
timeout-minutes: 10
@@ -679,8 +679,23 @@ jobs:
run: python -m pytest -n=auto extra/assembly/amd/ --durations 20
- name: Run RDNA3 emulator tests (AMD_LLVM=1)
run: AMD_LLVM=1 python -m pytest -n=auto extra/assembly/amd/ --durations 20
- name: Install pdfplumber
run: pip install pdfplumber
- name: Run RDNA3 dtype tests
run: PYTHONPATH="." AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=0 pytest -n=auto test/test_dtype_alu.py test/test_dtype.py
- name: Run RDNA3 dtype tests (AMD_LLVM=1)
run: PYTHONPATH="." AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=1 pytest -n=auto test/test_dtype_alu.py test/test_dtype.py
testamdautogen:
name: AMD autogen
runs-on: ubuntu-24.04
timeout-minutes: 10
steps:
- name: Checkout Code
uses: actions/checkout@v4
- name: Setup Environment
uses: ./.github/actions/setup-tinygrad
with:
key: rdna3-autogen
pydeps: "pdfplumber"
- name: Verify AMD autogen is up to date
run: |
python -m extra.assembly.amd.dsl --arch all

View File

@@ -18284,6 +18284,37 @@ def _VOP3AOp_V_ASHRREV_I64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, V
result['d0_64'] = True
return result
def _VOP3AOp_V_TRIG_PREOP_F64(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0, pc=0):
# shift = 32'I(S1[4 : 0].u32) * 53;
# if exponent(S0.f64) > 1077 then
# shift += exponent(S0.f64) - 1077
# endif;
# // (2.0/PI) == 0.{b_1200, b_1199, b_1198, ..., b_1, b_0}
# // b_1200 is the MSB of the fractional part of 2.0/PI
# // Left shift operation indicates which bits are brought
# result = 64'F((1201'B(2.0 / PI)[1200 : 0] << shift.u32) & 1201'0x1fffffffffffff);
# scale = -53 - shift;
# if exponent(S0.f64) >= 1968 then
# scale += 128
# endif;
# D0.f64 = ldexp(result, scale)
S0 = Reg(s0)
S1 = Reg(s1)
D0 = Reg(d0)
# --- compiled pseudocode ---
shift = (S1[4 : 0].u32) * 53
if exponent(S0.f64) > 1077:
shift += exponent(S0.f64) - 1077
result = float(((TWO_OVER_PI_1201[1200 : 0] << int(shift)) >> (1201 - 53)) & 0x1fffffffffffff)
scale = -53 - shift
if exponent(S0.f64) >= 1968:
scale += 128
D0.f64 = ldexp(result, scale)
# --- end pseudocode ---
result = {'d0': D0._val, 'scc': scc & 1}
result['d0_64'] = True
return result
def _VOP3AOp_V_BFM_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0, pc=0):
# D0.u32 = (((1U << S0[4 : 0].u32) - 1U) << S1[4 : 0].u32)
S0 = Reg(s0)
@@ -18940,6 +18971,7 @@ VOP3AOp_FUNCTIONS = {
VOP3AOp.V_LSHLREV_B64: _VOP3AOp_V_LSHLREV_B64,
VOP3AOp.V_LSHRREV_B64: _VOP3AOp_V_LSHRREV_B64,
VOP3AOp.V_ASHRREV_I64: _VOP3AOp_V_ASHRREV_I64,
VOP3AOp.V_TRIG_PREOP_F64: _VOP3AOp_V_TRIG_PREOP_F64,
VOP3AOp.V_BFM_B32: _VOP3AOp_V_BFM_B32,
VOP3AOp.V_CVT_PKNORM_I16_F32: _VOP3AOp_V_CVT_PKNORM_I16_F32,
VOP3AOp.V_CVT_PKNORM_U16_F32: _VOP3AOp_V_CVT_PKNORM_U16_F32,

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -283,6 +283,10 @@ class Inst:
from extra.assembly.amd.autogen.rdna3 import VOP3Op
try: op_name = VOP3Op(op).name
except ValueError: pass
if op_name is None and self.__class__.__name__ == 'VOPC':
from extra.assembly.amd.autogen.rdna3 import VOPCOp
try: op_name = VOPCOp(op).name
except ValueError: pass
if op_name is None: return False
# V_LDEXP_F64 has 32-bit integer exponent in src1, so literal is 32-bit
if op_name == 'V_LDEXP_F64': return False

View File

@@ -17,6 +17,7 @@ VCC_LO, VCC_HI, NULL, EXEC_LO, EXEC_HI, SCC = SrcEnum.VCC_LO, SrcEnum.VCC_HI, Sr
# VOP3 ops that use 64-bit operands (and thus 64-bit literals when src is 255)
# Exception: V_LDEXP_F64 has 32-bit integer src1, so literal should NOT be 64-bit when src1=255
_VOP3_64BIT_OPS = {op.value for op in VOP3Op if op.name.endswith(('_F64', '_B64', '_I64', '_U64'))}
_VOPC_64BIT_OPS = {op.value for op in VOPCOp if op.name.endswith(('_F64', '_B64', '_I64', '_U64'))}
# Ops where src1 is 32-bit (exponent/shift amount) even though the op name suggests 64-bit
_VOP3_64BIT_OPS_32BIT_SRC1 = {VOP3Op.V_LDEXP_F64.value}
# Ops with 16-bit types in name (for source/dest handling)
@@ -185,7 +186,7 @@ def decode_program(data: bytes) -> Program:
# Exception: some ops have mixed src sizes (e.g., V_LDEXP_F64 has 32-bit src1)
op_val = inst._values.get('op')
if hasattr(op_val, 'value'): op_val = op_val.value
is_64bit = inst_class is VOP3 and op_val in _VOP3_64BIT_OPS
is_64bit = (inst_class is VOP3 and op_val in _VOP3_64BIT_OPS) or (inst_class is VOPC and op_val in _VOPC_64BIT_OPS)
# Don't treat literal as 64-bit if the op has 32-bit src1 and src1 is the literal
if is_64bit and op_val in _VOP3_64BIT_OPS_32BIT_SRC1 and getattr(inst, 'src1', None) == 255:
is_64bit = False
@@ -336,14 +337,22 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
op = VOP3SDOp(inst.op)
fn = compiled.get(VOP3SDOp, {}).get(op)
if fn is None: raise NotImplementedError(f"{op.name} not in pseudocode")
s0, s1, s2 = st.rsrc(inst.src0, lane), st.rsrc(inst.src1, lane), st.rsrc(inst.src2, lane)
# For 64-bit src2 ops (V_MAD_U64_U32, V_MAD_I64_I32), read from consecutive registers
# VOP3SD has both 32-bit ops (V_ADD_CO_CI_U32, etc.) and 64-bit ops (V_DIV_SCALE_F64, V_MAD_U64_U32, etc.)
div_scale_64_ops = (VOP3SDOp.V_DIV_SCALE_F64,)
mad64_ops = (VOP3SDOp.V_MAD_U64_U32, VOP3SDOp.V_MAD_I64_I32)
if op in mad64_ops:
if op in div_scale_64_ops:
# V_DIV_SCALE_F64: all sources are 64-bit
s0, s1, s2 = st.rsrc64(inst.src0, lane), st.rsrc64(inst.src1, lane), st.rsrc64(inst.src2, lane)
elif op in mad64_ops:
# V_MAD_U64_U32, V_MAD_I64_I32: src0/src1 are 32-bit, src2 is 64-bit
s0, s1 = st.rsrc(inst.src0, lane), st.rsrc(inst.src1, lane)
if inst.src2 >= 256: # VGPR
s2 = V[inst.src2 - 256] | (V[inst.src2 - 256 + 1] << 32)
else: # SGPR - read 64-bit from consecutive SGPRs
s2 = st.rsgpr64(inst.src2)
else:
# Default: 32-bit sources
s0, s1, s2 = st.rsrc(inst.src0, lane), st.rsrc(inst.src1, lane), st.rsrc(inst.src2, lane)
d0 = V[inst.vdst]
# For carry-in operations (V_*_CO_CI_*), src2 register contains the carry bitmask (not VCC).
# The pseudocode uses VCC but in VOP3SD encoding, the actual carry source is inst.src2.
@@ -516,8 +525,9 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
# For 64-bit shift ops: src0 is 32-bit (shift amount), src1 is 64-bit (value to shift)
# For most other _B64/_I64/_U64/_F64 ops: all sources are 64-bit
is_64bit_op = op.name.endswith(('_B64', '_I64', '_U64', '_F64'))
# V_LDEXP_F64: src0 is 64-bit float, src1 is 32-bit integer exponent
is_ldexp_64 = op in (VOP3Op.V_LDEXP_F64,)
# V_LDEXP_F64, V_TRIG_PREOP_F64, V_CMP_CLASS_F64, V_CMPX_CLASS_F64: src0 is 64-bit, src1 is 32-bit
is_ldexp_64 = op in (VOP3Op.V_LDEXP_F64, VOP3Op.V_TRIG_PREOP_F64, VOP3Op.V_CMP_CLASS_F64, VOP3Op.V_CMPX_CLASS_F64,
VOPCOp.V_CMP_CLASS_F64, VOPCOp.V_CMPX_CLASS_F64)
is_shift_64 = op in (VOP3Op.V_LSHLREV_B64, VOP3Op.V_LSHRREV_B64, VOP3Op.V_ASHRREV_I64)
# 16-bit source ops: use precomputed sets instead of string checks
# Note: must check op_cls to avoid cross-enum value collisions
@@ -531,7 +541,12 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
s2 = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0
elif is_ldexp_64:
s0 = mod_src64(st.rsrc64(src0, lane), 0) # mantissa is 64-bit float
s1 = mod_src(st.rsrc(src1, lane), 1) if src1 is not None else 0 # exponent is 32-bit int
# src1 is 32-bit int. For 64-bit ops (like V_CMP_CLASS_F64), the literal is stored shifted left by 32.
# For V_LDEXP_F64/V_TRIG_PREOP_F64, _is_64bit_op() returns False so literal is stored as-is.
s1_raw = st.rsrc(src1, lane) if src1 is not None else 0
# Only shift if src1 is literal AND this is a true 64-bit op (V_CMP_CLASS ops, not LDEXP/TRIG_PREOP)
is_class_op = op in (VOP3Op.V_CMP_CLASS_F64, VOP3Op.V_CMPX_CLASS_F64, VOPCOp.V_CMP_CLASS_F64, VOPCOp.V_CMPX_CLASS_F64)
s1 = mod_src((s1_raw >> 32) if src1 == 255 and is_class_op else s1_raw, 1)
s2 = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0
elif is_64bit_op:
# 64-bit ops: apply neg/abs modifiers using f64 interpretation for float ops
@@ -651,7 +666,7 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
is_16bit_dst = (op_cls is VOP3Op and op in _VOP3_16BIT_DST_OPS) or (op_cls is VOP1Op and op in _VOP1_16BIT_DST_OPS)
if writes_to_sgpr:
st.wsgpr(vdst, result['d0'] & 0xffffffff)
elif result.get('d0_64') or is_64bit_op:
elif result.get('d0_64'):
V[vdst] = result['d0'] & 0xffffffff
V[vdst + 1] = (result['d0'] >> 32) & 0xffffffff
elif is_16bit_dst and inst_type is VOP3:

View File

@@ -280,7 +280,7 @@ def f32_to_u8(f): return max(0, min(255, int(f))) if not math.isnan(f) else 0
def mantissa(f):
if f == 0.0 or math.isinf(f) or math.isnan(f): return f
m, _ = math.frexp(f)
return math.copysign(m * 2.0, f)
return m # AMD V_FREXP_MANT returns mantissa in [0.5, 1.0) range
def signext_from_bit(val, bit):
bit = int(bit)
if bit == 0: return 0
@@ -301,6 +301,7 @@ __all__ = [
# Constants
'WAVE32', 'WAVE64', 'MASK32', 'MASK64', 'WAVE_MODE', 'DENORM', 'OVERFLOW_F32', 'UNDERFLOW_F32',
'OVERFLOW_F64', 'UNDERFLOW_F64', 'MAX_FLOAT_F32', 'ROUND_MODE', 'cvtToQuietNAN', 'DST', 'INF', 'PI',
'TWO_OVER_PI_1201',
# Aliases for pseudocode
's_ff1_i32_b32', 's_ff1_i32_b64', 'GT_NEG_ZERO', 'LT_NEG_ZERO',
'isNAN', 'isQuietNAN', 'isSignalNAN', 'fma', 'ldexp', 'sign', 'exponent', 'F', 'signext',
@@ -359,12 +360,14 @@ class _Inf:
f16 = f32 = f64 = float('inf')
def __neg__(self): return _NegInf()
def __pos__(self): return self
def __float__(self): return float('inf')
def __eq__(self, other): return float(other) == float('inf') if not isinstance(other, _NegInf) else False
def __req__(self, other): return self.__eq__(other)
class _NegInf:
f16 = f32 = f64 = float('-inf')
def __neg__(self): return _Inf()
def __pos__(self): return self
def __float__(self): return float('-inf')
def __eq__(self, other): return float(other) == float('-inf') if not isinstance(other, _Inf) else False
def __req__(self, other): return self.__eq__(other)
INF = _Inf()
@@ -380,6 +383,31 @@ DST = None # Placeholder, will be set in context
MASK32, MASK64 = 0xffffffff, 0xffffffffffffffff
# 2/PI with 1201 bits of precision for V_TRIG_PREOP_F64
# Computed as: int((2/pi) * 2^1201) - this is the fractional part of 2/pi scaled to integer
# The MSB (bit 1200) corresponds to 2^0 position in the fraction 0.b1200 b1199 ... b1 b0
_TWO_OVER_PI_1201_RAW = 0x0145f306dc9c882a53f84eafa3ea69bb81b6c52b3278872083fca2c757bd778ac36e48dc74849ba5c00c925dd413a32439fc3bd63962534e7dd1046bea5d768909d338e04d68befc827323ac7306a673e93908bf177bf250763ff12fffbc0b301fde5e2316b414da3eda6cfd9e4f96136e9e8c7ecd3cbfd45aea4f758fd7cbe2f67a0e73ef14a525d4d7f6bf623f1aba10ac06608df8f6
class _BigInt:
"""Wrapper for large integers that supports bit slicing [high:low]."""
__slots__ = ('_val',)
def __init__(self, val): self._val = val
def __getitem__(self, key):
if isinstance(key, slice):
high, low = key.start, key.stop
if high < low: high, low = low, high # Handle reversed slice
mask = (1 << (high - low + 1)) - 1
return (self._val >> low) & mask
return (self._val >> key) & 1
def __int__(self): return self._val
def __index__(self): return self._val
def __lshift__(self, n): return self._val << int(n)
def __rshift__(self, n): return self._val >> int(n)
def __and__(self, n): return self._val & int(n)
def __or__(self, n): return self._val | int(n)
TWO_OVER_PI_1201 = _BigInt(_TWO_OVER_PI_1201_RAW)
class _WaveMode:
IEEE = False
WAVE_MODE = _WaveMode()
@@ -693,6 +721,9 @@ def _expr(e: str) -> str:
return f'_pack({hi}, {lo})'
e = re.sub(r'\{\s*([^,{}]+)\s*,\s*([^,{}]+)\s*\}', pack, e)
# Special constant: 1201'B(2.0 / PI) -> TWO_OVER_PI_1201 (precomputed 1201-bit 2/pi)
e = re.sub(r"1201'B\(2\.0\s*/\s*PI\)", "TWO_OVER_PI_1201", e)
# Literals: 1'0U -> 0, 32'I(x) -> (x), B(x) -> (x)
e = re.sub(r"\d+'([0-9a-fA-Fx]+)[UuFf]*", r'\1', e)
e = re.sub(r"\d+'[FIBU]\(", "(", e)
@@ -815,7 +846,7 @@ INST_PATTERN = re.compile(r'^([SV]_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS',
'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt',
'CVT_OFF_TABLE', 'ThreadMask',
'S1[i', 'C.i32', 'S[i]', 'in[', '2.0 / PI',
'S1[i', 'C.i32', 'S[i]', 'in[',
'if n.', 'DST.u32', 'addrd = DST', 'addr = DST'] # Malformed pseudocode from PDF
def extract_pseudocode(text: str) -> str | None:
@@ -1050,12 +1081,22 @@ from extra.assembly.amd.pcode import *
code = code.replace(
'D0.f64 = ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))',
'D0.f64 = ((-OVERFLOW_F64) if (sign_out) else (OVERFLOW_F64)) if isNAN(S0.f64) else ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))')
# V_TRIG_PREOP_F64: AMD pseudocode uses (x << shift) & mask but mask needs to extract TOP bits.
# The PDF shows: result = 64'F((1201'B(2.0/PI)[1200:0] << shift) & 1201'0x1fffffffffffff)
# Issues to fix:
# 1. After left shift, the interesting bits are at the top, not bottom - need >> (1201-53)
# 2. shift.u32 fails because shift is a plain int after * 53 - use int(shift)
# 3. 64'F(...) means convert int to float (not interpret as bit pattern) - use float()
if op.name == 'V_TRIG_PREOP_F64':
code = code.replace(
'result = F((TWO_OVER_PI_1201[1200 : 0] << shift.u32) & 0x1fffffffffffff)',
'result = float(((TWO_OVER_PI_1201[1200 : 0] << int(shift)) >> (1201 - 53)) & 0x1fffffffffffff)')
# Detect flags for result handling
is_64 = any(p in pc for p in ['D0.u64', 'D0.b64', 'D0.f64', 'D0.i64', 'D1.u64', 'D1.b64', 'D1.f64', 'D1.i64'])
has_d1 = '{ D1' in pc
if has_d1: is_64 = True
is_cmp = cls_name == 'VOPCOp' and 'D0.u64[laneId]' in pc
is_cmpx = cls_name == 'VOPCOp' and 'EXEC.u64[laneId]' in pc # V_CMPX writes to EXEC per-lane
is_cmp = (cls_name == 'VOPCOp' or cls_name == 'VOP3Op') and 'D0.u64[laneId]' in pc
is_cmpx = (cls_name == 'VOPCOp' or cls_name == 'VOP3Op') and 'EXEC.u64[laneId]' in pc # V_CMPX writes to EXEC per-lane
# V_DIV_SCALE passes through S0 if no branch taken
is_div_scale = 'DIV_SCALE' in op.name
# VOP3SD instructions that write VCC per-lane (either via VCC.u64[laneId] or by setting VCC = 0/1)

View File

@@ -2454,6 +2454,82 @@ class TestF64Conversions(unittest.TestCase):
result = struct.unpack('<q', struct.pack('<II', lo, hi))[0]
self.assertEqual(result, -8, f"Expected -8, got {result} (lo=0x{lo:08x}, hi=0x{hi:08x})")
def test_v_cvt_i32_f64_writes_32bit_only(self):
"""V_CVT_I32_F64 should only write 32 bits, not 64.
Regression test: V_CVT_I32_F64 has a 64-bit source (f64) but 32-bit destination (i32).
The emulator was incorrectly writing 64 bits (clobbering vdst+1) because
is_64bit_op was True for any op ending in '_F64'.
"""
# Pre-fill v3 with a canary value that should NOT be clobbered
val_bits = f2i64(-1.0)
instructions = [
s_mov_b32(s[0], val_bits & 0xffffffff),
s_mov_b32(s[1], val_bits >> 32),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
s_mov_b32(s[2], 0xDEADBEEF), # Canary value
v_mov_b32_e32(v[3], s[2]), # Put canary in v3
v_cvt_i32_f64_e32(v[2], v[0:2]), # Convert -1.0 -> -1 (0xffffffff)
]
st = run_program(instructions, n_lanes=1)
result = st.vgpr[0][2]
canary = st.vgpr[0][3]
# V_CVT_I32_F64 of -1.0 should produce 0xffffffff (-1)
self.assertEqual(result, 0xffffffff, f"Expected 0xffffffff (-1), got 0x{result:08x}")
# v3 should still contain the canary (not clobbered by 64-bit write)
self.assertEqual(canary, 0xDEADBEEF, f"v3 canary should be 0xDEADBEEF, got 0x{canary:08x} (clobbered!)")
def test_v_frexp_mant_f64_range(self):
"""V_FREXP_MANT_F64 should return mantissa in [0.5, 1.0) range.
Regression test: The mantissa() helper was incorrectly multiplying by 2.0,
returning values in [1.0, 2.0) instead of the correct [0.5, 1.0) range.
"""
# Test with 2.0: frexp(2.0) should give mantissa=0.5, exponent=2
two_f64 = f2i64(2.0)
instructions = [
s_mov_b32(s[0], two_f64 & 0xffffffff),
s_mov_b32(s[1], two_f64 >> 32),
v_frexp_mant_f64_e32(v[0:2], s[0:2]),
v_frexp_exp_i32_f64_e32(v[2], s[0:2]),
]
st = run_program(instructions, n_lanes=1)
mant = i642f(st.vgpr[0][0] | (st.vgpr[0][1] << 32))
exp = st.vgpr[0][2]
if exp >= 0x80000000: exp -= 0x100000000 # sign extend
# frexp(2.0) = 0.5 * 2^2
self.assertAlmostEqual(mant, 0.5, places=10, msg=f"Expected mantissa 0.5, got {mant}")
self.assertEqual(exp, 2, f"Expected exponent 2, got {exp}")
def test_v_div_scale_f64_reads_64bit_sources(self):
"""V_DIV_SCALE_F64 must read all sources as 64-bit values.
Regression test: VOP3SD was reading sources as 32-bit for V_DIV_SCALE_F64,
causing incorrect results when the low 32 bits happened to look like 0 or denorm.
"""
# Set up v0:v1 = sqrt(2) ≈ 1.414, v2:v3 = 1.0
sqrt2_f64 = f2i64(1.4142135623730951)
one_f64 = f2i64(1.0)
instructions = [
s_mov_b32(s[0], sqrt2_f64 & 0xffffffff),
s_mov_b32(s[1], sqrt2_f64 >> 32),
v_mov_b32_e32(v[0], s[0]),
v_mov_b32_e32(v[1], s[1]),
s_mov_b32(s[2], one_f64 & 0xffffffff),
s_mov_b32(s[3], one_f64 >> 32),
v_mov_b32_e32(v[2], s[2]),
v_mov_b32_e32(v[3], s[3]),
# V_DIV_SCALE_F64: src0=v0:v1, src1=v0:v1, src2=v2:v3
# For normal inputs, should pass through src0 unchanged
VOP3SD(VOP3SDOp.V_DIV_SCALE_F64, vdst=v[4], sdst=s[10], src0=v[0], src1=v[0], src2=v[2]),
]
st = run_program(instructions, n_lanes=1)
result = i642f(st.vgpr[0][4] | (st.vgpr[0][5] << 32))
# For normal (non-denorm, non-edge-case) inputs, V_DIV_SCALE_F64 passes through src0
self.assertAlmostEqual(result, 1.4142135623730951, places=10,
msg=f"Expected ~1.414, got {result} (may be nan if 64-bit sources not read correctly)")
class TestNewPcodeHelpers(unittest.TestCase):
"""Tests for newly added pcode helper functions (SAD, BYTE_PERMUTE, BF16)."""
@@ -3650,3 +3726,90 @@ class TestVFmaMixSinCase(unittest.TestCase):
# Result should be approximately -π = -3.14...
# f16 -π ≈ 0xc248 = -3.140625
self.assertAlmostEqual(lo, -3.14159, delta=0.01, msg=f"Expected ~-π, got {lo}")
class TestVTrigPreopF64(unittest.TestCase):
"""Tests for V_TRIG_PREOP_F64 instruction.
V_TRIG_PREOP_F64 extracts chunks of 2/PI for Payne-Hanek trig range reduction.
For input S0 (f64) and index S1 (0, 1, or 2), it returns a portion of 2/PI
scaled appropriately for computing |S0| * (2/PI) in extended precision.
The three chunks (index 0, 1, 2) when summed should equal 2/PI.
"""
def test_trig_preop_f64_index0(self):
"""V_TRIG_PREOP_F64 index=0: primary chunk of 2/PI."""
import math
two_over_pi = 2.0 / math.pi
instructions = [
# S0 = 1.0 (f64), S1 = 0 (index)
s_mov_b32(s[0], 0x00000000), # low bits of 1.0
s_mov_b32(s[1], 0x3ff00000), # high bits of 1.0
v_trig_preop_f64(v[0], abs(s[0]), 0), # index 0
]
st = run_program(instructions, n_lanes=1)
result = i642f(st.vgpr[0][0] | (st.vgpr[0][1] << 32))
# For x=1.0, index=0 should give the main part of 2/PI
self.assertAlmostEqual(result, two_over_pi, places=10, msg=f"Expected ~{two_over_pi}, got {result}")
def test_trig_preop_f64_index1(self):
"""V_TRIG_PREOP_F64 index=1: secondary chunk (extended precision bits)."""
instructions = [
s_mov_b32(s[0], 0x00000000), # low bits of 1.0
s_mov_b32(s[1], 0x3ff00000), # high bits of 1.0
v_trig_preop_f64(v[0], abs(s[0]), 1), # index 1
]
st = run_program(instructions, n_lanes=1)
result = i642f(st.vgpr[0][0] | (st.vgpr[0][1] << 32))
# Index 1 gives the next 53 bits, should be very small (~1e-16)
self.assertLess(abs(result), 1e-15, msg=f"Expected tiny value, got {result}")
self.assertGreater(abs(result), 0, msg="Expected non-zero value")
def test_trig_preop_f64_index2(self):
"""V_TRIG_PREOP_F64 index=2: tertiary chunk (more extended precision bits)."""
instructions = [
s_mov_b32(s[0], 0x00000000), # low bits of 1.0
s_mov_b32(s[1], 0x3ff00000), # high bits of 1.0
v_trig_preop_f64(v[0], abs(s[0]), 2), # index 2
]
st = run_program(instructions, n_lanes=1)
result = i642f(st.vgpr[0][0] | (st.vgpr[0][1] << 32))
# Index 2 gives the next 53 bits after index 1, should be tiny (~1e-32)
self.assertLess(abs(result), 1e-30, msg=f"Expected very tiny value, got {result}")
def test_trig_preop_f64_sum_equals_two_over_pi(self):
"""V_TRIG_PREOP_F64: sum of chunks 0,1,2 should equal 2/PI."""
import math
two_over_pi = 2.0 / math.pi
instructions = [
s_mov_b32(s[0], 0x00000000), # low bits of 1.0
s_mov_b32(s[1], 0x3ff00000), # high bits of 1.0
v_trig_preop_f64(v[0], abs(s[0]), 0), # index 0 -> v[0:1]
v_trig_preop_f64(v[2], abs(s[0]), 1), # index 1 -> v[2:3]
v_trig_preop_f64(v[4], abs(s[0]), 2), # index 2 -> v[4:5]
]
st = run_program(instructions, n_lanes=1)
p0 = i642f(st.vgpr[0][0] | (st.vgpr[0][1] << 32))
p1 = i642f(st.vgpr[0][2] | (st.vgpr[0][3] << 32))
p2 = i642f(st.vgpr[0][4] | (st.vgpr[0][5] << 32))
total = p0 + p1 + p2
self.assertAlmostEqual(total, two_over_pi, places=14, msg=f"Expected {two_over_pi}, got {total} (p0={p0}, p1={p1}, p2={p2})")
def test_trig_preop_f64_large_input(self):
"""V_TRIG_PREOP_F64 with larger input should adjust shift based on exponent."""
import math
# For x=2.0, exponent(2.0)=1024 which is <= 1077, so no adjustment
# But let's test with x=2^60 where exponent > 1077
large_val = 2.0 ** 60 # exponent = 1083 > 1077
large_bits = f2i64(large_val)
instructions = [
s_mov_b32(s[0], large_bits & 0xffffffff),
s_mov_b32(s[1], (large_bits >> 32) & 0xffffffff),
v_trig_preop_f64(v[0], abs(s[0]), 0),
]
st = run_program(instructions, n_lanes=1)
result = i642f(st.vgpr[0][0] | (st.vgpr[0][1] << 32))
# Result should still be a valid float (not NaN or inf)
self.assertFalse(math.isnan(result), "Result should not be NaN")
self.assertFalse(math.isinf(result), "Result should not be inf")