From 0739895b4d05439db93e2e9a226a44a628a28ead Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 30 Oct 2024 22:22:48 -0400 Subject: [PATCH] tiny clena up pow2if and payne_hanek_reduction (#7423) --- test/unit/test_transcendental_helpers.py | 13 ++++++++++++- tinygrad/codegen/transcendental.py | 12 +++++------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/test/unit/test_transcendental_helpers.py b/test/unit/test_transcendental_helpers.py index cca15e0a59..86b93e671d 100644 --- a/test/unit/test_transcendental_helpers.py +++ b/test/unit/test_transcendental_helpers.py @@ -2,7 +2,7 @@ import unittest, math import numpy as np from tinygrad import dtypes from tinygrad.ops import UOp -from tinygrad.codegen.transcendental import payne_hanek_reduction, cody_waite_reduction, frexp, rintk +from tinygrad.codegen.transcendental import payne_hanek_reduction, cody_waite_reduction, frexp, rintk, pow2if from test.helpers import eval_uop class TestTranscendentalFunctions(unittest.TestCase): @@ -48,5 +48,16 @@ class TestTranscendentalFunctions(unittest.TestCase): np.testing.assert_allclose(eval_uop(rintk(UOp.const(dtypes.float, -5.5))), -6) np.testing.assert_allclose(eval_uop(rintk(UOp.const(dtypes.float, -5.999))), -6) + def test_pow2if(self): + np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, 0), dtypes.float)), 1.0) + np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, 1), dtypes.float)), 2.0) + np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, 2), dtypes.float)), 4.0) + np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, 10), dtypes.float)), 1024.0) + np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, 63), dtypes.float)), 2**63) + np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, -1), dtypes.float)), 0.5) + np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, -2), dtypes.float)), 0.25) + np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, -10), dtypes.float)), 2**-10) + np.testing.assert_allclose(eval_uop(pow2if(UOp.const(dtypes.int, -63), dtypes.float)), 2**-63) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/codegen/transcendental.py b/tinygrad/codegen/transcendental.py index feb4af6909..e2fc9e2705 100644 --- a/tinygrad/codegen/transcendental.py +++ b/tinygrad/codegen/transcendental.py @@ -33,9 +33,8 @@ def rintk(d:UOp) -> UOp: def pow2if(q:UOp, float_dtype:DType): """cast(2^q, float_dtype) where q is any integer in the range of [-126, 127]""" - assert q.dtype in (dtypes.int64, dtypes.int32, dtypes.int16, dtypes.uint32) - final_dtype = {dtypes.int64: dtypes.float64, dtypes.int32: dtypes.float32, dtypes.int16: float_dtype, dtypes.uint32: dtypes.float32}[q.dtype] - return shl(q + exponent_bias(final_dtype), mantissa_bits(final_dtype)).bitcast(final_dtype) + out_dtype = {dtypes.int64: dtypes.float64, dtypes.int32: dtypes.float32, dtypes.int16: float_dtype}[q.dtype] + return shl(q + exponent_bias(out_dtype), mantissa_bits(out_dtype)).bitcast(out_dtype) def ilogb2k(d:UOp) -> UOp: """calculate the integer part of log2(d), where d is normalized fp value in the range of [0, +inf).""" @@ -96,14 +95,13 @@ def payne_hanek_reduction(d:UOp) -> Tuple[UOp, UOp]: ia = (f.cast(dtype_via) * 4.294967296e9).cast(dtypes.uint64) # extract 96 relevant bits of 2/pi based on magnitude of argument i = shr(e.cast(dtypes.uint64), 5) - e = (e.cast(dtypes.uint64) & 31).cast(dtypes.uint32) - offset = -e + 32 + e = e.cast(dtypes.int32) & 31 + offset = 32 - e - def _eq(arr:UOp, eq_to:int) -> UOp: return arr.ne(eq_to) def _take(an:UOp, offset:int, count:int=0) -> UOp: """an = two_over_pi_f[i+offset]""" if count+offset <= len(two_over_pi_f[0:-2]): - an = _eq(i, count).where(_take(an, offset, count=count+1), an.const_like(two_over_pi_f[count+offset])) + an = i.ne(count).where(_take(an, offset, count=count+1), an.const_like(two_over_pi_f[count+offset])) return an def _shl_lazy(x, y): return (x.cast(dtypes.uint64) * pow2if(y, input_dtype).cast(dtypes.uint64)).cast(dtypes.uint32) def _shr_lazy(x, y): return (x.cast(dtypes.uint64) // pow2if(y, input_dtype).cast(dtypes.uint64)).cast(dtypes.uint32)