mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 13:28:06 -05:00
remove special 0 case in frexp (#7450)
we can safely assume input is non-zero, also removed unneeded bitcast
This commit is contained in:
@@ -19,26 +19,24 @@ class TestTranscendentalFunctions(unittest.TestCase):
|
||||
np.testing.assert_equal(q, 12)
|
||||
|
||||
def test_frexp(self):
|
||||
mantissa, exponent = (eval_uop(u) for u in frexp(UOp.const(dtypes.float64, 0.0)))
|
||||
np.testing.assert_equal(mantissa, 0.0)
|
||||
np.testing.assert_equal(exponent, 0)
|
||||
for x in (1, -1):
|
||||
mantissa, exponent = (eval_uop(u) for u in frexp(UOp.const(dtypes.float64, x)))
|
||||
np.testing.assert_equal(mantissa, 0.5)
|
||||
np.testing.assert_equal(exponent, 1)
|
||||
|
||||
mantissa, exponent = (eval_uop(u) for u in frexp(UOp.const(dtypes.float64, 1.0)))
|
||||
np.testing.assert_equal(mantissa, 0.5)
|
||||
np.testing.assert_equal(exponent, 1)
|
||||
|
||||
mantissa, exponent = (eval_uop(u) for u in frexp(UOp.const(dtypes.float64, -1.0)))
|
||||
np.testing.assert_equal(mantissa, 0.5)
|
||||
np.testing.assert_equal(exponent, 1)
|
||||
|
||||
mantissa, exponent = (eval_uop(u) for u in frexp(UOp.const(dtypes.float64, 2.0)))
|
||||
np.testing.assert_equal(mantissa, 0.5)
|
||||
np.testing.assert_equal(exponent, 2)
|
||||
for x in (2, -2):
|
||||
mantissa, exponent = (eval_uop(u) for u in frexp(UOp.const(dtypes.float64, 2.0)))
|
||||
np.testing.assert_equal(mantissa, 0.5)
|
||||
np.testing.assert_equal(exponent, 2)
|
||||
|
||||
mantissa, exponent = (eval_uop(u) for u in frexp(UOp.const(dtypes.float64, 5.0)))
|
||||
np.testing.assert_equal(mantissa, 0.625)
|
||||
np.testing.assert_equal(exponent, 3)
|
||||
|
||||
mantissa, exponent = (eval_uop(u) for u in frexp(UOp.const(dtypes.float64, 1000.0)))
|
||||
np.testing.assert_allclose(mantissa, 0.9765625)
|
||||
np.testing.assert_equal(exponent, 10)
|
||||
|
||||
def test_rintk(self):
|
||||
np.testing.assert_allclose(eval_uop(rintk(UOp.const(dtypes.float, 0.0))), 0)
|
||||
np.testing.assert_allclose(eval_uop(rintk(UOp.const(dtypes.float, 5.0))), 5)
|
||||
|
||||
@@ -50,7 +50,7 @@ def ldexp2k(d:UOp, e:UOp) -> UOp:
|
||||
return (d * pow2if(shr(e, 1), d.dtype)) * pow2if(e - shr(e, 1), d.dtype)
|
||||
|
||||
def frexp(v:UOp) -> Tuple[UOp, UOp]:
|
||||
"""frexp(v) -> (mantissa, exponent)"""
|
||||
"""frexp(v) -> (mantissa, exponent) assuming v != 0"""
|
||||
assert v.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
# m1 = masks for mantissa, m2 = masks to normalize the mantissa.
|
||||
m1 = {dtypes.float64: 0x000FFFFFFFFFFFFF, dtypes.float32: 0x807FFFFF, dtypes.float16: 0x83FF}[v.dtype]
|
||||
@@ -60,10 +60,6 @@ def frexp(v:UOp) -> Tuple[UOp, UOp]:
|
||||
# Set the exponent bits appropriately to normalize the mantissa into the range of [0.5, 1.0).
|
||||
mantissa = ((bits & m1) | m2).bitcast(v.dtype)
|
||||
exp = exponent - exponent_bias(v.dtype) + 1
|
||||
# special case of 0 # TODO: can we remove this case?
|
||||
mantissa = exponent.ne(0).where(mantissa, v)
|
||||
exp = exponent.ne(0).where(exp, exp.const_like(0))
|
||||
if v.dtype == dtypes.float16: exp = exp.bitcast(dtypes.int16)
|
||||
return mantissa, exp
|
||||
|
||||
# *** reduction algorithms for sine ***
|
||||
|
||||
Reference in New Issue
Block a user