mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-29 00:38:10 -05:00
cleanup xlog2 and remove unneeded functions (#7446)
denormal_map still looks wrong but a lot cleaner
This commit is contained in:
@@ -9,14 +9,7 @@ TRANSCENDENTAL_SUPPORTED_DTYPES = (dtypes.float16, dtypes.float32, dtypes.float6
|
||||
def _lazy_map_numbers(x:UOp, inf:UOp, _inf:UOp, nan:UOp, ratio:UOp):
|
||||
"""replace inf -> inf, -inf -> _inf, nan -> nan, otherwise -> ratio"""
|
||||
return x.ne(math.inf).where(x.ne(x).where(nan, x.ne(-math.inf).where(ratio, _inf)), inf)
|
||||
# *** helper functions for double/quad precision arithmetics ***
|
||||
def dfadd2_f2_f2_f2(xx:UOp, xy:UOp, yx:UOp, yy:UOp) -> Tuple[UOp, UOp]: return xx + yx, xy + yy
|
||||
def dfmul2_f2_f2_f2(xx:UOp, xy:UOp, yx:UOp, yy:UOp) -> Tuple[UOp, UOp]: return xx * yx, xx * yy + xy * yx
|
||||
def dfdiv2_f2_f2_f2(nx:UOp, ny:UOp, dx:UOp, dy:UOp) -> Tuple[UOp, UOp]:
|
||||
t = dx.reciprocal()
|
||||
qx = nx * t
|
||||
qy = (ny - qx * dy) * t
|
||||
return qx, qy
|
||||
|
||||
# *** helper functions for bit manipulation ***
|
||||
def mantissa_bits(d:DType) -> int: return dtypes.finfo(d)[1]
|
||||
def exponent_bias(d:DType) -> int: return {dtypes.float64: 1023, dtypes.float32: 127, dtypes.float16: 15}[d]
|
||||
@@ -238,22 +231,17 @@ def xlog2(d:UOp) -> UOp:
|
||||
m = ldexp3k(d, -e)
|
||||
e = denormal_map.where(e + (-64), e)
|
||||
|
||||
x = (m - 1.0) / (m + 1.0)
|
||||
x2 = x * x
|
||||
if d.dtype == dtypes.float64:
|
||||
x = (m - 1.0) * (m + 1.0).reciprocal()
|
||||
x2 = x * x
|
||||
t = polyN(x2, [0.2211941750456081490e+0, 0.2200768693152277689e+0, 0.2623708057488514656e+0, 0.3205977477944495502e+0,
|
||||
0.4121985945485324709e+0, 0.5770780162997058982e+0, 0.96179669392608091449])
|
||||
s_hi, s_lo = dfadd2_f2_f2_f2(e, e.const_like(0), *dfmul2_f2_f2_f2(t.const_like(2.885390081777926774), t.const_like(0), x, x.const_like(0)))
|
||||
r = t * (x * x2) + (s_hi + s_lo)
|
||||
s_hi, s_lo = e+x*2.885390081777926774, e.const_like(0)
|
||||
else:
|
||||
xx, xy = dfdiv2_f2_f2_f2(*dfadd2_f2_f2_f2(m.const_like(-1), m.const_like(0), m, m.const_like(0)),
|
||||
*dfadd2_f2_f2_f2(m.const_like(1), m.const_like(0), m, m.const_like(0)))
|
||||
x2 = xx * xx
|
||||
t = polyN(x2, [0.4374550283e+0, 0.5764790177e+0, 0.9618012905120])
|
||||
sx, sy = dfadd2_f2_f2_f2(e, e.const_like(0),
|
||||
*dfmul2_f2_f2_f2(xx, xy, xx.const_like(2.8853900432586669922), xy.const_like(3.2734474483568488616e-08)))
|
||||
sx, sy = dfadd2_f2_f2_f2(sx, sy, x2.const_like(0), (x2 * xx) * t)
|
||||
r = sx + sy
|
||||
s_hi, s_lo = e+x*2.8853900432586669922, x*3.2734474483568488616e-08
|
||||
r = t * (x * x2) + (s_hi + s_lo)
|
||||
|
||||
# log2(Inf) = Inf
|
||||
r = d_orig.ne(math.inf).where(r, r.const_like(math.inf))
|
||||
# log2(x=-0.01) = NaN. where x < 0
|
||||
|
||||
Reference in New Issue
Block a user