clean up transcendental.rintk [pr] (#7422)

added unit tests and updated the comment. it's rounding away from 0 for negatives
This commit is contained in:
chenyu
2024-10-30 20:37:28 -04:00
committed by GitHub
parent fb694a63eb
commit 118dd7721f
2 changed files with 14 additions and 6 deletions

View File

@@ -2,10 +2,10 @@ import unittest, math
import numpy as np
from tinygrad import dtypes
from tinygrad.ops import UOp
from tinygrad.codegen.transcendental import payne_hanek_reduction, cody_waite_reduction, frexp
from tinygrad.codegen.transcendental import payne_hanek_reduction, cody_waite_reduction, frexp, rintk
from test.helpers import eval_uop
class TestReduction(unittest.TestCase):
class TestTranscendentalFunctions(unittest.TestCase):
def test_payne_hanek_reduction(self):
r, q = (eval_uop(u) for u in payne_hanek_reduction(UOp.const(dtypes.float64, 12 * math.pi + 0.1)))
# TODO: should r be in [0, pi/2) per doc?
@@ -39,5 +39,14 @@ class TestReduction(unittest.TestCase):
np.testing.assert_equal(mantissa, 0.625)
np.testing.assert_equal(exponent, 3)
def test_rintk(self):
np.testing.assert_allclose(eval_uop(rintk(UOp.const(dtypes.float, 0.0))), 0)
np.testing.assert_allclose(eval_uop(rintk(UOp.const(dtypes.float, 5.0))), 5)
np.testing.assert_allclose(eval_uop(rintk(UOp.const(dtypes.float, 5.5))), 6)
np.testing.assert_allclose(eval_uop(rintk(UOp.const(dtypes.float, 5.999))), 6)
np.testing.assert_allclose(eval_uop(rintk(UOp.const(dtypes.float, -5.0))), -5)
np.testing.assert_allclose(eval_uop(rintk(UOp.const(dtypes.float, -5.5))), -6)
np.testing.assert_allclose(eval_uop(rintk(UOp.const(dtypes.float, -5.999))), -6)
if __name__ == '__main__':
unittest.main()