more xlog2 cleanups (#7451)

following the notations in the paper closer
This commit is contained in:
chenyu
2024-10-31 13:52:31 -04:00
committed by GitHub
parent 4065c3dec8
commit 5648b9788e

View File

@@ -215,19 +215,18 @@ def xexp2(d:UOp) -> UOp:
def xlog2(d:UOp) -> UOp:
"""
Implements a 1.0 ULP approximation for UnaryOps.LOG2
Paper: https://arxiv.org/pdf/2001.09258
Paper: https://arxiv.org/pdf/2001.09258 5.5
"""
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
# TODO: float16 denormal need float32 to achieve precision
if d.dtype == dtypes.float16: return xlog2(d.cast(dtypes.float32)).cast(dtypes.float16)
FLT_MIN = d.const_like(1e-6 if d.dtype == dtypes.float16 else 1e-4)
d_orig = d
denormal_map = d.lt(FLT_MIN)
for _ in range(8): d = denormal_map.where(d * (2 ** 8), d)
is_denormal = d.lt(FLT_MIN)
a = is_denormal.where(d * (2 ** 64), d)
e = ilogb2k(d * (1.0 / 0.75)).cast(d.dtype)
m = ldexp3k(d, -e)
e = denormal_map.where(e + (-64), e)
e = ilogb2k(a * (1.0 / 0.75)).cast(a.dtype)
m = ldexp3k(a, -e)
e = is_denormal.where(e - 64, e)
x = (m - 1.0) / (m + 1.0)
x2 = x * x
@@ -241,14 +240,14 @@ def xlog2(d:UOp) -> UOp:
r = t * (x * x2) + (s_hi + s_lo)
# log2(Inf) = Inf
r = d_orig.ne(math.inf).where(r, r.const_like(math.inf))
# log2(x=-0.01) = NaN. where x < 0
r = d_orig.lt(-0.0).where(r.const_like(math.nan), r)
r = d.ne(math.inf).where(r, r.const_like(math.inf))
# log2(x) = NaN for x < 0
r = d.lt(-0.0).where(r.const_like(math.nan), r)
# log2(0) = -Inf, but we will compare using the value of y because 1e-200==0 is true.
# log2_zero = the value of unmasked xlog2(0.0).
log2_zero = {dtypes.float64: -1087, dtypes.float32: -191, dtypes.float16: -79, None: -math.inf}[d.dtype]
log2_zero = {dtypes.float64: -1087, dtypes.float32: -191, dtypes.float16: -79}[d.dtype]
r = r.ne(log2_zero).where(r, r.const_like(-math.inf))
# log2(NaN) = NaN
r = d_orig.ne(d_orig).where(r.const_like(math.nan), r)
r = d.ne(d).where(r.const_like(math.nan), r)
# log2(-0.0) = -Inf. In certain devices like PTX, x == -0.0 won't be true. so making reciprocal.
return d_orig.reciprocal().ne(-math.inf).where(r, r.const_like(-math.inf))
return d.reciprocal().ne(-math.inf).where(r, r.const_like(-math.inf))