diff --git a/test/test_transcendental.py b/test/test_transcendental.py index 503286773c..92b91a819c 100644 --- a/test/test_transcendental.py +++ b/test/test_transcendental.py @@ -101,6 +101,44 @@ class TestFromFuzzer(unittest.TestCase): _test_value(0) _test_value(0.0000009) +class TestFloat16Log2(unittest.TestCase): + """Tests for native float16 log2 implementation (no float32 cast)""" + @unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}") + def test_float16_log2_basic(self): + # basic values + test_values = [1.0, 2.0, 4.0, 0.5, 0.25, 10.0, 100.0, 1000.0] + with Context(TRANSCENDENTAL=2): + for val in test_values: + result = Tensor([val], dtype=dtypes.float16).log2().numpy()[0] + expected = np.log2(np.float16(val)) + np.testing.assert_allclose(result, expected, rtol=1e-3, err_msg=f"log2({val})") + + @unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}") + @unittest.skipIf(Device.DEFAULT == "WEBGPU" and CI, "Nan handling differs on Vulkan") + def test_float16_log2_special(self): + # special values: inf, -inf, nan, 0, negative + with Context(TRANSCENDENTAL=2), np.errstate(all='ignore'): + # log2(inf) = inf + assert np.isinf(Tensor([np.inf], dtype=dtypes.float16).log2().numpy()[0]) + # log2(0) = -inf + assert Tensor([0.0], dtype=dtypes.float16).log2().numpy()[0] == -np.inf + # log2(negative) = nan + assert np.isnan(Tensor([-1.0], dtype=dtypes.float16).log2().numpy()[0]) + # log2(nan) = nan + assert np.isnan(Tensor([np.nan], dtype=dtypes.float16).log2().numpy()[0]) + + @unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}") + def test_float16_log2_denormal(self): + # test values near and below float16 min normal (6.1e-5) + # these exercise the denormal handling path with 2^10 scaling + test_values = [1e-4, 6e-5, 1e-5] + with Context(TRANSCENDENTAL=2): + for val in test_values: + result = Tensor([val], dtype=dtypes.float16).log2().numpy()[0] + expected = np.log2(np.float16(val)) + # denormals have lower precision due to float16 limitations + np.testing.assert_allclose(result, expected, rtol=5e-2, err_msg=f"log2({val})") + class TestTranscendentalSchedule(unittest.TestCase): @unittest.skipUnless(is_dtype_supported(dtypes.ulong), "Needs ulong") def test_transcendental_sin_fusion(self): diff --git a/tinygrad/uop/decompositions.py b/tinygrad/uop/decompositions.py index 82dfe67316..f37d01bd58 100644 --- a/tinygrad/uop/decompositions.py +++ b/tinygrad/uop/decompositions.py @@ -223,26 +223,26 @@ def xlog2(d:UOp) -> UOp: Paper: https://arxiv.org/pdf/2001.09258 5.5 """ assert d.dtype.scalar() in TRANSCENDENTAL_DTYPES - # TODO: float16 denormal need float32 to achieve precision - if d.dtype.scalar() == dtypes.float16: return xlog2(d.cast(dtypes.float32)).cast(dtypes.float16) - FLT_MIN = d.const_like(1e-6 if d.dtype.scalar() == dtypes.float16 else 1e-4) + # float16 uses 2^10 for denormal scaling (2^64 overflows), float32/64 use 2^64 + denormal_exp = 10 if d.dtype.scalar() == dtypes.float16 else 64 + FLT_MIN = d.const_like({dtypes.float16: 6.1e-5, dtypes.float32: 1e-4, dtypes.float64: 1e-4}[d.dtype.scalar()]) is_denormal = d