diff --git a/tinygrad/codegen/transcendental.py b/tinygrad/codegen/transcendental.py index e4229d718c..b8dc78eb97 100644 --- a/tinygrad/codegen/transcendental.py +++ b/tinygrad/codegen/transcendental.py @@ -1,4 +1,4 @@ -import math +import math, functools from typing import Tuple, List from tinygrad.dtype import dtypes, DType from tinygrad.codegen.uops import UOp @@ -17,17 +17,9 @@ def dfdiv2_f2_f2_f2(nx:UOp, ny:UOp, dx:UOp, dy:UOp) -> Tuple[UOp, UOp]: qy = (ny - qx * dy) * t return qx, qy # *** helper functions for bit manipulation *** -def significand_bits(d:DType) -> int: - assert d in TRANSCENDENTAL_SUPPORTED_DTYPES - return {dtypes.float64: 52, dtypes.float32: 23, dtypes.float16: 10}[d] - -def exponent_bias(d:DType) -> int: - assert d in TRANSCENDENTAL_SUPPORTED_DTYPES - return {dtypes.float64: 1022, dtypes.float32: 126, dtypes.float16: 14}[d] - -def exponent_mask(d:DType) -> int: - assert d in TRANSCENDENTAL_SUPPORTED_DTYPES - return {dtypes.float64: 0x7FF, dtypes.float32: 0xFF, dtypes.float16: 0x1F}[d] +def significand_bits(d:DType) -> int: return {dtypes.float64: 52, dtypes.float32: 23, dtypes.float16: 10}[d] +def exponent_bias(d:DType) -> int: return {dtypes.float64: 1022, dtypes.float32: 126, dtypes.float16: 14}[d] +def exponent_mask(d:DType) -> int: return {dtypes.float64: 0x7FF, dtypes.float32: 0xFF, dtypes.float16: 0x1F}[d] def float_to_bits(d:UOp) -> UOp: assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES @@ -96,9 +88,7 @@ def frexp(v:UOp) -> Tuple[UOp, UOp]: def mla(x:UOp, y:UOp, z:UOp) -> UOp: return x * y + z -def polyN(u:UOp, s:UOp, coeffs:List[float]) -> UOp: - for c in coeffs: u = mla(u, s, u.const(c)) - return u +def polyN(u:UOp, s:UOp, coeffs:List[float]) -> UOp: return functools.reduce(lambda u,c: mla(u, s, u.const(c)), coeffs, u) # *** reduction algorithms for sine *** def payne_hanek_reduction(d:UOp) -> Tuple[UOp, UOp]: """ @@ -148,16 +138,10 @@ def payne_hanek_reduction(d:UOp) -> Tuple[UOp, UOp]: q = shr(p, 62).cast(dtypes.int32) p = p & 0x3fffffffffffffff + r = (p.cast(dtype_via) * (3.4061215800865545e-19)).cast(input_dtype) - d = p.cast(dtype_via) - d = d * (3.4061215800865545e-19) - r = d.cast(input_dtype) - - fraction_map = f.lt(0.5) # if fraction >= 0.5, r -= pi/2, q += 1 - r = fraction_map.where(r, r + r.const(-math.pi / 2)) - q = fraction_map.where(q, q + 1) - return r, q + return f.lt(0.5).where(r, r + r.const(-math.pi / 2)), f.lt(0.5).where(q, q + 1) def cody_waite_reduction(d:UOp) -> Tuple[UOp, UOp]: """ @@ -231,8 +215,7 @@ def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp: x_sign = x.ne(0).where(x.lt(0).where(x.const(-1), x.const(1)), x.const(0)) x_abs = x * x_sign r, q = reduction_algo(x_abs) - if fast: - result = sin_poly_small(r, q) + if fast: result = sin_poly_small(r, q) else: # Payne Hanek Reduction assumes abs(x) >= pi/4, so for smaller values, use cody_waite_reduction. switch_over_map = x_abs.lt(switch_over) @@ -314,5 +297,4 @@ def xlog2(d:UOp) -> UOp: # log(NaN) = NaN, using for all real number x, either of x < Inf, x == Inf becomes True. r = d_orig.lt(math.inf).where(r, d_orig.ne(math.inf).where(d.const(math.nan), d)) # log(-0.0) = -Inf. In certain devices like PTX, x == -0.0 won't be true. so making reciprocal. - r = d_orig.recip().ne(-math.inf).where(r, r.const(-math.inf)) - return r + return d_orig.recip().ne(-math.inf).where(r, r.const(-math.inf))