shave some lines from transcend math [run_process_replay] (#5500)

* shave some lines from transcend math [run_process_replay]

* put input_dtype back
This commit is contained in:
chenyu
2024-07-15 21:02:24 -04:00
committed by GitHub
parent 63990705b5
commit fd43d33b7d

View File

@@ -1,4 +1,4 @@
import math
import math, functools
from typing import Tuple, List
from tinygrad.dtype import dtypes, DType
from tinygrad.codegen.uops import UOp
@@ -17,17 +17,9 @@ def dfdiv2_f2_f2_f2(nx:UOp, ny:UOp, dx:UOp, dy:UOp) -> Tuple[UOp, UOp]:
qy = (ny - qx * dy) * t
return qx, qy
# *** helper functions for bit manipulation ***
def significand_bits(d:DType) -> int:
assert d in TRANSCENDENTAL_SUPPORTED_DTYPES
return {dtypes.float64: 52, dtypes.float32: 23, dtypes.float16: 10}[d]
def exponent_bias(d:DType) -> int:
assert d in TRANSCENDENTAL_SUPPORTED_DTYPES
return {dtypes.float64: 1022, dtypes.float32: 126, dtypes.float16: 14}[d]
def exponent_mask(d:DType) -> int:
assert d in TRANSCENDENTAL_SUPPORTED_DTYPES
return {dtypes.float64: 0x7FF, dtypes.float32: 0xFF, dtypes.float16: 0x1F}[d]
def significand_bits(d:DType) -> int: return {dtypes.float64: 52, dtypes.float32: 23, dtypes.float16: 10}[d]
def exponent_bias(d:DType) -> int: return {dtypes.float64: 1022, dtypes.float32: 126, dtypes.float16: 14}[d]
def exponent_mask(d:DType) -> int: return {dtypes.float64: 0x7FF, dtypes.float32: 0xFF, dtypes.float16: 0x1F}[d]
def float_to_bits(d:UOp) -> UOp:
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
@@ -96,9 +88,7 @@ def frexp(v:UOp) -> Tuple[UOp, UOp]:
def mla(x:UOp, y:UOp, z:UOp) -> UOp: return x * y + z
def polyN(u:UOp, s:UOp, coeffs:List[float]) -> UOp:
for c in coeffs: u = mla(u, s, u.const(c))
return u
def polyN(u:UOp, s:UOp, coeffs:List[float]) -> UOp: return functools.reduce(lambda u,c: mla(u, s, u.const(c)), coeffs, u)
# *** reduction algorithms for sine ***
def payne_hanek_reduction(d:UOp) -> Tuple[UOp, UOp]:
"""
@@ -148,16 +138,10 @@ def payne_hanek_reduction(d:UOp) -> Tuple[UOp, UOp]:
q = shr(p, 62).cast(dtypes.int32)
p = p & 0x3fffffffffffffff
r = (p.cast(dtype_via) * (3.4061215800865545e-19)).cast(input_dtype)
d = p.cast(dtype_via)
d = d * (3.4061215800865545e-19)
r = d.cast(input_dtype)
fraction_map = f.lt(0.5)
# if fraction >= 0.5, r -= pi/2, q += 1
r = fraction_map.where(r, r + r.const(-math.pi / 2))
q = fraction_map.where(q, q + 1)
return r, q
return f.lt(0.5).where(r, r + r.const(-math.pi / 2)), f.lt(0.5).where(q, q + 1)
def cody_waite_reduction(d:UOp) -> Tuple[UOp, UOp]:
"""
@@ -231,8 +215,7 @@ def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp:
x_sign = x.ne(0).where(x.lt(0).where(x.const(-1), x.const(1)), x.const(0))
x_abs = x * x_sign
r, q = reduction_algo(x_abs)
if fast:
result = sin_poly_small(r, q)
if fast: result = sin_poly_small(r, q)
else:
# Payne Hanek Reduction assumes abs(x) >= pi/4, so for smaller values, use cody_waite_reduction.
switch_over_map = x_abs.lt(switch_over)
@@ -314,5 +297,4 @@ def xlog2(d:UOp) -> UOp:
# log(NaN) = NaN, using for all real number x, either of x < Inf, x == Inf becomes True.
r = d_orig.lt(math.inf).where(r, d_orig.ne(math.inf).where(d.const(math.nan), d))
# log(-0.0) = -Inf. In certain devices like PTX, x == -0.0 won't be true. so making reciprocal.
r = d_orig.recip().ne(-math.inf).where(r, r.const(-math.inf))
return r
return d_orig.recip().ne(-math.inf).where(r, r.const(-math.inf))