mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 22:08:08 -05:00
shave some lines from transcend math [run_process_replay] (#5500)
* shave some lines from transcend math [run_process_replay] * put input_dtype back
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import math
|
||||
import math, functools
|
||||
from typing import Tuple, List
|
||||
from tinygrad.dtype import dtypes, DType
|
||||
from tinygrad.codegen.uops import UOp
|
||||
@@ -17,17 +17,9 @@ def dfdiv2_f2_f2_f2(nx:UOp, ny:UOp, dx:UOp, dy:UOp) -> Tuple[UOp, UOp]:
|
||||
qy = (ny - qx * dy) * t
|
||||
return qx, qy
|
||||
# *** helper functions for bit manipulation ***
|
||||
def significand_bits(d:DType) -> int:
|
||||
assert d in TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
return {dtypes.float64: 52, dtypes.float32: 23, dtypes.float16: 10}[d]
|
||||
|
||||
def exponent_bias(d:DType) -> int:
|
||||
assert d in TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
return {dtypes.float64: 1022, dtypes.float32: 126, dtypes.float16: 14}[d]
|
||||
|
||||
def exponent_mask(d:DType) -> int:
|
||||
assert d in TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
return {dtypes.float64: 0x7FF, dtypes.float32: 0xFF, dtypes.float16: 0x1F}[d]
|
||||
def significand_bits(d:DType) -> int: return {dtypes.float64: 52, dtypes.float32: 23, dtypes.float16: 10}[d]
|
||||
def exponent_bias(d:DType) -> int: return {dtypes.float64: 1022, dtypes.float32: 126, dtypes.float16: 14}[d]
|
||||
def exponent_mask(d:DType) -> int: return {dtypes.float64: 0x7FF, dtypes.float32: 0xFF, dtypes.float16: 0x1F}[d]
|
||||
|
||||
def float_to_bits(d:UOp) -> UOp:
|
||||
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
@@ -96,9 +88,7 @@ def frexp(v:UOp) -> Tuple[UOp, UOp]:
|
||||
|
||||
def mla(x:UOp, y:UOp, z:UOp) -> UOp: return x * y + z
|
||||
|
||||
def polyN(u:UOp, s:UOp, coeffs:List[float]) -> UOp:
|
||||
for c in coeffs: u = mla(u, s, u.const(c))
|
||||
return u
|
||||
def polyN(u:UOp, s:UOp, coeffs:List[float]) -> UOp: return functools.reduce(lambda u,c: mla(u, s, u.const(c)), coeffs, u)
|
||||
# *** reduction algorithms for sine ***
|
||||
def payne_hanek_reduction(d:UOp) -> Tuple[UOp, UOp]:
|
||||
"""
|
||||
@@ -148,16 +138,10 @@ def payne_hanek_reduction(d:UOp) -> Tuple[UOp, UOp]:
|
||||
|
||||
q = shr(p, 62).cast(dtypes.int32)
|
||||
p = p & 0x3fffffffffffffff
|
||||
r = (p.cast(dtype_via) * (3.4061215800865545e-19)).cast(input_dtype)
|
||||
|
||||
d = p.cast(dtype_via)
|
||||
d = d * (3.4061215800865545e-19)
|
||||
r = d.cast(input_dtype)
|
||||
|
||||
fraction_map = f.lt(0.5)
|
||||
# if fraction >= 0.5, r -= pi/2, q += 1
|
||||
r = fraction_map.where(r, r + r.const(-math.pi / 2))
|
||||
q = fraction_map.where(q, q + 1)
|
||||
return r, q
|
||||
return f.lt(0.5).where(r, r + r.const(-math.pi / 2)), f.lt(0.5).where(q, q + 1)
|
||||
|
||||
def cody_waite_reduction(d:UOp) -> Tuple[UOp, UOp]:
|
||||
"""
|
||||
@@ -231,8 +215,7 @@ def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp:
|
||||
x_sign = x.ne(0).where(x.lt(0).where(x.const(-1), x.const(1)), x.const(0))
|
||||
x_abs = x * x_sign
|
||||
r, q = reduction_algo(x_abs)
|
||||
if fast:
|
||||
result = sin_poly_small(r, q)
|
||||
if fast: result = sin_poly_small(r, q)
|
||||
else:
|
||||
# Payne Hanek Reduction assumes abs(x) >= pi/4, so for smaller values, use cody_waite_reduction.
|
||||
switch_over_map = x_abs.lt(switch_over)
|
||||
@@ -314,5 +297,4 @@ def xlog2(d:UOp) -> UOp:
|
||||
# log(NaN) = NaN, using for all real number x, either of x < Inf, x == Inf becomes True.
|
||||
r = d_orig.lt(math.inf).where(r, d_orig.ne(math.inf).where(d.const(math.nan), d))
|
||||
# log(-0.0) = -Inf. In certain devices like PTX, x == -0.0 won't be true. so making reciprocal.
|
||||
r = d_orig.recip().ne(-math.inf).where(r, r.const(-math.inf))
|
||||
return r
|
||||
return d_orig.recip().ne(-math.inf).where(r, r.const(-math.inf))
|
||||
|
||||
Reference in New Issue
Block a user