diff --git a/tinygrad/codegen/transcendental.py b/tinygrad/codegen/transcendental.py index 142f6b6311..5c0f753b67 100644 --- a/tinygrad/codegen/transcendental.py +++ b/tinygrad/codegen/transcendental.py @@ -9,14 +9,7 @@ TRANSCENDENTAL_SUPPORTED_DTYPES = (dtypes.float16, dtypes.float32, dtypes.float6 def _lazy_map_numbers(x:UOp, inf:UOp, _inf:UOp, nan:UOp, ratio:UOp): """replace inf -> inf, -inf -> _inf, nan -> nan, otherwise -> ratio""" return x.ne(math.inf).where(x.ne(x).where(nan, x.ne(-math.inf).where(ratio, _inf)), inf) -# *** helper functions for double/quad precision arithmetics *** -def dfadd2_f2_f2_f2(xx:UOp, xy:UOp, yx:UOp, yy:UOp) -> Tuple[UOp, UOp]: return xx + yx, xy + yy -def dfmul2_f2_f2_f2(xx:UOp, xy:UOp, yx:UOp, yy:UOp) -> Tuple[UOp, UOp]: return xx * yx, xx * yy + xy * yx -def dfdiv2_f2_f2_f2(nx:UOp, ny:UOp, dx:UOp, dy:UOp) -> Tuple[UOp, UOp]: - t = dx.reciprocal() - qx = nx * t - qy = (ny - qx * dy) * t - return qx, qy + # *** helper functions for bit manipulation *** def mantissa_bits(d:DType) -> int: return dtypes.finfo(d)[1] def exponent_bias(d:DType) -> int: return {dtypes.float64: 1023, dtypes.float32: 127, dtypes.float16: 15}[d] @@ -238,22 +231,17 @@ def xlog2(d:UOp) -> UOp: m = ldexp3k(d, -e) e = denormal_map.where(e + (-64), e) + x = (m - 1.0) / (m + 1.0) + x2 = x * x if d.dtype == dtypes.float64: - x = (m - 1.0) * (m + 1.0).reciprocal() - x2 = x * x t = polyN(x2, [0.2211941750456081490e+0, 0.2200768693152277689e+0, 0.2623708057488514656e+0, 0.3205977477944495502e+0, 0.4121985945485324709e+0, 0.5770780162997058982e+0, 0.96179669392608091449]) - s_hi, s_lo = dfadd2_f2_f2_f2(e, e.const_like(0), *dfmul2_f2_f2_f2(t.const_like(2.885390081777926774), t.const_like(0), x, x.const_like(0))) - r = t * (x * x2) + (s_hi + s_lo) + s_hi, s_lo = e+x*2.885390081777926774, e.const_like(0) else: - xx, xy = dfdiv2_f2_f2_f2(*dfadd2_f2_f2_f2(m.const_like(-1), m.const_like(0), m, m.const_like(0)), - *dfadd2_f2_f2_f2(m.const_like(1), m.const_like(0), m, m.const_like(0))) - x2 = xx * xx t = polyN(x2, [0.4374550283e+0, 0.5764790177e+0, 0.9618012905120]) - sx, sy = dfadd2_f2_f2_f2(e, e.const_like(0), - *dfmul2_f2_f2_f2(xx, xy, xx.const_like(2.8853900432586669922), xy.const_like(3.2734474483568488616e-08))) - sx, sy = dfadd2_f2_f2_f2(sx, sy, x2.const_like(0), (x2 * xx) * t) - r = sx + sy + s_hi, s_lo = e+x*2.8853900432586669922, x*3.2734474483568488616e-08 + r = t * (x * x2) + (s_hi + s_lo) + # log2(Inf) = Inf r = d_orig.ne(math.inf).where(r, r.const_like(math.inf)) # log2(x=-0.01) = NaN. where x < 0