From 5648b9788ef9568b7aca80c930aae120f63d9b85 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 31 Oct 2024 13:52:31 -0400 Subject: [PATCH] more xlog2 cleanups (#7451) following the notations in the paper closer --- tinygrad/codegen/transcendental.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/tinygrad/codegen/transcendental.py b/tinygrad/codegen/transcendental.py index 22702a20dc..bc4bdd6a57 100644 --- a/tinygrad/codegen/transcendental.py +++ b/tinygrad/codegen/transcendental.py @@ -215,19 +215,18 @@ def xexp2(d:UOp) -> UOp: def xlog2(d:UOp) -> UOp: """ Implements a 1.0 ULP approximation for UnaryOps.LOG2 - Paper: https://arxiv.org/pdf/2001.09258 + Paper: https://arxiv.org/pdf/2001.09258 5.5 """ assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES # TODO: float16 denormal need float32 to achieve precision if d.dtype == dtypes.float16: return xlog2(d.cast(dtypes.float32)).cast(dtypes.float16) FLT_MIN = d.const_like(1e-6 if d.dtype == dtypes.float16 else 1e-4) - d_orig = d - denormal_map = d.lt(FLT_MIN) - for _ in range(8): d = denormal_map.where(d * (2 ** 8), d) + is_denormal = d.lt(FLT_MIN) + a = is_denormal.where(d * (2 ** 64), d) - e = ilogb2k(d * (1.0 / 0.75)).cast(d.dtype) - m = ldexp3k(d, -e) - e = denormal_map.where(e + (-64), e) + e = ilogb2k(a * (1.0 / 0.75)).cast(a.dtype) + m = ldexp3k(a, -e) + e = is_denormal.where(e - 64, e) x = (m - 1.0) / (m + 1.0) x2 = x * x @@ -241,14 +240,14 @@ def xlog2(d:UOp) -> UOp: r = t * (x * x2) + (s_hi + s_lo) # log2(Inf) = Inf - r = d_orig.ne(math.inf).where(r, r.const_like(math.inf)) - # log2(x=-0.01) = NaN. where x < 0 - r = d_orig.lt(-0.0).where(r.const_like(math.nan), r) + r = d.ne(math.inf).where(r, r.const_like(math.inf)) + # log2(x) = NaN for x < 0 + r = d.lt(-0.0).where(r.const_like(math.nan), r) # log2(0) = -Inf, but we will compare using the value of y because 1e-200==0 is true. # log2_zero = the value of unmasked xlog2(0.0). - log2_zero = {dtypes.float64: -1087, dtypes.float32: -191, dtypes.float16: -79, None: -math.inf}[d.dtype] + log2_zero = {dtypes.float64: -1087, dtypes.float32: -191, dtypes.float16: -79}[d.dtype] r = r.ne(log2_zero).where(r, r.const_like(-math.inf)) # log2(NaN) = NaN - r = d_orig.ne(d_orig).where(r.const_like(math.nan), r) + r = d.ne(d).where(r.const_like(math.nan), r) # log2(-0.0) = -Inf. In certain devices like PTX, x == -0.0 won't be true. so making reciprocal. - return d_orig.reciprocal().ne(-math.inf).where(r, r.const_like(-math.inf)) + return d.reciprocal().ne(-math.inf).where(r, r.const_like(-math.inf))