mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 13:28:06 -05:00
lil transcendental folding cleanup [run_process_replay] (#5822)
* lil transcendental folding cleanup [run_process_replay] * idk why function isn't Callable
This commit is contained in:
@@ -227,35 +227,35 @@ def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp:
|
||||
# sin(Inf) = NaN, sin(-Inf) = NaN, sin(NaN) = NaN
|
||||
return _lazy_map_numbers(d, d.const(math.nan), d.const(math.nan), d.const(math.nan), result)
|
||||
|
||||
def xexp2(x:UOp) -> UOp:
|
||||
def xexp2(d:UOp) -> UOp:
|
||||
"""
|
||||
Implements a 1.0 ULP approximation for UnaryOps.EXP2
|
||||
- Paper: https://arxiv.org/pdf/2001.09258
|
||||
"""
|
||||
assert x.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
fp64_p = x.dtype == dtypes.float64
|
||||
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
fp64_p = d.dtype == dtypes.float64
|
||||
# mask +=inf/nan as zero.
|
||||
d = _lazy_map_numbers(x, x.const(0.0), x.const(0.0), x.const(0.0), x)
|
||||
q = rintk(d)
|
||||
x = _lazy_map_numbers(d, d.const(0.0), d.const(0.0), d.const(0.0), d)
|
||||
q = rintk(x)
|
||||
# s = d - round(d)
|
||||
s = d - q.cast(d.dtype)
|
||||
s = x - q.cast(x.dtype)
|
||||
# a polynomial approximation with 13 non-zero terms in the range of [−(log 2)/2,(log 2)/2].
|
||||
if fp64_p:
|
||||
u = polyN(s.const(0.4434359082926529454e-9), s, [0.7073164598085707425e-8, 0.1017819260921760451e-6, 0.1321543872511327615e-5, 0.1525273353517584730e-4, 0.1540353045101147808e-3, 0.1333355814670499073e-2, 0.9618129107597600536e-2, 0.5550410866482046596e-1, 0.2402265069591012214e+0, 0.6931471805599452862e+0, 0.1000000000000000000e+1]) # noqa: E501
|
||||
else:
|
||||
u = polyN(s.const(0.1535920892e-3), s, [0.1339262701e-2, 0.9618384764e-2, 0.5550347269e-1, 0.2402264476e+0, 0.6931471825e+0, 0.1000000000e+1])
|
||||
u = ldexp2k(u, q) # u*2^q
|
||||
upper = {dtypes.float64: 1024, dtypes.float32: 128, dtypes.float16: 23.0}[d.dtype]
|
||||
lower = {dtypes.float64: -2000, dtypes.float32: -150, dtypes.float16: -22}[d.dtype]
|
||||
upper = {dtypes.float64: 1024, dtypes.float32: 128, dtypes.float16: 23.0}[x.dtype]
|
||||
lower = {dtypes.float64: -2000, dtypes.float32: -150, dtypes.float16: -22}[x.dtype]
|
||||
# Replace x >= upper with +inf
|
||||
u = d.ne(upper).where(u, d.const(math.inf))
|
||||
u = d.lt(upper).where(u, d.const(math.inf))
|
||||
u = x.ne(upper).where(u, x.const(math.inf))
|
||||
u = x.lt(upper).where(u, x.const(math.inf))
|
||||
# Replace x <= lower with zero.
|
||||
u = d.lt(lower).where(d.const(0.0), u)
|
||||
u = x.lt(lower).where(x.const(0.0), u)
|
||||
# x=NaN never satisfies x < Inf. (for fastmode)
|
||||
u = d.lt(math.inf).where(u, u.const(math.nan))
|
||||
u = x.lt(math.inf).where(u, u.const(math.nan))
|
||||
# exp2(Inf) = Inf, exp2(-Inf) = 0, exp2(NaN) = NaN
|
||||
return _lazy_map_numbers(x, x.const(math.inf), x.const(0.0), x.const(math.nan), u)
|
||||
return _lazy_map_numbers(d, d.const(math.inf), d.const(0.0), d.const(math.nan), u)
|
||||
|
||||
def xlog2(d:UOp) -> UOp:
|
||||
"""
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from typing import Iterator, Optional, Tuple, Dict, List, Set, Union, cast, TYPE_CHECKING, Any, DefaultDict
|
||||
from typing import Iterator, Optional, Tuple, Dict, List, Set, Union, cast, TYPE_CHECKING, Any, DefaultDict, Callable
|
||||
import functools, itertools, heapq, math
|
||||
from collections import defaultdict
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType
|
||||
@@ -78,11 +78,8 @@ float4_folding = PatternMatcher([
|
||||
|
||||
# ***** transcendental *****
|
||||
|
||||
transcendental_folding = PatternMatcher([
|
||||
(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat(name="x"),), arg=UnaryOps.EXP2), xexp2),
|
||||
(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat(name="d"),), arg=UnaryOps.LOG2), xlog2),
|
||||
(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat(name="d"),), arg=UnaryOps.SIN), xsin),
|
||||
])
|
||||
transcendental_folding = PatternMatcher([(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat(name="d"),), arg=k), cast(Callable, v))
|
||||
for k,v in ((UnaryOps.EXP2, xexp2), (UnaryOps.LOG2, xlog2), (UnaryOps.SIN, xsin))])
|
||||
|
||||
# ***** threefry *****
|
||||
|
||||
|
||||
Reference in New Issue
Block a user