tiny clena up pow2if and payne_hanek_reduction (#7423)

This commit is contained in:
chenyu
2024-10-30 22:22:48 -04:00
committed by GitHub
parent 118dd7721f
commit 0739895b4d
2 changed files with 17 additions and 8 deletions

View File

@@ -2,7 +2,7 @@ import unittest, math
import numpy as np
from tinygrad import dtypes
from tinygrad.ops import UOp
from tinygrad.codegen.transcendental import payne_hanek_reduction, cody_waite_reduction, frexp, rintk
from tinygrad.codegen.transcendental import payne_hanek_reduction, cody_waite_reduction, frexp, rintk, pow2if
from test.helpers import eval_uop
class TestTranscendentalFunctions(unittest.TestCase):
@@ -48,5 +48,16 @@ class TestTranscendentalFunctions(unittest.TestCase):
np.testing.assert_allclose(eval_uop(rintk(UOp.const(dtypes.float, -5.5))), -6)
np.testing.assert_allclose(eval_uop(rintk(UOp.const(dtypes.float, -5.999))), -6)
def test_pow2if(self):
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, 0), dtypes.float)), 1.0)
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, 1), dtypes.float)), 2.0)
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, 2), dtypes.float)), 4.0)
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, 10), dtypes.float)), 1024.0)
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, 63), dtypes.float)), 2**63)
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, -1), dtypes.float)), 0.5)
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, -2), dtypes.float)), 0.25)
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, -10), dtypes.float)), 2**-10)
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, -63), dtypes.float)), 2**-63)
if __name__ == '__main__':
unittest.main()

View File

@@ -33,9 +33,8 @@ def rintk(d:UOp) -> UOp:
def pow2if(q:UOp, float_dtype:DType):
"""cast(2^q, float_dtype) where q is any integer in the range of [-126, 127]"""
assert q.dtype in (dtypes.int64, dtypes.int32, dtypes.int16, dtypes.uint32)
final_dtype = {dtypes.int64: dtypes.float64, dtypes.int32: dtypes.float32, dtypes.int16: float_dtype, dtypes.uint32: dtypes.float32}[q.dtype]
return shl(q + exponent_bias(final_dtype), mantissa_bits(final_dtype)).bitcast(final_dtype)
out_dtype = {dtypes.int64: dtypes.float64, dtypes.int32: dtypes.float32, dtypes.int16: float_dtype}[q.dtype]
return shl(q + exponent_bias(out_dtype), mantissa_bits(out_dtype)).bitcast(out_dtype)
def ilogb2k(d:UOp) -> UOp:
"""calculate the integer part of log2(d), where d is normalized fp value in the range of [0, +inf)."""
@@ -96,14 +95,13 @@ def payne_hanek_reduction(d:UOp) -> Tuple[UOp, UOp]:
ia = (f.cast(dtype_via) * 4.294967296e9).cast(dtypes.uint64)
# extract 96 relevant bits of 2/pi based on magnitude of argument
i = shr(e.cast(dtypes.uint64), 5)
e = (e.cast(dtypes.uint64) & 31).cast(dtypes.uint32)
offset = -e + 32
e = e.cast(dtypes.int32) & 31
offset = 32 - e
def _eq(arr:UOp, eq_to:int) -> UOp: return arr.ne(eq_to)
def _take(an:UOp, offset:int, count:int=0) -> UOp:
"""an = two_over_pi_f[i+offset]"""
if count+offset <= len(two_over_pi_f[0:-2]):
an = _eq(i, count).where(_take(an, offset, count=count+1), an.const_like(two_over_pi_f[count+offset]))
an = i.ne(count).where(_take(an, offset, count=count+1), an.const_like(two_over_pi_f[count+offset]))
return an
def _shl_lazy(x, y): return (x.cast(dtypes.uint64) * pow2if(y, input_dtype).cast(dtypes.uint64)).cast(dtypes.uint32)
def _shr_lazy(x, y): return (x.cast(dtypes.uint64) // pow2if(y, input_dtype).cast(dtypes.uint64)).cast(dtypes.uint32)