clean up cody_waite_reduction magic numbers (#7452)

This commit is contained in:
chenyu
2024-10-31 14:45:04 -04:00
committed by GitHub
parent 5648b9788e
commit 5777fca904

View File

@@ -119,28 +119,33 @@ def cody_waite_reduction(d:UOp) -> Tuple[UOp, UOp]:
0 <= abs(d) <= 39800.0
Returns a tuple of `(r, q)`, where the output format is the same as that of `payne_hanek_reduction`.
"""
m_1_pi = 0.318309886183790671537767526745028724
qdh = (d * (m_1_pi / 16777216)).cast(dtypes.int64).cast(d.dtype) * 16777216.0
def _quadrant(x:UOp) -> UOp: return rintk(d * m_1_pi -qdh).cast(x.dtype) if x.dtype == dtypes.float64 else rintk(x * m_1_pi).cast(x.dtype)
def _reduce_d(x:UOp, q:UOp):
# https://github.com/shibatch/sleef/blob/4e08851f59fc2b545f9c393c6a23dfd311a26308/src/libm/sleefdp.c#L789-L823
if x.dtype == dtypes.float64:
d = qdh * -3.1415926218032836914 + x
d = q * -3.1415926218032836914 + d
d = qdh * -3.1786509424591713469e-08 + d
d = q * -3.1786509424591713469e-08 + d
d = qdh * -1.2246467864107188502e-16 + d
d = q * -1.2246467864107188502e-16 + d
d = (qdh + q) * -1.2736634327021899816e-24 + d
# https://github.com/shibatch/sleef/blob/f6d8a841fbfddd26ce712834d4da220cd76048fb/src/common/misc.h#L77
PI_A, PI_B, PI_C, PI_D = 3.1415926218032836914, 3.1786509424591713469e-08, 1.2246467864107188502e-16, 1.2736634327021899816e-24
d = qdh * -PI_A + x
d = q * -PI_A + d
d = qdh * -PI_B + d
d = q * -PI_B + d
d = qdh * -PI_C + d
d = q * -PI_C + d
d = (qdh + q) * -PI_D + d
elif x.dtype == dtypes.float16:
# [FIXME] when reducing `d`, FP16 needs FP32 precision to achieve 1.0 ULP precision.
d = _reduce_d(x.cast(dtypes.float32), q.cast(dtypes.float32)).cast(dtypes.float16)
else:
# https://github.com/shibatch/sleef/blob/4e08851f59fc2b545f9c393c6a23dfd311a26308/src/libm/sleefsp.c#L464-L503
d = q * -3.1414794921875 + x
d = q * -0.00011315941810607910156 + d
d = q * -1.9841872589410058936e-09 + d
d = q * -1.2154201256553420762e-10 + d
return d
return _reduce_d(d, (q := _quadrant(d))), q.cast(dtypes.int32)
m_1_pi = 0.318309886183790671537767526745028724
qdh = (d * (m_1_pi / 2.0**24)).cast(dtypes.int64).cast(d.dtype) * (2.0**24)
quadrant = rintk(d * m_1_pi -qdh) if d.dtype == dtypes.float64 else rintk(d * m_1_pi)
return _reduce_d(d, quadrant.cast(d.dtype)), quadrant.cast(dtypes.int32)
# *** approximate sine on small angle. ***
def trig_poly(d:UOp, coeff32, coeff64): return d * (polyN(d*d, coeff64) if d.dtype == dtypes.float64 else polyN(d*d, coeff32))