mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
tiny clena up pow2if and payne_hanek_reduction (#7423)
This commit is contained in:
@@ -2,7 +2,7 @@ import unittest, math
|
||||
import numpy as np
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.ops import UOp
|
||||
from tinygrad.codegen.transcendental import payne_hanek_reduction, cody_waite_reduction, frexp, rintk
|
||||
from tinygrad.codegen.transcendental import payne_hanek_reduction, cody_waite_reduction, frexp, rintk, pow2if
|
||||
from test.helpers import eval_uop
|
||||
|
||||
class TestTranscendentalFunctions(unittest.TestCase):
|
||||
@@ -48,5 +48,16 @@ class TestTranscendentalFunctions(unittest.TestCase):
|
||||
np.testing.assert_allclose(eval_uop(rintk(UOp.const(dtypes.float, -5.5))), -6)
|
||||
np.testing.assert_allclose(eval_uop(rintk(UOp.const(dtypes.float, -5.999))), -6)
|
||||
|
||||
def test_pow2if(self):
|
||||
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, 0), dtypes.float)), 1.0)
|
||||
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, 1), dtypes.float)), 2.0)
|
||||
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, 2), dtypes.float)), 4.0)
|
||||
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, 10), dtypes.float)), 1024.0)
|
||||
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, 63), dtypes.float)), 2**63)
|
||||
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, -1), dtypes.float)), 0.5)
|
||||
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, -2), dtypes.float)), 0.25)
|
||||
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, -10), dtypes.float)), 2**-10)
|
||||
np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, -63), dtypes.float)), 2**-63)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -33,9 +33,8 @@ def rintk(d:UOp) -> UOp:
|
||||
|
||||
def pow2if(q:UOp, float_dtype:DType):
|
||||
"""cast(2^q, float_dtype) where q is any integer in the range of [-126, 127]"""
|
||||
assert q.dtype in (dtypes.int64, dtypes.int32, dtypes.int16, dtypes.uint32)
|
||||
final_dtype = {dtypes.int64: dtypes.float64, dtypes.int32: dtypes.float32, dtypes.int16: float_dtype, dtypes.uint32: dtypes.float32}[q.dtype]
|
||||
return shl(q + exponent_bias(final_dtype), mantissa_bits(final_dtype)).bitcast(final_dtype)
|
||||
out_dtype = {dtypes.int64: dtypes.float64, dtypes.int32: dtypes.float32, dtypes.int16: float_dtype}[q.dtype]
|
||||
return shl(q + exponent_bias(out_dtype), mantissa_bits(out_dtype)).bitcast(out_dtype)
|
||||
|
||||
def ilogb2k(d:UOp) -> UOp:
|
||||
"""calculate the integer part of log2(d), where d is normalized fp value in the range of [0, +inf)."""
|
||||
@@ -96,14 +95,13 @@ def payne_hanek_reduction(d:UOp) -> Tuple[UOp, UOp]:
|
||||
ia = (f.cast(dtype_via) * 4.294967296e9).cast(dtypes.uint64)
|
||||
# extract 96 relevant bits of 2/pi based on magnitude of argument
|
||||
i = shr(e.cast(dtypes.uint64), 5)
|
||||
e = (e.cast(dtypes.uint64) & 31).cast(dtypes.uint32)
|
||||
offset = -e + 32
|
||||
e = e.cast(dtypes.int32) & 31
|
||||
offset = 32 - e
|
||||
|
||||
def _eq(arr:UOp, eq_to:int) -> UOp: return arr.ne(eq_to)
|
||||
def _take(an:UOp, offset:int, count:int=0) -> UOp:
|
||||
"""an = two_over_pi_f[i+offset]"""
|
||||
if count+offset <= len(two_over_pi_f[0:-2]):
|
||||
an = _eq(i, count).where(_take(an, offset, count=count+1), an.const_like(two_over_pi_f[count+offset]))
|
||||
an = i.ne(count).where(_take(an, offset, count=count+1), an.const_like(two_over_pi_f[count+offset]))
|
||||
return an
|
||||
def _shl_lazy(x, y): return (x.cast(dtypes.uint64) * pow2if(y, input_dtype).cast(dtypes.uint64)).cast(dtypes.uint32)
|
||||
def _shr_lazy(x, y): return (x.cast(dtypes.uint64) // pow2if(y, input_dtype).cast(dtypes.uint64)).cast(dtypes.uint32)
|
||||
|
||||
Reference in New Issue
Block a user