mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 13:37:58 -05:00
Fix bug in exponentiation.
This commit is contained in:
@@ -1625,10 +1625,14 @@ class sbitfixvec(_fix, _vec, _binary):
|
||||
"""
|
||||
return cls._new(cls.int_type.get_input_from(player, size=size,
|
||||
f=cls.f))
|
||||
def __init__(self, value=None, *args, **kwargs):
|
||||
def __init__(self, value=None, k=None, *args, **kwargs):
|
||||
if isinstance(value, (list, tuple)):
|
||||
self.v = self.int_type.from_vec(sbitvec([x.v for x in value]))
|
||||
super(sbitfixvec, self).__init__(None, k=k, *args, **kwargs)
|
||||
self.int_type = sbitintvec.get_type(self.k)
|
||||
self.v = self.int_type.from_vec(sbitvec([x.v for x in value]).v)
|
||||
else:
|
||||
self.k = k or self.k
|
||||
self.int_type = sbitintvec.get_type(self.k)
|
||||
if isinstance(value, sbitvec):
|
||||
value = self.int_type(value)
|
||||
super(sbitfixvec, self).__init__(value, *args, **kwargs)
|
||||
|
||||
@@ -274,11 +274,12 @@ def exp2_fx(a, zero_output=False, as19=False):
|
||||
# improve precision
|
||||
my_fix.set_precision(a.k - 2, a.k)
|
||||
n_shift = a.k - 2 - a.f
|
||||
res_k = 2 * a.k - n_shift
|
||||
x = my_fix._new(frac.v << n_shift)
|
||||
# evaluates fractional part of a in p_1045
|
||||
e = p_eval(p_1045, x)
|
||||
g = a._new(whole_exp.TruncMul(e.v, 2 * a.k, n_shift,
|
||||
nearest=a.round_nearest), a.k, a.f)
|
||||
nearest=a.round_nearest), res_k, a.f)
|
||||
return g
|
||||
# how many bits to use from integer part
|
||||
n_int_bits = int(math.ceil(math.log(a.k - a.f, 2)))
|
||||
@@ -368,7 +369,7 @@ def exp2_fx(a, zero_output=False, as19=False):
|
||||
pow2_bits = [sint.conv(x) for x in higher_bits]
|
||||
d = floatingpoint.Pow2_from_bits(pow2_bits)
|
||||
g = exp_from_parts(d, c)
|
||||
small_result = a._new(g.v.round(a.f + 2 ** n_int_bits,
|
||||
small_result = a._new(g.v.round(a.f + 2 ** n_int_bits + 1,
|
||||
2 ** n_int_bits, signed=False,
|
||||
nearest=a.round_nearest),
|
||||
k=a.k, f=a.f)
|
||||
@@ -376,7 +377,7 @@ def exp2_fx(a, zero_output=False, as19=False):
|
||||
t = sint.conv(floatingpoint.KOpL(lambda x, y: x.bit_and(y),
|
||||
bits_to_check))
|
||||
small_result = t.if_else(small_result, 0)
|
||||
return s.if_else(small_result, g)
|
||||
return s.if_else(small_result, a._new(g.v, k=a.k, f=a.f))
|
||||
else:
|
||||
assert not zero_output
|
||||
# obtain absolute value of a
|
||||
|
||||
@@ -4744,7 +4744,7 @@ class _fix(_single):
|
||||
@classmethod
|
||||
def _new(cls, other, k=None, f=None):
|
||||
res = cls(k=k, f=f, initialize=False)
|
||||
res.v = cls.int_type.conv(other)
|
||||
res.v = res.int_type.conv(other)
|
||||
return res
|
||||
|
||||
@vectorize_init
|
||||
|
||||
Reference in New Issue
Block a user