mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fixes
This commit is contained in:
@@ -1096,5 +1096,60 @@ class TestFuzzFailure(unittest.TestCase):
|
||||
rn = expr.substitute({v1:v1_val, v2:v2_val, v3:v3_val}).ssimplify()
|
||||
assert num==rn, f"{num} != {rn}"
|
||||
|
||||
class TestBitcast(unittest.TestCase):
|
||||
def test_bitcast_preserves_signaling_nan_bits(self):
|
||||
# Signaling NaN in f32: exponent all 1s, mantissa non-zero with MSB clear
|
||||
snan_bits = 0x7f800001
|
||||
bits = UOp.const(dtypes.uint32, snan_bits)
|
||||
# BITCAST uint32 -> float32 should NOT fold (would corrupt NaN bits)
|
||||
bf = bits.bitcast(dtypes.float32)
|
||||
result = bf.simplify()
|
||||
self.assertEqual(result.op, Ops.BITCAST, "signaling NaN bitcast should not fold to CONST")
|
||||
self.assertEqual(result.src[0].arg, snan_bits)
|
||||
|
||||
def test_bitcast_double_preserves_bits(self):
|
||||
# BITCAST(BITCAST(x)) where outer dtype == x.dtype should fold back to x
|
||||
snan_bits = 0x7f800001
|
||||
bits = UOp.const(dtypes.uint32, snan_bits)
|
||||
bf1 = bits.bitcast(dtypes.float32)
|
||||
bf2 = bf1.bitcast(dtypes.uint32)
|
||||
result = bf2.simplify()
|
||||
self.assertEqual(result.op, Ops.CONST)
|
||||
self.assertEqual(result.arg, snan_bits, "double bitcast should preserve original bits")
|
||||
|
||||
def test_bitcast_quiet_nan_folds(self):
|
||||
# Quiet NaN in f32: can be folded since Python's nan preserves these bits
|
||||
qnan_bits = 0x7fc00000
|
||||
bits = UOp.const(dtypes.uint32, qnan_bits)
|
||||
bf = bits.bitcast(dtypes.float32)
|
||||
result = bf.simplify()
|
||||
self.assertEqual(result.op, Ops.CONST)
|
||||
self.assertTrue(math.isnan(result.arg))
|
||||
|
||||
def test_bitcast_normal_float_folds(self):
|
||||
# Normal float values should fold
|
||||
bits = UOp.const(dtypes.uint32, 0x40490fdb) # pi
|
||||
bf = bits.bitcast(dtypes.float32)
|
||||
result = bf.simplify()
|
||||
self.assertEqual(result.op, Ops.CONST)
|
||||
self.assertAlmostEqual(result.arg, 3.14159265, places=5)
|
||||
|
||||
def test_bitcast_f64_signaling_nan(self):
|
||||
# Signaling NaN in f64
|
||||
snan_bits = 0x7ff0000000000001
|
||||
bits = UOp.const(dtypes.uint64, snan_bits)
|
||||
bf = bits.bitcast(dtypes.float64)
|
||||
result = bf.simplify()
|
||||
self.assertEqual(result.op, Ops.BITCAST, "f64 signaling NaN bitcast should not fold")
|
||||
|
||||
def test_bitcast_f64_double_preserves_bits(self):
|
||||
snan_bits = 0x7ff0000000000001
|
||||
bits = UOp.const(dtypes.uint64, snan_bits)
|
||||
bf1 = bits.bitcast(dtypes.float64)
|
||||
bf2 = bf1.bitcast(dtypes.uint64)
|
||||
result = bf2.simplify()
|
||||
self.assertEqual(result.op, Ops.CONST)
|
||||
self.assertEqual(result.arg, snan_bits)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user