Files
MP-SPDZ/Compiler/non_linear.py
2021-07-02 15:50:34 +10:00

152 lines
4.3 KiB
Python

from .comparison import *
from .floatingpoint import *
from .types import *
from . import comparison
class NonLinear:
kappa = None
def set_security(self, kappa):
pass
def check_security(self, kappa):
pass
def mod2m(self, a, k, m, signed):
"""
a_prime = a % 2^m
k: bit length of a
m: compile-time integer
signed: True/False, describes a
"""
if not util.is_constant(m):
raise CompilerError('m must be a public constant')
if m >= k:
return a
else:
return self._mod2m(a, k, m, signed)
def trunc_pr(self, a, k, m, signed=True):
if isinstance(a, types.cint):
return shift_two(a, m)
return self._trunc_pr(a, k, m, signed)
def trunc_round_nearest(self, a, k, m, signed):
res = sint()
comparison.Trunc(res, a + (1 << (m - 1)), k + 1, m, self.kappa,
signed)
return res
class Masking(NonLinear):
def eqz(self, a, k):
c, r = self._mask(a, k)
d = [None]*k
for i,b in enumerate(r[0].bit_decompose_clear(c, k)):
d[i] = r[i].bit_xor(b)
return 1 - types.sintbit.conv(self.kor(d))
class Prime(Masking):
""" Non-linear functionality modulo a prime with statistical masking. """
def __init__(self, kappa):
self.set_security(kappa)
def set_security(self, kappa):
self.kappa = kappa
def check_security(self, kappa):
assert self.kappa == kappa or kappa is None
def _mod2m(self, a, k, m, signed):
res = sint()
if m == 1:
Mod2(res, a, k, self.kappa, signed)
else:
Mod2mField(res, a, k, m, self.kappa, signed)
return res
def _mask(self, a, k):
return maskField(a, k, self.kappa)
def _trunc_pr(self, a, k, m, signed=None):
return TruncPrField(a, k, m, self.kappa)
def bit_dec(self, a, k, m, maybe_mixed=False):
if maybe_mixed:
return BitDecFieldRaw(a, k, m, self.kappa)
else:
return BitDecField(a, k, m, self.kappa)
def kor(self, d):
return KOR(d, self.kappa)
class KnownPrime(NonLinear):
""" Non-linear functionality modulo a prime known at compile time. """
def __init__(self, prime):
self.prime = prime
def _mod2m(self, a, k, m, signed):
if signed:
a += cint(1) << (k - 1)
return sint.bit_compose(self.bit_dec(a, k, k, True)[:m])
def _trunc_pr(self, a, k, m, signed):
# nearest truncation
return self.trunc_round_nearest(a, k, m, signed)
def trunc_round_nearest(self, a, k, m, signed):
a += cint(1) << (m - 1)
if signed:
a += cint(1) << (k - 1)
k += 1
res = sint.bit_compose(self.bit_dec(a, k, k, True)[m:])
if signed:
res -= cint(1) << (k - m - 2)
return res
def bit_dec(self, a, k, m, maybe_mixed=False):
assert k < self.prime.bit_length()
bits = BitDecFull(a, maybe_mixed=maybe_mixed)
if len(bits) < m:
raise CompilerError('%d has fewer than %d bits' % (self.prime, m))
return bits[:m]
def eqz(self, a, k):
# always signed
a += two_power(k)
return 1 - types.sintbit.conv(KORL(self.bit_dec(a, k, k, True)))
class Ring(Masking):
""" Non-linear functionality modulo a power of two known at compile time.
"""
def __init__(self, ring_size):
self.ring_size = ring_size
def _mod2m(self, a, k, m, signed):
res = sint()
Mod2mRing(res, a, k, m, signed)
return res
def _mask(self, a, k):
return maskRing(a, k)
def _trunc_pr(self, a, k, m, signed):
return TruncPrRing(a, k, m, signed=signed)
def bit_dec(self, a, k, m, maybe_mixed=False):
if maybe_mixed:
return BitDecRingRaw(a, k, m)
else:
return BitDecRing(a, k, m)
def kor(self, d):
return KORL(d)
def trunc_round_nearest(self, a, k, m, signed):
if k == self.ring_size:
# cannot work with bit length k+1
tmp = TruncRing(None, a, k, m - 1, signed)
return TruncRing(None, tmp + 1, k - m + 1, 1, signed)
else:
return super(Ring, self).trunc_round_nearest(a, k, m, signed)