mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-29 00:38:10 -05:00
more xlog2 cleanups (#7451)
following the notations in the paper closer
This commit is contained in:
@@ -215,19 +215,18 @@ def xexp2(d:UOp) -> UOp:
|
||||
def xlog2(d:UOp) -> UOp:
|
||||
"""
|
||||
Implements a 1.0 ULP approximation for UnaryOps.LOG2
|
||||
Paper: https://arxiv.org/pdf/2001.09258
|
||||
Paper: https://arxiv.org/pdf/2001.09258 5.5
|
||||
"""
|
||||
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
# TODO: float16 denormal need float32 to achieve precision
|
||||
if d.dtype == dtypes.float16: return xlog2(d.cast(dtypes.float32)).cast(dtypes.float16)
|
||||
FLT_MIN = d.const_like(1e-6 if d.dtype == dtypes.float16 else 1e-4)
|
||||
d_orig = d
|
||||
denormal_map = d.lt(FLT_MIN)
|
||||
for _ in range(8): d = denormal_map.where(d * (2 ** 8), d)
|
||||
is_denormal = d.lt(FLT_MIN)
|
||||
a = is_denormal.where(d * (2 ** 64), d)
|
||||
|
||||
e = ilogb2k(d * (1.0 / 0.75)).cast(d.dtype)
|
||||
m = ldexp3k(d, -e)
|
||||
e = denormal_map.where(e + (-64), e)
|
||||
e = ilogb2k(a * (1.0 / 0.75)).cast(a.dtype)
|
||||
m = ldexp3k(a, -e)
|
||||
e = is_denormal.where(e - 64, e)
|
||||
|
||||
x = (m - 1.0) / (m + 1.0)
|
||||
x2 = x * x
|
||||
@@ -241,14 +240,14 @@ def xlog2(d:UOp) -> UOp:
|
||||
r = t * (x * x2) + (s_hi + s_lo)
|
||||
|
||||
# log2(Inf) = Inf
|
||||
r = d_orig.ne(math.inf).where(r, r.const_like(math.inf))
|
||||
# log2(x=-0.01) = NaN. where x < 0
|
||||
r = d_orig.lt(-0.0).where(r.const_like(math.nan), r)
|
||||
r = d.ne(math.inf).where(r, r.const_like(math.inf))
|
||||
# log2(x) = NaN for x < 0
|
||||
r = d.lt(-0.0).where(r.const_like(math.nan), r)
|
||||
# log2(0) = -Inf, but we will compare using the value of y because 1e-200==0 is true.
|
||||
# log2_zero = the value of unmasked xlog2(0.0).
|
||||
log2_zero = {dtypes.float64: -1087, dtypes.float32: -191, dtypes.float16: -79, None: -math.inf}[d.dtype]
|
||||
log2_zero = {dtypes.float64: -1087, dtypes.float32: -191, dtypes.float16: -79}[d.dtype]
|
||||
r = r.ne(log2_zero).where(r, r.const_like(-math.inf))
|
||||
# log2(NaN) = NaN
|
||||
r = d_orig.ne(d_orig).where(r.const_like(math.nan), r)
|
||||
r = d.ne(d).where(r.const_like(math.nan), r)
|
||||
# log2(-0.0) = -Inf. In certain devices like PTX, x == -0.0 won't be true. so making reciprocal.
|
||||
return d_orig.reciprocal().ne(-math.inf).where(r, r.const_like(-math.inf))
|
||||
return d.reciprocal().ne(-math.inf).where(r, r.const_like(-math.inf))
|
||||
|
||||
Reference in New Issue
Block a user