Files
MP-SPDZ/Compiler/non_linear.py
2025-12-24 13:47:42 +11:00

195 lines
5.9 KiB
Python

from .comparison import *
from .floatingpoint import *
from .types import *
from . import comparison, program
class NonLinear:
def mod2m(self, a, k, m, signed):
"""
a_prime = a % 2^m
k: bit length of a
m: compile-time integer
signed: True/False, describes a
"""
if not util.is_constant(m):
raise CompilerError('m must be a public constant')
if m >= k:
return a
else:
return self._mod2m(a, k, m, signed)
def trunc_pr(self, a, k, m, signed=True):
if isinstance(a, types.cint):
return shift_two(a, m)
prog = program.Program.prog
if prog.use_trunc_pr and m and (
not prog.options.ring or \
prog.use_trunc_pr <= (int(prog.options.ring) - k)):
prog.reading('probabilistic truncation', 'DEK20', 'Section 3.2.2')
if prog.options.ring:
comparison.require_ring_size(k, 'truncation')
else:
prog.curr_tape.require_bit_length(k + prog.security)
if not signed:
a -= (1 << (k - 1))
res = sint()
trunc_pr(res, a, k, m)
if not signed:
res += (1 << (k - m - 1))
return res
return self._trunc_pr(a, k, m, signed)
def trunc_round_nearest(self, a, k, m, signed):
res = sint()
comparison.Trunc(res, a + (1 << (m - 1)), k + 1, m, signed)
return res
def trunc(self, a, k, m, signed):
if m == 0:
return a
return self._trunc(a, k, m, signed)
def ltz(self, a, k):
return -self.trunc(a, k, k - 1, True)
class Masking(NonLinear):
def eqz(self, a, k):
c, r = self._mask(a, k)
d = [None]*k
for i,b in enumerate(r[0].bit_decompose_clear(c, k)):
d[i] = r[i].bit_xor(b)
return 1 - types.sintbit.conv(self.kor(d))
class Prime(Masking):
""" Non-linear functionality modulo a prime with statistical masking. """
def _mod2m(self, a, k, m, signed):
res = sint()
if m == 1:
Mod2(res, a, k, signed)
else:
Mod2mField(res, a, k, m, signed)
return res
def _mask(self, a, k):
return maskField(a, k)
def _trunc_pr(self, a, k, m, signed=None):
return TruncPrField(a, k, m)
def _trunc(self, a, k, m, signed=None):
a_prime = self.mod2m(a, k, m, signed)
tmp = cint()
inv2m(tmp, m)
return (a - a_prime) * tmp
def bit_dec(self, a, k, m, maybe_mixed=False):
if maybe_mixed:
return BitDecFieldRaw(a, k, m)
else:
return BitDecField(a, k, m)
def kor(self, d):
return KOR(d)
def require_bit_length(self, bit_length, op, slack=0):
prog = program.Program.prog
bit_length += slack * (not prog.allow_tight_parameters)
if bit_length > 32:
prog.curr_tape.require_bit_length(bit_length - 1, reason=op)
class KnownPrime(NonLinear):
""" Non-linear functionality modulo a prime known at compile time. """
def __init__(self, prime):
self.prime = prime
def _mod2m(self, a, k, m, signed):
if signed:
a += cint(1) << (k - 1)
prog = program.Program.prog
return sint.bit_compose(self.bit_dec(a, k, m, prog.use_edabit()))
def _trunc_pr(self, a, k, m, signed):
# nearest truncation
return self.trunc_round_nearest(a, k, m, signed)
def _trunc(self, a, k, m, signed=None):
return TruncZeros(a - self._mod2m(a, k, m, signed), k, m, signed)
def trunc_round_nearest(self, a, k, m, signed):
a += cint(1) << (m - 1)
if signed:
a += cint(1) << (k - 1)
k += 1
res = self._trunc(a, k, m, False)
if signed:
res -= cint(1) << (k - m - 2)
return res
def bit_dec(self, a, k, m, maybe_mixed=False):
assert k <= self.prime.bit_length()
bits = BitDecFull(a, m, maybe_mixed=maybe_mixed)
assert len(bits) == m
return bits
def eqz(self, a, k):
# always signed
a += two_power(k)
prog = program.Program.prog
return 1 - types.sintbit.conv(KORL(
self.bit_dec(a, k, k, prog.use_edabit())))
def ltz(self, a, k):
if k + 1 < self.prime.bit_length():
# https://dl.acm.org/doi/10.1145/3474123.3486757
# "negative" values wrap around when doubling, thus becoming odd
return self.mod2m(2 * a, k + 1, 1, False)
else:
return super(KnownPrime, self).ltz(a, k)
def require_bit_length(self, *args, **kwargs):
pass
class Ring(Masking):
""" Non-linear functionality modulo a power of two known at compile time.
"""
def __init__(self, ring_size):
self.ring_size = ring_size
def _mod2m(self, a, k, m, signed):
res = sint()
Mod2mRing(res, a, k, m, signed)
return res
def _mask(self, a, k):
return maskRing(a, k)
def _trunc_pr(self, a, k, m, signed):
return TruncPrRing(a, k, m, signed=signed)
def _trunc(self, a, k, m, signed=None):
return comparison.TruncRing(None, a, k, m, signed=signed)
def bit_dec(self, a, k, m, maybe_mixed=False):
if maybe_mixed:
return BitDecRingRaw(a, k, m)
else:
return BitDecRing(a, k, m)
def kor(self, d):
return KORL(d)
def trunc_round_nearest(self, a, k, m, signed):
if k == self.ring_size:
# cannot work with bit length k+1
tmp = TruncRing(None, a, k, m - 1, signed)
return TruncRing(None, tmp + 1, k - m + 1, 1, signed)
else:
return super(Ring, self).trunc_round_nearest(a, k, m, signed)
def ltz(self, a, k):
return LtzRing(a, k)
def require_bit_length(self, *args, **kwargs):
comparison.require_ring_size(*args, **kwargs)