lil transcendental folding cleanup [run_process_replay] (#5822)

* lil transcendental folding cleanup [run_process_replay]

* idk why function isn't Callable
This commit is contained in:
George Hotz
2024-07-30 14:10:17 -07:00
committed by GitHub
parent 693990a346
commit 3630208a01
2 changed files with 16 additions and 19 deletions

View File

@@ -227,35 +227,35 @@ def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp:
# sin(Inf) = NaN, sin(-Inf) = NaN, sin(NaN) = NaN
return _lazy_map_numbers(d, d.const(math.nan), d.const(math.nan), d.const(math.nan), result)
def xexp2(x:UOp) -> UOp:
def xexp2(d:UOp) -> UOp:
"""
Implements a 1.0 ULP approximation for UnaryOps.EXP2
- Paper: https://arxiv.org/pdf/2001.09258
"""
assert x.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
fp64_p = x.dtype == dtypes.float64
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
fp64_p = d.dtype == dtypes.float64
# mask +=inf/nan as zero.
d = _lazy_map_numbers(x, x.const(0.0), x.const(0.0), x.const(0.0), x)
q = rintk(d)
x = _lazy_map_numbers(d, d.const(0.0), d.const(0.0), d.const(0.0), d)
q = rintk(x)
# s = d - round(d)
s = d - q.cast(d.dtype)
s = x - q.cast(x.dtype)
# a polynomial approximation with 13 non-zero terms in the range of [(log 2)/2,(log 2)/2].
if fp64_p:
u = polyN(s.const(0.4434359082926529454e-9), s, [0.7073164598085707425e-8, 0.1017819260921760451e-6, 0.1321543872511327615e-5, 0.1525273353517584730e-4, 0.1540353045101147808e-3, 0.1333355814670499073e-2, 0.9618129107597600536e-2, 0.5550410866482046596e-1, 0.2402265069591012214e+0, 0.6931471805599452862e+0, 0.1000000000000000000e+1]) # noqa: E501
else:
u = polyN(s.const(0.1535920892e-3), s, [0.1339262701e-2, 0.9618384764e-2, 0.5550347269e-1, 0.2402264476e+0, 0.6931471825e+0, 0.1000000000e+1])
u = ldexp2k(u, q) # u*2^q
upper = {dtypes.float64: 1024, dtypes.float32: 128, dtypes.float16: 23.0}[d.dtype]
lower = {dtypes.float64: -2000, dtypes.float32: -150, dtypes.float16: -22}[d.dtype]
upper = {dtypes.float64: 1024, dtypes.float32: 128, dtypes.float16: 23.0}[x.dtype]
lower = {dtypes.float64: -2000, dtypes.float32: -150, dtypes.float16: -22}[x.dtype]
# Replace x >= upper with +inf
u = d.ne(upper).where(u, d.const(math.inf))
u = d.lt(upper).where(u, d.const(math.inf))
u = x.ne(upper).where(u, x.const(math.inf))
u = x.lt(upper).where(u, x.const(math.inf))
# Replace x <= lower with zero.
u = d.lt(lower).where(d.const(0.0), u)
u = x.lt(lower).where(x.const(0.0), u)
# x=NaN never satisfies x < Inf. (for fastmode)
u = d.lt(math.inf).where(u, u.const(math.nan))
u = x.lt(math.inf).where(u, u.const(math.nan))
# exp2(Inf) = Inf, exp2(-Inf) = 0, exp2(NaN) = NaN
return _lazy_map_numbers(x, x.const(math.inf), x.const(0.0), x.const(math.nan), u)
return _lazy_map_numbers(d, d.const(math.inf), d.const(0.0), d.const(math.nan), u)
def xlog2(d:UOp) -> UOp:
"""

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Iterator, Optional, Tuple, Dict, List, Set, Union, cast, TYPE_CHECKING, Any, DefaultDict
from typing import Iterator, Optional, Tuple, Dict, List, Set, Union, cast, TYPE_CHECKING, Any, DefaultDict, Callable
import functools, itertools, heapq, math
from collections import defaultdict
from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType
@@ -78,11 +78,8 @@ float4_folding = PatternMatcher([
# ***** transcendental *****
transcendental_folding = PatternMatcher([
(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat(name="x"),), arg=UnaryOps.EXP2), xexp2),
(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat(name="d"),), arg=UnaryOps.LOG2), xlog2),
(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat(name="d"),), arg=UnaryOps.SIN), xsin),
])
transcendental_folding = PatternMatcher([(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat(name="d"),), arg=k), cast(Callable, v))
for k,v in ((UnaryOps.EXP2, xexp2), (UnaryOps.LOG2, xlog2), (UnaryOps.SIN, xsin))])
# ***** threefry *****