mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
update xlog2 fp16 decomp to not use fp32 (#13955)
This commit is contained in:
@@ -101,6 +101,44 @@ class TestFromFuzzer(unittest.TestCase):
|
||||
_test_value(0)
|
||||
_test_value(0.0000009)
|
||||
|
||||
class TestFloat16Log2(unittest.TestCase):
|
||||
"""Tests for native float16 log2 implementation (no float32 cast)"""
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}")
|
||||
def test_float16_log2_basic(self):
|
||||
# basic values
|
||||
test_values = [1.0, 2.0, 4.0, 0.5, 0.25, 10.0, 100.0, 1000.0]
|
||||
with Context(TRANSCENDENTAL=2):
|
||||
for val in test_values:
|
||||
result = Tensor([val], dtype=dtypes.float16).log2().numpy()[0]
|
||||
expected = np.log2(np.float16(val))
|
||||
np.testing.assert_allclose(result, expected, rtol=1e-3, err_msg=f"log2({val})")
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}")
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and CI, "Nan handling differs on Vulkan")
|
||||
def test_float16_log2_special(self):
|
||||
# special values: inf, -inf, nan, 0, negative
|
||||
with Context(TRANSCENDENTAL=2), np.errstate(all='ignore'):
|
||||
# log2(inf) = inf
|
||||
assert np.isinf(Tensor([np.inf], dtype=dtypes.float16).log2().numpy()[0])
|
||||
# log2(0) = -inf
|
||||
assert Tensor([0.0], dtype=dtypes.float16).log2().numpy()[0] == -np.inf
|
||||
# log2(negative) = nan
|
||||
assert np.isnan(Tensor([-1.0], dtype=dtypes.float16).log2().numpy()[0])
|
||||
# log2(nan) = nan
|
||||
assert np.isnan(Tensor([np.nan], dtype=dtypes.float16).log2().numpy()[0])
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}")
|
||||
def test_float16_log2_denormal(self):
|
||||
# test values near and below float16 min normal (6.1e-5)
|
||||
# these exercise the denormal handling path with 2^10 scaling
|
||||
test_values = [1e-4, 6e-5, 1e-5]
|
||||
with Context(TRANSCENDENTAL=2):
|
||||
for val in test_values:
|
||||
result = Tensor([val], dtype=dtypes.float16).log2().numpy()[0]
|
||||
expected = np.log2(np.float16(val))
|
||||
# denormals have lower precision due to float16 limitations
|
||||
np.testing.assert_allclose(result, expected, rtol=5e-2, err_msg=f"log2({val})")
|
||||
|
||||
class TestTranscendentalSchedule(unittest.TestCase):
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.ulong), "Needs ulong")
|
||||
def test_transcendental_sin_fusion(self):
|
||||
|
||||
@@ -223,26 +223,26 @@ def xlog2(d:UOp) -> UOp:
|
||||
Paper: https://arxiv.org/pdf/2001.09258 5.5
|
||||
"""
|
||||
assert d.dtype.scalar() in TRANSCENDENTAL_DTYPES
|
||||
# TODO: float16 denormal need float32 to achieve precision
|
||||
if d.dtype.scalar() == dtypes.float16: return xlog2(d.cast(dtypes.float32)).cast(dtypes.float16)
|
||||
FLT_MIN = d.const_like(1e-6 if d.dtype.scalar() == dtypes.float16 else 1e-4)
|
||||
# float16 uses 2^10 for denormal scaling (2^64 overflows), float32/64 use 2^64
|
||||
denormal_exp = 10 if d.dtype.scalar() == dtypes.float16 else 64
|
||||
FLT_MIN = d.const_like({dtypes.float16: 6.1e-5, dtypes.float32: 1e-4, dtypes.float64: 1e-4}[d.dtype.scalar()])
|
||||
is_denormal = d<FLT_MIN
|
||||
a = is_denormal.where(d * (2 ** 64), d)
|
||||
a = is_denormal.where(d * (2 ** denormal_exp), d)
|
||||
|
||||
e = ilogb2k(a * (1.0 / 0.75)).cast(a.dtype)
|
||||
m = ldexp3k(a, -e)
|
||||
e = is_denormal.where(e - 64, e)
|
||||
e = is_denormal.where(e - denormal_exp, e)
|
||||
|
||||
x = (m - 1.0) / (m + 1.0)
|
||||
x2 = x * x
|
||||
if d.dtype.scalar() == dtypes.float64:
|
||||
t = polyN(x2, [0.2211941750456081490e+0, 0.2200768693152277689e+0, 0.2623708057488514656e+0, 0.3205977477944495502e+0,
|
||||
0.4121985945485324709e+0, 0.5770780162997058982e+0, 0.96179669392608091449])
|
||||
s_hi, s_lo = e+x*2.885390081777926774, e.const_like(0)
|
||||
r = t * (x * x2) + e + x * 2.885390081777926774
|
||||
else:
|
||||
t = polyN(x2, [0.4374550283e+0, 0.5764790177e+0, 0.9618012905120])
|
||||
s_hi, s_lo = e+x*2.8853900432586669922, x*3.2734474483568488616e-08
|
||||
r = t * (x * x2) + (s_hi + s_lo)
|
||||
# s_lo term (x*3.27e-08) only for float32 - underflows in float16
|
||||
r = t * (x * x2) + e + x * 2.8853900432586669922 + (x * 3.2734474483568488616e-08 if d.dtype.scalar() == dtypes.float32 else 0)
|
||||
|
||||
# log2(Inf) = Inf
|
||||
r = d.ne(math.inf).where(r, r.const_like(math.inf))
|
||||
|
||||
Reference in New Issue
Block a user