update xlog2 fp16 decomp to not use fp32 (#13955)

This commit is contained in:
chenyu
2026-01-01 11:18:29 -05:00
committed by GitHub
parent ce84a23142
commit ed222070f7
2 changed files with 46 additions and 8 deletions

View File

@@ -101,6 +101,44 @@ class TestFromFuzzer(unittest.TestCase):
_test_value(0)
_test_value(0.0000009)
class TestFloat16Log2(unittest.TestCase):
"""Tests for native float16 log2 implementation (no float32 cast)"""
@unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}")
def test_float16_log2_basic(self):
# basic values
test_values = [1.0, 2.0, 4.0, 0.5, 0.25, 10.0, 100.0, 1000.0]
with Context(TRANSCENDENTAL=2):
for val in test_values:
result = Tensor([val], dtype=dtypes.float16).log2().numpy()[0]
expected = np.log2(np.float16(val))
np.testing.assert_allclose(result, expected, rtol=1e-3, err_msg=f"log2({val})")
@unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}")
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and CI, "Nan handling differs on Vulkan")
def test_float16_log2_special(self):
# special values: inf, -inf, nan, 0, negative
with Context(TRANSCENDENTAL=2), np.errstate(all='ignore'):
# log2(inf) = inf
assert np.isinf(Tensor([np.inf], dtype=dtypes.float16).log2().numpy()[0])
# log2(0) = -inf
assert Tensor([0.0], dtype=dtypes.float16).log2().numpy()[0] == -np.inf
# log2(negative) = nan
assert np.isnan(Tensor([-1.0], dtype=dtypes.float16).log2().numpy()[0])
# log2(nan) = nan
assert np.isnan(Tensor([np.nan], dtype=dtypes.float16).log2().numpy()[0])
@unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}")
def test_float16_log2_denormal(self):
# test values near and below float16 min normal (6.1e-5)
# these exercise the denormal handling path with 2^10 scaling
test_values = [1e-4, 6e-5, 1e-5]
with Context(TRANSCENDENTAL=2):
for val in test_values:
result = Tensor([val], dtype=dtypes.float16).log2().numpy()[0]
expected = np.log2(np.float16(val))
# denormals have lower precision due to float16 limitations
np.testing.assert_allclose(result, expected, rtol=5e-2, err_msg=f"log2({val})")
class TestTranscendentalSchedule(unittest.TestCase):
@unittest.skipUnless(is_dtype_supported(dtypes.ulong), "Needs ulong")
def test_transcendental_sin_fusion(self):

View File

@@ -223,26 +223,26 @@ def xlog2(d:UOp) -> UOp:
Paper: https://arxiv.org/pdf/2001.09258 5.5
"""
assert d.dtype.scalar() in TRANSCENDENTAL_DTYPES
# TODO: float16 denormal need float32 to achieve precision
if d.dtype.scalar() == dtypes.float16: return xlog2(d.cast(dtypes.float32)).cast(dtypes.float16)
FLT_MIN = d.const_like(1e-6 if d.dtype.scalar() == dtypes.float16 else 1e-4)
# float16 uses 2^10 for denormal scaling (2^64 overflows), float32/64 use 2^64
denormal_exp = 10 if d.dtype.scalar() == dtypes.float16 else 64
FLT_MIN = d.const_like({dtypes.float16: 6.1e-5, dtypes.float32: 1e-4, dtypes.float64: 1e-4}[d.dtype.scalar()])
is_denormal = d<FLT_MIN
a = is_denormal.where(d * (2 ** 64), d)
a = is_denormal.where(d * (2 ** denormal_exp), d)
e = ilogb2k(a * (1.0 / 0.75)).cast(a.dtype)
m = ldexp3k(a, -e)
e = is_denormal.where(e - 64, e)
e = is_denormal.where(e - denormal_exp, e)
x = (m - 1.0) / (m + 1.0)
x2 = x * x
if d.dtype.scalar() == dtypes.float64:
t = polyN(x2, [0.2211941750456081490e+0, 0.2200768693152277689e+0, 0.2623708057488514656e+0, 0.3205977477944495502e+0,
0.4121985945485324709e+0, 0.5770780162997058982e+0, 0.96179669392608091449])
s_hi, s_lo = e+x*2.885390081777926774, e.const_like(0)
r = t * (x * x2) + e + x * 2.885390081777926774
else:
t = polyN(x2, [0.4374550283e+0, 0.5764790177e+0, 0.9618012905120])
s_hi, s_lo = e+x*2.8853900432586669922, x*3.2734474483568488616e-08
r = t * (x * x2) + (s_hi + s_lo)
# s_lo term (x*3.27e-08) only for float32 - underflows in float16
r = t * (x * x2) + e + x * 2.8853900432586669922 + (x * 3.2734474483568488616e-08 if d.dtype.scalar() == dtypes.float32 else 0)
# log2(Inf) = Inf
r = d.ne(math.inf).where(r, r.const_like(math.inf))