diff --git a/test/unit/test_transcendental_helpers.py b/test/unit/test_transcendental_helpers.py index 86b93e671d..30891868ce 100644 --- a/test/unit/test_transcendental_helpers.py +++ b/test/unit/test_transcendental_helpers.py @@ -19,26 +19,24 @@ class TestTranscendentalFunctions(unittest.TestCase): np.testing.assert_equal(q, 12) def test_frexp(self): - mantissa, exponent = (eval_uop(u) for u in frexp(UOp.const(dtypes.float64, 0.0))) - np.testing.assert_equal(mantissa, 0.0) - np.testing.assert_equal(exponent, 0) + for x in (1, -1): + mantissa, exponent = (eval_uop(u) for u in frexp(UOp.const(dtypes.float64, x))) + np.testing.assert_equal(mantissa, 0.5) + np.testing.assert_equal(exponent, 1) - mantissa, exponent = (eval_uop(u) for u in frexp(UOp.const(dtypes.float64, 1.0))) - np.testing.assert_equal(mantissa, 0.5) - np.testing.assert_equal(exponent, 1) - - mantissa, exponent = (eval_uop(u) for u in frexp(UOp.const(dtypes.float64, -1.0))) - np.testing.assert_equal(mantissa, 0.5) - np.testing.assert_equal(exponent, 1) - - mantissa, exponent = (eval_uop(u) for u in frexp(UOp.const(dtypes.float64, 2.0))) - np.testing.assert_equal(mantissa, 0.5) - np.testing.assert_equal(exponent, 2) + for x in (2, -2): + mantissa, exponent = (eval_uop(u) for u in frexp(UOp.const(dtypes.float64, 2.0))) + np.testing.assert_equal(mantissa, 0.5) + np.testing.assert_equal(exponent, 2) mantissa, exponent = (eval_uop(u) for u in frexp(UOp.const(dtypes.float64, 5.0))) np.testing.assert_equal(mantissa, 0.625) np.testing.assert_equal(exponent, 3) + mantissa, exponent = (eval_uop(u) for u in frexp(UOp.const(dtypes.float64, 1000.0))) + np.testing.assert_allclose(mantissa, 0.9765625) + np.testing.assert_equal(exponent, 10) + def test_rintk(self): np.testing.assert_allclose(eval_uop(rintk(UOp.const(dtypes.float, 0.0))), 0) np.testing.assert_allclose(eval_uop(rintk(UOp.const(dtypes.float, 5.0))), 5) diff --git a/tinygrad/codegen/transcendental.py b/tinygrad/codegen/transcendental.py index 80e2137fe9..22702a20dc 100644 --- a/tinygrad/codegen/transcendental.py +++ b/tinygrad/codegen/transcendental.py @@ -50,7 +50,7 @@ def ldexp2k(d:UOp, e:UOp) -> UOp: return (d * pow2if(shr(e, 1), d.dtype)) * pow2if(e - shr(e, 1), d.dtype) def frexp(v:UOp) -> Tuple[UOp, UOp]: - """frexp(v) -> (mantissa, exponent)""" + """frexp(v) -> (mantissa, exponent) assuming v != 0""" assert v.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES # m1 = masks for mantissa, m2 = masks to normalize the mantissa. m1 = {dtypes.float64: 0x000FFFFFFFFFFFFF, dtypes.float32: 0x807FFFFF, dtypes.float16: 0x83FF}[v.dtype] @@ -60,10 +60,6 @@ def frexp(v:UOp) -> Tuple[UOp, UOp]: # Set the exponent bits appropriately to normalize the mantissa into the range of [0.5, 1.0). mantissa = ((bits & m1) | m2).bitcast(v.dtype) exp = exponent - exponent_bias(v.dtype) + 1 - # special case of 0 # TODO: can we remove this case? - mantissa = exponent.ne(0).where(mantissa, v) - exp = exponent.ne(0).where(exp, exp.const_like(0)) - if v.dtype == dtypes.float16: exp = exp.bitcast(dtypes.int16) return mantissa, exp # *** reduction algorithms for sine ***