This commit is contained in:
George Hotz
2026-01-04 15:48:31 -08:00
parent 7f7f12d5b4
commit b52ff63896
4 changed files with 218 additions and 44 deletions

View File

@@ -1096,5 +1096,60 @@ class TestFuzzFailure(unittest.TestCase):
rn = expr.substitute({v1:v1_val, v2:v2_val, v3:v3_val}).ssimplify()
assert num==rn, f"{num} != {rn}"
class TestBitcast(unittest.TestCase):
def test_bitcast_preserves_signaling_nan_bits(self):
# Signaling NaN in f32: exponent all 1s, mantissa non-zero with MSB clear
snan_bits = 0x7f800001
bits = UOp.const(dtypes.uint32, snan_bits)
# BITCAST uint32 -> float32 should NOT fold (would corrupt NaN bits)
bf = bits.bitcast(dtypes.float32)
result = bf.simplify()
self.assertEqual(result.op, Ops.BITCAST, "signaling NaN bitcast should not fold to CONST")
self.assertEqual(result.src[0].arg, snan_bits)
def test_bitcast_double_preserves_bits(self):
# BITCAST(BITCAST(x)) where outer dtype == x.dtype should fold back to x
snan_bits = 0x7f800001
bits = UOp.const(dtypes.uint32, snan_bits)
bf1 = bits.bitcast(dtypes.float32)
bf2 = bf1.bitcast(dtypes.uint32)
result = bf2.simplify()
self.assertEqual(result.op, Ops.CONST)
self.assertEqual(result.arg, snan_bits, "double bitcast should preserve original bits")
def test_bitcast_quiet_nan_folds(self):
# Quiet NaN in f32: can be folded since Python's nan preserves these bits
qnan_bits = 0x7fc00000
bits = UOp.const(dtypes.uint32, qnan_bits)
bf = bits.bitcast(dtypes.float32)
result = bf.simplify()
self.assertEqual(result.op, Ops.CONST)
self.assertTrue(math.isnan(result.arg))
def test_bitcast_normal_float_folds(self):
# Normal float values should fold
bits = UOp.const(dtypes.uint32, 0x40490fdb) # pi
bf = bits.bitcast(dtypes.float32)
result = bf.simplify()
self.assertEqual(result.op, Ops.CONST)
self.assertAlmostEqual(result.arg, 3.14159265, places=5)
def test_bitcast_f64_signaling_nan(self):
# Signaling NaN in f64
snan_bits = 0x7ff0000000000001
bits = UOp.const(dtypes.uint64, snan_bits)
bf = bits.bitcast(dtypes.float64)
result = bf.simplify()
self.assertEqual(result.op, Ops.BITCAST, "f64 signaling NaN bitcast should not fold")
def test_bitcast_f64_double_preserves_bits(self):
snan_bits = 0x7ff0000000000001
bits = UOp.const(dtypes.uint64, snan_bits)
bf1 = bits.bitcast(dtypes.float64)
bf2 = bf1.bitcast(dtypes.uint64)
result = bf2.simplify()
self.assertEqual(result.op, Ops.CONST)
self.assertEqual(result.arg, snan_bits)
if __name__ == '__main__':
unittest.main()