Fix bug in exponentiation.

This commit is contained in:
Marcel Keller
2025-12-12 15:19:59 +11:00
parent b47c9bb6f8
commit ce83a3708c
3 changed files with 11 additions and 6 deletions

View File

@@ -1625,10 +1625,14 @@ class sbitfixvec(_fix, _vec, _binary):
"""
return cls._new(cls.int_type.get_input_from(player, size=size,
f=cls.f))
def __init__(self, value=None, *args, **kwargs):
def __init__(self, value=None, k=None, *args, **kwargs):
if isinstance(value, (list, tuple)):
self.v = self.int_type.from_vec(sbitvec([x.v for x in value]))
super(sbitfixvec, self).__init__(None, k=k, *args, **kwargs)
self.int_type = sbitintvec.get_type(self.k)
self.v = self.int_type.from_vec(sbitvec([x.v for x in value]).v)
else:
self.k = k or self.k
self.int_type = sbitintvec.get_type(self.k)
if isinstance(value, sbitvec):
value = self.int_type(value)
super(sbitfixvec, self).__init__(value, *args, **kwargs)

View File

@@ -274,11 +274,12 @@ def exp2_fx(a, zero_output=False, as19=False):
# improve precision
my_fix.set_precision(a.k - 2, a.k)
n_shift = a.k - 2 - a.f
res_k = 2 * a.k - n_shift
x = my_fix._new(frac.v << n_shift)
# evaluates fractional part of a in p_1045
e = p_eval(p_1045, x)
g = a._new(whole_exp.TruncMul(e.v, 2 * a.k, n_shift,
nearest=a.round_nearest), a.k, a.f)
nearest=a.round_nearest), res_k, a.f)
return g
# how many bits to use from integer part
n_int_bits = int(math.ceil(math.log(a.k - a.f, 2)))
@@ -368,7 +369,7 @@ def exp2_fx(a, zero_output=False, as19=False):
pow2_bits = [sint.conv(x) for x in higher_bits]
d = floatingpoint.Pow2_from_bits(pow2_bits)
g = exp_from_parts(d, c)
small_result = a._new(g.v.round(a.f + 2 ** n_int_bits,
small_result = a._new(g.v.round(a.f + 2 ** n_int_bits + 1,
2 ** n_int_bits, signed=False,
nearest=a.round_nearest),
k=a.k, f=a.f)
@@ -376,7 +377,7 @@ def exp2_fx(a, zero_output=False, as19=False):
t = sint.conv(floatingpoint.KOpL(lambda x, y: x.bit_and(y),
bits_to_check))
small_result = t.if_else(small_result, 0)
return s.if_else(small_result, g)
return s.if_else(small_result, a._new(g.v, k=a.k, f=a.f))
else:
assert not zero_output
# obtain absolute value of a

View File

@@ -4744,7 +4744,7 @@ class _fix(_single):
@classmethod
def _new(cls, other, k=None, f=None):
res = cls(k=k, f=f, initialize=False)
res.v = cls.int_type.conv(other)
res.v = res.int_type.conv(other)
return res
@vectorize_init