remove noqa: E501 from transcendental

This commit is contained in:
George Hotz
2024-08-10 10:30:05 -07:00
parent 350276e947
commit 8302dd6ea4

View File

@@ -188,7 +188,11 @@ def trig_poly(d:UOp, coeff32, coeff64):
u = polyN(s.const(coeff32[0]), s, coeff32[1:])
return mla(s, u * d, d)
# approximate sine on [-pi/2, pi/2]
def sin_poly(d:UOp) -> UOp: return trig_poly(d, [2.6083159809786593541503e-06, -0.0001981069071916863322258, 0.00833307858556509017944336, -0.166666597127914428710938], [-7.97255955009037868891952e-18, 2.81009972710863200091251e-15, -7.64712219118158833288484e-13, 1.60590430605664501629054e-10, -2.50521083763502045810755e-08, 2.75573192239198747630416e-06, -0.000198412698412696162806809, 0.00833333333333332974823815, -0.166666666666666657414808]) # noqa: E501
def sin_poly(d:UOp) -> UOp:
return trig_poly(d, [2.6083159809786593541503e-06, -0.0001981069071916863322258, 0.00833307858556509017944336, -0.166666597127914428710938],
[-7.97255955009037868891952e-18, 2.81009972710863200091251e-15, -7.64712219118158833288484e-13, 1.60590430605664501629054e-10,
-2.50521083763502045810755e-08, 2.75573192239198747630416e-06, -0.000198412698412696162806809, 0.00833333333333332974823815,
-0.166666666666666657414808])
def sin_poly_small(d:UOp, q:UOp) -> UOp:
def _ifand(n:int): return (q & n).ne(0)
@@ -200,7 +204,9 @@ def sin_poly_large(d:UOp, q:UOp) -> UOp:
d = d + _ifand(1).where(d.const(math.pi / 2), d.const(0))
r = sin_poly(d)
return r * _ifand(2).where(r.const(-1), r.const(1))
# *** toplevel functions for xsin/xlog2/xexp2 ***
def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp:
"""
Implements a 1.0 ULP approximation for UnaryOps.SIN.
@@ -241,7 +247,10 @@ def xexp2(d:UOp) -> UOp:
s = x - q.cast(x.dtype)
# a polynomial approximation with 13 non-zero terms in the range of [(log 2)/2,(log 2)/2].
if fp64_p:
u = polyN(s.const(0.4434359082926529454e-9), s, [0.7073164598085707425e-8, 0.1017819260921760451e-6, 0.1321543872511327615e-5, 0.1525273353517584730e-4, 0.1540353045101147808e-3, 0.1333355814670499073e-2, 0.9618129107597600536e-2, 0.5550410866482046596e-1, 0.2402265069591012214e+0, 0.6931471805599452862e+0, 0.1000000000000000000e+1]) # noqa: E501
u = polyN(s.const(0.4434359082926529454e-9), s, [0.7073164598085707425e-8, 0.1017819260921760451e-6, 0.1321543872511327615e-5,
0.1525273353517584730e-4, 0.1540353045101147808e-3, 0.1333355814670499073e-2,
0.9618129107597600536e-2, 0.5550410866482046596e-1, 0.2402265069591012214e+0,
0.6931471805599452862e+0, 0.1000000000000000000e+1])
else:
u = polyN(s.const(0.1535920892e-3), s, [0.1339262701e-2, 0.9618384764e-2, 0.5550347269e-1, 0.2402264476e+0, 0.6931471825e+0, 0.1000000000e+1])
u = ldexp2k(u, q) # u*2^q
@@ -276,7 +285,8 @@ def xlog2(d:UOp) -> UOp:
if fp64_p:
x = (m - 1.0) * (m + 1.0).recip()
x2 = x * x
t = polyN(x.const(0.2211941750456081490e+0), x2, [0.2200768693152277689e+0, 0.2623708057488514656e+0, 0.3205977477944495502e+0, 0.4121985945485324709e+0, 0.5770780162997058982e+0, 0.96179669392608091449]) # noqa: E501
t = polyN(x.const(0.2211941750456081490e+0), x2, [0.2200768693152277689e+0, 0.2623708057488514656e+0, 0.3205977477944495502e+0,
0.4121985945485324709e+0, 0.5770780162997058982e+0, 0.96179669392608091449])
s_hi, s_lo = dfadd2_f2_f2_f2(e, e.const(0), *dfmul2_f2_f2_f2(t.const(2.885390081777926774), t.const(0), x, x.const(0)))
r = mla(t, x * x2, s_hi + s_lo)
else: