diff --git a/tinygrad/codegen/transcendental.py b/tinygrad/codegen/transcendental.py index b8dc78eb97..f83ba0d714 100644 --- a/tinygrad/codegen/transcendental.py +++ b/tinygrad/codegen/transcendental.py @@ -227,35 +227,35 @@ def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp: # sin(Inf) = NaN, sin(-Inf) = NaN, sin(NaN) = NaN return _lazy_map_numbers(d, d.const(math.nan), d.const(math.nan), d.const(math.nan), result) -def xexp2(x:UOp) -> UOp: +def xexp2(d:UOp) -> UOp: """ Implements a 1.0 ULP approximation for UnaryOps.EXP2 - Paper: https://arxiv.org/pdf/2001.09258 """ - assert x.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES - fp64_p = x.dtype == dtypes.float64 + assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES + fp64_p = d.dtype == dtypes.float64 # mask +=inf/nan as zero. - d = _lazy_map_numbers(x, x.const(0.0), x.const(0.0), x.const(0.0), x) - q = rintk(d) + x = _lazy_map_numbers(d, d.const(0.0), d.const(0.0), d.const(0.0), d) + q = rintk(x) # s = d - round(d) - s = d - q.cast(d.dtype) + s = x - q.cast(x.dtype) # a polynomial approximation with 13 non-zero terms in the range of [−(log 2)/2,(log 2)/2]. if fp64_p: u = polyN(s.const(0.4434359082926529454e-9), s, [0.7073164598085707425e-8, 0.1017819260921760451e-6, 0.1321543872511327615e-5, 0.1525273353517584730e-4, 0.1540353045101147808e-3, 0.1333355814670499073e-2, 0.9618129107597600536e-2, 0.5550410866482046596e-1, 0.2402265069591012214e+0, 0.6931471805599452862e+0, 0.1000000000000000000e+1]) # noqa: E501 else: u = polyN(s.const(0.1535920892e-3), s, [0.1339262701e-2, 0.9618384764e-2, 0.5550347269e-1, 0.2402264476e+0, 0.6931471825e+0, 0.1000000000e+1]) u = ldexp2k(u, q) # u*2^q - upper = {dtypes.float64: 1024, dtypes.float32: 128, dtypes.float16: 23.0}[d.dtype] - lower = {dtypes.float64: -2000, dtypes.float32: -150, dtypes.float16: -22}[d.dtype] + upper = {dtypes.float64: 1024, dtypes.float32: 128, dtypes.float16: 23.0}[x.dtype] + lower = {dtypes.float64: -2000, dtypes.float32: -150, dtypes.float16: -22}[x.dtype] # Replace x >= upper with +inf - u = d.ne(upper).where(u, d.const(math.inf)) - u = d.lt(upper).where(u, d.const(math.inf)) + u = x.ne(upper).where(u, x.const(math.inf)) + u = x.lt(upper).where(u, x.const(math.inf)) # Replace x <= lower with zero. - u = d.lt(lower).where(d.const(0.0), u) + u = x.lt(lower).where(x.const(0.0), u) # x=NaN never satisfies x < Inf. (for fastmode) - u = d.lt(math.inf).where(u, u.const(math.nan)) + u = x.lt(math.inf).where(u, u.const(math.nan)) # exp2(Inf) = Inf, exp2(-Inf) = 0, exp2(NaN) = NaN - return _lazy_map_numbers(x, x.const(math.inf), x.const(0.0), x.const(math.nan), u) + return _lazy_map_numbers(d, d.const(math.inf), d.const(0.0), d.const(math.nan), u) def xlog2(d:UOp) -> UOp: """ diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index b394dc378c..46d7ba26ec 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Iterator, Optional, Tuple, Dict, List, Set, Union, cast, TYPE_CHECKING, Any, DefaultDict +from typing import Iterator, Optional, Tuple, Dict, List, Set, Union, cast, TYPE_CHECKING, Any, DefaultDict, Callable import functools, itertools, heapq, math from collections import defaultdict from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType @@ -78,11 +78,8 @@ float4_folding = PatternMatcher([ # ***** transcendental ***** -transcendental_folding = PatternMatcher([ - (UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat(name="x"),), arg=UnaryOps.EXP2), xexp2), - (UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat(name="d"),), arg=UnaryOps.LOG2), xlog2), - (UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat(name="d"),), arg=UnaryOps.SIN), xsin), -]) +transcendental_folding = PatternMatcher([(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat(name="d"),), arg=k), cast(Callable, v)) + for k,v in ((UnaryOps.EXP2, xexp2), (UnaryOps.LOG2, xlog2), (UnaryOps.SIN, xsin))]) # ***** threefry *****