mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
clean up cody_waite_reduction magic numbers (#7452)
This commit is contained in:
@@ -119,28 +119,33 @@ def cody_waite_reduction(d:UOp) -> Tuple[UOp, UOp]:
|
||||
0 <= abs(d) <= 39800.0
|
||||
Returns a tuple of `(r, q)`, where the output format is the same as that of `payne_hanek_reduction`.
|
||||
"""
|
||||
m_1_pi = 0.318309886183790671537767526745028724
|
||||
qdh = (d * (m_1_pi / 16777216)).cast(dtypes.int64).cast(d.dtype) * 16777216.0
|
||||
def _quadrant(x:UOp) -> UOp: return rintk(d * m_1_pi -qdh).cast(x.dtype) if x.dtype == dtypes.float64 else rintk(x * m_1_pi).cast(x.dtype)
|
||||
def _reduce_d(x:UOp, q:UOp):
|
||||
# https://github.com/shibatch/sleef/blob/4e08851f59fc2b545f9c393c6a23dfd311a26308/src/libm/sleefdp.c#L789-L823
|
||||
if x.dtype == dtypes.float64:
|
||||
d = qdh * -3.1415926218032836914 + x
|
||||
d = q * -3.1415926218032836914 + d
|
||||
d = qdh * -3.1786509424591713469e-08 + d
|
||||
d = q * -3.1786509424591713469e-08 + d
|
||||
d = qdh * -1.2246467864107188502e-16 + d
|
||||
d = q * -1.2246467864107188502e-16 + d
|
||||
d = (qdh + q) * -1.2736634327021899816e-24 + d
|
||||
# https://github.com/shibatch/sleef/blob/f6d8a841fbfddd26ce712834d4da220cd76048fb/src/common/misc.h#L77
|
||||
PI_A, PI_B, PI_C, PI_D = 3.1415926218032836914, 3.1786509424591713469e-08, 1.2246467864107188502e-16, 1.2736634327021899816e-24
|
||||
d = qdh * -PI_A + x
|
||||
d = q * -PI_A + d
|
||||
d = qdh * -PI_B + d
|
||||
d = q * -PI_B + d
|
||||
d = qdh * -PI_C + d
|
||||
d = q * -PI_C + d
|
||||
d = (qdh + q) * -PI_D + d
|
||||
elif x.dtype == dtypes.float16:
|
||||
# [FIXME] when reducing `d`, FP16 needs FP32 precision to achieve 1.0 ULP precision.
|
||||
d = _reduce_d(x.cast(dtypes.float32), q.cast(dtypes.float32)).cast(dtypes.float16)
|
||||
else:
|
||||
# https://github.com/shibatch/sleef/blob/4e08851f59fc2b545f9c393c6a23dfd311a26308/src/libm/sleefsp.c#L464-L503
|
||||
d = q * -3.1414794921875 + x
|
||||
d = q * -0.00011315941810607910156 + d
|
||||
d = q * -1.9841872589410058936e-09 + d
|
||||
d = q * -1.2154201256553420762e-10 + d
|
||||
return d
|
||||
return _reduce_d(d, (q := _quadrant(d))), q.cast(dtypes.int32)
|
||||
|
||||
m_1_pi = 0.318309886183790671537767526745028724
|
||||
qdh = (d * (m_1_pi / 2.0**24)).cast(dtypes.int64).cast(d.dtype) * (2.0**24)
|
||||
quadrant = rintk(d * m_1_pi -qdh) if d.dtype == dtypes.float64 else rintk(d * m_1_pi)
|
||||
return _reduce_d(d, quadrant.cast(d.dtype)), quadrant.cast(dtypes.int32)
|
||||
|
||||
# *** approximate sine on small angle. ***
|
||||
def trig_poly(d:UOp, coeff32, coeff64): return d * (polyN(d*d, coeff64) if d.dtype == dtypes.float64 else polyN(d*d, coeff32))
|
||||
|
||||
Reference in New Issue
Block a user