cleanup xlog2 and remove unneeded functions (#7446)

denormal_map still looks wrong but a lot cleaner
This commit is contained in:
chenyu
2024-10-31 09:45:16 -04:00
committed by GitHub
parent 02636bc05e
commit 5085b2fde7

View File

@@ -9,14 +9,7 @@ TRANSCENDENTAL_SUPPORTED_DTYPES = (dtypes.float16, dtypes.float32, dtypes.float6
def _lazy_map_numbers(x:UOp, inf:UOp, _inf:UOp, nan:UOp, ratio:UOp):
"""replace inf -> inf, -inf -> _inf, nan -> nan, otherwise -> ratio"""
return x.ne(math.inf).where(x.ne(x).where(nan, x.ne(-math.inf).where(ratio, _inf)), inf)
# *** helper functions for double/quad precision arithmetics ***
def dfadd2_f2_f2_f2(xx:UOp, xy:UOp, yx:UOp, yy:UOp) -> Tuple[UOp, UOp]: return xx + yx, xy + yy
def dfmul2_f2_f2_f2(xx:UOp, xy:UOp, yx:UOp, yy:UOp) -> Tuple[UOp, UOp]: return xx * yx, xx * yy + xy * yx
def dfdiv2_f2_f2_f2(nx:UOp, ny:UOp, dx:UOp, dy:UOp) -> Tuple[UOp, UOp]:
t = dx.reciprocal()
qx = nx * t
qy = (ny - qx * dy) * t
return qx, qy
# *** helper functions for bit manipulation ***
def mantissa_bits(d:DType) -> int: return dtypes.finfo(d)[1]
def exponent_bias(d:DType) -> int: return {dtypes.float64: 1023, dtypes.float32: 127, dtypes.float16: 15}[d]
@@ -238,22 +231,17 @@ def xlog2(d:UOp) -> UOp:
m = ldexp3k(d, -e)
e = denormal_map.where(e + (-64), e)
x = (m - 1.0) / (m + 1.0)
x2 = x * x
if d.dtype == dtypes.float64:
x = (m - 1.0) * (m + 1.0).reciprocal()
x2 = x * x
t = polyN(x2, [0.2211941750456081490e+0, 0.2200768693152277689e+0, 0.2623708057488514656e+0, 0.3205977477944495502e+0,
0.4121985945485324709e+0, 0.5770780162997058982e+0, 0.96179669392608091449])
s_hi, s_lo = dfadd2_f2_f2_f2(e, e.const_like(0), *dfmul2_f2_f2_f2(t.const_like(2.885390081777926774), t.const_like(0), x, x.const_like(0)))
r = t * (x * x2) + (s_hi + s_lo)
s_hi, s_lo = e+x*2.885390081777926774, e.const_like(0)
else:
xx, xy = dfdiv2_f2_f2_f2(*dfadd2_f2_f2_f2(m.const_like(-1), m.const_like(0), m, m.const_like(0)),
*dfadd2_f2_f2_f2(m.const_like(1), m.const_like(0), m, m.const_like(0)))
x2 = xx * xx
t = polyN(x2, [0.4374550283e+0, 0.5764790177e+0, 0.9618012905120])
sx, sy = dfadd2_f2_f2_f2(e, e.const_like(0),
*dfmul2_f2_f2_f2(xx, xy, xx.const_like(2.8853900432586669922), xy.const_like(3.2734474483568488616e-08)))
sx, sy = dfadd2_f2_f2_f2(sx, sy, x2.const_like(0), (x2 * xx) * t)
r = sx + sy
s_hi, s_lo = e+x*2.8853900432586669922, x*3.2734474483568488616e-08
r = t * (x * x2) + (s_hi + s_lo)
# log2(Inf) = Inf
r = d_orig.ne(math.inf).where(r, r.const_like(math.inf))
# log2(x=-0.01) = NaN. where x < 0