mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 05:27:56 -05:00
955 lines
29 KiB
Python
955 lines
29 KiB
Python
"""
|
|
Module for math operations.
|
|
|
|
Most of the functionality is due to `Aly and Smart
|
|
<https://eprint.iacr.org/2019/354>`_ with some optimizations by
|
|
`Keller and Sun <https://eprint.iacr.org/2022/933>`_.
|
|
|
|
This has to imported explicitly.
|
|
"""
|
|
|
|
|
|
import math
|
|
import operator
|
|
from functools import reduce
|
|
from Compiler import floatingpoint
|
|
from Compiler import types
|
|
from Compiler import comparison
|
|
from Compiler import program
|
|
from Compiler import instructions_base
|
|
from Compiler import library, util
|
|
|
|
# polynomials as enumerated on Hart's book
|
|
##
|
|
# @private
|
|
p_3307 = [1.57079632679489000000000, -0.64596409750624600000000,
|
|
0.07969262624616700000000, -0.00468175413531868000000,
|
|
0.00016044118478735800000, -0.00000359884323520707000,
|
|
0.00000005692172920657320, -0.00000000066880348849204,
|
|
0.00000000000606691056085, -0.00000000000004375295071,
|
|
0.00000000000000025002854]
|
|
##
|
|
# @private
|
|
p_3508 = [1.00000000000000000000, -0.50000000000000000000,
|
|
0.04166666666666667129, -0.00138888888888888873,
|
|
0.00002480158730158702, -0.00000027557319223933,
|
|
0.00000000208767569817, -0.00000000001147074513,
|
|
0.00000000000004779454, -0.00000000000000015612,
|
|
0.00000000000000000040]
|
|
##
|
|
# @private
|
|
p_1045 = [math.log(2) ** i / math.factorial(i) for i in range(100)]
|
|
|
|
p_2508 = [-4.585323876456, 18.351352559641, -51.525644374262,
|
|
111.76784165654, -174.170840774074, 191.731001033848,
|
|
-145.61191979671, 72.650082977468, -21.447349196774,
|
|
2.840799797315]
|
|
|
|
##
|
|
# @private
|
|
p_2524 = [-2.05466671951, -8.8626599391,
|
|
+6.10585199015, +4.81147460989]
|
|
##
|
|
# @private
|
|
q_2524 = [+0.353553425277, +4.54517087629,
|
|
+6.42784209029, +1]
|
|
##
|
|
# @private
|
|
p_5102 = [+21514.05962602441933193254468, +73597.43380288444240814980706,
|
|
+100272.5618306302784970511863, +69439.29750032252337059765503,
|
|
+25858.09739719099025716567793, +5038.63918550126655793779119,
|
|
+460.1588804635351471161727227, +15.08767735870030987717455528,
|
|
+0.07523052818757628444510729539]
|
|
##
|
|
# @private
|
|
q_5102 = [+21514.05962602441933193298234, +80768.78701155924885176713209,
|
|
+122892.6789092784776298743322, +97323.20349053555680260434387,
|
|
+42868.57652046408093184006664, +10401.13491566890057005103878,
|
|
+1289.75056911611097141145955, +68.51937831018968013114024294,
|
|
+1]
|
|
##
|
|
# @private
|
|
p_4737 = [-9338.550897341021522505385079, +43722.68009378241623148489754,
|
|
-86008.12066370804865047446067, +92190.57592175496843898184959,
|
|
-58360.27724533928122075635101, +22081.61324178027161353562222,
|
|
-4805.541226761699661564427739, +542.2148323255220943742314911,
|
|
-24.94928894422502466205102672, 0.2222361619461131578797029272]
|
|
##
|
|
# @private
|
|
q_4737 =[-9338.550897341021522505384935, +45279.10524333925315190231067,
|
|
-92854.24688696401422824346529, +104687.2504366298224257408682,
|
|
-70581.74909396877350961227976, +28972.22947326672977624954443,
|
|
-7044.002024719172700685571406, +935.7104153502806086331621628,
|
|
-56.83369358538071475796209327, 1]
|
|
##
|
|
# @private
|
|
p_4754 = [-6.90859801, +12.85564644, -5.94939208]
|
|
|
|
##
|
|
# @private
|
|
q_4754 = [-6.92529156, +14.20305096, -8.27925501, 1]
|
|
|
|
# all inputs are calcualted in radians hence we need some conversion.
|
|
pi = math.radians(180)
|
|
pi_over_2 = math.radians(90)
|
|
|
|
##
|
|
# truncates values regardless of the input type. (It always rounds down)
|
|
# @param x: coefficient to be truncated.
|
|
#
|
|
# @return truncated sint value of x
|
|
def trunc(x):
|
|
if isinstance(x, types._fix):
|
|
return x.v.right_shift(x.f, x.k, signed=True)
|
|
elif type(x) is types.sfloat:
|
|
v, p, z, s = floatingpoint.FLRound(x, 0)
|
|
#return types.sfloat(v, p, z, s, x.err)
|
|
return types.sfloat(v, p, z, s)
|
|
return x
|
|
|
|
|
|
##
|
|
# evaluates a Polynomial to a given x in a privacy preserving manner.
|
|
# Inputs can be of any kind of register, secret or otherwise.
|
|
#
|
|
# @param p_c: Polynomial coefficients. (Array)
|
|
#
|
|
# @param x: Value to which the polynomial p_c is evaluated to.(register)
|
|
#
|
|
# @return the evaluation of the polynomial. return type depends on inputs.
|
|
def p_eval(p_c, x):
|
|
degree = len(p_c) - 1
|
|
if isinstance(x, types._fix):
|
|
# ignore coefficients smaller than precision
|
|
for c in reversed(p_c):
|
|
if abs(c) < 2 ** -(x.f + 1):
|
|
degree -= 1
|
|
else:
|
|
break
|
|
pre_mults = floatingpoint.PreOpL(lambda a,b,_: a * b,
|
|
[x] * degree)
|
|
local_aggregation = 0
|
|
# Evaluation of the Polynomial
|
|
for i, pre_mult in zip(p_c[1:], pre_mults):
|
|
local_aggregation += pre_mult.mul_no_reduce(i)
|
|
return local_aggregation.reduce_after_mul() + p_c[0]
|
|
|
|
|
|
##
|
|
# reduces the input to [0,90) and returns whether the reduced value is
|
|
# greater than \Pi and greater than Pi over 2
|
|
# @param x: value of any type to be reduced to the [0,90) interval
|
|
#
|
|
# @return w: reduced angle in either fixed or floating point .
|
|
#
|
|
# @return b1: \{0,1\} value. Returns one when reduction to 2*\pi
|
|
# is greater than \pi
|
|
#
|
|
# @return b2: \{0,1\} value. Returns one when reduction to
|
|
# \pi is greater than \pi/2.
|
|
def sTrigSub(x):
|
|
library.get_program().reading('trigonometric functions', 'AS19',
|
|
'Section 4')
|
|
# reduction to 2* \pi
|
|
f = x * (1.0 / (2 * pi))
|
|
f = trunc(f)
|
|
y = x - (f) * x.coerce(2 * pi)
|
|
# reduction to \pi
|
|
b1 = y > pi
|
|
w = b1.if_else(2 * pi - y, y)
|
|
# reduction to \pi/2
|
|
b2 = w > pi_over_2
|
|
w = b2.if_else(pi - w, w)
|
|
# returns scaled angle and boolean flags
|
|
return w, b1, b2
|
|
|
|
# kernel method calls -- they are built in a generic way
|
|
|
|
|
|
##
|
|
# Kernel sin. Returns the sin of a given angle on the [0, \pi/2) interval and
|
|
# adjust the sign in case the angle was reduced on the [0,360) interval
|
|
#
|
|
# @param w: fractional value for an angle on the [0, \pi) interval.
|
|
#
|
|
# @return returns the sin of w.
|
|
def ssin(w, s):
|
|
# calculates the v of w for polynomial evaluation
|
|
v = w * (1.0 / pi_over_2)
|
|
v_2 = v ** 2
|
|
# adjust sign according to the movement in the reduction
|
|
# calculate the sin using polynomial evaluation
|
|
local_sin = s.if_else(-v, v) * p_eval(p_3307, v_2)
|
|
return local_sin
|
|
|
|
|
|
##
|
|
# Kernel cos. Returns the cos of a given angle on the [0.pi/2)
|
|
# interval and adjust
|
|
# the sign in case the angle was reduced on the [0,360) interval.
|
|
#
|
|
# @param w: fractional value for an angle on the [0,\pi) interval.
|
|
#
|
|
# @param s: \{0,1\} value. Corresponding to b2. Returns 1 if the angle
|
|
# was reduced from an angle in the [\pi/2,\pi) interval.
|
|
#
|
|
# @return returns the cos of w (sfix).
|
|
def scos(w, s):
|
|
# calculates the v of the w.
|
|
v = w
|
|
v_2 = v ** 2
|
|
# calculate the cos using polynomial evaluation
|
|
tmp = p_eval(p_3508, v_2)
|
|
# adjust sign according to the movement in the reduction
|
|
local_cos = s.if_else(-tmp, tmp)
|
|
return local_cos
|
|
|
|
|
|
# facade method calls --it is built in a generic way
|
|
|
|
@instructions_base.sfix_cisc
|
|
def sin(x):
|
|
"""
|
|
Returns the sine of any given fractional value.
|
|
|
|
:param x: fractional input (sfix, sfloat)
|
|
|
|
:return: sin of :py:obj:`x` (sfix, sfloat)
|
|
"""
|
|
# reduces the angle to the [0,\pi/2) interval.
|
|
w, b1, b2 = sTrigSub(x)
|
|
# returns the sin with sign correction
|
|
return ssin(w, b1)
|
|
|
|
|
|
@instructions_base.sfix_cisc
|
|
def cos(x):
|
|
"""
|
|
Returns the cosine of any given fractional value.
|
|
|
|
:param x: fractional input (sfix, sfloat)
|
|
|
|
:return: cos of :py:obj:`x` (sfix, sfloat)
|
|
"""
|
|
# reduces the angle to the [0,\pi/2) interval.
|
|
w, b1, b2 = sTrigSub(x)
|
|
|
|
# returns the sin with sign correction
|
|
return scos(w, b2)
|
|
|
|
|
|
@instructions_base.sfix_cisc
|
|
def tan(x):
|
|
"""
|
|
Returns the tangent of any given fractional value.
|
|
|
|
:param x: fractional input (sfix, sfloat)
|
|
|
|
:return: tan of :py:obj:`x` (sfix, sfloat)
|
|
"""
|
|
# reduces the angle to the [0,\pi/2) interval.
|
|
w, b1, b2 = sTrigSub(x)
|
|
# calculates the sin and the cos.
|
|
local_sin = ssin(w, b1)
|
|
local_cos = scos(w, b2)
|
|
# obtains the local tan
|
|
local_tan = local_sin/local_cos
|
|
return local_tan
|
|
|
|
|
|
@types.vectorize
|
|
@instructions_base.sfix_cisc
|
|
def exp2_fx(a, zero_output=False, as19=False):
|
|
"""
|
|
Power of two for fixed-point numbers.
|
|
|
|
:param a: exponent for :math:`2^a` (sfix)
|
|
:param zero_output: whether to output zero for very small values. If not, the result will be undefined.
|
|
|
|
:return: :math:`2^a` if it is within the range. Undefined otherwise
|
|
"""
|
|
library.get_program().reading('exponential', 'AS19', 'Protocol 6')
|
|
def exp_from_parts(whole_exp, frac):
|
|
class my_fix(type(a)):
|
|
pass
|
|
# improve precision
|
|
my_fix.set_precision(a.k - 2, a.k)
|
|
n_shift = a.k - 2 - a.f
|
|
res_k = 2 * a.k - n_shift
|
|
x = my_fix._new(frac.v << n_shift)
|
|
# evaluates fractional part of a in p_1045
|
|
e = p_eval(p_1045, x)
|
|
g = a._new(whole_exp.TruncMul(e.v, 2 * a.k, n_shift,
|
|
nearest=a.round_nearest), res_k, a.f)
|
|
return g
|
|
# how many bits to use from integer part
|
|
n_int_bits = int(math.ceil(math.log(a.k - a.f, 2)))
|
|
n_bits = a.f + n_int_bits
|
|
sint = a.int_type
|
|
if types.program.options.ring and not as19:
|
|
intbitint = types.intbitint
|
|
n_shift = int(types.program.options.ring) - a.k
|
|
if types.program.use_split():
|
|
from Compiler.GC.types import sbitvec
|
|
if types.program.use_split() == 3:
|
|
x = a.v.split_to_two_summands(a.k)
|
|
bits = types._bitint.carry_lookahead_adder(x[0], x[1],
|
|
fewer_inv=False)
|
|
# converting MSB first reduces the number of rounds
|
|
s = sint.conv(bits[-1])
|
|
lower_overflow = sint.conv(x[0][a.f]) + \
|
|
sint.conv(x[0][a.f] ^ x[1][a.f] ^ bits[a.f])
|
|
lower = a.v.raw_mod2m(a.f) - (lower_overflow << a.f)
|
|
elif types.program.use_split() == 4:
|
|
x = list(zip(*a.v.split_to_n_summands(a.k, 4)))
|
|
bi = types._bitint
|
|
red = bi.wallace_reduction
|
|
sums1, carries1 = red(*x[:3], get_carry=False)
|
|
sums2, carries2 = red(x[3], sums1, carries1, False)
|
|
bits = bi.carry_lookahead_adder(sums2, carries2,
|
|
fewer_inv=False)
|
|
overflows = bi.full_adder(carries1[a.f], carries2[a.f],
|
|
bits[a.f] ^ sums2[a.f] ^ carries2[a.f])
|
|
overflows = reversed(list((sint.conv(x)
|
|
for x in reversed(overflows))))
|
|
lower_overflow = sint.bit_compose(sint.conv(x)
|
|
for x in overflows)
|
|
s = sint.conv(bits[-1])
|
|
lower = a.v.raw_mod2m(a.f) - (lower_overflow << a.f)
|
|
else:
|
|
bits = sbitvec(a.v, a.k).v
|
|
s = sint.conv(bits[-1])
|
|
lower = sint.bit_compose(sint.conv(b) for b in bits[:a.f])
|
|
higher_bits = bits[a.f:n_bits]
|
|
bits_to_check = bits[n_bits:-1]
|
|
else:
|
|
if types.program.use_edabit():
|
|
l = sint.get_edabit(a.f, True)
|
|
u = sint.get_edabit(a.k - a.f, True)
|
|
r_bits = l[1] + u[1]
|
|
r = l[0] + (u[0] << a.f)
|
|
lower_r = l[0]
|
|
else:
|
|
r_bits = [sint.get_random_bit() for i in range(a.k)]
|
|
r = sint.bit_compose(r_bits)
|
|
lower_r = sint.bit_compose(r_bits[:a.f])
|
|
shifted = ((a.v - r) << n_shift).reveal(False)
|
|
masked_bits = (shifted >> n_shift).bit_decompose(a.k)
|
|
lower_overflow = comparison.CarryOutRaw(masked_bits[a.f-1::-1],
|
|
r_bits[a.f-1::-1])
|
|
lower_masked = sint.bit_compose(masked_bits[:a.f])
|
|
lower = lower_r + lower_masked - \
|
|
(sint.conv(lower_overflow) << (a.f))
|
|
higher_bits = r_bits[0].bit_adder(r_bits[a.f:n_bits],
|
|
masked_bits[a.f:n_bits],
|
|
carry_in=lower_overflow,
|
|
get_carry=True)
|
|
carry = comparison.CarryOutLE(masked_bits[n_bits:-1],
|
|
r_bits[n_bits:-1],
|
|
higher_bits[-1])
|
|
if zero_output:
|
|
# should be for free
|
|
highest_bits = r_bits[0].ripple_carry_adder(
|
|
masked_bits[n_bits:-1], [0] * (a.k - n_bits - 1),
|
|
carry_in=higher_bits[-1])
|
|
bits_to_check = [x.bit_xor(y)
|
|
for x, y in zip(highest_bits[:-1],
|
|
r_bits[n_bits:-1])]
|
|
# sign
|
|
s = carry.bit_xor(sint.conv(r_bits[-1])).bit_xor(masked_bits[-1])
|
|
del higher_bits[-1]
|
|
else:
|
|
bits = a.v.bit_decompose(a.k, maybe_mixed=True)
|
|
lower = sint.bit_compose(bits[:a.f])
|
|
higher_bits = bits[a.f:n_bits]
|
|
s = a.bit_type.conv(bits[-1])
|
|
bits_to_check = bits[n_bits:-1]
|
|
if not as19:
|
|
c = a._new(lower, k=a.k, f=a.f)
|
|
assert(len(higher_bits) == n_bits - a.f)
|
|
pow2_bits = [sint.conv(x) for x in higher_bits]
|
|
d = floatingpoint.Pow2_from_bits(pow2_bits)
|
|
g = exp_from_parts(d, c)
|
|
small_result = a._new(g.v.round(a.f + 2 ** n_int_bits + 1,
|
|
2 ** n_int_bits, signed=False,
|
|
nearest=a.round_nearest),
|
|
k=a.k, f=a.f)
|
|
if zero_output:
|
|
t = sint.conv(floatingpoint.KOpL(lambda x, y: x.bit_and(y),
|
|
bits_to_check))
|
|
small_result = t.if_else(small_result, 0)
|
|
return s.if_else(small_result, a._new(g.v, k=a.k, f=a.f))
|
|
else:
|
|
assert not zero_output
|
|
# obtain absolute value of a
|
|
s = a < 0
|
|
a = s.if_else(-a, a)
|
|
# isolates fractional part of number
|
|
b = trunc(a)
|
|
c = a - b
|
|
# squares integer part of a
|
|
d = b.pow2(a.k - a.f)
|
|
g = exp_from_parts(d, c)
|
|
return s.if_else(1 / g, g)
|
|
|
|
|
|
def mux_exp(x, y, block_size=8):
|
|
assert util.is_constant_float(x)
|
|
from Compiler.GC.types import sbitvec, sbits
|
|
bits = sbitvec.from_vec(y.v.bit_decompose(y.k, maybe_mixed=True)).v
|
|
sign = bits[-1]
|
|
m = math.log(2 ** (y.k - y.f - 1), x)
|
|
del bits[int(math.ceil(math.log(m, 2))) + y.f:]
|
|
parts = []
|
|
for i in range(0, len(bits), block_size):
|
|
one_hot = sbitvec.from_vec(bits[i:i + block_size]).demux().v
|
|
exp = []
|
|
try:
|
|
for j in range(len(one_hot)):
|
|
exp.append(types.cfix.int_rep(x ** (j * 2 ** (i - y.f)), y.f))
|
|
except OverflowError:
|
|
pass
|
|
exp = list(filter(lambda x: x < 2 ** (y.k - 1), exp))
|
|
bin_part = [0] * max(x.bit_length() for x in exp)
|
|
for j in range(len(bin_part)):
|
|
for k, (a, b) in enumerate(zip(one_hot, exp)):
|
|
bin_part[j] ^= a if util.bit_decompose(b, len(bin_part))[j] \
|
|
else 0
|
|
if util.is_zero(bin_part[j]):
|
|
bin_part[j] = sbits.get_type(y.size)(0)
|
|
if i == 0:
|
|
bin_part[j] = sign.if_else(0, bin_part[j])
|
|
parts.append(y._new(y.int_type(sbitvec.from_vec(bin_part))))
|
|
return util.tree_reduce(operator.mul, parts)
|
|
|
|
|
|
@types.vectorize
|
|
@instructions_base.sfix_cisc
|
|
def log2_fx(x, use_division=True):
|
|
r"""
|
|
Returns the result of :math:`\log_2(x)` for any unbounded
|
|
number. This is achieved by changing :py:obj:`x` into
|
|
:math:`f \cdot 2^n` where f is bounded by :math:`[0.5, 1]`. Then the
|
|
polynomials are used to calculate :math:`\log_2(f)`, which is then
|
|
just added to :math:`n`.
|
|
|
|
:param x: input for :math:`\log_2` (sfix, sint).
|
|
|
|
:return: (sfix) the value of :math:`\log_2(x)`
|
|
|
|
"""
|
|
library.get_program().reading('logarithm', 'AS19', 'Section 5')
|
|
if isinstance(x, types._fix):
|
|
# transforms sfix to f*2^n, where f is [o.5,1] bounded
|
|
# obtain number bounded by [0,5 and 1] by transforming input to sfloat
|
|
v, p, z, s = floatingpoint.Int2FL(x.v, x.k, x.f)
|
|
p -= x.f
|
|
vlen = x.f
|
|
v = x._new(v, k=x.k, f=x.f)
|
|
elif isinstance(x, (types._register, types.cfix)):
|
|
return log2_fx(types.sfix(x), use_division)
|
|
else:
|
|
d = types.sfloat(x)
|
|
v, p, vlen = d.v, d.p, d.vlen
|
|
w = x.coerce(1.0 / (2 ** (vlen)))
|
|
v *= w
|
|
# isolates mantisa of d, now the n can be also substituted by the
|
|
# secret shared p from d in the expresion above.
|
|
# polynomials for the log_2 evaluation of f are calculated
|
|
if use_division:
|
|
P = p_eval(p_2524, v)
|
|
Q = p_eval(q_2524, v)
|
|
approx = P / Q
|
|
else:
|
|
approx = p_eval(p_2508, v)
|
|
# the log is returned by adding the result of the division plus p.
|
|
a = approx + (vlen + p)
|
|
return a # *(1-(f.z))*(1-f.s)*(1-f.error)
|
|
|
|
|
|
def pow_fx(x, y, zero_output=False):
|
|
r"""
|
|
Returns the value of the expression :math:`x^y` where both inputs
|
|
are secret shared. It uses :py:func:`log2_fx` together with
|
|
:py:func:`exp2_fx` to calculate the expression :math:`2^{y \log_2(x)}`.
|
|
|
|
:param x: (sfix) secret shared base.
|
|
|
|
:param y: (sfix, clear types) secret shared exponent.
|
|
|
|
:return: :math:`x^y` (sfix) if positive and in range
|
|
"""
|
|
log2_x =0
|
|
# obtains log2(x)
|
|
if (type(x) == int or type(x) == float):
|
|
log2_x = math.log(x,2)
|
|
else:
|
|
log2_x = log2_fx(x)
|
|
# obtains y * log2(x)
|
|
exp = y * log2_x
|
|
# returns 2^(y*log2(x))
|
|
return exp2_fx(exp, zero_output)
|
|
|
|
|
|
def log_fx(x, b):
|
|
r"""
|
|
Returns the value of the expression :math:`\log_b(x)` where
|
|
:py:obj:`x` is secret shared. It uses :py:func:`log2_fx` to
|
|
calculate the expression :math:`\log_b(2) \cdot \log_2(x)`.
|
|
|
|
:param x: (sfix, sint) secret shared coefficient for log.
|
|
|
|
:param b: (float) base for log operation.
|
|
|
|
:return: (sfix) the value of :math:`log_b(x)`.
|
|
|
|
"""
|
|
# calculates logb(2)
|
|
logb_2 = math.log(2, b)
|
|
# returns logb(2) * log2(x)
|
|
return logb_2 * log2_fx(x)
|
|
|
|
|
|
##
|
|
# Returns the absolute value of a fix point number.
|
|
# The method is also applicable to sfloat,
|
|
# however, more efficient mechanisms can be devised.
|
|
#
|
|
# @param x: (sfix)
|
|
#
|
|
# @return (sfix) unsigned
|
|
def abs_fx(x):
|
|
s = x < 0
|
|
return (1 - 2 * s) * x
|
|
|
|
|
|
##
|
|
# Floors the input and stores the value into a sflix register
|
|
# @param x: coefficient to be floored.
|
|
#
|
|
# @return floored sint value of x
|
|
def floor_fx(x):
|
|
return type(x)(x.v.right_shift(x.f, bit_length=x.k, signed=True),
|
|
k=x.k, f=x.f)
|
|
|
|
|
|
### sqrt methods
|
|
|
|
|
|
##
|
|
# obtains the most significative bit (MSB)
|
|
# of a given input. The size of the vector
|
|
# is tuned to the needs of sqrt.
|
|
# @param b: number from which you obtain the
|
|
# most significative bit.
|
|
# @param k: number of bits for which
|
|
# an output of size (k+1) if even
|
|
# is going to be produced.
|
|
# @return z: index array for MSB of size
|
|
# k or K+1 if even.
|
|
def MSB(b, k):
|
|
# calculation of z
|
|
# x in order 0 - k
|
|
|
|
x_order = b.bit_decompose(k)
|
|
x = [0] * k
|
|
# x i now inverted
|
|
for i in range(k - 1, -1, -1):
|
|
x[k - 1 - i] = x_order[i]
|
|
# y is inverted for PReOR and then restored
|
|
y_order = floatingpoint.PreOR(x)
|
|
|
|
# y in order (restored in orginal order
|
|
y = [0] * k
|
|
for i in range(k - 1, -1, -1):
|
|
y[k - 1 - i] = y_order[i]
|
|
|
|
# obtain z
|
|
z = [0] * (k + 1 - k % 2)
|
|
for i in range(k - 1):
|
|
z[i] = y[i] - y[i + 1]
|
|
z[k - 1] = y[k - 1]
|
|
|
|
return z
|
|
|
|
|
|
##
|
|
# Similar to norm_SQ, saves rounds by not
|
|
# calculating v and c.
|
|
#
|
|
# @param b: sint input to be normalized.
|
|
# @param k: bitsize of the input, by definition
|
|
# its value is either sfix.k or program.bit_lengthh
|
|
# @return m_odd: the parity of most signficative bit index m
|
|
# @return m: index of most significative bit
|
|
# @return w: 2^m/2 or 2^ (m-1) /2
|
|
def norm_simplified_SQ(b, k):
|
|
z = MSB(b, k)
|
|
# construct m
|
|
#m = types.sint(0)
|
|
m_odd = 0
|
|
for i in range(k):
|
|
#m = m + (i + 1) * z[i]
|
|
# determine the parity of the input
|
|
if (i % 2 == 0):
|
|
m_odd = m_odd + z[i]
|
|
|
|
# construct w,
|
|
k_over_2 = k // 2 + 1
|
|
w_array = [0] * (k_over_2)
|
|
w_array[0] = z[0]
|
|
for i in range(1, k_over_2):
|
|
w_array[i] = z[2 * i - 1] + z[2 * i]
|
|
|
|
# w aggregation
|
|
w = b.bit_compose(w_array)
|
|
|
|
# return computed values
|
|
#return m_odd, m, w
|
|
return m_odd, None, w
|
|
|
|
|
|
##
|
|
# Obtains the sqrt using our custom mechanism
|
|
# for any sfix input value.
|
|
# no restrictions on the size of f.
|
|
#
|
|
# @param x: secret shared input from which the sqrt
|
|
# is calucalted,
|
|
#
|
|
# @return g: approximated sqrt
|
|
def sqrt_simplified_fx(x):
|
|
# adapt parameters to fit the algorithm
|
|
f = x.f
|
|
k = x.k
|
|
my_f = max(f, k - f + 1)
|
|
shift = my_f - f
|
|
my_k = k + shift
|
|
assert my_k < 2 * my_f
|
|
x = type(x)._new(x.v << shift, f=my_f, k=my_k)
|
|
|
|
# fix theta (number of iterations)
|
|
theta = max(int(math.ceil(math.log(x.k))), 6)
|
|
|
|
# process to use 2^(m/2) approximation
|
|
m_odd, m, w = norm_simplified_SQ(x.v, x.k)
|
|
# process to set up the precision and allocate correct 2**f
|
|
if x.f % 2 == 1:
|
|
m_odd = (1 - 2 * m_odd) + m_odd
|
|
w = m_odd.if_else(w, 2 * w)
|
|
# map number to use sfix format and instantiate the number
|
|
w = x._new(w << ((x.f - (x.f % 2)) // 2), k=x.k, f=x.f)
|
|
# obtains correct 2 ** (m/2)
|
|
w = (w * (2 ** (1/2.0)) - w) * m_odd + w
|
|
# produce x/ 2^(m/2)
|
|
y_0 = 1 / w
|
|
|
|
# from this point on it sufices to work sfix-wise
|
|
g_0 = (y_0 * x)
|
|
h_0 = y_0 * 0.5
|
|
gh_0 = g_0 * h_0
|
|
|
|
## initialization
|
|
g = g_0
|
|
h = h_0
|
|
gh = gh_0
|
|
|
|
for i in range(1, theta - 2):
|
|
r = (3 / 2.0) - gh
|
|
g = g * r
|
|
h = h * r
|
|
gh = g * h
|
|
|
|
# newton
|
|
r = (3 / 2.0) - gh
|
|
h = h * r
|
|
H = 4 * (h * h)
|
|
|
|
if not x.round_nearest or (2 * x.f < x.k - 1):
|
|
H = (h < 2 ** (-x.f / 2) / 2).if_else(0, H)
|
|
|
|
H = H * x
|
|
H = (3) - H
|
|
H = h * H
|
|
g = H * x
|
|
g = g
|
|
|
|
return type(x)._new((g * 2 ** -shift).v, f=f, k=k)
|
|
|
|
|
|
##
|
|
# Calculates the normSQ of a number
|
|
# @param x: number from which the norm is going to be extracted
|
|
# @param k: bitsize of x
|
|
#
|
|
# @return c: where c = x*v where c is bounded by 2^{k-1} and 2^k
|
|
# @return v: where v = 2^k-m
|
|
# @return m: where m = MSB
|
|
# @return w: where w = 2^{m/2} if m is oeven and 2^{m-1 / 2} otherwise
|
|
def norm_SQ(b, k):
|
|
# calculation of z
|
|
# x in order 0 - k
|
|
z = MSB(b,k)
|
|
# now reverse bits of z[i] to generate v
|
|
v = types.sint(0)
|
|
for i in range(k):
|
|
v += (2**(k - i - 1)) * z[i]
|
|
c = b * v
|
|
|
|
# construct m
|
|
m = types.sint(0)
|
|
for i in range(k):
|
|
m = m + (i+1) * z[i]
|
|
|
|
# construct w, changes from what is on the paper
|
|
# and the documentation
|
|
k_over_2= k/2+1#int(math.ceil((k/2.0)))+1
|
|
w_array = [0]*(k_over_2 )
|
|
w_array[0] = z[0]
|
|
for i in range(1, k_over_2):
|
|
w_array[i] = z[2 * i - 1] + z[2 * i]
|
|
|
|
w = types.sint(0)
|
|
for i in range(k_over_2):
|
|
w += (2 ** i) * w_array[i]
|
|
|
|
# return computed values
|
|
return c, v, m, w
|
|
|
|
|
|
##
|
|
# Given f and k, returns a linear approximation of 1/x^{1/2}
|
|
# escalated by s^f.
|
|
# Method only works for sfix inputs. It uses the normSQ.
|
|
# the method is an implementation of [Liedel2012]
|
|
# @param x: number from which the approximation is caluclated
|
|
# @param k: bitsize of x
|
|
# @param f: precision of the input f
|
|
#
|
|
# @return c: Some approximation of (1/x^{1/2} * 2^f) *K
|
|
# where K is close to 1
|
|
def lin_app_SQ(b, k, f):
|
|
|
|
alpha = types.cfix((-0.8099868542) * 2 ** (k))
|
|
beta = types.cfix(1.787727479 * 2 ** (2 * k))
|
|
|
|
# obtain normSQ parameters
|
|
c, v, m, W = norm_SQ(types.sint(b), k)
|
|
|
|
# c is now escalated
|
|
w = alpha * c + beta # equation before b and reduction by order of k
|
|
|
|
|
|
# m even or odd determination
|
|
m_bit = types.sint()
|
|
comparison.Mod2(m_bit, m, int(math.ceil(math.log(k, 2))), signed=False)
|
|
m = m_bit
|
|
|
|
# w times v this way both terms have 2^3k and can be symplified
|
|
w = w * v
|
|
factor = 1.0 / (2 ** (3.0 * k - 2 * f))
|
|
w = w * factor # w escalated to 3k -2 * f
|
|
# normalization factor W* 1/2 ^{f/2}
|
|
w = w * W * types.cfix(1.0 / (2 ** (f / 2.0)))
|
|
# now we need to elminate an additional root of 2 in case m was odd
|
|
sqr_2 = types.cfix((2 ** (1 / 2.0)))
|
|
w = (1 - m) * w + sqr_2 * w * m
|
|
|
|
return w
|
|
|
|
|
|
##
|
|
# Given bitsize k and precision f, it calulates the square root of x.
|
|
# @param x: number from which the norm is going to be extracted
|
|
# @param k: bitsize of x.
|
|
# @param f: precision of x.
|
|
#
|
|
# @return g: square root of de-scaled input x
|
|
def sqrt_fx(x_l, k, f):
|
|
factor = 1.0 / (2.0 ** f)
|
|
|
|
x = x_l * factor
|
|
|
|
theta = int(math.ceil(math.log(k/5.4)))
|
|
|
|
y_0 = lin_app_SQ(x_l,k,f) #cfix(1.0/ (cx ** (1/2.0))) # lin_app_SQ(x_l,5,2)
|
|
|
|
y_0 = y_0 * factor #*((1.0/(2.0 ** f)))
|
|
g_0 = y_0 * x
|
|
|
|
|
|
#g = mpc_math.load_sint(mpc_math.trunc(g_0),types.sfix)
|
|
h_0 = y_0 *(0.5)
|
|
gh_0 = g_0 * h_0
|
|
|
|
##initialization
|
|
g= g_0
|
|
h= h_0
|
|
gh =gh_0
|
|
|
|
for i in range(1,theta-2): #to implement \in [1,\theta-2]
|
|
r = (3/2.0) - gh
|
|
g = g * r
|
|
h = h * r
|
|
gh = g * h
|
|
|
|
# newton
|
|
r = (3/2.0) - gh
|
|
h = h * r
|
|
H = 4 * (h * h)
|
|
H = H * x
|
|
H = (3) - H
|
|
H = h * H
|
|
g = H * x
|
|
g = g #* (0.5)
|
|
|
|
return g
|
|
|
|
|
|
@types.vectorize
|
|
@instructions_base.sfix_cisc
|
|
def sqrt(x, k=None, f=None):
|
|
"""
|
|
Square root.
|
|
|
|
:param x: fractional input (sfix).
|
|
|
|
:return: square root of :py:obj:`x` (sfix).
|
|
"""
|
|
library.get_program().reading('square root', 'AS19', 'Section 3')
|
|
if k is None:
|
|
k = x.k
|
|
if f is None:
|
|
f = x.f
|
|
if (3 *k -2 * f >= f):
|
|
return sqrt_simplified_fx(x)
|
|
# raise OverflowError("bound for precision violated: 3 * k - 2 * f < x.f ")
|
|
else:
|
|
param = trunc(x *(2 ** (f)))
|
|
return sqrt_fx(param ,k ,f)
|
|
|
|
|
|
@instructions_base.sfix_cisc
|
|
def atan(x):
|
|
"""
|
|
Returns the arctangent (sfix) of any given fractional value.
|
|
|
|
:param x: fractional input (sfix).
|
|
|
|
:return: arctan of :py:obj:`x` (sfix).
|
|
"""
|
|
library.get_program().reading('inverse trigonometric functions', 'AS19',
|
|
'Protocol 5')
|
|
# obtain absolute value of x
|
|
s = x < 0
|
|
x_abs = s.if_else(-x, x)
|
|
# angle isolation
|
|
b = x_abs > 1
|
|
v = 1 / x_abs
|
|
v = b.if_else(v, x_abs)
|
|
v_2 =v*v
|
|
|
|
# range of polynomial coefficients
|
|
m = max(sum(p_5102), sum(q_5102))
|
|
scale = m / (2 ** (x.k - x.f - 1) - 1)
|
|
P = p_eval([c / scale for c in p_5102], v_2)
|
|
Q = p_eval([c / scale for c in q_5102], v_2)
|
|
|
|
# padding
|
|
y = v * (P / Q)
|
|
y_pi_over_two = pi_over_2 - y
|
|
|
|
# sign correction
|
|
y = b.if_else(y_pi_over_two, y)
|
|
y = s.if_else(-y, y)
|
|
|
|
return y
|
|
|
|
|
|
def asin(x):
|
|
r"""
|
|
Returns the arcsine (sfix) of any given fractional value.
|
|
|
|
:param x: fractional input (sfix). valid interval is :math:`-1 \le x \le 1`
|
|
|
|
:return: arcsin of :py:obj:`x` (sfix).
|
|
"""
|
|
# Square x
|
|
x_2 = x*x
|
|
# trignometric identities
|
|
sqrt_l = sqrt(1- (x_2))
|
|
x_sqrt_l =x / sqrt_l
|
|
return atan(x_sqrt_l)
|
|
|
|
|
|
def acos(x):
|
|
r"""
|
|
Returns the arccosine (sfix) of any given fractional value.
|
|
|
|
:param x: fractional input (sfix). :math:`-1 \le x \le 1`
|
|
|
|
:return: arccos of :py:obj:`x` (sfix).
|
|
"""
|
|
y = asin(x)
|
|
return pi_over_2 - y
|
|
|
|
|
|
def tanh(x):
|
|
r"""
|
|
Hyperbolic tangent. For efficiency, accuracy is diminished
|
|
around :math:`\pm \log(k - f - 2) / 2` where :math:`k` and
|
|
:math:`f` denote the fixed-point parameters.
|
|
"""
|
|
limit = math.log(2 ** (x.k - x.f - 2)) / 2
|
|
s = x < -limit
|
|
t = x > limit
|
|
y = pow_fx(math.e, 2 * x)
|
|
return s.if_else(-1, t.if_else(1, (y - 1) / (y + 1)))
|
|
|
|
|
|
# next functions due to https://dl.acm.org/doi/10.1145/3411501.3419427
|
|
|
|
def Sep(x, sfix=types.sfix):
|
|
b = floatingpoint.PreOR(list(reversed(x.v.bit_decompose(x.k, maybe_mixed=True))))
|
|
bb = b[:]
|
|
while len(bb) < 2 * x.f - 1:
|
|
bb.insert(0, type(b[0])(types.cint(0)))
|
|
t = x.v * (1 + x.v.bit_compose(b_i.bit_not()
|
|
for b_i in bb[-2 * x.f + 1:]))
|
|
u = sfix._new(t.right_shift(x.f, 2 * x.k, signed=False))
|
|
b += [b[0].long_one()]
|
|
return u, [b[i + 1] - b[i] for i in reversed(range(x.k))]
|
|
|
|
def SqrtComp(z, old=False, sfix=types.sfix):
|
|
f = sfix.f
|
|
k = len(z)
|
|
if isinstance(z[0], types.sint):
|
|
return sfix._new(sum(z[i] * types.cfix(
|
|
2 ** (-(i - f + 1) / 2), k=k, f=f).v for i in range(k)))
|
|
k_prime = k // 2
|
|
f_prime = f // 2
|
|
c1 = sfix(2 ** ((f + 1) / 2 + 1))
|
|
c0 = sfix(2 ** (f / 2 + 1))
|
|
a = [z[2 * i].bit_or(z[2 * i + 1]) for i in range(k_prime)]
|
|
tmp = sfix._new(types.sint.bit_compose(reversed(a[:2 * f_prime])))
|
|
if old:
|
|
b = sum(types.sint.conv(zi).if_else(i, 0) for i, zi in enumerate(z)) % 2
|
|
else:
|
|
b = util.tree_reduce(lambda x, y: x.bit_xor(y), z[::2])
|
|
return types.sint.conv(b).if_else(c1, c0) * tmp
|
|
|
|
@types.vectorize
|
|
@instructions_base.sfix_cisc
|
|
def InvertSqrt(x, old=False):
|
|
"""
|
|
Reciprocal square root approximation by `Lu et al.
|
|
<https://dl.acm.org/doi/10.1145/3411501.3419427>`_
|
|
"""
|
|
library.get_program().reading('approximate inverse square root', 'LFHH+20')
|
|
class my_sfix(types.sfix):
|
|
f = x.f
|
|
k = x.k
|
|
u, z = Sep(x, sfix=my_sfix)
|
|
c = 3.14736 + u * (4.63887 * u - 5.77789)
|
|
return c * SqrtComp(z, old=old, sfix=my_sfix)
|