From 6bf38c35e54a32cd0a06b54c764c31087560e8a3 Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 29 Oct 2024 18:51:37 -0400 Subject: [PATCH] clean up transcendental frexp [pr] (#7384) also added some unit tests for frexp --- test/unit/test_transcendental_helpers.py | 23 ++++++++++++++++++++++- tinygrad/codegen/transcendental.py | 20 ++++++++++---------- tinygrad/dtype.py | 3 ++- 3 files changed, 34 insertions(+), 12 deletions(-) diff --git a/test/unit/test_transcendental_helpers.py b/test/unit/test_transcendental_helpers.py index 93eb806bf8..52c9bb26db 100644 --- a/test/unit/test_transcendental_helpers.py +++ b/test/unit/test_transcendental_helpers.py @@ -3,7 +3,7 @@ import numpy as np from tinygrad import dtypes from tinygrad.ops import UOp, UOps from tinygrad.codegen.uopgraph import full_graph_rewrite -from tinygrad.codegen.transcendental import payne_hanek_reduction, cody_waite_reduction +from tinygrad.codegen.transcendental import payne_hanek_reduction, cody_waite_reduction, frexp from tinygrad.codegen.linearize import linearize_uop from tinygrad.runtime.ops_python import PythonProgram, PythonRenderer, PythonCompiler, PythonAllocator @@ -28,5 +28,26 @@ class TestReduction(unittest.TestCase): # TODO: should q be in [0, 1, 2, 3]? np.testing.assert_equal(q, 12) + def test_frexp(self): + mantissa, exponent = (self._run_uop(u) for u in frexp(UOp.const(dtypes.float64, 0.0))) + np.testing.assert_equal(mantissa, 0.0) + np.testing.assert_equal(exponent, 0) + + mantissa, exponent = (self._run_uop(u) for u in frexp(UOp.const(dtypes.float64, 1.0))) + np.testing.assert_equal(mantissa, 0.5) + np.testing.assert_equal(exponent, 1) + + mantissa, exponent = (self._run_uop(u) for u in frexp(UOp.const(dtypes.float64, -1.0))) + np.testing.assert_equal(mantissa, 0.5) + np.testing.assert_equal(exponent, 1) + + mantissa, exponent = (self._run_uop(u) for u in frexp(UOp.const(dtypes.float64, 2.0))) + np.testing.assert_equal(mantissa, 0.5) + np.testing.assert_equal(exponent, 2) + + mantissa, exponent = (self._run_uop(u) for u in frexp(UOp.const(dtypes.float64, 5.0))) + np.testing.assert_equal(mantissa, 0.625) + np.testing.assert_equal(exponent, 3) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/codegen/transcendental.py b/tinygrad/codegen/transcendental.py index 26254a006d..97279844e9 100644 --- a/tinygrad/codegen/transcendental.py +++ b/tinygrad/codegen/transcendental.py @@ -17,7 +17,7 @@ def dfdiv2_f2_f2_f2(nx:UOp, ny:UOp, dx:UOp, dy:UOp) -> Tuple[UOp, UOp]: qy = (ny - qx * dy) * t return qx, qy # *** helper functions for bit manipulation *** -def significand_bits(d:DType) -> int: return dtypes.finfo(d)[1] +def mantissa_bits(d:DType) -> int: return dtypes.finfo(d)[1] def exponent_bias(d:DType) -> int: return {dtypes.float64: 1023, dtypes.float32: 127, dtypes.float16: 15}[d] def exponent_mask(d:DType) -> int: return {dtypes.float64: 2047, dtypes.float32: 255, dtypes.float16: 31}[d] @@ -35,21 +35,21 @@ def pow2if(q:UOp, float_dtype:DType): """cast(2^q, float_dtype) where q is any integer in the range of [-126, 127]""" assert q.dtype in (dtypes.int64, dtypes.int32, dtypes.int16, dtypes.uint32) final_dtype = {dtypes.int64: dtypes.float64, dtypes.int32: dtypes.float32, dtypes.int16: float_dtype, dtypes.uint32: dtypes.float32}[q.dtype] - return shl(q + exponent_bias(final_dtype), significand_bits(final_dtype)).bitcast(final_dtype) + return shl(q + exponent_bias(final_dtype), mantissa_bits(final_dtype)).bitcast(final_dtype) def ilogb2k(d:UOp) -> UOp: """calculate the integer part of log2(d), where d is normalized fp value in the range of [0, +inf).""" assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES dint = d.bitcast({dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype]) # -1 <= ilog2bk(d) <= 128 - return (shr(dint, significand_bits(d.dtype)) & exponent_mask(d.dtype)) - exponent_bias(d.dtype) + return (shr(dint, mantissa_bits(d.dtype)) & exponent_mask(d.dtype)) - exponent_bias(d.dtype) def ldexp3k(d:UOp, e:UOp) -> UOp: """d*2^e. e is a number obtained by casting an integer in the range [-127, 127] to a float. d is any float number.""" assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES cast_map = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16} m1 = d.bitcast(cast_map[d.dtype]) - m2 = shl(e.cast(cast_map[d.dtype]), significand_bits(d.dtype)) + m2 = shl(e.cast(cast_map[d.dtype]), mantissa_bits(d.dtype)) return (m1 + m2).bitcast(d.dtype).cast(d.dtype) def ldexp2k(d:UOp, e:UOp) -> UOp: @@ -64,15 +64,15 @@ def frexp(v:UOp) -> Tuple[UOp, UOp]: m1 = {dtypes.float64: 0x000FFFFFFFFFFFFF, dtypes.float32: 0x807FFFFF, dtypes.float16: 0x83FF}[v.dtype] m2 = {dtypes.float64: 0x3FE0000000000000, dtypes.float32: 0x3F000000, dtypes.float16: 0x3800}[v.dtype] bits = v.bitcast({dtypes.float64: dtypes.uint64, dtypes.float32: dtypes.uint32, dtypes.float16: dtypes.uint16}[v.dtype]) - exponent = shr(bits, significand_bits(v.dtype)) & exponent_mask(v.dtype) - exponent_zero = exponent.ne(0) + exponent = shr(bits, mantissa_bits(v.dtype)) & exponent_mask(v.dtype) # Set the exponent bits appropriately to normalize the mantissa into the range of [0.5, 1.0). - result_f = ((bits & m1) | m2).bitcast(v.dtype) - value = exponent_zero.where(result_f, v) + mantissa = ((bits & m1) | m2).bitcast(v.dtype) exp = exponent - exponent_bias(v.dtype) + 1 - exp = exponent_zero.where(exp, exp.const_like(0)) + # special case of 0 # TODO: can we remove this case? + mantissa = exponent.ne(0).where(mantissa, v) + exp = exponent.ne(0).where(exp, exp.const_like(0)) if v.dtype == dtypes.float16: exp = exp.bitcast(dtypes.int16) - return value, exp + return mantissa, exp def polyN(s:UOp, coeffs:List[float]) -> UOp: return functools.reduce(lambda u,c: u*s+c, coeffs, s.const_like(0)) diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index 2cdf985b78..f2236324c7 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -79,7 +79,8 @@ class dtypes: if dtypes.is_int(dtype): return (2**(dtype.itemsize*8-(0 if dtypes.is_unsigned(dtype) else 1)))-1 return float("inf") if dtypes.is_float(dtype) else True @staticmethod - def finfo(dtype:DType) -> Tuple[int, int]: # (exponent, mantissa) + def finfo(dtype:DType) -> Tuple[int, int]: + """(exponent, mantissa)""" if not dtypes.is_float(dtype): raise ValueError(f"{dtype} is not a floating point type") return {dtypes.float16: (5, 10), dtypes.bfloat16: (8, 7), dtypes.float32: (8, 23), dtypes.float64: (11, 52)}[dtype] @staticmethod