update xlog2 fp16 decomp to not use fp32 (#13955)

This commit is contained in:
chenyu
2026-01-01 11:18:29 -05:00
committed by GitHub
parent ce84a23142
commit ed222070f7
2 changed files with 46 additions and 8 deletions

View File

@@ -101,6 +101,44 @@ class TestFromFuzzer(unittest.TestCase):
_test_value(0)
_test_value(0.0000009)
class TestFloat16Log2(unittest.TestCase):
"""Tests for native float16 log2 implementation (no float32 cast)"""
@unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}")
def test_float16_log2_basic(self):
# basic values
test_values = [1.0, 2.0, 4.0, 0.5, 0.25, 10.0, 100.0, 1000.0]
with Context(TRANSCENDENTAL=2):
for val in test_values:
result = Tensor([val], dtype=dtypes.float16).log2().numpy()[0]
expected = np.log2(np.float16(val))
np.testing.assert_allclose(result, expected, rtol=1e-3, err_msg=f"log2({val})")
@unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}")
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and CI, "Nan handling differs on Vulkan")
def test_float16_log2_special(self):
# special values: inf, -inf, nan, 0, negative
with Context(TRANSCENDENTAL=2), np.errstate(all='ignore'):
# log2(inf) = inf
assert np.isinf(Tensor([np.inf], dtype=dtypes.float16).log2().numpy()[0])
# log2(0) = -inf
assert Tensor([0.0], dtype=dtypes.float16).log2().numpy()[0] == -np.inf
# log2(negative) = nan
assert np.isnan(Tensor([-1.0], dtype=dtypes.float16).log2().numpy()[0])
# log2(nan) = nan
assert np.isnan(Tensor([np.nan], dtype=dtypes.float16).log2().numpy()[0])
@unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}")
def test_float16_log2_denormal(self):
# test values near and below float16 min normal (6.1e-5)
# these exercise the denormal handling path with 2^10 scaling
test_values = [1e-4, 6e-5, 1e-5]
with Context(TRANSCENDENTAL=2):
for val in test_values:
result = Tensor([val], dtype=dtypes.float16).log2().numpy()[0]
expected = np.log2(np.float16(val))
# denormals have lower precision due to float16 limitations
np.testing.assert_allclose(result, expected, rtol=5e-2, err_msg=f"log2({val})")
class TestTranscendentalSchedule(unittest.TestCase):
@unittest.skipUnless(is_dtype_supported(dtypes.ulong), "Needs ulong")
def test_transcendental_sin_fusion(self):