mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-08 21:18:03 -05:00
Maintenance.
This commit is contained in:
@@ -15,7 +15,7 @@ int AndJob::run()
|
||||
#endif
|
||||
__m128i* prf_output = new __m128i[PAD_TO_8(ProgramParty::s().get_n_parties())];
|
||||
auto gate = gates.begin();
|
||||
vector< GC::Secret<EvalRegister> >& S = *this->S;
|
||||
auto& S = *this->S;
|
||||
const vector<int>& args = *this->args;
|
||||
int i_gate = 0;
|
||||
for (size_t i = start; i < end; i += 4)
|
||||
|
||||
@@ -15,7 +15,7 @@ using namespace std;
|
||||
|
||||
class AndJob
|
||||
{
|
||||
vector< GC::Secret<EvalRegister> >* S;
|
||||
StackedVector< GC::Secret<EvalRegister> >* S;
|
||||
const vector<int>* args;
|
||||
|
||||
public:
|
||||
@@ -25,7 +25,7 @@ public:
|
||||
|
||||
AndJob() : S(0), args(0), start(0), end(0), gate_id(0) {}
|
||||
|
||||
void reset(vector<GC::Secret<EvalRegister> >& S, const vector<int>& args,
|
||||
void reset(StackedVector<GC::Secret<EvalRegister> >& S, const vector<int>& args,
|
||||
size_t start, gate_id_t gate_id, size_t n_gates, int n_parties)
|
||||
{
|
||||
this->S = &S;
|
||||
|
||||
17
CHANGELOG.md
17
CHANGELOG.md
@@ -1,5 +1,22 @@
|
||||
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
|
||||
|
||||
## 0.3.9 (July 9, 2024)
|
||||
|
||||
- Inference with non-sequential PyTorch networks
|
||||
- SHA-3 for any input length (@hiddely)
|
||||
- Improved client facilities
|
||||
- Shuffling with malicious security for SPDZ-wise protocols by [Asharov et al.](https://ia.cr/2022/1595)
|
||||
- More reusable bytecode via in-thread calling facility
|
||||
- Recursive functions without return values
|
||||
- Fewer rounds for parallel matrix multiplications (@vincent-ehrmanntraut)
|
||||
- Optimized usage of SoftSpokenOT in semi-honest protocols
|
||||
- More integrity checks on storage in MAC-based protocols
|
||||
- Use C++17
|
||||
- Use glibc 2.18 for the binaries
|
||||
- Fixed security bugs: remotely caused buffer overflows (#1382)
|
||||
- Fixed security bug: Missing randomization before revealing to client
|
||||
- Fixed security bug: Bias in Rep3 secure shuffling
|
||||
|
||||
## 0.3.8 (December 14, 2023)
|
||||
|
||||
- Functionality for multiple nodes per party
|
||||
|
||||
2
CONFIG
2
CONFIG
@@ -106,7 +106,7 @@ else
|
||||
BOOST = -lboost_thread $(MY_BOOST)
|
||||
endif
|
||||
|
||||
CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -I$(ROOT)/deps -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SSL_DIR) $(SECURE) -std=c++11 -Werror
|
||||
CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -I$(ROOT)/deps -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SSL_DIR) $(SECURE) -std=c++17 -Werror
|
||||
CFLAGS += $(BREW_CFLAGS)
|
||||
CPPFLAGS = $(CFLAGS)
|
||||
LD = $(CXX)
|
||||
|
||||
@@ -203,8 +203,10 @@ class andrsvec(base.VarArgsInstruction, base.Mergeable,
|
||||
def add_usage(self, req_node):
|
||||
for i, n in self.bases(iter(self.args)):
|
||||
size = self.args[i + 1]
|
||||
req_node.increment(('bit', 'triple'), size * (n - 3) // 2)
|
||||
req_node.increment(('bit', 'mixed'), size)
|
||||
n = (n - 3) // 2
|
||||
req_node.increment(('bit', 'triple'), size * n)
|
||||
if n > 1:
|
||||
req_node.increment(('bit', 'mixed'), size * ((n + 63) // 64))
|
||||
|
||||
def copy(self, size, subs):
|
||||
return type(self)(*self.get_new_args(size, subs))
|
||||
|
||||
@@ -13,7 +13,7 @@ from Compiler.types import _bitint, _number, _fix, _structure, _bit, _vec, sint,
|
||||
from Compiler.types import vectorized_classmethod
|
||||
from Compiler.program import Tape, Program
|
||||
from Compiler.exceptions import *
|
||||
from Compiler import util, oram, floatingpoint, library
|
||||
from Compiler import util, oram, floatingpoint, library, comparison
|
||||
from Compiler import instructions_base
|
||||
import Compiler.GC.instructions as inst
|
||||
import operator
|
||||
@@ -21,6 +21,11 @@ import math
|
||||
import itertools
|
||||
from functools import reduce
|
||||
|
||||
class _binary:
|
||||
def reveal_to(self, *args, **kwargs):
|
||||
raise CompilerError(
|
||||
'%s does not support revealing to indivual players' % type(self))
|
||||
|
||||
class bits(Tape.Register, _structure, _bit):
|
||||
n = 40
|
||||
unit = 64
|
||||
@@ -149,6 +154,12 @@ class bits(Tape.Register, _structure, _bit):
|
||||
self.n = n
|
||||
def set_size(self, size):
|
||||
pass
|
||||
def load_int(self, value):
|
||||
n_limbs = math.ceil(self.n / self.unit)
|
||||
for i in range(n_limbs):
|
||||
self.conv_regint(min(self.unit, self.n - i * self.unit),
|
||||
self[i], regint(value % 2 ** self.unit))
|
||||
value >>= self.unit
|
||||
def load_other(self, other):
|
||||
if isinstance(other, cint):
|
||||
assert(self.n == other.size)
|
||||
@@ -236,12 +247,14 @@ class bits(Tape.Register, _structure, _bit):
|
||||
return res
|
||||
def if_else(self, x, y):
|
||||
"""
|
||||
Vectorized oblivious selection::
|
||||
Bit-wise oblivious selection::
|
||||
|
||||
sb32 = sbits.get_type(32)
|
||||
print_ln('%s', sb32(3).if_else(sb32(5), sb32(2)).reveal())
|
||||
|
||||
This will output 1.
|
||||
This will output 1 because it selects the two least
|
||||
significant bits from 5 and the rest of the bits from 2.
|
||||
|
||||
"""
|
||||
return result_conv(x, y)(self & (x ^ y) ^ y)
|
||||
def zero_if_not(self, condition):
|
||||
@@ -268,6 +281,9 @@ class bits(Tape.Register, _structure, _bit):
|
||||
self.bit_compose(source.bit_decompose()[base:base + size]))
|
||||
def vector_size(self):
|
||||
return self.n
|
||||
@staticmethod
|
||||
def size_for_mem():
|
||||
return 1
|
||||
|
||||
class cbits(bits):
|
||||
""" Clear bits register. Helper type with limited functionality. """
|
||||
@@ -302,13 +318,6 @@ class cbits(bits):
|
||||
else:
|
||||
return super(cbits, cls).conv(other)
|
||||
types = {}
|
||||
def load_int(self, value):
|
||||
n_limbs = math.ceil(self.n / self.unit)
|
||||
tmp = regint(size=n_limbs)
|
||||
for i in range(n_limbs):
|
||||
tmp[i].load_int(value % 2 ** self.unit)
|
||||
value >>= self.unit
|
||||
self.load_other(tmp)
|
||||
def store_in_dynamic_mem(self, address):
|
||||
inst.stmsdci(self, cbits.conv(address))
|
||||
def clear_op(self, other, c_inst, ci_inst, op):
|
||||
@@ -502,11 +511,7 @@ class sbits(bits):
|
||||
if self.n <= 32:
|
||||
inst.ldbits(self, self.n, value)
|
||||
else:
|
||||
size = math.ceil(self.n / self.unit)
|
||||
tmp = regint(size=size)
|
||||
for i in range(size):
|
||||
tmp[i].load_int((value >> (i * 64)) % 2**64)
|
||||
self.load_other(tmp)
|
||||
bits.load_int(self, value)
|
||||
def load_other(self, other):
|
||||
if isinstance(other, cbits) and self.n == other.n:
|
||||
inst.convcbit2s(self.n, self, other)
|
||||
@@ -675,7 +680,7 @@ class sbits(bits):
|
||||
def ripple_carry_adder(*args, **kwargs):
|
||||
return sbitint.ripple_carry_adder(*args, **kwargs)
|
||||
|
||||
class sbitvec(_vec, _bit):
|
||||
class sbitvec(_vec, _bit, _binary):
|
||||
""" Vector of registers of secret bits, effectively a matrix of secret bits.
|
||||
This facilitates parallel arithmetic operations in binary circuits.
|
||||
Container types are not supported, use :py:obj:`sbitvec.get_type` for that.
|
||||
@@ -732,15 +737,16 @@ class sbitvec(_vec, _bit):
|
||||
:py:obj:`v` and the columns by calling :py:obj:`elements`.
|
||||
"""
|
||||
class sbitvecn(cls, _structure):
|
||||
@staticmethod
|
||||
def malloc(size, creator_tape=None):
|
||||
return sbit.malloc(size * n, creator_tape=creator_tape)
|
||||
@classmethod
|
||||
def malloc(cls, size, creator_tape=None):
|
||||
return sbit.malloc(
|
||||
size * cls.mem_size(), creator_tape=creator_tape)
|
||||
@staticmethod
|
||||
def n_elements():
|
||||
return 1
|
||||
@staticmethod
|
||||
def mem_size():
|
||||
return n
|
||||
return sbits.get_type(n).mem_size()
|
||||
@classmethod
|
||||
def get_input_from(cls, player, size=1, f=0):
|
||||
""" Secret input from :py:obj:`player`. The input is decomposed
|
||||
@@ -780,38 +786,28 @@ class sbitvec(_vec, _bit):
|
||||
self.v = sbits.get_type(n)(other).bit_decompose()
|
||||
assert len(self.v) == n
|
||||
assert size is None or size == self.v[0].n
|
||||
@vectorized_classmethod
|
||||
def load_mem(cls, address):
|
||||
size = instructions_base.get_global_vector_size()
|
||||
if size not in (None, 1):
|
||||
assert isinstance(address, int) or len(address) == 1
|
||||
sb = sbits.get_type(size)
|
||||
return cls.from_vec(sb.bit_compose(
|
||||
sbit.load_mem(address + i + j * n) for j in range(size))
|
||||
for i in range(n))
|
||||
if not isinstance(address, int):
|
||||
v = [sbit.load_mem(x, size=n).v[0] for x in address]
|
||||
return cls(v)
|
||||
@classmethod
|
||||
def load_mem(cls, address, size=None):
|
||||
if isinstance(address, int) or len(address) == 1:
|
||||
address = [address + i for i in range(size or 1)]
|
||||
else:
|
||||
return cls.from_vec(sbit.load_mem(address + i)
|
||||
for i in range(n))
|
||||
assert size == None
|
||||
return cls(
|
||||
[sbits.get_type(n).load_mem(x) for x in address])
|
||||
def store_in_mem(self, address):
|
||||
size = 1
|
||||
for x in self.v:
|
||||
if not util.is_constant(x):
|
||||
size = max(size, x.n)
|
||||
v = [sbits.get_type(size).conv(x) for x in self.v]
|
||||
if not isinstance(address, int) and len(address) != 1:
|
||||
v = self.elements()
|
||||
assert len(v) == len(address)
|
||||
for x, y in zip(v, address):
|
||||
for i, xx in enumerate(x.bit_decompose(n)):
|
||||
xx.store_in_mem(y + i)
|
||||
if isinstance(address, int):
|
||||
address = range(address, address + size)
|
||||
elif len(address) == 1:
|
||||
address = [address + i * self.mem_size()
|
||||
for i in range(size)]
|
||||
else:
|
||||
assert isinstance(address, int) or len(address) == 1
|
||||
for i in range(n):
|
||||
for j, x in enumerate(v[i].bit_decompose()):
|
||||
x.store_in_mem(address + i + j * n)
|
||||
assert size == len(address)
|
||||
for x, dest in zip(self.elements(), address):
|
||||
x.store_in_mem(dest)
|
||||
@classmethod
|
||||
def two_power(cls, nn, size=1):
|
||||
return cls.from_vec(
|
||||
@@ -864,7 +860,7 @@ class sbitvec(_vec, _bit):
|
||||
assert isinstance(elements, sint)
|
||||
if Program.prog.use_split():
|
||||
x = elements.split_to_two_summands(length)
|
||||
v = sbitint.carry_lookahead_adder(x[0], x[1], fewer_inv=True)
|
||||
v = sbitint.bit_adder(x[0], x[1])
|
||||
else:
|
||||
prog = Program.prog
|
||||
if not prog.options.ring:
|
||||
@@ -877,6 +873,7 @@ class sbitvec(_vec, _bit):
|
||||
length, prog.security)
|
||||
prog.use_edabit(backup)
|
||||
return
|
||||
comparison.require_ring_size(length, 'A2B conversion')
|
||||
l = int(Program.prog.options.ring)
|
||||
r, r_bits = sint.get_edabit(length, size=elements.size)
|
||||
c = ((elements - r) << (l - length)).reveal()
|
||||
@@ -885,6 +882,8 @@ class sbitvec(_vec, _bit):
|
||||
x = sbitintvec.from_vec(r_bits) + sbitintvec.from_vec(cb)
|
||||
v = x.v
|
||||
self.v = v[:length]
|
||||
elif isinstance(elements, sbitvec):
|
||||
self.v = elements.v
|
||||
elif elements is not None and not (util.is_constant(elements) and \
|
||||
elements == 0):
|
||||
self.v = sbits.trans(elements)
|
||||
@@ -1347,13 +1346,19 @@ class sbitintvec(sbitvec, _bitint, _number, _sbitintbase):
|
||||
def __add__(self, other):
|
||||
if util.is_zero(other):
|
||||
return self
|
||||
a, b = self.expand(other)
|
||||
try:
|
||||
a, b = self.expand(other)
|
||||
except:
|
||||
return NotImplemented
|
||||
v = sbitint.bit_adder(a, b)
|
||||
return self.get_type(len(v)).from_vec(v)
|
||||
__radd__ = __add__
|
||||
__sub__ = _bitint.__sub__
|
||||
def __rsub__(self, other):
|
||||
a, b = self.expand(other)
|
||||
try:
|
||||
a, b = self.expand(other)
|
||||
except:
|
||||
return NotImplemented
|
||||
return self.from_vec(b) - self.from_vec(a)
|
||||
def __mul__(self, other):
|
||||
if isinstance(other, sbits):
|
||||
@@ -1447,7 +1452,7 @@ class cbitfix(object):
|
||||
inst.print_float_plainb(v, cbits.get_type(32)(-self.f), cbits(0),
|
||||
cbits(0), cbits(0))
|
||||
|
||||
class sbitfix(_fix):
|
||||
class sbitfix(_fix, _binary):
|
||||
""" Secret signed fixed-point number in one binary register.
|
||||
Use :py:obj:`set_precision()` to change the precision.
|
||||
|
||||
@@ -1515,7 +1520,7 @@ class sbitfix(_fix):
|
||||
cls.set_precision(f, k)
|
||||
return cls._new(cls.int_type(other), k, f)
|
||||
|
||||
class sbitfixvec(_fix, _vec):
|
||||
class sbitfixvec(_fix, _vec, _binary):
|
||||
""" Vector of fixed-point numbers for parallel binary computation.
|
||||
|
||||
Use :py:obj:`set_precision()` to change the precision.
|
||||
|
||||
@@ -76,7 +76,7 @@ class AllocRange:
|
||||
self.top += size
|
||||
self.limit = max(self.limit, self.top)
|
||||
if res >= REG_MAX:
|
||||
raise RegisterOverflowError()
|
||||
raise RegisterOverflowError(size)
|
||||
return res
|
||||
|
||||
def free(self, base, size):
|
||||
@@ -209,7 +209,8 @@ class StraightlineAllocator:
|
||||
for x in itertools.chain(dup.duplicates, base.duplicates):
|
||||
to_check.add(x)
|
||||
|
||||
if reg not in self.program.base_addresses:
|
||||
if reg not in self.program.base_addresses \
|
||||
and not isinstance(inst, call_arg):
|
||||
free.free(base)
|
||||
if inst.is_vec() and base.vector:
|
||||
self.defined[base] = inst
|
||||
@@ -608,7 +609,8 @@ class Merger:
|
||||
# so this threshold should lead to acceptable compile times even on slower processors.
|
||||
first_factor_total_number_of_values = instr.args[12 * matmul_idx + 3] * instr.args[12 * matmul_idx + 4]
|
||||
second_factor_total_number_of_values = instr.args[12 * matmul_idx + 4] * instr.args[12 * matmul_idx + 5]
|
||||
max_dependencies_per_matrix = 1500**2
|
||||
max_dependencies_per_matrix = \
|
||||
self.block.parent.program.budget
|
||||
if first_factor_total_number_of_values > max_dependencies_per_matrix or second_factor_total_number_of_values > max_dependencies_per_matrix:
|
||||
if block.warn_about_mem and not block.parent.warned_about_mem:
|
||||
print('WARNING: Order of memory instructions not preserved due to long vector, errors possible')
|
||||
|
||||
@@ -5,7 +5,7 @@ the ones used below into ``Programs/Circuits`` as follows::
|
||||
|
||||
make Programs/Circuits
|
||||
|
||||
.. _`Bristol Fashion`: https://homes.esat.kuleuven.be/~nsmart/MPC
|
||||
.. _`Bristol Fashion`: https://nigelsmart.github.io/MPC-Circuits
|
||||
|
||||
"""
|
||||
import math
|
||||
@@ -15,6 +15,7 @@ from Compiler.library import function_block, get_tape
|
||||
from Compiler import util
|
||||
import itertools
|
||||
import struct
|
||||
import os
|
||||
|
||||
class Circuit:
|
||||
"""
|
||||
@@ -47,7 +48,12 @@ class Circuit:
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.filename = 'Programs/Circuits/%s.txt' % name
|
||||
if not os.path.exists(self.filename):
|
||||
if os.system('make Programs/Circuits'):
|
||||
raise CompilerError('Cannot download circuit descriptions. '
|
||||
'Make sure make and git are installed.')
|
||||
f = open(self.filename)
|
||||
self.functions = {}
|
||||
|
||||
@@ -57,8 +63,9 @@ class Circuit:
|
||||
def run(self, *inputs):
|
||||
n = inputs[0][0].n, get_tape()
|
||||
if n not in self.functions:
|
||||
self.functions[n] = function_block(lambda *args:
|
||||
self.compile(*args))
|
||||
self.functions[n] = function_block(
|
||||
lambda *args: self.compile(*args))
|
||||
self.functions[n].name = '%s(%d)' % (self.name, inputs[0][0].n)
|
||||
flat_res = self.functions[n](*itertools.chain(*inputs))
|
||||
res = []
|
||||
i = 0
|
||||
@@ -124,7 +131,7 @@ Keccak_f = None
|
||||
|
||||
def sha3_256(x):
|
||||
"""
|
||||
This function implements SHA3-256 for inputs of up to 1080 bits::
|
||||
This function implements SHA3-256 for inputs of any length::
|
||||
|
||||
from circuit import sha3_256
|
||||
a = sbitvec.from_vec([])
|
||||
@@ -138,7 +145,8 @@ def sha3_256(x):
|
||||
for x in a, b, c, d, e, f, g, h:
|
||||
sha3_256(x).reveal_print_hex()
|
||||
|
||||
This should output the `test vectors
|
||||
This should output the hashes of the above inputs, beginning with
|
||||
the `test vectors
|
||||
<https://github.com/XKCP/XKCP/blob/master/tests/TestVectors/ShortMsgKAT_SHA3-256.txt>`_
|
||||
of SHA3-256 for 0, 8, 16, and 24 bits as well as the hash of the
|
||||
0 byte::
|
||||
|
||||
@@ -76,13 +76,13 @@ def require_ring_size(k, op):
|
||||
program.curr_tape.require_bit_length(k)
|
||||
|
||||
@instructions_base.cisc
|
||||
def LTZ(s, a, k, kappa):
|
||||
def LTZ(s, a, k):
|
||||
"""
|
||||
s = (a ?< 0)
|
||||
|
||||
k: bit length of a
|
||||
"""
|
||||
movs(s, program.non_linear.ltz(a, k, kappa))
|
||||
movs(s, program.non_linear.ltz(a, k))
|
||||
|
||||
def LtzRing(a, k):
|
||||
from .types import sint, _bitint
|
||||
@@ -105,14 +105,14 @@ def LtzRing(a, k):
|
||||
u = CarryOutRaw(a[::-1], b[::-1])
|
||||
return sint.conv(r_bin[m].bit_xor(c_prime >> m).bit_xor(u))
|
||||
|
||||
def LessThanZero(a, k, kappa):
|
||||
def LessThanZero(a, k):
|
||||
from . import types
|
||||
res = types.sint()
|
||||
LTZ(res, a, k, kappa)
|
||||
LTZ(res, a, k)
|
||||
return res
|
||||
|
||||
@instructions_base.cisc
|
||||
def Trunc(d, a, k, m, kappa, signed):
|
||||
def Trunc(d, a, k, m, signed):
|
||||
"""
|
||||
d = a >> m
|
||||
|
||||
@@ -124,7 +124,7 @@ def Trunc(d, a, k, m, kappa, signed):
|
||||
movs(d, a)
|
||||
return
|
||||
else:
|
||||
movs(d, program.non_linear.trunc(a, k, m, kappa, signed))
|
||||
movs(d, program.non_linear.trunc(a, k, m, signed=signed))
|
||||
|
||||
def TruncRing(d, a, k, m, signed):
|
||||
program.curr_tape.require_bit_length(1)
|
||||
@@ -197,13 +197,13 @@ def TruncLeakyInRing(a, k, m, signed):
|
||||
shifted = ((a << (n_shift - m)) + (r << n_shift)).reveal(False)
|
||||
masked = shifted >> n_shift
|
||||
u = sint()
|
||||
BitLTL(u, masked, r_bits[:n_bits], 0)
|
||||
BitLTL(u, masked, r_bits[:n_bits])
|
||||
res = (u << n_bits) + masked - r
|
||||
if signed:
|
||||
res -= (1 << (n_bits - 1))
|
||||
return res
|
||||
|
||||
def TruncRoundNearest(a, k, m, kappa, signed=False):
|
||||
def TruncRoundNearest(a, k, m, signed=False):
|
||||
"""
|
||||
Returns a / 2^m, rounded to the nearest integer.
|
||||
|
||||
@@ -212,12 +212,10 @@ def TruncRoundNearest(a, k, m, kappa, signed=False):
|
||||
"""
|
||||
if m == 0:
|
||||
return a
|
||||
nl = program.non_linear
|
||||
nl.check_security(kappa)
|
||||
return program.non_linear.trunc_round_nearest(a, k, m, signed)
|
||||
|
||||
@instructions_base.cisc
|
||||
def Mod2m(a_prime, a, k, m, kappa, signed):
|
||||
def Mod2m(a_prime, a, k, m, signed):
|
||||
"""
|
||||
a_prime = a % 2^m
|
||||
|
||||
@@ -225,8 +223,6 @@ def Mod2m(a_prime, a, k, m, kappa, signed):
|
||||
m: compile-time integer
|
||||
signed: True/False, describes a
|
||||
"""
|
||||
nl = program.non_linear
|
||||
nl.check_security(kappa)
|
||||
movs(a_prime, program.non_linear.mod2m(a, k, m, signed))
|
||||
|
||||
def Mod2mRing(a_prime, a, k, m, signed):
|
||||
@@ -237,13 +233,13 @@ def Mod2mRing(a_prime, a, k, m, signed):
|
||||
tmp = a + r_prime
|
||||
c_prime = (tmp << shift).reveal(False) >> shift
|
||||
u = sint()
|
||||
BitLTL(u, c_prime, r_bin[:m], 0)
|
||||
BitLTL(u, c_prime, r_bin[:m])
|
||||
res = (u << m) + c_prime - r_prime
|
||||
if a_prime is not None:
|
||||
movs(a_prime, res)
|
||||
return res
|
||||
|
||||
def Mod2mField(a_prime, a, k, m, kappa, signed):
|
||||
def Mod2mField(a_prime, a, k, m, signed):
|
||||
from .types import sint
|
||||
r_dprime = program.curr_block.new_reg('s')
|
||||
r_prime = program.curr_block.new_reg('s')
|
||||
@@ -255,7 +251,7 @@ def Mod2mField(a_prime, a, k, m, kappa, signed):
|
||||
t = [program.curr_block.new_reg('s') for i in range(6)]
|
||||
c2m = program.curr_block.new_reg('c')
|
||||
c2k1 = program.curr_block.new_reg('c')
|
||||
PRandM(r_dprime, r_prime, r, k, m, kappa)
|
||||
PRandM(r_dprime, r_prime, r, k, m)
|
||||
ld2i(c2m, m)
|
||||
mulm(t[0], r_dprime, c2m)
|
||||
if signed:
|
||||
@@ -268,9 +264,9 @@ def Mod2mField(a_prime, a, k, m, kappa, signed):
|
||||
asm_open(True, c, t[3])
|
||||
modc(c_prime, c, c2m)
|
||||
if const_rounds:
|
||||
BitLTC1(u, c_prime, r, kappa)
|
||||
BitLTC1(u, c_prime, r)
|
||||
else:
|
||||
BitLTL(u, c_prime, r, kappa)
|
||||
BitLTL(u, c_prime, r)
|
||||
mulm(t[4], u, c2m)
|
||||
submr(t[5], c_prime, r_prime)
|
||||
adds(a_prime, t[5], t[4])
|
||||
@@ -288,13 +284,15 @@ def MaskingBitsInRing(m, strict=False):
|
||||
r_bin = r
|
||||
return sint.bit_compose(r), r_bin
|
||||
|
||||
def PRandM(r_dprime, r_prime, b, k, m, kappa, use_dabit=True):
|
||||
def PRandM(r_dprime, r_prime, b, k, m, use_dabit=True):
|
||||
"""
|
||||
r_dprime = random secret integer in range [0, 2^(k + kappa - m) - 1]
|
||||
r_prime = random secret integer in range [0, 2^m - 1]
|
||||
b = array containing bits of r_prime
|
||||
"""
|
||||
program.curr_tape.require_bit_length(k + kappa)
|
||||
assert k >= m
|
||||
kappa = program.security
|
||||
program.curr_tape.require_bit_length(k + kappa, reason='statistical masking as in https://www.researchgate.net/publication/225092133_Improved_Primitives_for_Secure_Multiparty_Integer_Computation')
|
||||
from .types import sint
|
||||
if program.use_edabit() and not const_rounds:
|
||||
movs(r_dprime, sint.get_edabit(k + kappa - m, True)[0])
|
||||
@@ -329,7 +327,7 @@ def PRandInt(r, k):
|
||||
bit(t[1][i])
|
||||
adds(t[2][i], t[0][i], t[1][i])
|
||||
|
||||
def BitLTC1(u, a, b, kappa):
|
||||
def BitLTC1(u, a, b):
|
||||
"""
|
||||
u = a <? b
|
||||
|
||||
@@ -395,7 +393,7 @@ def BitLTC1(u, a, b, kappa):
|
||||
subcfi(c[3][i], a_bits[i], 1)
|
||||
mulm(t[3][i], s[i], c[3][i])
|
||||
adds(t[4][i], t[4][i-1], t[3][i])
|
||||
Mod2(u, t[4][k-1], k, kappa, False)
|
||||
Mod2(u, t[4][k-1], k, False)
|
||||
return p, a_bits, d, s, t, c, b, pre_input
|
||||
|
||||
def carry(b, a, compute_p=True):
|
||||
@@ -414,7 +412,7 @@ def carry(b, a, compute_p=True):
|
||||
|
||||
# from WP9 report
|
||||
# length of a is even
|
||||
def CarryOutAux(a, kappa):
|
||||
def CarryOutAux(a):
|
||||
k = len(a)
|
||||
if k > 1 and k % 2 == 1:
|
||||
a.append(None)
|
||||
@@ -424,12 +422,12 @@ def CarryOutAux(a, kappa):
|
||||
if k > 1:
|
||||
for i in range(k//2):
|
||||
u[i] = carry(a[2*i+1], a[2*i], i != k//2-1)
|
||||
return CarryOutAux(u[:k//2][::-1], kappa)
|
||||
return CarryOutAux(u[:k//2][::-1])
|
||||
else:
|
||||
return a[0][1]
|
||||
|
||||
# carry out with carry-in bit c
|
||||
def CarryOut(res, a, b, c=0, kappa=None):
|
||||
def CarryOut(res, a, b, c=0):
|
||||
"""
|
||||
res = last carry bit in addition of a and b
|
||||
|
||||
@@ -456,7 +454,7 @@ def CarryOutRaw(a, b, c=0):
|
||||
s[0] = d[-1][0].bit_and(c)
|
||||
s[1] = d[-1][1] + s[0]
|
||||
d[-1][1] = s[1]
|
||||
return CarryOutAux(d[::-1], None)
|
||||
return CarryOutAux(d[::-1])
|
||||
|
||||
def CarryOutRawLE(a, b, c=0):
|
||||
""" Little-endian version """
|
||||
@@ -469,7 +467,7 @@ def CarryOutLE(a, b, c=0):
|
||||
CarryOut(res, a[::-1], b[::-1], c)
|
||||
return res
|
||||
|
||||
def BitLTL(res, a, b, kappa):
|
||||
def BitLTL(res, a, b):
|
||||
"""
|
||||
res = a <? b (logarithmic rounds version)
|
||||
|
||||
@@ -624,7 +622,7 @@ def KMulC(a):
|
||||
PreMulC_without_inverses(p, a)
|
||||
return p
|
||||
|
||||
def Mod2(a_0, a, k, kappa, signed):
|
||||
def Mod2(a_0, a, k, signed):
|
||||
"""
|
||||
a_0 = a % 2
|
||||
|
||||
@@ -641,7 +639,7 @@ def Mod2(a_0, a, k, kappa, signed):
|
||||
tc = program.curr_block.new_reg('c')
|
||||
t = [program.curr_block.new_reg('s') for i in range(6)]
|
||||
c2k1 = program.curr_block.new_reg('c')
|
||||
PRandM(r_dprime, r_prime, [r_0], k, 1, kappa)
|
||||
PRandM(r_dprime, r_prime, [r_0], k, 1)
|
||||
r_0 = r_prime
|
||||
mulsi(t[0], r_dprime, 2)
|
||||
if signed:
|
||||
|
||||
@@ -165,7 +165,8 @@ class Compiler:
|
||||
dest="prime",
|
||||
default=defaults.prime,
|
||||
help="use bit decomposition with a specifed prime modulus "
|
||||
"for non-linear computation (default: use the masking approach)",
|
||||
"for non-linear computation (default: use the masking approach). "
|
||||
"Don't use this unless you're certain that you need it.",
|
||||
)
|
||||
parser.add_option(
|
||||
"-I",
|
||||
@@ -263,6 +264,14 @@ class Compiler:
|
||||
|
||||
def parse_args(self):
|
||||
self.options, self.args = self.parser.parse_args(self.custom_args)
|
||||
if self.options.verbose:
|
||||
self.runtime_args += ["--verbose"]
|
||||
if self.options.execute:
|
||||
self.options.execute = re.sub("-party.x$", "", self.options.execute)
|
||||
for s, l in self.match.items():
|
||||
if self.options.execute == l:
|
||||
self.options.execute = s
|
||||
break
|
||||
if self.execute:
|
||||
if not self.options.execute:
|
||||
if len(self.args) > 1:
|
||||
@@ -313,6 +322,8 @@ class Compiler:
|
||||
self.prog.use_split(int(os.getenv("PLAYERS", 2)))
|
||||
if self.options.execute in ("rep4-ring",):
|
||||
self.prog.use_split(4)
|
||||
if self.options.execute.find("dealer") >= 0:
|
||||
self.prog.use_edabit(True)
|
||||
|
||||
def build_vars(self):
|
||||
from . import comparison, floatingpoint, instructions, library, types
|
||||
@@ -368,7 +379,14 @@ class Compiler:
|
||||
"cfloat",
|
||||
"squant",
|
||||
]:
|
||||
del self.VARS[i]
|
||||
class dummy:
|
||||
def __init__(self, *args):
|
||||
raise CompilerError(self.error)
|
||||
dummy.error = i + " not availabe with binary circuits"
|
||||
if i in ("cint", "cfix"):
|
||||
dummy.error += ". See https://mp-spdz.readthedocs.io/en/" \
|
||||
"latest/Compiler.html#Compiler.types." + i
|
||||
self.VARS[i] = dummy
|
||||
else:
|
||||
self.sint = types.sint
|
||||
self.sfix = types.sfix
|
||||
@@ -503,13 +521,15 @@ class Compiler:
|
||||
|
||||
return self.prog
|
||||
|
||||
@staticmethod
|
||||
def executable_from_protocol(protocol):
|
||||
match = {
|
||||
"ring": "replicated-ring",
|
||||
"rep-field": "replicated-field",
|
||||
"replicated": "replicated-bin"
|
||||
}
|
||||
match = {
|
||||
"ring": "replicated-ring",
|
||||
"rep-field": "replicated-field",
|
||||
"replicated": "replicated-bin"
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def executable_from_protocol(cls, protocol):
|
||||
match = cls.match
|
||||
if protocol in match:
|
||||
protocol = match[protocol]
|
||||
if protocol.find("bmr") == -1:
|
||||
@@ -588,11 +608,19 @@ class Compiler:
|
||||
for filename in glob.glob("Player-Data/*.0"):
|
||||
connection.put(filename, dest + "Player-Data")
|
||||
|
||||
def run_with_error(i):
|
||||
try:
|
||||
run(i)
|
||||
except IOError:
|
||||
print('IO error when copying files, does %s have enough space?' %
|
||||
hostnames[i])
|
||||
raise
|
||||
|
||||
import threading
|
||||
import random
|
||||
threads = []
|
||||
for i in range(len(hosts)):
|
||||
threads.append(threading.Thread(target=run, args=(i,)))
|
||||
threads.append(threading.Thread(target=run_with_error, args=(i,)))
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
|
||||
@@ -2,6 +2,7 @@ from collections import defaultdict
|
||||
|
||||
REG_MAX = 2 ** 32
|
||||
USER_MEM = 8192
|
||||
MEM_MAX = 2 ** 64
|
||||
|
||||
P_VALUES = { 32: 2147565569, \
|
||||
64: 9223372036855103489, \
|
||||
|
||||
@@ -558,7 +558,7 @@ def test_stupid_dijkstra_on_cycle(n, n_loops=None):
|
||||
@for_range(n)
|
||||
def f(i):
|
||||
M[i][(i+1)%n] = ExtInt(1)
|
||||
M[i][(i-1)%n] = ExtInt(1)
|
||||
M[i][(i-1+n)%n] = ExtInt(1)
|
||||
if n_loops is not None:
|
||||
stop_timer(1)
|
||||
start_timer()
|
||||
|
||||
@@ -39,26 +39,25 @@ def maskRing(a, k):
|
||||
c = ((a + r_prime) << shift).reveal(False) >> shift
|
||||
return c, r
|
||||
|
||||
def maskField(a, k, kappa):
|
||||
def maskField(a, k):
|
||||
r_dprime = types.sint()
|
||||
r_prime = types.sint()
|
||||
c = types.cint()
|
||||
r = [types.sint() for i in range(k)]
|
||||
comparison.PRandM(r_dprime, r_prime, r, k, k, kappa)
|
||||
comparison.PRandM(r_dprime, r_prime, r, k, k)
|
||||
# always signed due to usage in equality testing
|
||||
a += two_power(k)
|
||||
asm_open(True, c, a + two_power(k) * r_dprime + r_prime)
|
||||
return c, r
|
||||
|
||||
@instructions_base.ret_cisc
|
||||
def EQZ(a, k, kappa):
|
||||
def EQZ(a, k):
|
||||
prog = program.Program.prog
|
||||
if prog.use_split():
|
||||
from GC.types import sbitvec
|
||||
v = sbitvec(a, k).v
|
||||
bit = util.tree_reduce(operator.and_, (~b for b in v))
|
||||
return types.sintbit.conv(bit)
|
||||
prog.non_linear.check_security(kappa)
|
||||
return prog.non_linear.eqz(a, k)
|
||||
|
||||
def bits(a,m):
|
||||
@@ -99,12 +98,12 @@ def or_op(a, b, void=None):
|
||||
def mul_op(a, b, void=None):
|
||||
return a * b
|
||||
|
||||
def PreORC(a, kappa=None, m=None, raw=False):
|
||||
def PreORC(a, m=None, raw=False):
|
||||
k = len(a)
|
||||
if k == 1:
|
||||
return [a[0]]
|
||||
prog = program.Program.prog
|
||||
kappa = kappa or prog.security
|
||||
kappa = prog.security
|
||||
m = m or k
|
||||
if isinstance(a[0], types.sgf2n):
|
||||
max_k = program.Program.prog.galois_length - 1
|
||||
@@ -128,13 +127,13 @@ def PreORC(a, kappa=None, m=None, raw=False):
|
||||
t = [types.sint() for i in range(m)]
|
||||
b = comparison.PreMulC([a[i] + 1 for i in range(k)])
|
||||
for i in range(m):
|
||||
comparison.Mod2(t[i], b[k-1-i], k, kappa, False)
|
||||
comparison.Mod2(t[i], b[k-1-i], k, False)
|
||||
p[m-1-i] = 1 - t[i]
|
||||
return p
|
||||
else:
|
||||
# not constant-round anymore
|
||||
s = [PreORC(a[i:i+max_k], kappa, raw=raw) for i in range(0,k,max_k)]
|
||||
t = PreORC([si[-1] for si in s[:-1]], kappa, raw=raw)
|
||||
s = [PreORC(a[i:i+max_k], raw=raw) for i in range(0,k,max_k)]
|
||||
t = PreORC([si[-1] for si in s[:-1]], raw=raw)
|
||||
return sum(([or_op(x, y) for x in si]
|
||||
for si,y in zip(s[1:],t)), s[0])[-m:]
|
||||
|
||||
@@ -175,6 +174,41 @@ def PreOpL2(op, items):
|
||||
output[2 * i] = op(v[i - 1], items[2 * i])
|
||||
return output
|
||||
|
||||
def PreOpL2_vec(op, *items):
|
||||
""" Vectorized version of :py:func:`PreOpL2` """
|
||||
k = len(items[0])
|
||||
for x in items:
|
||||
assert len(x) == k
|
||||
if k == 1:
|
||||
return items
|
||||
half = k // 2
|
||||
other_half = (k + 1) // 2 - 1
|
||||
u = op([x.get_vector(base=0, size=half, skip=2) for x in items],
|
||||
[x.get_vector(base=1, size=half, skip=2) for x in items])
|
||||
assert len(u) == len(items)
|
||||
assert len(u[0]) == half
|
||||
v = PreOpL2_vec(op, *u)
|
||||
if other_half:
|
||||
w = op([x.get_vector(base=0, size=other_half) for x in v],
|
||||
[x.get_vector(base=2, size=other_half, skip=2) for x in items])
|
||||
if half == other_half:
|
||||
res = [type(x).zip(x, y) for x, y in zip(v, w)]
|
||||
for i in range(len(res)):
|
||||
res[i] = type(res[i]).concat((items[i].get_vector(base=0, size=1),
|
||||
res[i]))
|
||||
else:
|
||||
if other_half:
|
||||
for i in range(len(w)):
|
||||
w[i] = type(w[i]).concat((items[i].get_vector(base=0, size=1),
|
||||
w[i]))
|
||||
else:
|
||||
w = [x.get_vector(base=0, size=1) for x in items]
|
||||
res = [type(x).zip(x, y) for x, y in zip(w, v)]
|
||||
assert len(res) == len(items)
|
||||
for x in res:
|
||||
assert len(x) == k
|
||||
return res
|
||||
|
||||
def PreOpN(op, items):
|
||||
""" Naive PreOp algorithm """
|
||||
k = len(items)
|
||||
@@ -184,9 +218,9 @@ def PreOpN(op, items):
|
||||
output[i] = op(output[i-1], items[i])
|
||||
return output
|
||||
|
||||
def PreOR(a, kappa=None, raw=False):
|
||||
def PreOR(a=None, raw=False):
|
||||
if comparison.const_rounds:
|
||||
return PreORC(a, kappa, raw=raw)
|
||||
return PreORC(a, raw=raw)
|
||||
else:
|
||||
return PreOpL(or_op, a)
|
||||
|
||||
@@ -199,24 +233,24 @@ def KOpL(op, a):
|
||||
t2 = KOpL(op, a[k//2:])
|
||||
return op(t1, t2)
|
||||
|
||||
def KORL(a, kappa=None):
|
||||
def KORL(a):
|
||||
""" log rounds k-ary OR """
|
||||
k = len(a)
|
||||
if k == 1:
|
||||
return a[0]
|
||||
else:
|
||||
t1 = KORL(a[:k//2], kappa)
|
||||
t2 = KORL(a[k//2:], kappa)
|
||||
t1 = KORL(a[:k//2])
|
||||
t2 = KORL(a[k//2:])
|
||||
return t1 + t2 - t1.bit_and(t2)
|
||||
|
||||
def KORC(a, kappa):
|
||||
return PreORC(a, kappa, 1)[0]
|
||||
def KORC(a):
|
||||
return PreORC(a, 1)[0]
|
||||
|
||||
def KOR(a, kappa):
|
||||
def KOR(a):
|
||||
if comparison.const_rounds:
|
||||
return KORC(a, kappa)
|
||||
return KORC(a)
|
||||
else:
|
||||
return KORL(a, None)
|
||||
return KORL(a)
|
||||
|
||||
def KMul(a):
|
||||
if comparison.const_rounds:
|
||||
@@ -262,7 +296,7 @@ def BitAdd(a, b, bits_to_compute=None):
|
||||
s[k] = c[k-1]
|
||||
return s
|
||||
|
||||
def BitDec(a, k, m, kappa, bits_to_compute=None):
|
||||
def BitDec(a, k, m, bits_to_compute=None):
|
||||
return program.Program.prog.non_linear.bit_dec(a, k, m)
|
||||
|
||||
def BitDecRingRaw(a, k, m):
|
||||
@@ -270,7 +304,7 @@ def BitDecRingRaw(a, k, m):
|
||||
n_shift = int(program.Program.prog.options.ring) - m
|
||||
if program.Program.prog.use_split():
|
||||
x = a.split_to_two_summands(m)
|
||||
bits = types._bitint.carry_lookahead_adder(x[0], x[1], fewer_inv=False)
|
||||
bits = types._bitint.bit_adder(x[0], x[1])
|
||||
return bits[:m]
|
||||
else:
|
||||
if program.Program.prog.use_edabit():
|
||||
@@ -292,13 +326,14 @@ def BitDecRing(a, k, m):
|
||||
# reversing to reduce number of rounds
|
||||
return [types.sintbit.conv(bit) for bit in reversed(bits)][::-1]
|
||||
|
||||
def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None):
|
||||
def BitDecFieldRaw(a, k, m, bits_to_compute=None):
|
||||
instructions_base.set_global_vector_size(a.size)
|
||||
r_dprime = types.sint()
|
||||
r_prime = types.sint()
|
||||
c = types.cint()
|
||||
r = [types.sint() for i in range(m)]
|
||||
comparison.PRandM(r_dprime, r_prime, r, k, m, kappa)
|
||||
comparison.PRandM(r_dprime, r_prime, r, k, m)
|
||||
kappa = program.Program.prog.security
|
||||
pow2 = two_power(k + kappa)
|
||||
asm_open(True, c, pow2 + two_power(k) + a - two_power(m)*r_dprime - r_prime)
|
||||
res = r[0].bit_adder(r, list(r[0].bit_decompose_clear(c,m)))
|
||||
@@ -306,16 +341,16 @@ def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None):
|
||||
return res
|
||||
|
||||
@instructions_base.bit_cisc
|
||||
def BitDecField(a, k, m, kappa, bits_to_compute=None):
|
||||
res = BitDecFieldRaw(a, k, m, kappa, bits_to_compute)
|
||||
def BitDecField(a, k, m, bits_to_compute=None):
|
||||
res = BitDecFieldRaw(a, k, m, bits_to_compute)
|
||||
return [types.sintbit.conv(bit) for bit in res]
|
||||
|
||||
|
||||
@instructions_base.ret_cisc
|
||||
def Pow2(a, l, kappa):
|
||||
def Pow2(a, l):
|
||||
comparison.program.curr_tape.require_bit_length(l - 1)
|
||||
m = int(ceil(log(l, 2)))
|
||||
t = BitDec(a, m, m, kappa)
|
||||
t = BitDec(a, m, m)
|
||||
return Pow2_from_bits(t)
|
||||
|
||||
def Pow2_from_bits(bits):
|
||||
@@ -327,11 +362,12 @@ def Pow2_from_bits(bits):
|
||||
t[i] = t[i]*pow2k[i] + 1 - t[i]
|
||||
return KMul(t)
|
||||
|
||||
def B2U(a, l, kappa):
|
||||
pow2a = Pow2(a, l, kappa)
|
||||
return B2U_from_Pow2(pow2a, l, kappa), pow2a
|
||||
def B2U(a, l):
|
||||
pow2a = Pow2(a, l)
|
||||
return B2U_from_Pow2(pow2a, l), pow2a
|
||||
|
||||
def B2U_from_Pow2(pow2a, l, kappa):
|
||||
def B2U_from_Pow2(pow2a, l):
|
||||
kappa = program.Program.prog.security
|
||||
r = [types.sint() for i in range(l)]
|
||||
t = types.sint()
|
||||
c = types.cint()
|
||||
@@ -353,17 +389,17 @@ def B2U_from_Pow2(pow2a, l, kappa):
|
||||
c = list(r_bits[0].bit_decompose_clear(c, l))
|
||||
x = [r_bits[i].bit_xor(c[i]) for i in range(l)]
|
||||
#print ' '.join(str(b.value) for b in x)
|
||||
y = PreOR(x, kappa)
|
||||
y = PreOR(x)
|
||||
#print ' '.join(str(b.value) for b in y)
|
||||
return [types.sint.conv(1 - y[i]) for i in range(l)]
|
||||
|
||||
def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False):
|
||||
def Trunc(a, l, m, compute_modulo=False, signed=False):
|
||||
""" Oblivious truncation by secret m """
|
||||
prog = program.Program.prog
|
||||
if util.is_constant(m) and not compute_modulo:
|
||||
# cheaper
|
||||
res = type(a)(size=a.size)
|
||||
comparison.Trunc(res, a, l, m, kappa, signed=signed)
|
||||
comparison.Trunc(res, a, l, m, signed=signed)
|
||||
return res
|
||||
if l == 1:
|
||||
if compute_modulo:
|
||||
@@ -371,9 +407,9 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False):
|
||||
else:
|
||||
return a * (1 - m)
|
||||
if program.Program.prog.options.ring and not compute_modulo:
|
||||
return TruncInRing(a, l, Pow2(m, l, kappa))
|
||||
return TruncInRing(a, l, Pow2(m, l))
|
||||
else:
|
||||
kappa = kappa or program.Program.prog.security
|
||||
kappa = program.Program.prog.security
|
||||
r = [types.sint() for i in range(l)]
|
||||
r_dprime = types.sint(0)
|
||||
r_prime = types.sint(0)
|
||||
@@ -381,7 +417,7 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False):
|
||||
c = types.cint()
|
||||
ci = [types.cint() for i in range(l)]
|
||||
d = types.sint()
|
||||
x, pow2m = B2U(m, l, kappa)
|
||||
x, pow2m = B2U(m, l)
|
||||
for i in range(l):
|
||||
bit(r[i])
|
||||
t1 = two_power(i) * r[i]
|
||||
@@ -398,7 +434,7 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False):
|
||||
for i in range(1,l):
|
||||
ci[i] = c % two_power(i)
|
||||
c_dprime = sum(ci[i]*(x[i-1] - x[i]) for i in range(1,l))
|
||||
d = program.Program.prog.non_linear.ltz(c_dprime - r_prime, l, kappa)
|
||||
d = program.Program.prog.non_linear.ltz(c_dprime - r_prime, l)
|
||||
if compute_modulo:
|
||||
b = c_dprime - r_prime + pow2m * d
|
||||
return b, pow2m
|
||||
@@ -429,33 +465,33 @@ def TruncInRing(to_shift, l, pow2m):
|
||||
def SplitInRing(a, l, m):
|
||||
if l == 1:
|
||||
return m.if_else(a, 0), m.if_else(0, a), 1
|
||||
pow2m = Pow2(m, l, None)
|
||||
pow2m = Pow2(m, l)
|
||||
upper = TruncInRing(a, l, pow2m)
|
||||
lower = a - upper * pow2m
|
||||
return lower, upper, pow2m
|
||||
|
||||
def TruncRoundNearestAdjustOverflow(a, length, target_length, kappa):
|
||||
t = comparison.TruncRoundNearest(a, length, length - target_length, kappa)
|
||||
overflow = t.greater_equal(two_power(target_length), target_length + 1, kappa)
|
||||
def TruncRoundNearestAdjustOverflow(a, length, target_length):
|
||||
t = comparison.TruncRoundNearest(a, length, length - target_length)
|
||||
overflow = t.greater_equal(two_power(target_length), target_length + 1)
|
||||
s = (1 - overflow) * t + overflow * t.trunc_zeros(1, length, False)
|
||||
return s, overflow
|
||||
|
||||
def Int2FL(a, gamma, l, kappa=None):
|
||||
def Int2FL(a, gamma, l):
|
||||
lam = gamma - 1
|
||||
s = a.less_than(0, gamma, security=kappa)
|
||||
z = a.equal(0, gamma, security=kappa)
|
||||
s = a.less_than(0, gamma)
|
||||
z = a.equal(0, gamma)
|
||||
a = s.if_else(-a, a)
|
||||
a_bits = a.bit_decompose(lam, security=kappa)
|
||||
a_bits = a.bit_decompose(lam)
|
||||
a_bits.reverse()
|
||||
b = PreOR(a_bits, kappa)
|
||||
b = PreOR(a_bits)
|
||||
t = a * (1 + a.bit_compose(1 - b_i for b_i in b))
|
||||
p = a.popcnt_bits(b) - lam
|
||||
if gamma - 1 > l:
|
||||
if types.sfloat.round_nearest:
|
||||
v, overflow = TruncRoundNearestAdjustOverflow(t, gamma - 1, l, kappa)
|
||||
v, overflow = TruncRoundNearestAdjustOverflow(t, gamma - 1, l)
|
||||
p = p + overflow
|
||||
else:
|
||||
v = t.right_shift(gamma - l - 1, gamma - 1, kappa, signed=False)
|
||||
v = t.right_shift(gamma - l - 1, gamma - 1, signed=False)
|
||||
else:
|
||||
v = 2**(l-gamma+1) * t
|
||||
p = (p + gamma - 1 - l) * z.bit_not()
|
||||
@@ -466,32 +502,31 @@ def FLRound(x, mode):
|
||||
*mode*: 0 -> floor, 1 -> ceil, -1 > trunc """
|
||||
v1, p1, z1, s1, l, k = x.v, x.p, x.z, x.s, x.vlen, x.plen
|
||||
a = types.sint()
|
||||
comparison.LTZ(a, p1, k, x.kappa)
|
||||
b = p1.less_than(-l + 1, k, x.kappa)
|
||||
v2, inv_2pow_p1 = Trunc(v1, l, -a * (1 - b) * x.p, x.kappa, True)
|
||||
c = EQZ(v2, l, x.kappa)
|
||||
comparison.LTZ(a, p1, k)
|
||||
b = p1.less_than(-l + 1, k)
|
||||
v2, inv_2pow_p1 = Trunc(v1, l, -a * (1 - b) * x.p, compute_modulo=True)
|
||||
c = EQZ(v2, l)
|
||||
if mode == -1:
|
||||
away_from_zero = 0
|
||||
mode = x.s
|
||||
else:
|
||||
away_from_zero = mode + s1 - 2 * mode * s1
|
||||
v = v1 - v2 + (1 - c) * inv_2pow_p1 * away_from_zero
|
||||
d = v.equal(two_power(l), l + 1, x.kappa)
|
||||
d = v.equal(two_power(l), l + 1)
|
||||
v = d * two_power(l-1) + (1 - d) * v
|
||||
v = a * ((1 - b) * v + b * away_from_zero * two_power(l-1)) + (1 - a) * v1
|
||||
s = (1 - b * mode) * s1
|
||||
z = or_op(EQZ(v, l, x.kappa), z1)
|
||||
z = or_op(EQZ(v, l), z1)
|
||||
v = v * (1 - z)
|
||||
p = ((p1 + d * a) * (1 - b) + b * away_from_zero * (1 - l)) * (1 - z)
|
||||
return v, p, z, s
|
||||
|
||||
@instructions_base.ret_cisc
|
||||
def TruncPr(a, k, m, kappa=None, signed=True):
|
||||
def TruncPr(a, k, m, signed=True):
|
||||
""" Probabilistic truncation [a/2^m + u]
|
||||
where Pr[u = 1] = (a % 2^m) / 2^m
|
||||
"""
|
||||
nl = program.Program.prog.non_linear
|
||||
nl.check_security(kappa)
|
||||
return nl.trunc_pr(a, k, m, signed)
|
||||
|
||||
def TruncPrRing(a, k, m, signed=True):
|
||||
@@ -540,16 +575,14 @@ def TruncPrRing(a, k, m, signed=True):
|
||||
res -= (1 << (k - m - 1))
|
||||
return res
|
||||
|
||||
def TruncPrField(a, k, m, kappa=None):
|
||||
def TruncPrField(a, k, m):
|
||||
if m == 0:
|
||||
return a
|
||||
if kappa is None:
|
||||
kappa = 40
|
||||
|
||||
b = two_power(k-1) + a
|
||||
r_prime, r_dprime = types.sint(), types.sint()
|
||||
comparison.PRandM(r_dprime, r_prime, [types.sint() for i in range(m)],
|
||||
k, m, kappa, use_dabit=False)
|
||||
k, m, use_dabit=False)
|
||||
two_to_m = two_power(m)
|
||||
r = two_to_m * r_dprime + r_prime
|
||||
c = (b + r).reveal(False)
|
||||
@@ -559,49 +592,49 @@ def TruncPrField(a, k, m, kappa=None):
|
||||
return d
|
||||
|
||||
@instructions_base.ret_cisc
|
||||
def SDiv(a, b, l, kappa, round_nearest=False):
|
||||
def SDiv(a, b, l, round_nearest=False):
|
||||
theta = int(ceil(log(l / 3.5) / log(2)))
|
||||
alpha = two_power(2*l)
|
||||
w = types.cint(int(2.9142 * 2 ** l)) - 2 * b
|
||||
x = alpha - b * w
|
||||
y = a * w
|
||||
y = y.round(2 * l + 1, l, kappa, round_nearest, signed=False)
|
||||
y = y.round(2 * l + 1, l, nearest=round_nearest, signed=False)
|
||||
x2 = types.sint()
|
||||
comparison.Mod2m(x2, x, 2 * l + 1, l, kappa, True)
|
||||
comparison.Mod2m(x2, x, 2 * l + 1, l, signed=True)
|
||||
x1 = comparison.TruncZeros(x - x2, 2 * l + 1, l, True)
|
||||
for i in range(theta-1):
|
||||
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa,
|
||||
round_nearest,
|
||||
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l,
|
||||
nearest=round_nearest,
|
||||
signed=False)
|
||||
y = y.round(2 * l + 1, l, kappa, round_nearest, signed=False)
|
||||
x = x1 * x2 + (x2**2).round(2 * l + 1, l + 1, kappa, round_nearest,
|
||||
y = y.round(2 * l + 1, l, nearest=round_nearest, signed=False)
|
||||
x = x1 * x2 + (x2**2).round(2 * l + 1, l + 1, nearest=round_nearest,
|
||||
signed=False)
|
||||
x = x1 * x1 + x.round(2 * l + 1, l - 1, kappa, round_nearest,
|
||||
x = x1 * x1 + x.round(2 * l + 1, l - 1, nearest=round_nearest,
|
||||
signed=False)
|
||||
x2 = types.sint()
|
||||
comparison.Mod2m(x2, x, 2 * l, l, kappa, False)
|
||||
comparison.Mod2m(x2, x, 2 * l, l, signed=False)
|
||||
x1 = comparison.TruncZeros(x - x2, 2 * l + 1, l, True)
|
||||
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa,
|
||||
round_nearest, signed=False)
|
||||
y = y.round(2 * l + 1, l + 1, kappa, round_nearest)
|
||||
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, nearest=round_nearest,
|
||||
signed=False)
|
||||
y = y.round(2 * l + 1, l + 1, nearest=round_nearest)
|
||||
return y
|
||||
|
||||
def SDiv_mono(a, b, l, kappa):
|
||||
def SDiv_mono(a, b, l):
|
||||
theta = int(ceil(log(l / 3.5) / log(2)))
|
||||
alpha = two_power(2*l)
|
||||
w = types.cint(int(2.9142 * two_power(l))) - 2 * b
|
||||
x = alpha - b * w
|
||||
y = a * w
|
||||
y = TruncPr(y, 2 * l + 1, l + 1, kappa)
|
||||
y = TruncPr(y, 2 * l + 1, l + 1)
|
||||
for i in range(theta-1):
|
||||
y = y * (alpha + x)
|
||||
# keep y with l bits
|
||||
y = TruncPr(y, 3 * l, 2 * l, kappa)
|
||||
y = TruncPr(y, 3 * l, 2 * l)
|
||||
x = x**2
|
||||
# keep x with 2l bits
|
||||
x = TruncPr(x, 4 * l, 2 * l, kappa)
|
||||
x = TruncPr(x, 4 * l, 2 * l)
|
||||
y = y * (alpha + x)
|
||||
y = TruncPr(y, 3 * l, 2 * l, kappa)
|
||||
y = TruncPr(y, 3 * l, 2 * l)
|
||||
return y
|
||||
|
||||
# LT bit comparison on shared bit values
|
||||
|
||||
@@ -488,6 +488,70 @@ class join_tape(base.Instruction):
|
||||
code = base.opcodes['JOIN_TAPE']
|
||||
arg_format = ['int']
|
||||
|
||||
class call_tape(base.DoNotEliminateInstruction):
|
||||
""" Start tape/bytecode file in same thread. Arguments/return values
|
||||
starting from :py:obj:`direction` are optional.
|
||||
|
||||
:param: tape number (int)
|
||||
:param: arg (regint)
|
||||
:param: direction (0 for argument, 1 for return value)
|
||||
:param: register type (see :py:obj:`vm_types`)
|
||||
:param: register size (int)
|
||||
:param: destination register
|
||||
:param: source register
|
||||
:param: (repeat from direction)
|
||||
|
||||
"""
|
||||
code = base.opcodes['CALL_TAPE']
|
||||
arg_format = tools.chain(['int', 'ci'],
|
||||
tools.cycle(['int','int','int','*w','*']))
|
||||
|
||||
@staticmethod
|
||||
def type_check(reg, type_id):
|
||||
assert base.vm_types[reg.reg_type] == type_id
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(call_tape, self).__init__(*args, **kwargs)
|
||||
for i in range(2, len(args), 5):
|
||||
for reg in args[i + 3:i + 5]:
|
||||
self.type_check(reg, args[i + 1])
|
||||
assert reg.size == args[i + 2]
|
||||
assert args[i] in (0, 1)
|
||||
assert args[i + 4 - args[i]].program == program.curr_tape
|
||||
assert args[i + 3 + args[i]].program == program.tapes[args[0]]
|
||||
|
||||
def get_def(self):
|
||||
# hide registers from called tape
|
||||
for i in range(2, len(self.args), 5):
|
||||
if self.args[i]:
|
||||
yield self.args[i + 3]
|
||||
|
||||
def get_used(self):
|
||||
# hide registers from called tape
|
||||
yield self.args[1]
|
||||
for i in range(2, len(self.args), 5):
|
||||
if not self.args[i]:
|
||||
yield self.args[i + 4]
|
||||
|
||||
def add_usage(self, req_node):
|
||||
req_node.num += program.tapes[self.args[0]].req_tree.aggregate()
|
||||
|
||||
class call_arg(base.DoNotEliminateInstruction, base.VectorInstruction):
|
||||
""" Pseudo instruction for arguments in connection with
|
||||
:py:class:`call_tape`.
|
||||
|
||||
:param: destination (register)
|
||||
:param: register type (see :py:obj:`vm_types`)
|
||||
|
||||
"""
|
||||
code = base.opcodes['CALL_ARG']
|
||||
arg_format = ['*w','int']
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(call_arg, self).__init__(*args, **kwargs)
|
||||
for i in range(0, len(args), 2):
|
||||
call_tape.type_check(args[i], args[i + 1])
|
||||
|
||||
class crash(base.IOInstruction):
|
||||
""" Crash runtime if the value in the register is not zero.
|
||||
|
||||
@@ -687,6 +751,27 @@ class concats(base.VectorInstruction):
|
||||
for i in range(1, len(args), 2):
|
||||
assert args[i] == len(args[i + 1])
|
||||
|
||||
class zips(base.Instruction):
|
||||
""" Zip vectors.
|
||||
|
||||
:param: result (sint)
|
||||
:param: operand (sint)
|
||||
:param: operand (sint)
|
||||
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['ZIPS']
|
||||
arg_format = ['sw','s','s']
|
||||
is_vec = lambda self: True
|
||||
|
||||
def __init__(self, *args):
|
||||
super(zips, self).__init__(*args)
|
||||
assert len(args[0]) == len(args[1]) + len(args[2])
|
||||
assert len(args[1]) == len(args[2])
|
||||
|
||||
def get_code(self):
|
||||
return super(zips, self).get_code(len(self.args[1]))
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
class mulc(base.MulBase):
|
||||
|
||||
@@ -68,6 +68,8 @@ opcodes = dict(
|
||||
USE_MATMUL = 0x1F,
|
||||
ACTIVE = 0xE9,
|
||||
CMDLINEARG = 0xEB,
|
||||
CALL_TAPE = 0xEC,
|
||||
CALL_ARG = 0xED,
|
||||
# Addition
|
||||
ADDC = 0x20,
|
||||
ADDS = 0x21,
|
||||
@@ -85,6 +87,7 @@ opcodes = dict(
|
||||
PREFIXSUMS = 0x2D,
|
||||
PICKS = 0x2E,
|
||||
CONCATS = 0x2F,
|
||||
ZIPS = 0x3F,
|
||||
# Multiplication/division
|
||||
MULC = 0x30,
|
||||
MULM = 0x31,
|
||||
@@ -223,6 +226,17 @@ opcodes = dict(
|
||||
)
|
||||
|
||||
|
||||
vm_types = dict(
|
||||
ci = 0,
|
||||
sb = 1,
|
||||
cb = 2,
|
||||
s = 4,
|
||||
c = 5,
|
||||
sg = 6,
|
||||
cg = 7,
|
||||
)
|
||||
|
||||
|
||||
def int_to_bytes(x):
|
||||
""" 32 bit int to big-endian 4 byte conversion. """
|
||||
assert(x < 2**32 and x >= -2**32)
|
||||
@@ -491,7 +505,8 @@ def cisc(function, n_outputs=1):
|
||||
reset_global_vector_size()
|
||||
program.curr_tape = old_tape
|
||||
for x, bl in tape.req_bit_length.items():
|
||||
old_tape.require_bit_length(bl - 1, x)
|
||||
old_tape.require_bit_length(
|
||||
bl - 1, x, tape.bit_length_reason if x == 'p' else '')
|
||||
from Compiler.allocator import Merger
|
||||
merger = Merger(block, program.options,
|
||||
tuple(program.to_merge))
|
||||
@@ -516,40 +531,48 @@ def cisc(function, n_outputs=1):
|
||||
inst.copy(size, subs)
|
||||
reset_global_vector_size()
|
||||
|
||||
def expand_to_function(self, size, new_regs):
|
||||
key = size, program.curr_tape, \
|
||||
tuple(arg for arg, reg in zip(self.args, new_regs) if reg is None), \
|
||||
tuple(type(reg) for reg in new_regs)
|
||||
if key not in self.functions:
|
||||
from Compiler import library, types
|
||||
class Arg:
|
||||
def __init__(self, reg):
|
||||
from Compiler.GC.types import bits
|
||||
class Arg:
|
||||
def __init__(self, reg):
|
||||
self.type = type(reg)
|
||||
self.binary = isinstance(reg, bits)
|
||||
self.reg = reg
|
||||
# if reg is not None:
|
||||
# program.base_addresses[reg] = None
|
||||
def new(self):
|
||||
if self.binary:
|
||||
return self.type()
|
||||
else:
|
||||
return self.type(size=size)
|
||||
def load(self):
|
||||
return self.reg
|
||||
def store(self, reg):
|
||||
if self.type != type(None):
|
||||
self.reg.update(reg)
|
||||
args = [Arg(x) for x in new_regs]
|
||||
self.type = type(reg)
|
||||
self.binary = isinstance(reg, bits)
|
||||
self.reg = reg
|
||||
def new(self, size):
|
||||
if self.binary:
|
||||
return self.type()
|
||||
else:
|
||||
return self.type(size=size)
|
||||
def load(self):
|
||||
return self.reg
|
||||
def store(self, reg):
|
||||
if self.type != type(None):
|
||||
self.reg.update(reg)
|
||||
def is_real(self):
|
||||
return self.reg is not None
|
||||
|
||||
def base_key(self, size, new_regs):
|
||||
return size, tuple(
|
||||
arg for arg, reg in zip(self.args, new_regs) if reg is None), \
|
||||
tuple(type(reg) for reg in new_regs)
|
||||
|
||||
@staticmethod
|
||||
def get_name(key):
|
||||
return '_'.join(['%s(%d)' % (function.__name__, key[0])] +
|
||||
[str(x) for x in key[1]])
|
||||
|
||||
def expand_to_function(self, size, new_regs):
|
||||
key = self.base_key(size, new_regs) + (program.curr_tape,)
|
||||
if key not in self.functions:
|
||||
args = [self.Arg(x) for x in new_regs]
|
||||
from Compiler import library, types
|
||||
@library.function_block
|
||||
def f():
|
||||
res = [arg.new() for arg in args[:n_outputs]]
|
||||
self.new_instructions(size,
|
||||
res + [arg.load() for arg in args[n_outputs:]])
|
||||
res = [arg.new(size) for arg in args[:n_outputs]]
|
||||
self.new_instructions(
|
||||
size, res + [arg.load() for arg in args[n_outputs:]])
|
||||
for reg, arg in zip(res, args):
|
||||
arg.store(reg)
|
||||
f.name = '_'.join(['%s(%d)' % (function.__name__, size)] +
|
||||
[str(x) for x in key[2]])
|
||||
f.name = self.get_name(key)
|
||||
self.functions[key] = f, args
|
||||
f, args = self.functions[key]
|
||||
for i in range(len(new_regs) - n_outputs):
|
||||
@@ -558,6 +581,31 @@ def cisc(function, n_outputs=1):
|
||||
for i in range(n_outputs):
|
||||
new_regs[i].link(args[i].load())
|
||||
|
||||
def expand_to_tape(self, size, new_regs):
|
||||
key = self.base_key(size, new_regs)
|
||||
args = [self.Arg(x) for x in new_regs]
|
||||
if key not in self.functions:
|
||||
from Compiler import library, types
|
||||
@library.function_call_tape
|
||||
def f(*in_args):
|
||||
res = [arg.new(size) for arg in args[:n_outputs]]
|
||||
in_args = list(in_args)
|
||||
my_args = list(res)
|
||||
for arg in args[n_outputs:]:
|
||||
if arg.is_real():
|
||||
my_args.append(in_args.pop(0))
|
||||
else:
|
||||
my_args.append(arg.reg)
|
||||
self.new_instructions(size, my_args)
|
||||
return res
|
||||
f.name = self.get_name(key)
|
||||
self.functions[key] = f
|
||||
f = self.functions[key]
|
||||
in_args = filter(lambda arg: arg.is_real(), args[n_outputs:])
|
||||
res = util.tuplify(f(*(arg.load() for arg in in_args)))
|
||||
for i in range(n_outputs):
|
||||
new_regs[i].link(res[i])
|
||||
|
||||
def expand_merged(self, skip):
|
||||
if function.__name__ in skip:
|
||||
good = True
|
||||
@@ -595,7 +643,11 @@ def cisc(function, n_outputs=1):
|
||||
raise
|
||||
if program.cisc_to_function and \
|
||||
(program.curr_tape.singular or program.n_running_threads):
|
||||
self.expand_to_function(size, new_regs)
|
||||
if (program.options.garbled or program.options.binary or \
|
||||
not program.use_tape_calls) and not program.force_cisc_tape:
|
||||
self.expand_to_function(size, new_regs)
|
||||
else:
|
||||
self.expand_to_tape(size, new_regs)
|
||||
else:
|
||||
self.new_instructions(size, new_regs)
|
||||
program.curr_block.n_rounds += self.n_rounds - 1
|
||||
@@ -795,6 +847,12 @@ class ClearIntAF(RegisterArgFormat):
|
||||
reg_type = RegType.ClearInt
|
||||
name = 'regint'
|
||||
|
||||
class AnyRegAF(RegisterArgFormat):
|
||||
reg_type = '*'
|
||||
@staticmethod
|
||||
def check(arg):
|
||||
assert isinstance(arg, program.curr_tape.Register)
|
||||
|
||||
class IntArgFormat(ArgFormat):
|
||||
n_bits = 32
|
||||
|
||||
@@ -898,6 +956,8 @@ ArgFormats = {
|
||||
'sgw': SecretGF2NAF,
|
||||
'ci': ClearIntAF,
|
||||
'ciw': ClearIntAF,
|
||||
'*': AnyRegAF,
|
||||
'*w': AnyRegAF,
|
||||
'i': ImmediateModpAF,
|
||||
'ig': ImmediateGF2NAF,
|
||||
'int': IntArgFormat,
|
||||
@@ -938,6 +998,7 @@ class Instruction(object):
|
||||
Instruction.count += 1
|
||||
if Instruction.count % 100000 == 0:
|
||||
print("Compiled %d lines at" % self.__class__.count, time.asctime())
|
||||
sys.stdout.flush()
|
||||
if Instruction.count > 10 ** 7:
|
||||
print("Compilation produced more that 10 million instructions. "
|
||||
"Consider using './compile.py -l' or replacing for loops "
|
||||
|
||||
@@ -107,7 +107,7 @@ def print_str(s, *args, print_secrets=False):
|
||||
elif isinstance(val, cfloat):
|
||||
val.print_float_plain()
|
||||
elif isinstance(val, (list, tuple, Array, SubMultiArray)):
|
||||
print_str(*_expand_to_print(val))
|
||||
print_str(*_expand_to_print(val), print_secrets=print_secrets)
|
||||
else:
|
||||
try:
|
||||
val.output()
|
||||
@@ -314,7 +314,7 @@ def get_cmdline_arg(idx):
|
||||
return localint(res)
|
||||
|
||||
def make_array(l, t=None):
|
||||
if isinstance(l, Tape.Register):
|
||||
if isinstance(l, types._structure):
|
||||
res = Array(len(l), t or type(l))
|
||||
res[:] = l
|
||||
else:
|
||||
@@ -334,13 +334,12 @@ class FunctionTapeCall:
|
||||
return self
|
||||
def join(self):
|
||||
self.thread.join()
|
||||
instructions.program.free(self.base, 'ci')
|
||||
for reg_type,addr in self.bases.items():
|
||||
get_program().free(addr, reg_type.reg_type)
|
||||
if self.base is not None:
|
||||
instructions.program.free(self.base, 'ci')
|
||||
|
||||
class Function:
|
||||
def __init__(self, function, name=None, compile_args=[]):
|
||||
self.type_args = {}
|
||||
self.last_key = None
|
||||
self.function = function
|
||||
self.name = name
|
||||
if name is None:
|
||||
@@ -348,46 +347,40 @@ class Function:
|
||||
self.compile_args = compile_args
|
||||
def __call__(self, *args):
|
||||
args = tuple(arg.read() if isinstance(arg, MemValue) else arg for arg in args)
|
||||
from .types import _types
|
||||
get_reg_type = lambda x: \
|
||||
regint if isinstance(x, int) else _types.get(x.reg_type, type(x))
|
||||
key = len(args), get_tape()
|
||||
if key not in self.type_args:
|
||||
runtime_args = []
|
||||
reg_args = []
|
||||
key = self.base_key(),
|
||||
for i,arg in enumerate(args):
|
||||
if isinstance(arg, types._vectorizable):
|
||||
key += (arg.shape, arg.value_type)
|
||||
else:
|
||||
arg = MemValue(arg)
|
||||
reg_args.append(arg)
|
||||
t = arg.value_type
|
||||
key += (arg.size, t)
|
||||
runtime_args.append(arg)
|
||||
if key != self.last_key:
|
||||
# first call
|
||||
type_args = collections.defaultdict(list)
|
||||
for i,arg in enumerate(args):
|
||||
if not isinstance(arg, types._vectorizable):
|
||||
type_args[get_reg_type(arg)].append(i)
|
||||
outer_runtime_args = runtime_args
|
||||
def wrapped_function(*compile_args):
|
||||
base = get_arg()
|
||||
bases = dict((t, regint.load_mem(base + i)) \
|
||||
for i,t in enumerate(sorted(type_args,
|
||||
key=lambda x:
|
||||
x.reg_type)))
|
||||
runtime_args = list(args)
|
||||
for t in sorted(type_args, key=lambda x: x.reg_type):
|
||||
i = 0
|
||||
for i_arg in type_args[t]:
|
||||
runtime_args[i_arg] = t.load_mem(bases[t] + i)
|
||||
i += util.mem_size(t)
|
||||
return self.function(*(list(compile_args) + runtime_args))
|
||||
addresses = regint.Array(len(outer_runtime_args),
|
||||
address=get_arg())
|
||||
runtime_args = []
|
||||
for i, arg in enumerate(outer_runtime_args):
|
||||
if isinstance(arg, MemValue):
|
||||
arg = arg.value_type.load_mem(
|
||||
address=addresses[i], size=arg.size)
|
||||
runtime_args.append(arg)
|
||||
self.result = self.function(
|
||||
*(list(compile_args) + runtime_args))
|
||||
return self.result
|
||||
self.on_first_call(wrapped_function)
|
||||
self.type_args[key] = type_args
|
||||
type_args = self.type_args[key]
|
||||
base = instructions.program.malloc(len(type_args), 'ci')
|
||||
bases = dict((t, get_program().malloc(len(type_args[t]), t)) \
|
||||
for t in type_args)
|
||||
for i,reg_type in enumerate(sorted(type_args,
|
||||
key=lambda x: x.reg_type)):
|
||||
store_in_mem(bases[reg_type], base + i)
|
||||
j = 0
|
||||
for i_arg in type_args[reg_type]:
|
||||
if get_reg_type(args[i_arg]) != reg_type:
|
||||
raise CompilerError('type mismatch: "%s" not of type "%s"' %
|
||||
(args[i_arg], reg_type))
|
||||
store_in_mem(args[i_arg], bases[reg_type] + j)
|
||||
j += util.mem_size(reg_type)
|
||||
return self.on_call(base, bases)
|
||||
self.last_key = key
|
||||
addresses = regint.Array(len(runtime_args))
|
||||
for i, arg in enumerate(reg_args):
|
||||
addresses[i] = arg.address
|
||||
return self.on_call(addresses._address,
|
||||
[(arg.value_type, arg.address) for arg in reg_args])
|
||||
|
||||
class FunctionTape(Function):
|
||||
# not thread-safe
|
||||
@@ -401,6 +394,113 @@ class FunctionTape(Function):
|
||||
single_thread=self.single_thread)
|
||||
def on_call(self, base, bases):
|
||||
return FunctionTapeCall(self.thread, base, bases)
|
||||
@staticmethod
|
||||
def base_key():
|
||||
pass
|
||||
|
||||
class FunctionCallTape(FunctionTape):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(FunctionTape, self).__init__(*args, **kwargs)
|
||||
self.instances = {}
|
||||
def __call__(self, *args, **kwargs):
|
||||
key = ()
|
||||
def process_for_key(arg):
|
||||
nonlocal key
|
||||
if isinstance(arg, types._vectorizable):
|
||||
key += (arg.value_type, tuple(arg.shape))
|
||||
elif isinstance(arg, Tape.Register):
|
||||
key += (type(arg), arg.size)
|
||||
elif isinstance(arg, list):
|
||||
key += (tuple(arg), 'l')
|
||||
else:
|
||||
key += (arg,)
|
||||
for arg in args:
|
||||
process_for_key(arg)
|
||||
for name, arg in sorted(kwargs.items()):
|
||||
key += (name, 'kw')
|
||||
process_for_key(arg)
|
||||
if key not in self.instances:
|
||||
my_args = []
|
||||
def wrapped_function():
|
||||
actual_call_args = []
|
||||
def process_for_call(arg):
|
||||
if isinstance(arg, Tape.Register):
|
||||
my_arg = arg.same_type()
|
||||
call_arg(my_arg, base.vm_types[my_arg.reg_type])
|
||||
my_args.append(my_arg)
|
||||
return my_arg
|
||||
elif isinstance(arg, types._vectorizable):
|
||||
my_arg = arg.same_shape(address=regint())
|
||||
call_arg(my_arg.address, base.vm_types['ci'])
|
||||
my_args.append(my_arg)
|
||||
my_arg = arg.same_shape(
|
||||
address=MemValue(my_arg.address))
|
||||
return my_arg
|
||||
actual_call_args.append(my_arg)
|
||||
else:
|
||||
my_args.append(arg)
|
||||
return arg
|
||||
for arg in args:
|
||||
actual_call_args.append(process_for_call(arg))
|
||||
actual_call_kwargs = {}
|
||||
for name, arg in sorted(kwargs.items()):
|
||||
actual_call_kwargs[name] = process_for_call(arg)
|
||||
self.result = self.function(*actual_call_args,
|
||||
**actual_call_kwargs)
|
||||
if self.result is not None:
|
||||
self.result = list(tuplify(self.result))
|
||||
for i, res in enumerate(self.result):
|
||||
if util.is_constant(res):
|
||||
self.result[i] = regint(res)
|
||||
self.on_first_call(wrapped_function, key, my_args)
|
||||
for name, arg in sorted(kwargs.items()):
|
||||
args += arg,
|
||||
return self.on_call(*self.instances[key], args)
|
||||
def on_first_call(self, wrapped_function, key, inside_args):
|
||||
program = get_program()
|
||||
program.curr_tape
|
||||
tape_handle = len(program.tapes)
|
||||
# entry for recursion
|
||||
self.instances[key] = tape_handle, None, inside_args
|
||||
assert tape_handle == program.new_tape(
|
||||
wrapped_function, name=self.name, args=self.compile_args,
|
||||
single_thread=get_tape().singular, finalize=False,
|
||||
thread_pool=get_tape().free_threads)
|
||||
tape = program.tapes[tape_handle]
|
||||
if self.result is not None:
|
||||
self.result = list(tuplify(self.result))
|
||||
for reg in self.result:
|
||||
reg.can_eliminate = False
|
||||
tape.return_values.append(reg)
|
||||
assert not tape.purged
|
||||
get_program().finalize_tape(tape)
|
||||
self.instances[key] = tape_handle, self.result, inside_args
|
||||
def on_call(self, tape_handle, result, inside_args, args):
|
||||
tape = get_program().tapes[tape_handle]
|
||||
if tape.ran_threads and tape.free_threads != get_tape().free_threads:
|
||||
raise CompilerError(
|
||||
'cannot call thread-running tape from another thread')
|
||||
assert len(inside_args) == len(args)
|
||||
out_result = []
|
||||
call_args = []
|
||||
if result is not None:
|
||||
out_result = [reg.same_type() for reg in result]
|
||||
for x, y in zip(out_result, result):
|
||||
call_args += [
|
||||
1, instructions_base.vm_types[x.reg_type],
|
||||
x.size_for_mem(), x, y]
|
||||
for x, y in zip(inside_args, args):
|
||||
if isinstance(x, Tape.Register):
|
||||
call_args += [
|
||||
0, instructions_base.vm_types[x.reg_type],
|
||||
x.size_for_mem(), x, y]
|
||||
elif isinstance(x, types._vectorizable):
|
||||
call_args += [0, base.vm_types['ci'], 1,
|
||||
x.address, regint.conv(y.address)]
|
||||
call_tape(tape_handle, regint(0),
|
||||
*call_args)
|
||||
break_point('call-%s' % self.name)
|
||||
return untuplify(tuple(out_result))
|
||||
|
||||
def function_tape(function):
|
||||
return FunctionTape(function)
|
||||
@@ -413,18 +513,108 @@ def function_tape_with_compile_args(*args):
|
||||
def single_thread_function_tape(function):
|
||||
return FunctionTape(function, single_thread=True)
|
||||
|
||||
def memorize(x):
|
||||
if isinstance(x, (tuple, list)):
|
||||
return tuple(memorize(i) for i in x)
|
||||
def function_call_tape(function):
|
||||
if get_program().use_tape_calls:
|
||||
return FunctionCallTape(function)
|
||||
else:
|
||||
return MemValue(x)
|
||||
return function
|
||||
|
||||
def method_call_tape(function):
|
||||
tapes = {}
|
||||
def wrapper(self, *args, **kwargs):
|
||||
def use(name):
|
||||
x = self.__dict__[name]
|
||||
return not isinstance(x, types.MultiArray) or \
|
||||
x.array._address is not None
|
||||
key = (type(self),) + tuple(filter(use, sorted(self.__dict__)))
|
||||
member_key = key[1:]
|
||||
if key not in tapes:
|
||||
def f(*args, **kwargs):
|
||||
class Dummy(type(self)):
|
||||
__init__ = lambda self: None
|
||||
dummy = Dummy()
|
||||
members = args[:len(member_key)]
|
||||
real_args = args[len(member_key):]
|
||||
addresses = {}
|
||||
for name, member in zip(member_key, members):
|
||||
dummy.__dict__[name] = member
|
||||
if isinstance(member, types._vectorizable):
|
||||
addresses[name] = member.address
|
||||
res = function(dummy, *real_args, **kwargs)
|
||||
for name, member in zip(member_key, members):
|
||||
new_member = dummy.__dict__[name]
|
||||
desc = '%s in %s.%s' % (name, type(self).__name__,
|
||||
function.__name__)
|
||||
if id(new_member) != id(member):
|
||||
raise CompilerError('cannot change members '
|
||||
'in method tape (%s)' % desc)
|
||||
if isinstance(member, types._vectorizable) and \
|
||||
id(new_member.address) != id(addresses[name]):
|
||||
raise CompilerError('cannot change memory address '
|
||||
'in method tape (%s)' % desc)
|
||||
if set(member_key) != set(dummy.__dict__):
|
||||
raise CompilerError('cannot add members '
|
||||
'in method tape (%s)' % desc)
|
||||
return res
|
||||
f.__name__ = '%s-%s' % (type(self).__name__, function.__name__)
|
||||
tapes[key] = function_call_tape(f)
|
||||
members = tuple(self.__dict__[x] for x in member_key)
|
||||
res = tapes[key](*(members + args), **kwargs)
|
||||
return res
|
||||
return wrapper
|
||||
|
||||
def function(function):
|
||||
""" Create a run-time function. The arguments can be memory or basic
|
||||
types, and return values can be basic types::
|
||||
|
||||
@function
|
||||
def f(x, y, z):
|
||||
y.write(1)
|
||||
z[0] = 2
|
||||
return x + 3
|
||||
|
||||
a = MemValue(sint(0))
|
||||
b = sint.Array(10)
|
||||
c = f(sint(4), a, b)
|
||||
|
||||
print_ln('%s %s %s', a.reveal(), b[0].reveal(), c.reveal())
|
||||
|
||||
This should output::
|
||||
|
||||
1 2 7
|
||||
|
||||
You can use run-time functions recursively but without return
|
||||
values in this case.
|
||||
|
||||
"""
|
||||
return FunctionCallTape(function)
|
||||
|
||||
def memorize(x, write=True):
|
||||
if isinstance(x, (tuple, list)):
|
||||
return tuple(memorize(i, write=write) for i in x)
|
||||
elif x is None:
|
||||
return
|
||||
else:
|
||||
return MemValue(x, write=write)
|
||||
|
||||
def unmemorize(x):
|
||||
if isinstance(x, (tuple, list)):
|
||||
return tuple(unmemorize(i) for i in x)
|
||||
elif x is None:
|
||||
return
|
||||
else:
|
||||
return x.read()
|
||||
|
||||
def write_mem(dest, source):
|
||||
if isinstance(dest, (tuple, list)):
|
||||
assert len(dest) == len(source)
|
||||
for x, y in zip(dest, source):
|
||||
write_mem(x, y)
|
||||
elif dest is None:
|
||||
return
|
||||
else:
|
||||
dest.write(source)
|
||||
|
||||
class FunctionBlock(Function):
|
||||
def on_first_call(self, wrapped_function):
|
||||
p_return_address = get_tape().program.malloc(1, 'ci')
|
||||
@@ -470,6 +660,10 @@ class FunctionBlock(Function):
|
||||
if self.result is not None:
|
||||
return unmemorize(self.result)
|
||||
|
||||
@staticmethod
|
||||
def base_key():
|
||||
return get_tape()
|
||||
|
||||
def function_block(function):
|
||||
return FunctionBlock(function)
|
||||
|
||||
@@ -907,11 +1101,15 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
|
||||
r = reducer(tuplify(loop_body(j)), mem_state)
|
||||
write_state_to_memory(r)
|
||||
state = mem_state
|
||||
for i,x in enumerate(state):
|
||||
if use_array:
|
||||
mem_state[i] = x
|
||||
else:
|
||||
mem_state[i].write(x)
|
||||
if use_array and len(state) and \
|
||||
isinstance(types._register, types._vectorizable):
|
||||
mem_state[:] = state.get_vector()
|
||||
else:
|
||||
for i,x in enumerate(state):
|
||||
if use_array:
|
||||
mem_state[i] = x
|
||||
else:
|
||||
mem_state[i].write(x)
|
||||
def returner():
|
||||
return untuplify(tuple(state))
|
||||
return returner
|
||||
@@ -987,7 +1185,7 @@ def multithread(n_threads, n_items=None, max_size=None):
|
||||
|
||||
.. code::
|
||||
|
||||
@multithread(8, 25)
|
||||
@multithread(3, 25)
|
||||
def f(base, size):
|
||||
...
|
||||
"""
|
||||
@@ -1077,7 +1275,7 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
|
||||
return loop_body(base + i)
|
||||
prog = get_program()
|
||||
thread_args = []
|
||||
if prog.curr_tape == prog.tapes[0]:
|
||||
if prog.curr_tape.singular:
|
||||
prog.n_running_threads = n_threads
|
||||
if not util.is_zero(thread_rounds):
|
||||
prog.prevent_breaks = False
|
||||
@@ -1288,9 +1486,21 @@ def while_do(condition, *args):
|
||||
return loop_body
|
||||
return decorator
|
||||
|
||||
def _run_and_link(function, g=None):
|
||||
def _run_and_link(function, g=None, lock_lists=True):
|
||||
if g is None:
|
||||
g = function.__globals__
|
||||
if lock_lists:
|
||||
class A(list):
|
||||
def __init_(self, l):
|
||||
self[:] = l
|
||||
def __setitem__(*args):
|
||||
raise Exception('you cannot change lists in branches, '
|
||||
'use Array or MultiArray instead')
|
||||
__delitem__ = append = clear = extend = insert = __setitem__
|
||||
pop = remove = reverse = sort = __setitem__
|
||||
for x in g:
|
||||
if isinstance(g[x], list):
|
||||
g[x] = A(g[x])
|
||||
pre = copy.copy(g)
|
||||
res = function()
|
||||
_link(pre, g)
|
||||
@@ -1570,14 +1780,25 @@ def listen_for_clients(port):
|
||||
"""
|
||||
instructions.listen(regint.conv(port))
|
||||
|
||||
def accept_client_connection(port):
|
||||
def accept_client_connection(port, players=None):
|
||||
""" Accept client connection on specific port base.
|
||||
|
||||
:param port: port base (int/regint/cint)
|
||||
:param players: subset of players (default: all)
|
||||
:returns: client id
|
||||
|
||||
"""
|
||||
res = regint()
|
||||
instructions.acceptclientconnection(res, regint.conv(port))
|
||||
if players is None:
|
||||
instructions.acceptclientconnection(res, regint.conv(port))
|
||||
else:
|
||||
@if_e(sum(regint(players) ==
|
||||
get_player_id()._v.expand_to_vector(len(players))))
|
||||
def _():
|
||||
res.update(accept_client_connection(port))
|
||||
@else_
|
||||
def _():
|
||||
res.update(-1)
|
||||
return res
|
||||
|
||||
def init_client_connection(host, port, my_id, relative_port=True):
|
||||
@@ -1697,14 +1918,14 @@ def cint_cint_division(a, b, k, f):
|
||||
from Compiler.program import Program
|
||||
|
||||
@instructions_base.ret_cisc
|
||||
def sint_cint_division(a, b, k, f, kappa, nearest=False):
|
||||
def sint_cint_division(a, b, k, f, nearest=False):
|
||||
"""
|
||||
type(a) = sint, type(b) = cint
|
||||
"""
|
||||
theta = int(ceil(log(k/3.5) / log(2)))
|
||||
two = cint(2) * two_power(f)
|
||||
sign_b = cint(1) - 2 * cint(b.less_than(0, k))
|
||||
sign_a = sint(1) - 2 * comparison.LessThanZero(a, k, kappa)
|
||||
sign_a = sint(1) - 2 * comparison.LessThanZero(a, k)
|
||||
absolute_b = b * sign_b
|
||||
absolute_a = a * sign_a
|
||||
w0 = approximate_reciprocal(absolute_b, k, f, theta)
|
||||
@@ -1714,20 +1935,20 @@ def sint_cint_division(a, b, k, f, kappa, nearest=False):
|
||||
W = w0
|
||||
|
||||
for i in range(1, theta):
|
||||
A = (A * W).round(2 * k, f, kappa=kappa, nearest=nearest, signed=True)
|
||||
A = (A * W).round(2 * k, f, nearest=nearest, signed=True)
|
||||
temp = (B * W + 2 * (f - 1)) >> f
|
||||
W = two - temp
|
||||
B = temp
|
||||
return (sign_a * sign_b) * A
|
||||
|
||||
def IntDiv(a, b, k, kappa=None):
|
||||
def IntDiv(a, b, k):
|
||||
l = 2 * k + 1
|
||||
b = a.conv(b)
|
||||
return FPDiv(a.extend(l) << k, b.extend(l) << k, l, k,
|
||||
kappa, nearest=True)
|
||||
nearest=True)
|
||||
|
||||
@instructions_base.ret_cisc
|
||||
def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False):
|
||||
def FPDiv(a, b, k, f, simplex_flag=False, nearest=False):
|
||||
"""
|
||||
Goldschmidt method as presented in Catrina10,
|
||||
"""
|
||||
@@ -1750,40 +1971,40 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False):
|
||||
|
||||
base.set_global_vector_size(b.size)
|
||||
alpha = b.get_type(2 * k).two_power(2*f, size=b.size)
|
||||
w = AppRcr(b, k, f, kappa, simplex_flag, nearest).extend(2 * k)
|
||||
w = AppRcr(b, k, f, simplex_flag, nearest).extend(2 * k)
|
||||
x = alpha - b.extend(2 * k) * w
|
||||
base.reset_global_vector_size()
|
||||
|
||||
y = a.extend(l_y) * w
|
||||
y = y.round(l_y, f, kappa, nearest, signed=True)
|
||||
y = y.round(l_y, f, nearest, signed=True)
|
||||
|
||||
for i in range(theta - 1):
|
||||
x = x.extend(2 * k)
|
||||
y = y.extend(l_y) * (alpha + x).extend(l_y)
|
||||
x = x * x
|
||||
y = y.round(l_y, 2*f, kappa, nearest, signed=True)
|
||||
x = x.round(2*k, 2*f, kappa, nearest, signed=True)
|
||||
y = y.round(l_y, 2*f, nearest, signed=True)
|
||||
x = x.round(2*k, 2*f, nearest, signed=True)
|
||||
|
||||
x = x.extend(2 * k)
|
||||
y = y.extend(l_y) * (alpha + x).extend(l_y)
|
||||
y = y.round(l_y, 3 * f - res_f, kappa, nearest, signed=True)
|
||||
y = y.round(l_y, 3 * f - res_f, nearest, signed=True)
|
||||
return y
|
||||
|
||||
def AppRcr(b, k, f, kappa=None, simplex_flag=False, nearest=False):
|
||||
def AppRcr(b, k, f, simplex_flag=False, nearest=False):
|
||||
"""
|
||||
Approximate reciprocal of [b]:
|
||||
Given [b], compute [1/b]
|
||||
"""
|
||||
alpha = b.get_type(2 * k)(int(2.9142 * 2**k))
|
||||
c, v = b.Norm(k, f, kappa, simplex_flag)
|
||||
c, v = b.Norm(k, f, simplex_flag)
|
||||
#v should be 2**{k - m} where m is the length of the bitwise repr of [b]
|
||||
d = alpha - 2 * c
|
||||
w = d * v
|
||||
w = w.round(2 * k + 1, 2 * (k - f), kappa, nearest, signed=True)
|
||||
w = w.round(2 * k + 1, 2 * (k - f), nearest, signed=True)
|
||||
# now w * 2 ^ {-f} should be an initial approximation of 1/b
|
||||
return w
|
||||
|
||||
def Norm(b, k, f, kappa, simplex_flag=False):
|
||||
def Norm(b, k, f, simplex_flag=False):
|
||||
"""
|
||||
Computes secret integer values [c] and [v_prime] st.
|
||||
2^{k-1} <= c < 2^k and c = b*v_prime
|
||||
@@ -1799,8 +2020,8 @@ def Norm(b, k, f, kappa, simplex_flag=False):
|
||||
absolute_val = sign * b
|
||||
|
||||
#next 2 lines actually compute the SufOR for little indian encoding
|
||||
bits = absolute_val.bit_decompose(k, kappa, maybe_mixed=True)[::-1]
|
||||
suffixes = PreOR(bits, kappa)[::-1]
|
||||
bits = absolute_val.bit_decompose(k, maybe_mixed=True)[::-1]
|
||||
suffixes = PreOR(bits)[::-1]
|
||||
|
||||
z = [0] * k
|
||||
for i in range(k - 1):
|
||||
|
||||
228
Compiler/ml.py
228
Compiler/ml.py
@@ -203,6 +203,20 @@ def _no_mem_warnings(function):
|
||||
copy_doc(wrapper, function)
|
||||
return wrapper
|
||||
|
||||
def _layer_method_call_tape(function):
|
||||
function = method_call_tape(function)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
self._Y.alloc()
|
||||
if self.inputs and len(self.inputs) == 1:
|
||||
backup = self.inputs
|
||||
del self.inputs
|
||||
res = function(self, *args, **kwargs)
|
||||
self.inputs = backup
|
||||
return res
|
||||
else:
|
||||
return function(self, *args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
class Tensor(MultiArray):
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs['alloc'] = False
|
||||
@@ -259,6 +273,7 @@ class Layer:
|
||||
def Y(self, value):
|
||||
self._Y = value
|
||||
|
||||
@_layer_method_call_tape
|
||||
def forward(self, batch=None, training=None):
|
||||
if batch is None:
|
||||
batch = Array.create_from(regint(0))
|
||||
@@ -1045,7 +1060,8 @@ class ElementWiseLayer(NoVariableLayer):
|
||||
|
||||
def _forward(self, batch=[0]):
|
||||
n_per_item = reduce(operator.mul, self.X.sizes[1:])
|
||||
@multithread(self.n_threads, len(batch) * n_per_item)
|
||||
@multithread(self.n_threads, len(batch) * n_per_item,
|
||||
max_size=program.budget)
|
||||
def _(base, size):
|
||||
self.Y.assign_vector(self.f_part(base, size), base)
|
||||
|
||||
@@ -1195,6 +1211,7 @@ class MaxPool(PoolBase):
|
||||
list/tuple of integers
|
||||
|
||||
"""
|
||||
@_layer_method_call_tape
|
||||
def forward(self, batch=None, training=False):
|
||||
if batch is None:
|
||||
batch = Array.create_from(regint(0))
|
||||
@@ -1252,26 +1269,38 @@ class Concat(NoVariableLayer):
|
||||
self.dimension = dimension
|
||||
shapes = [inp.shape for inp in inputs]
|
||||
assert dimension == 3
|
||||
assert len(shapes) == 2
|
||||
assert len(shapes[0]) == len(shapes[1])
|
||||
assert len(shapes[0]) == 4
|
||||
for shape in shapes:
|
||||
assert len(shape) == len(shapes[0])
|
||||
shape = []
|
||||
for i in range(len(shapes[0])):
|
||||
if i == dimension:
|
||||
shape.append(shapes[0][i] + shapes[1][i])
|
||||
shape.append(sum(x[i] for x in shapes))
|
||||
else:
|
||||
assert shapes[0][i] == shapes[1][i]
|
||||
shape.append(shapes[0][i])
|
||||
self.Y = Tensor(shape, sfix)
|
||||
self.bases = [sum(x[dimension] for x in shapes[:k])
|
||||
for k in range(len(shapes))]
|
||||
self.addresses = Array.create_from(regint(list(
|
||||
x.Y.address for x in inputs)))
|
||||
|
||||
def _forward(self, batch=[0]):
|
||||
assert len(batch) == 1
|
||||
@for_range_multithread(self.n_threads, 1, self.Y.sizes[1:3])
|
||||
def _(i, j):
|
||||
X = [x.Y[batch[0]] for x in self.inputs]
|
||||
self.Y[batch[0]][i][j].assign_vector(X[0][i][j].get_vector())
|
||||
self.Y[batch[0]][i][j].assign_part_vector(
|
||||
X[1][i][j].get_vector(),
|
||||
len(X[0][i][j]))
|
||||
if len(set(self.bases)) == 1:
|
||||
@for_range(len(self.inputs))
|
||||
def _(k):
|
||||
self.Y[batch[0]][i][j].assign_part_vector(
|
||||
MultiArray(
|
||||
self.inputs[0].shape,
|
||||
address=self.addresses[k])[i][j].get_vector(),
|
||||
k * self.bases[1])
|
||||
else:
|
||||
X = [x.Y[batch[0]] for x in self.inputs]
|
||||
for k in range(len(self.inputs)):
|
||||
self.Y[batch[0]][i][j].assign_part_vector(
|
||||
X[k][i][j].get_vector(), self.bases[k])
|
||||
|
||||
class Add(NoVariableLayer):
|
||||
""" Fixed-point addition layer.
|
||||
@@ -1334,12 +1363,14 @@ class BatchNorm(Layer):
|
||||
|
||||
def __init__(self, shape, approx=True, args=None):
|
||||
assert len(shape) in (2, 3, 4)
|
||||
self.Y = sfix.Tensor(shape)
|
||||
if len(shape) == 4:
|
||||
shape = [shape[0], shape[1] * shape[2], shape[3]]
|
||||
elif len(shape) == 2:
|
||||
shape = [shape[0], 1, shape[1]]
|
||||
tensors = (Tensor(shape, sfix) for i in range(4))
|
||||
self.X, self.Y, self.nabla_X, self.nabla_Y = tensors
|
||||
self.my_Y = sfix.Tensor(shape, address=self.Y.address)
|
||||
tensors = (Tensor(shape, sfix) for i in range(3))
|
||||
self.X, self.nabla_X, self.nabla_Y = tensors
|
||||
arrays = (sfix.Array(shape[2]) for i in range(4))
|
||||
self.var, self.mu, self.weights, self.bias = arrays
|
||||
arrays = (sfix.Array(shape[2]) for i in range(4))
|
||||
@@ -1374,11 +1405,11 @@ class BatchNorm(Layer):
|
||||
[len(batch), self.X.sizes[1]])
|
||||
def _(i, j):
|
||||
tmp = self.weights[:] * (self.X[i][j][:] - mu[:]) * factor[:]
|
||||
self.Y[i][j][:] = self.bias[:] + tmp
|
||||
self.my_Y[i][j][:] = self.bias[:] + tmp
|
||||
|
||||
@_layer_method_call_tape
|
||||
def forward(self, batch, training=False):
|
||||
if training or not self.is_trained:
|
||||
self.is_trained = True
|
||||
d = self.X.sizes[1]
|
||||
d_in = self.X.sizes[2]
|
||||
s = sfix.Array(d_in)
|
||||
@@ -1411,7 +1442,7 @@ class BatchNorm(Layer):
|
||||
print_ln('%s at %s: %s', str(x), k, x[k].reveal())
|
||||
print_ln('%s at (%s, %s, %s): in=%s out=%s',
|
||||
str(self.Y), i, j, k, self.X[i][j][k].reveal(),
|
||||
self.Y[i][j][k].reveal())
|
||||
self.my_Y[i][j][k].reveal())
|
||||
else:
|
||||
self._output(batch, self.mu_hat, self.var_hat)
|
||||
|
||||
@@ -1471,6 +1502,10 @@ class BatchNorm(Layer):
|
||||
self.nabla_Y[i][j][k].reveal(),
|
||||
self.nabla_X[i][j][k].reveal())
|
||||
|
||||
def reveal_parameters_to_binary(self):
|
||||
for param in self.thetas() + (self.mu_hat, self.var_hat):
|
||||
param.reveal().binary_output()
|
||||
|
||||
class QuantBase(object):
|
||||
bias_before_reduction = True
|
||||
|
||||
@@ -1498,11 +1533,12 @@ class QuantBase(object):
|
||||
class FixBase:
|
||||
bias_before_reduction = False
|
||||
|
||||
@staticmethod
|
||||
def new_squant():
|
||||
class _(sfix):
|
||||
params = None
|
||||
return _
|
||||
class my_squant(sfix):
|
||||
params = None
|
||||
|
||||
@classmethod
|
||||
def new_squant(cls):
|
||||
return cls.my_squant
|
||||
|
||||
def input_params_from(self, player):
|
||||
pass
|
||||
@@ -1553,7 +1589,8 @@ class ConvBase(BaseLayer):
|
||||
cls.temp_inputs = sfix.Array(size)
|
||||
|
||||
def __init__(self, input_shape, weight_shape, bias_shape, output_shape, stride,
|
||||
padding='SAME', tf_weight_format=False, inputs=None):
|
||||
padding='SAME', tf_weight_format=False, inputs=None,
|
||||
weight_type=None):
|
||||
super(ConvBase, self).__init__(input_shape, output_shape, inputs=inputs)
|
||||
|
||||
self.weight_shape = weight_shape
|
||||
@@ -1582,7 +1619,11 @@ class ConvBase(BaseLayer):
|
||||
else:
|
||||
self.padding = padding
|
||||
|
||||
self.weight_squant = self.new_squant()
|
||||
if weight_type:
|
||||
self.weight_squant = weight_type
|
||||
else:
|
||||
self.weight_squant = self.new_squant()
|
||||
|
||||
self.bias_squant = self.new_squant()
|
||||
|
||||
self.weights = Tensor(weight_shape, self.weight_squant)
|
||||
@@ -1645,8 +1686,7 @@ class ConvBase(BaseLayer):
|
||||
n_summands = self.n_summands()
|
||||
#start_timer(2)
|
||||
n_outputs = batch_length * reduce(operator.mul, self.output_shape[1:])
|
||||
@multithread(self.n_threads, n_outputs,
|
||||
1000 if sfix.round_nearest else 10 ** 6)
|
||||
@multithread(self.n_threads, n_outputs, max_size=program.budget)
|
||||
def _(base, n_per_thread):
|
||||
res = self.input_squant().unreduced(
|
||||
sint.load_mem(unreduced.address + base,
|
||||
@@ -1709,7 +1749,7 @@ class Conv2d(ConvBase):
|
||||
res += self.bias.expand_to_vector(j, res.size).v
|
||||
else:
|
||||
res += self.bias.expand_to_vector(j, res.size).v << \
|
||||
self.input_squant.f
|
||||
self.weight_squant.f
|
||||
addresses = regint.inc(res.size,
|
||||
self.unreduced[i * part_size].address + j,
|
||||
n_channels_out)
|
||||
@@ -2029,7 +2069,7 @@ class AveragePool2d(BaseLayer):
|
||||
self.Y[0][out_y][out_x][c] = self.output_squant._new(acc)
|
||||
|
||||
def easyConv2d(input_shape, batch_size, out_channels, kernel_size, stride=1,
|
||||
padding=0):
|
||||
padding=0, **kwargs):
|
||||
""" More convenient interface to :py:class:`FixConv2d`.
|
||||
|
||||
:param input_shape: input shape (tuple/list of four int)
|
||||
@@ -2050,7 +2090,7 @@ def easyConv2d(input_shape, batch_size, out_channels, kernel_size, stride=1,
|
||||
padding = padding.upper() if isinstance(padding, str) \
|
||||
else padding
|
||||
return FixConv2d(input_shape, weight_shape, (out_channels,), output_shape,
|
||||
stride, padding)
|
||||
stride, padding, **kwargs)
|
||||
|
||||
def easyMaxPool(input_shape, kernel_size, stride=None, padding=0):
|
||||
""" More convenient interface to :py:class:`MaxPool`.
|
||||
@@ -2102,6 +2142,7 @@ class FixAveragePool2d(PoolBase, FixBase):
|
||||
PoolBase.__init__(self, input_shape, [1] + list(strides) + [1],
|
||||
[1] + list(filter_size) + [1], padding)
|
||||
self.pool_size = reduce(operator.mul, filter_size)
|
||||
self.d_out = self.Y.shape[-1]
|
||||
if output_shape:
|
||||
assert self.Y.shape == list(output_shape)
|
||||
|
||||
@@ -2322,6 +2363,10 @@ class Optimizer:
|
||||
self.forward(batch=batch, run_last=False)
|
||||
part = self.layers[-1].eval(batch_size, top=top)
|
||||
res.assign_part_vector(part.get_vector(), start)
|
||||
if self.output_stats:
|
||||
for layer in self.layers[:-1]:
|
||||
print_ln(layer)
|
||||
self.stat(' Y', layer.Y)
|
||||
self.run_in_batches(f, data, batch_size or len(self.layers[1].X))
|
||||
return res
|
||||
|
||||
@@ -2567,9 +2612,9 @@ class Optimizer:
|
||||
|
||||
def run_in_batches(self, f, data, batch_size, truth=None):
|
||||
batch_size = min(batch_size, data.sizes[0])
|
||||
training_data = self.layers[0].X.address
|
||||
training_data = self.layers[0]._X.array._address
|
||||
training_truth = self.layers[-1].Y.address
|
||||
self.layers[0].X.address = data.address
|
||||
self.layers[0]._X.address = data.address
|
||||
if truth:
|
||||
self.layers[-1].Y.address = truth.address
|
||||
N = data.sizes[0]
|
||||
@@ -2748,13 +2793,15 @@ class Optimizer:
|
||||
return list(self.thetas)
|
||||
|
||||
def reveal_model_to_binary(self):
|
||||
input_shape = self.layers[0].X.shape
|
||||
""" Reveal model and store it in the binary output file, see
|
||||
:ref:`reveal-model` for details. """
|
||||
input_shape = self.layers[0]._X.shape
|
||||
for layer in self.layers:
|
||||
if len(input_shape) == 4 and isinstance(layer, DenseBase):
|
||||
layer.reveal_parameters_to_binary(reshape=input_shape[1:])
|
||||
else:
|
||||
layer.reveal_parameters_to_binary()
|
||||
input_shape = layer.Y.shape
|
||||
input_shape = layer._Y.shape
|
||||
|
||||
class Adam(Optimizer):
|
||||
""" Adam/AMSgrad optimizer.
|
||||
@@ -3046,7 +3093,7 @@ class keras:
|
||||
def summary(self):
|
||||
self.opt.summary()
|
||||
|
||||
def build(self, input_shape, batch_size=128):
|
||||
def build(self, input_shape, batch_size=128, program=None):
|
||||
data_input_shape = input_shape
|
||||
if self.opt != None and \
|
||||
input_shape == self.opt.layers[0]._X.sizes and \
|
||||
@@ -3120,8 +3167,11 @@ class keras:
|
||||
if layers[-1].d_out == 1:
|
||||
layers.append(Output(data_input_shape[0]))
|
||||
else:
|
||||
layers.append(
|
||||
MultiOutput(data_input_shape[0], layers[-1].d_out))
|
||||
shape = data_input_shape[0], layers[-1].d_out
|
||||
if program:
|
||||
layers.append(MultiOutput.from_args(program, *shape))
|
||||
else:
|
||||
layers.append(MultiOutput(*shape))
|
||||
if self.optimizer[1]:
|
||||
raise Exception('use keyword arguments for optimizer')
|
||||
opt = self.optimizer[0]
|
||||
@@ -3186,11 +3236,11 @@ class keras:
|
||||
batch_size = min(batch_size, self.batch_size)
|
||||
return self.opt.eval(x, batch_size=batch_size)
|
||||
|
||||
def layers_from_torch(sequence, data_input_shape, batch_size, input_via=None,
|
||||
regression=False):
|
||||
""" Convert a PyTorch Sequential object to MP-SPDZ layers.
|
||||
def layers_from_torch(model, data_input_shape, batch_size, input_via=None,
|
||||
regression=False, layer_args={}, program=None):
|
||||
""" Convert a PyTorch Module object to MP-SPDZ layers.
|
||||
|
||||
:param sequence: PyTorch Sequential object
|
||||
:param model: PyTorch Module object
|
||||
:param data_input_shape: input shape (list of four int)
|
||||
:param batch_size: batch size (int)
|
||||
:param input_via: player to input model data via (default: don't)
|
||||
@@ -3198,17 +3248,29 @@ def layers_from_torch(sequence, data_input_shape, batch_size, input_via=None,
|
||||
|
||||
"""
|
||||
layers = []
|
||||
named_layers = {}
|
||||
|
||||
def mul(x):
|
||||
return reduce(operator.mul, x)
|
||||
|
||||
def process(item):
|
||||
nonlocal input_shape
|
||||
import torch
|
||||
|
||||
def process(item, inputs, input_shape, args):
|
||||
if item == torch.cat:
|
||||
if len(inputs) > 1:
|
||||
layers.append(
|
||||
Concat(inputs, dimension=len(inputs[0].shape) - 1))
|
||||
return
|
||||
elif item == operator.add:
|
||||
layers.append(Add(inputs))
|
||||
return
|
||||
elif item == torch.flatten:
|
||||
return
|
||||
# single-input layers from here
|
||||
if inputs and len(inputs) > 1:
|
||||
raise CompilerError('multi-input layer %s not supported' % item)
|
||||
name = type(item).__name__
|
||||
if name == 'Sequential':
|
||||
for x in item:
|
||||
process(x)
|
||||
elif name == 'Linear':
|
||||
if name == 'Linear':
|
||||
assert mul(input_shape[1:]) == item.in_features
|
||||
assert item.bias is not None
|
||||
layers.append(Dense(input_shape[0], item.in_features,
|
||||
@@ -3239,7 +3301,7 @@ def layers_from_torch(sequence, data_input_shape, batch_size, input_via=None,
|
||||
elif name == 'Conv2d':
|
||||
layers.append(easyConv2d(input_shape, batch_size, item.out_channels,
|
||||
item.kernel_size, item.stride,
|
||||
item.padding))
|
||||
item.padding, **layer_args.get(item, {})))
|
||||
input_shape = layers[-1].Y.shape
|
||||
if input_via is not None:
|
||||
shapes = [x.shape for x in
|
||||
@@ -3247,11 +3309,14 @@ def layers_from_torch(sequence, data_input_shape, batch_size, input_via=None,
|
||||
import numpy
|
||||
swapped = numpy.moveaxis(
|
||||
numpy.array(item.weight.detach()), 1, -1)
|
||||
layers[-1].weights = sfix.input_tensor_via(input_via, swapped)
|
||||
layers[-1].bias = sfix.input_tensor_via(
|
||||
input_via, item.bias.detach())
|
||||
layers[-1].weights = \
|
||||
layers[-1].weights.value_type.input_tensor_via(
|
||||
input_via, swapped)
|
||||
assert layers[-1].weights.shape == shapes[0]
|
||||
assert layers[-1].bias.shape == shapes[1]
|
||||
if isinstance(item.bias, torch.Tensor):
|
||||
layers[-1].bias = sfix.input_tensor_via(
|
||||
input_via, item.bias.detach())
|
||||
assert layers[-1].bias.shape == shapes[1]
|
||||
elif name == 'MaxPool2d':
|
||||
layers.append(easyMaxPool(input_shape, item.kernel_size,
|
||||
item.stride, item.padding))
|
||||
@@ -3260,10 +3325,24 @@ def layers_from_torch(sequence, data_input_shape, batch_size, input_via=None,
|
||||
layers.append(FixAveragePool2d(input_shape, None, item.kernel_size,
|
||||
item.stride, item.padding))
|
||||
input_shape = layers[-1].Y.shape
|
||||
elif name == 'ReLU':
|
||||
elif name == 'AdaptiveAvgPool2d' or \
|
||||
item == torch.nn.functional.adaptive_avg_pool2d:
|
||||
if name == 'AdaptiveAvgPool2d':
|
||||
output = item.output_size
|
||||
else:
|
||||
output = args[1]
|
||||
for i in (0, 1):
|
||||
assert input_shape[1 + i] % output[i] == 0
|
||||
stride = [input_shape[1 + i] // output[i] for i in (0, 1)]
|
||||
kernel_size = [input_shape[1 + i] - (output[i] - 1) * stride[i]
|
||||
for i in (0, 1)]
|
||||
layers.append(FixAveragePool2d(input_shape, None, kernel_size,
|
||||
stride, padding=0))
|
||||
input_shape = layers[-1].Y.shape
|
||||
elif name == 'ReLU' or item == torch.nn.functional.relu:
|
||||
layers.append(Relu(input_shape))
|
||||
elif name == 'Flatten':
|
||||
pass
|
||||
return
|
||||
elif name == 'BatchNorm2d':
|
||||
layers.append(BatchNorm(layers[-1].Y.sizes))
|
||||
if input_via is not None:
|
||||
@@ -3277,16 +3356,52 @@ def layers_from_torch(sequence, data_input_shape, batch_size, input_via=None,
|
||||
alpha=item.p))
|
||||
input_shape = layers[-1].Y.sizes
|
||||
else:
|
||||
raise CompilerError('unknown PyTorch module: ' + name)
|
||||
raise CompilerError('unknown PyTorch module: %s' % item)
|
||||
layers[-1].inputs = inputs
|
||||
|
||||
input_shape = data_input_shape + [1] * (4 - len(data_input_shape))
|
||||
process(sequence)
|
||||
|
||||
torch_layers = list(torch.fx.symbolic_trace(model).graph.nodes)
|
||||
for i, layer in enumerate(torch_layers[1:-1]):
|
||||
if layer.op == 'call_module':
|
||||
target = model
|
||||
for attr in layer.target.split('.'):
|
||||
target = getattr(target, attr)
|
||||
else:
|
||||
target = layer.target
|
||||
if not layers:
|
||||
assert layer.args == (torch_layers[i],)
|
||||
inputs = None
|
||||
else:
|
||||
if len(layer.args) < 2 or (layer.args[1] != 1 and
|
||||
layer.args[1] != (1, 1)):
|
||||
args = layer.args
|
||||
elif isinstance(layer.args[0], list):
|
||||
args = layer.args[0]
|
||||
else:
|
||||
args = layer.args[0],
|
||||
inputs = [named_layers[x] for x in args]
|
||||
if len(inputs) == 1:
|
||||
if isinstance(inputs[0], (Dropout, BatchNorm)):
|
||||
input_shape = inputs[0].inputs[0].Y.shape
|
||||
else:
|
||||
input_shape = inputs[0]._Y.shape
|
||||
else:
|
||||
input_shape = None
|
||||
process(target, inputs, input_shape, layer.args)
|
||||
if layers:
|
||||
named_layers[layer] = layers[-1]
|
||||
|
||||
if regression:
|
||||
layers.append(LinearOutput(data_input_shape[0], layers[-1].d_out))
|
||||
elif layers[-1].d_out == 1:
|
||||
layers.append(Output(data_input_shape[0]))
|
||||
else:
|
||||
layers.append(MultiOutput(data_input_shape[0], layers[-1].d_out))
|
||||
shape = data_input_shape[0], layers[-1].d_out
|
||||
if program:
|
||||
layers.append(MultiOutput.from_args(program, *shape))
|
||||
else:
|
||||
layers.append(MultiOutput(*shape))
|
||||
return layers
|
||||
|
||||
class OneLayerSGD:
|
||||
@@ -3335,6 +3450,9 @@ class OneLayerSGD:
|
||||
|
||||
class SGDLogistic(OneLayerSGD):
|
||||
""" Logistic regression using SGD.
|
||||
The member :py:obj:`opt` refers to the internal instance of
|
||||
:py:class:`Optimizer`, which allows to use the funcionality
|
||||
therein.
|
||||
|
||||
:param n_epochs: number of epochs
|
||||
:param batch_size: batch size
|
||||
@@ -3460,6 +3578,8 @@ def solve_linear_diag_precond(A, b, x, r, n_iterations, progress=False,
|
||||
|
||||
def mr(A, n_iterations, stop=False):
|
||||
""" Iterative matrix inverse approximation.
|
||||
This is based on the conjugate gradients algorithm in Section
|
||||
10.2.4 of `these lecture notes <https://graphics.stanford.edu/courses/cs205a-13-fall/assets/notes/cs205a_notes.pdf>`_.
|
||||
|
||||
:param A: matrix to invert
|
||||
:param n_iterations: maximum number of iterations
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
"""
|
||||
Module for math operations.
|
||||
|
||||
Implements trigonometric and logarithmic functions.
|
||||
Most of the functionality is due to `Aly and Smart
|
||||
<https://eprint.iacr.org/2019/354>`_ with some optimizations by
|
||||
`Keller and Sun <https://eprint.iacr.org/2022/933>`_.
|
||||
|
||||
This has to imported explicitly.
|
||||
"""
|
||||
@@ -98,7 +100,7 @@ pi_over_2 = math.radians(90)
|
||||
# @return truncated sint value of x
|
||||
def trunc(x):
|
||||
if isinstance(x, types._fix):
|
||||
return x.v.right_shift(x.f, x.k, security=x.kappa, signed=True)
|
||||
return x.v.right_shift(x.f, x.k, signed=True)
|
||||
elif type(x) is types.sfloat:
|
||||
v, p, z, s = floatingpoint.FLRound(x, 0)
|
||||
#return types.sfloat(v, p, z, s, x.err)
|
||||
@@ -106,19 +108,6 @@ def trunc(x):
|
||||
return x
|
||||
|
||||
|
||||
##
|
||||
# loads integer to fractional type (sint)
|
||||
# @param x: coefficient to be truncated.
|
||||
#
|
||||
# @return returns sfix, sfloat loaded value
|
||||
def load_sint(x, l_type):
|
||||
if l_type is types.sfix:
|
||||
return types.sfix.from_sint(x)
|
||||
elif l_type is types.sfloat:
|
||||
return x
|
||||
return x
|
||||
|
||||
|
||||
##
|
||||
# evaluates a Polynomial to a given x in a privacy preserving manner.
|
||||
# Inputs can be of any kind of register, secret or otherwise.
|
||||
@@ -448,7 +437,7 @@ def log2_fx(x, use_division=True):
|
||||
if isinstance(x, types._fix):
|
||||
# transforms sfix to f*2^n, where f is [o.5,1] bounded
|
||||
# obtain number bounded by [0,5 and 1] by transforming input to sfloat
|
||||
v, p, z, s = floatingpoint.Int2FL(x.v, x.k, x.f, x.kappa)
|
||||
v, p, z, s = floatingpoint.Int2FL(x.v, x.k, x.f)
|
||||
p -= x.f
|
||||
vlen = x.f
|
||||
v = x._new(v, k=x.k, f=x.f)
|
||||
@@ -473,7 +462,7 @@ def log2_fx(x, use_division=True):
|
||||
return a # *(1-(f.z))*(1-f.s)*(1-f.error)
|
||||
|
||||
|
||||
def pow_fx(x, y):
|
||||
def pow_fx(x, y, zero_output=False):
|
||||
"""
|
||||
Returns the value of the expression :math:`x^y` where both inputs
|
||||
are secret shared. It uses :py:func:`log2_fx` together with
|
||||
@@ -494,7 +483,7 @@ def pow_fx(x, y):
|
||||
# obtains y * log2(x)
|
||||
exp = y * log2_x
|
||||
# returns 2^(y*log2(x))
|
||||
return exp2_fx(exp)
|
||||
return exp2_fx(exp, zero_output)
|
||||
|
||||
|
||||
def log_fx(x, b):
|
||||
@@ -535,8 +524,8 @@ def abs_fx(x):
|
||||
#
|
||||
# @return floored sint value of x
|
||||
def floor_fx(x):
|
||||
return load_sint(x.v.right_shift(x.f, bit_length=x.k, security=x.kappa,
|
||||
signed=True), type(x))
|
||||
return type(x)(x.v.right_shift(x.f, bit_length=x.k, signed=True),
|
||||
k=x.k, f=x.f)
|
||||
|
||||
|
||||
### sqrt methods
|
||||
@@ -743,13 +732,13 @@ def lin_app_SQ(b, k, f):
|
||||
c, v, m, W = norm_SQ(types.sint(b), k)
|
||||
|
||||
# c is now escalated
|
||||
w = alpha * load_sint(c,types.sfix) + beta # equation before b and reduction by order of k
|
||||
w = alpha * c + beta # equation before b and reduction by order of k
|
||||
|
||||
|
||||
# m even or odd determination
|
||||
m_bit = types.sint()
|
||||
comparison.Mod2(m_bit, m, int(math.ceil(math.log(k, 2))), w.kappa, False)
|
||||
m = load_sint(m_bit, types.sfix)
|
||||
comparison.Mod2(m_bit, m, int(math.ceil(math.log(k, 2))), signed=False)
|
||||
m = m_bit
|
||||
|
||||
# w times v this way both terms have 2^3k and can be symplified
|
||||
w = w * v
|
||||
@@ -774,7 +763,7 @@ def lin_app_SQ(b, k, f):
|
||||
def sqrt_fx(x_l, k, f):
|
||||
factor = 1.0 / (2.0 ** f)
|
||||
|
||||
x = load_sint(x_l, types.sfix) * factor
|
||||
x = x_l * factor
|
||||
|
||||
theta = int(math.ceil(math.log(k/5.4)))
|
||||
|
||||
@@ -912,29 +901,29 @@ def tanh(x):
|
||||
|
||||
# next functions due to https://dl.acm.org/doi/10.1145/3411501.3419427
|
||||
|
||||
def Sep(x):
|
||||
def Sep(x, sfix=types.sfix):
|
||||
b = floatingpoint.PreOR(list(reversed(x.v.bit_decompose(x.k, maybe_mixed=True))))
|
||||
bb = b[:]
|
||||
while len(bb) < 2 * x.f - 1:
|
||||
bb.insert(0, type(b[0])(0))
|
||||
t = x.v * (1 + x.v.bit_compose(b_i.bit_not()
|
||||
for b_i in bb[-2 * x.f + 1:]))
|
||||
u = types.sfix._new(t.right_shift(x.f, 2 * x.k, signed=False))
|
||||
u = sfix._new(t.right_shift(x.f, 2 * x.k, signed=False))
|
||||
b += [b[0].long_one()]
|
||||
return u, [b[i + 1] - b[i] for i in reversed(range(x.k))]
|
||||
|
||||
def SqrtComp(z, old=False):
|
||||
f = types.sfix.f
|
||||
def SqrtComp(z, old=False, sfix=types.sfix):
|
||||
f = sfix.f
|
||||
k = len(z)
|
||||
if isinstance(z[0], types.sint):
|
||||
return types.sfix._new(sum(z[i] * types.cfix(
|
||||
return sfix._new(sum(z[i] * types.cfix(
|
||||
2 ** (-(i - f + 1) / 2), k=k, f=f).v for i in range(k)))
|
||||
k_prime = k // 2
|
||||
f_prime = f // 2
|
||||
c1 = types.sfix(2 ** ((f + 1) / 2 + 1))
|
||||
c0 = types.sfix(2 ** (f / 2 + 1))
|
||||
c1 = sfix(2 ** ((f + 1) / 2 + 1))
|
||||
c0 = sfix(2 ** (f / 2 + 1))
|
||||
a = [z[2 * i].bit_or(z[2 * i + 1]) for i in range(k_prime)]
|
||||
tmp = types.sfix._new(types.sint.bit_compose(reversed(a[:2 * f_prime])))
|
||||
tmp = sfix._new(types.sint.bit_compose(reversed(a[:2 * f_prime])))
|
||||
if old:
|
||||
b = sum(types.sint.conv(zi).if_else(i, 0) for i, zi in enumerate(z)) % 2
|
||||
else:
|
||||
@@ -942,11 +931,15 @@ def SqrtComp(z, old=False):
|
||||
return types.sint.conv(b).if_else(c1, c0) * tmp
|
||||
|
||||
@types.vectorize
|
||||
@instructions_base.sfix_cisc
|
||||
def InvertSqrt(x, old=False):
|
||||
"""
|
||||
Reciprocal square root approximation by `Lu et al.
|
||||
<https://dl.acm.org/doi/10.1145/3411501.3419427>`_
|
||||
"""
|
||||
u, z = Sep(x)
|
||||
class my_sfix(types.sfix):
|
||||
f = x.f
|
||||
k = x.k
|
||||
u, z = Sep(x, sfix=my_sfix)
|
||||
c = 3.14736 + u * (4.63887 * u - 5.77789)
|
||||
return c * SqrtComp(z, old=old)
|
||||
return c * SqrtComp(z, old=old, sfix=my_sfix)
|
||||
|
||||
@@ -4,14 +4,6 @@ from .types import *
|
||||
from . import comparison, program
|
||||
|
||||
class NonLinear:
|
||||
kappa = None
|
||||
|
||||
def set_security(self, kappa):
|
||||
pass
|
||||
|
||||
def check_security(self, kappa):
|
||||
pass
|
||||
|
||||
def mod2m(self, a, k, m, signed):
|
||||
"""
|
||||
a_prime = a % 2^m
|
||||
@@ -45,18 +37,16 @@ class NonLinear:
|
||||
|
||||
def trunc_round_nearest(self, a, k, m, signed):
|
||||
res = sint()
|
||||
comparison.Trunc(res, a + (1 << (m - 1)), k + 1, m, self.kappa,
|
||||
signed)
|
||||
comparison.Trunc(res, a + (1 << (m - 1)), k + 1, m, signed)
|
||||
return res
|
||||
|
||||
def trunc(self, a, k, m, kappa, signed):
|
||||
self.check_security(kappa)
|
||||
def trunc(self, a, k, m, signed):
|
||||
if m == 0:
|
||||
return a
|
||||
return self._trunc(a, k, m, signed)
|
||||
|
||||
def ltz(self, a, k, kappa=None):
|
||||
return -self.trunc(a, k, k - 1, kappa, True)
|
||||
def ltz(self, a, k):
|
||||
return -self.trunc(a, k, k - 1, True)
|
||||
|
||||
class Masking(NonLinear):
|
||||
def eqz(self, a, k):
|
||||
@@ -68,28 +58,19 @@ class Masking(NonLinear):
|
||||
|
||||
class Prime(Masking):
|
||||
""" Non-linear functionality modulo a prime with statistical masking. """
|
||||
def __init__(self, kappa):
|
||||
self.set_security(kappa)
|
||||
|
||||
def set_security(self, kappa):
|
||||
self.kappa = kappa
|
||||
|
||||
def check_security(self, kappa):
|
||||
assert self.kappa == kappa or kappa is None
|
||||
|
||||
def _mod2m(self, a, k, m, signed):
|
||||
res = sint()
|
||||
if m == 1:
|
||||
Mod2(res, a, k, self.kappa, signed)
|
||||
Mod2(res, a, k, signed)
|
||||
else:
|
||||
Mod2mField(res, a, k, m, self.kappa, signed)
|
||||
Mod2mField(res, a, k, m, signed)
|
||||
return res
|
||||
|
||||
def _mask(self, a, k):
|
||||
return maskField(a, k, self.kappa)
|
||||
return maskField(a, k)
|
||||
|
||||
def _trunc_pr(self, a, k, m, signed=None):
|
||||
return TruncPrField(a, k, m, self.kappa)
|
||||
return TruncPrField(a, k, m)
|
||||
|
||||
def _trunc(self, a, k, m, signed=None):
|
||||
a_prime = self.mod2m(a, k, m, signed)
|
||||
@@ -99,12 +80,12 @@ class Prime(Masking):
|
||||
|
||||
def bit_dec(self, a, k, m, maybe_mixed=False):
|
||||
if maybe_mixed:
|
||||
return BitDecFieldRaw(a, k, m, self.kappa)
|
||||
return BitDecFieldRaw(a, k, m)
|
||||
else:
|
||||
return BitDecField(a, k, m, self.kappa)
|
||||
return BitDecField(a, k, m)
|
||||
|
||||
def kor(self, d):
|
||||
return KOR(d, self.kappa)
|
||||
return KOR(d)
|
||||
|
||||
class KnownPrime(NonLinear):
|
||||
""" Non-linear functionality modulo a prime known at compile time. """
|
||||
@@ -144,13 +125,13 @@ class KnownPrime(NonLinear):
|
||||
a += two_power(k)
|
||||
return 1 - types.sintbit.conv(KORL(self.bit_dec(a, k, k, True)))
|
||||
|
||||
def ltz(self, a, k, kappa=None):
|
||||
def ltz(self, a, k):
|
||||
if k + 1 < self.prime.bit_length():
|
||||
# https://dl.acm.org/doi/10.1145/3474123.3486757
|
||||
# "negative" values wrap around when doubling, thus becoming odd
|
||||
return self.mod2m(2 * a, k + 1, 1, False)
|
||||
else:
|
||||
return super(KnownPrime, self).ltz(a, k, kappa)
|
||||
return super(KnownPrime, self).ltz(a, k)
|
||||
|
||||
class Ring(Masking):
|
||||
""" Non-linear functionality modulo a power of two known at compile time.
|
||||
@@ -189,5 +170,5 @@ class Ring(Masking):
|
||||
else:
|
||||
return super(Ring, self).trunc_round_nearest(a, k, m, signed)
|
||||
|
||||
def ltz(self, a, k, kappa=None):
|
||||
def ltz(self, a, k):
|
||||
return LtzRing(a, k)
|
||||
|
||||
@@ -877,7 +877,7 @@ class LinearORAM(TrivialORAM):
|
||||
demux_array(bit_decompose(index, self.index_size), \
|
||||
self.index_vector)
|
||||
t = self.value_type.get_type(None if None in self.entry_size else max(self.entry_size))
|
||||
@map_sum(get_n_threads(self.size), n_parallel, self.size, \
|
||||
@map_sum(get_n_threads(self.size), None, self.size, \
|
||||
self.value_length + 1, t)
|
||||
def f(i):
|
||||
entry = self.ram[i]
|
||||
@@ -897,7 +897,7 @@ class LinearORAM(TrivialORAM):
|
||||
new_value = make_array(
|
||||
new_value, self.value_type.get_type(
|
||||
max(x or 0 for x in self.entry_size)))
|
||||
@for_range_multithread(get_n_threads(self.size), n_parallel, self.size)
|
||||
@for_range_multithread(get_n_threads(self.size), None, self.size)
|
||||
def f(i):
|
||||
entry = self.ram[i]
|
||||
access_here = self.index_vector[i]
|
||||
@@ -917,7 +917,7 @@ class LinearORAM(TrivialORAM):
|
||||
max(x or 0 for x in self.entry_size)))
|
||||
new_empty = MemValue(new_empty)
|
||||
write = MemValue(write)
|
||||
@map_sum(get_n_threads(self.size), n_parallel, self.size, \
|
||||
@map_sum(get_n_threads(self.size), None, self.size, \
|
||||
self.value_length + 1, [self.value_type.bit_type] + \
|
||||
[self.value_type] * self.value_length)
|
||||
def f(i):
|
||||
@@ -1340,8 +1340,8 @@ class TreeORAM(AbstractORAM):
|
||||
half = (empty_positions[i]+1 - parity) // 2
|
||||
half_max = self.bucket_size // 2
|
||||
|
||||
bits = floatingpoint.B2U(half, half_max, Program.prog.security)[0]
|
||||
bits2 = floatingpoint.B2U(half+parity, half_max, Program.prog.security)[0]
|
||||
bits = floatingpoint.B2U(half, half_max)[0]
|
||||
bits2 = floatingpoint.B2U(half+parity, half_max)[0]
|
||||
# (doesn't work)
|
||||
#bits2 = [0] * half_max
|
||||
## second half with parity bit
|
||||
@@ -1350,7 +1350,8 @@ class TreeORAM(AbstractORAM):
|
||||
#bits2[0] = (1 - bits[0]) * parity
|
||||
bucket_bits = [b for sl in zip(bits2,bits) for b in sl]
|
||||
else:
|
||||
bucket_bits = floatingpoint.B2U(empty_positions[i]+1, self.bucket_size, Program.prog.security)[0]
|
||||
bucket_bits = floatingpoint.B2U(empty_positions[i]+1,
|
||||
self.bucket_size)[0]
|
||||
assert len(bucket_bits) == self.bucket_size
|
||||
for j, b in enumerate(bucket_bits):
|
||||
pos_bits[i * self.bucket_size + j] = [b, leaf]
|
||||
@@ -1376,8 +1377,7 @@ class TreeORAM(AbstractORAM):
|
||||
Program.prog.curr_tape.start_new_basicblock()
|
||||
|
||||
bucket_sizes = Array(2**self.D, regint)
|
||||
for i in range(2**self.D):
|
||||
bucket_sizes[i] = 0
|
||||
bucket_sizes.assign_all(0)
|
||||
|
||||
@for_range_opt(len(entries))
|
||||
def _(k):
|
||||
@@ -1697,7 +1697,7 @@ class OneLevelORAM(TreeORAM):
|
||||
|
||||
class BinaryORAM:
|
||||
def __init__(self, size, value_type=None, **kwargs):
|
||||
import circuit_oram
|
||||
from Compiler import circuit_oram
|
||||
from Compiler.GC import types
|
||||
n_bits = int(get_program().options.binary)
|
||||
self.value_type = value_type or types.sbitintvec.get_type(n_bits)
|
||||
|
||||
@@ -18,7 +18,7 @@ from functools import reduce
|
||||
import Compiler.instructions
|
||||
import Compiler.instructions_base
|
||||
import Compiler.instructions_base as inst_base
|
||||
from Compiler.config import REG_MAX, USER_MEM, COST
|
||||
from Compiler.config import REG_MAX, USER_MEM, COST, MEM_MAX
|
||||
from Compiler.exceptions import CompilerError
|
||||
from Compiler.instructions_base import RegType
|
||||
|
||||
@@ -103,6 +103,38 @@ class Program(object):
|
||||
self.bit_length = int(options.binary) or int(options.field)
|
||||
if options.prime:
|
||||
self.prime = int(options.prime)
|
||||
print("WARNING: --prime/-P activates code that usually isn't "
|
||||
"the most efficient variant. Consider using --field/-F "
|
||||
"and set the prime only during the actual computation.")
|
||||
if not self.rabbit_gap() and self.prime > 2 ** 50:
|
||||
print("The chosen prime is particularly inefficient. "
|
||||
"Consider using a prime that is closer to a power "
|
||||
"of two", end='')
|
||||
try:
|
||||
import gmpy2
|
||||
bad_prime = self.prime
|
||||
self.prime = 2 ** int(
|
||||
round(math.log(self.prime, 2))) + 1
|
||||
while True:
|
||||
if self.prime > 2 ** 59:
|
||||
# LWE compatibility
|
||||
step = 2 ** 15
|
||||
else:
|
||||
step = 1
|
||||
if self.prime < bad_prime:
|
||||
self.prime += step
|
||||
else:
|
||||
self.prime -= step
|
||||
if gmpy2.is_prime(self.prime):
|
||||
break
|
||||
assert self.rabbit_gap()
|
||||
print(", for example, %d." % self.prime)
|
||||
self.prime = bad_prime
|
||||
except ImportError:
|
||||
print(".")
|
||||
if options.execute:
|
||||
print("Use '-- --prime <prime>' to specify the prime for "
|
||||
"execution only.")
|
||||
max_bit_length = int(options.prime).bit_length() - 2
|
||||
if self.bit_length > max_bit_length:
|
||||
raise CompilerError(
|
||||
@@ -111,7 +143,7 @@ class Program(object):
|
||||
self.bit_length = self.bit_length or max_bit_length
|
||||
self.non_linear = KnownPrime(self.prime)
|
||||
else:
|
||||
self.non_linear = Prime(self.security)
|
||||
self.non_linear = Prime()
|
||||
if not self.bit_length:
|
||||
self.bit_length = 64
|
||||
print("Default bit length for compilation:", self.bit_length)
|
||||
@@ -197,6 +229,8 @@ class Program(object):
|
||||
self.cisc_to_function = True
|
||||
if not self.options.cisc:
|
||||
self.options.cisc = not self.options.optimize_hard
|
||||
self.use_tape_calls = True
|
||||
self.force_cisc_tape = False
|
||||
|
||||
Program.prog = self
|
||||
from . import comparison, instructions, instructions_base, types
|
||||
@@ -278,7 +312,8 @@ class Program(object):
|
||||
self.non_linear = Ring(ring_size)
|
||||
self.options.ring = str(ring_size)
|
||||
|
||||
def new_tape(self, function, args=[], name=None, single_thread=False):
|
||||
def new_tape(self, function, args=[], name=None, single_thread=False,
|
||||
finalize=True, **kwargs):
|
||||
"""
|
||||
Create a new tape from a function. See
|
||||
:py:func:`~Compiler.library.multithread` and
|
||||
@@ -309,11 +344,12 @@ class Program(object):
|
||||
self.curr_tape
|
||||
tape_index = len(self.tapes)
|
||||
self.tape_stack.append(self.curr_tape)
|
||||
self.curr_tape = Tape(name, self)
|
||||
self.curr_tape = Tape(name, self, **kwargs)
|
||||
self.curr_tape.singular = single_thread
|
||||
self.tapes.append(self.curr_tape)
|
||||
function(*args)
|
||||
self.finalize_tape(self.curr_tape)
|
||||
if finalize:
|
||||
self.finalize_tape(self.curr_tape)
|
||||
if self.tape_stack:
|
||||
self.curr_tape = self.tape_stack.pop()
|
||||
return tape_index
|
||||
@@ -346,6 +382,7 @@ class Program(object):
|
||||
thread_numbers = []
|
||||
while len(thread_numbers) < len(args):
|
||||
free_threads = self.curr_tape.free_threads
|
||||
self.curr_tape.ran_threads = True
|
||||
if free_threads:
|
||||
thread_numbers.append(min(free_threads))
|
||||
free_threads.remove(thread_numbers[-1])
|
||||
@@ -417,7 +454,10 @@ class Program(object):
|
||||
|
||||
def finalize_tape(self, tape):
|
||||
if not tape.purged:
|
||||
curr_tape = self.curr_tape
|
||||
self.curr_tape = tape
|
||||
tape.optimize(self.options)
|
||||
self.curr_tape = curr_tape
|
||||
tape.write_bytes()
|
||||
if self.options.asmoutfile:
|
||||
tape.write_str(self.options.asmoutfile + "-" + tape.name)
|
||||
@@ -472,16 +512,18 @@ class Program(object):
|
||||
self.allocated_mem[mem_type] += size
|
||||
if len(str(addr)) != len(str(addr + size)) and self.verbose:
|
||||
print("Memory of type '%s' now of size %d" % (mem_type, addr + size))
|
||||
if addr + size >= 2**64:
|
||||
raise CompilerError("allocation exceeded for type '%s'" % mem_type)
|
||||
if addr + size >= MEM_MAX:
|
||||
raise CompilerError(
|
||||
"allocation exceeded for type '%s' after adding %d" % \
|
||||
(mem_type, size))
|
||||
self.allocated_mem_blocks[addr, mem_type] = size, self.curr_block.alloc_pool
|
||||
if single_size:
|
||||
from .library import get_thread_number, runtime_error_if
|
||||
from .library import get_arg, runtime_error_if
|
||||
bak = self.curr_tape.active_basicblock
|
||||
self.curr_tape.active_basicblock = self.curr_tape.basicblocks[0]
|
||||
tn = get_thread_number()
|
||||
runtime_error_if(tn > self.n_running_threads, "malloc")
|
||||
res = addr + single_size * (tn - 1)
|
||||
arg = get_arg()
|
||||
runtime_error_if(arg >= self.n_running_threads, "malloc")
|
||||
res = addr + single_size * arg
|
||||
self.curr_tape.active_basicblock = bak
|
||||
self.base_addresses[res] = addr
|
||||
return res
|
||||
@@ -577,7 +619,6 @@ class Program(object):
|
||||
def set_security(self, security):
|
||||
changed = self._security != security
|
||||
self._security = security
|
||||
self.non_linear.set_security(security)
|
||||
if changed:
|
||||
print("Changed statistical security for comparison etc. to",
|
||||
security)
|
||||
@@ -783,7 +824,7 @@ class Program(object):
|
||||
class Tape:
|
||||
"""A tape contains a list of basic blocks, onto which instructions are added."""
|
||||
|
||||
def __init__(self, name, program):
|
||||
def __init__(self, name, program, thread_pool=None):
|
||||
"""Set prime p and the initial instructions and registers."""
|
||||
self.program = program
|
||||
name += "-%d" % program.get_tape_counter()
|
||||
@@ -800,12 +841,15 @@ class Tape:
|
||||
self.merge_opens = True
|
||||
self.if_states = []
|
||||
self.req_bit_length = defaultdict(lambda: 0)
|
||||
self.bit_length_reason = None
|
||||
self.function_basicblocks = {}
|
||||
self.functions = []
|
||||
self.singular = True
|
||||
self.free_threads = set()
|
||||
self.free_threads = set() if thread_pool is None else thread_pool
|
||||
self.loop_breaks = []
|
||||
self.warned_about_mem = False
|
||||
self.return_values = []
|
||||
self.ran_threads = False
|
||||
|
||||
class BasicBlock(object):
|
||||
def __init__(self, parent, name, scope, exit_condition=None,
|
||||
@@ -880,7 +924,8 @@ class Tape:
|
||||
self.usage_instructions = list(filter(relevant, self.instructions))
|
||||
else:
|
||||
self.usage_instructions = []
|
||||
if len(self.usage_instructions) > 1000:
|
||||
if len(self.usage_instructions) > 1000 and \
|
||||
self.parent.program.verbose:
|
||||
print("Retaining %d instructions" % len(self.usage_instructions))
|
||||
del self.instructions
|
||||
self.purged = True
|
||||
@@ -1107,6 +1152,9 @@ class Tape:
|
||||
if addr.program == self and self.basicblocks:
|
||||
allocator.alloc_reg(addr, self.basicblocks[-1].alloc_pool)
|
||||
|
||||
for reg in self.return_values:
|
||||
allocator.alloc_reg(reg, self.basicblocks[-1].alloc_pool)
|
||||
|
||||
seen = set()
|
||||
|
||||
def alloc(block):
|
||||
@@ -1214,7 +1262,10 @@ class Tape:
|
||||
Compiler.instructions.reqbl(bl, add_to_prog=False)
|
||||
)
|
||||
if self.program.verbose:
|
||||
print("Tape requires prime bit length", self.req_bit_length["p"])
|
||||
print("Tape requires prime bit length",
|
||||
self.req_bit_length["p"],
|
||||
('for %s' % self.bit_length_reason
|
||||
if self.bit_length_reason else ''))
|
||||
print("Tape requires galois bit length", self.req_bit_length["2"])
|
||||
|
||||
@unpurged
|
||||
@@ -1287,6 +1338,7 @@ class Tape:
|
||||
if "Bytecode" not in filename:
|
||||
filename = self.program.programs_dir + "/Bytecode/" + filename
|
||||
print("Writing to", filename)
|
||||
sys.stdout.flush()
|
||||
f = open(filename, "wb")
|
||||
h = hashlib.sha256()
|
||||
for i in self._get_instructions():
|
||||
@@ -1395,13 +1447,12 @@ class Tape:
|
||||
return repr(dict(self))
|
||||
|
||||
class ReqNode(object):
|
||||
__slots__ = ["num", "_children", "name", "blocks", "aggregated"]
|
||||
|
||||
def __init__(self, name):
|
||||
self._children = []
|
||||
self.name = name
|
||||
self.blocks = []
|
||||
self.aggregated = None
|
||||
self.num = None
|
||||
|
||||
@property
|
||||
def children(self):
|
||||
@@ -1411,12 +1462,17 @@ class Tape:
|
||||
def aggregate(self, *args):
|
||||
if self.aggregated is not None:
|
||||
return self.aggregated
|
||||
self.recursion = self.num is not None
|
||||
if self.recursion:
|
||||
return Tape.ReqNum()
|
||||
self.num = Tape.ReqNum()
|
||||
for block in self.blocks:
|
||||
block.add_usage(self)
|
||||
res = reduce(
|
||||
lambda x, y: x + y.aggregate(self.name), self.children, self.num
|
||||
)
|
||||
if self.recursion:
|
||||
res *= float('inf')
|
||||
self.aggregated = res
|
||||
return res
|
||||
|
||||
@@ -1442,7 +1498,7 @@ class Tape:
|
||||
n_reps = self.aggregator([1])
|
||||
n_rounds = res["all", "round"]
|
||||
n_invs = res["all", "inv"]
|
||||
if (n_invs / n_rounds) * 1000 < n_reps:
|
||||
if (n_invs / n_rounds) * 1000 < n_reps and Program.prog.verbose:
|
||||
print(
|
||||
self.nodes[0].blocks[0].name,
|
||||
"blowing up rounds: ",
|
||||
@@ -1468,15 +1524,19 @@ class Tape:
|
||||
def close_scope(self, outer_scope, parent_req_node, name):
|
||||
self.start_new_basicblock(outer_scope, name, req_node=parent_req_node)
|
||||
|
||||
def require_bit_length(self, bit_length, t="p"):
|
||||
def require_bit_length(self, bit_length, t="p", reason=None):
|
||||
if t == "p":
|
||||
if self.program.prime:
|
||||
if bit_length >= self.program.prime.bit_length() - 1:
|
||||
raise CompilerError(
|
||||
"required bit length %d too much for %d"
|
||||
% (bit_length, self.program.prime)
|
||||
+ ('(for %s)' % reason if reason else '')
|
||||
)
|
||||
self.req_bit_length[t] = max(bit_length + 1, self.req_bit_length[t])
|
||||
bit_length += 1
|
||||
if bit_length > self.req_bit_length[t]:
|
||||
self.req_bit_length[t] = bit_length
|
||||
self.bit_length_reason = reason
|
||||
else:
|
||||
self.req_bit_length[t] = max(bit_length, self.req_bit_length)
|
||||
|
||||
@@ -1498,6 +1558,21 @@ class Tape:
|
||||
"In some cases, you can fix this by using 'compile.py -l'."
|
||||
)
|
||||
|
||||
def __int__(self):
|
||||
raise CompilerError(
|
||||
"It is impossible to convert run-time types to compile-time "
|
||||
"Python types like int or float. The reason for this is that "
|
||||
"%s objects are only a placeholder during the execution in "
|
||||
"Python, the actual value of which is only defined in the "
|
||||
"virtual machine at a later time. See "
|
||||
"https://mp-spdz.readthedocs.io/en/latest/journey.html "
|
||||
"to get an understanding of the overall design. "
|
||||
"In rare cases, you can fix this by using 'compile.py -l'." % \
|
||||
type(self).__name__
|
||||
)
|
||||
|
||||
__float__ = __int__
|
||||
|
||||
class Register(_no_truth):
|
||||
"""
|
||||
Class for creating new registers. The register's index is automatically assigned
|
||||
@@ -1619,6 +1694,9 @@ class Tape:
|
||||
def copy(self):
|
||||
return Tape.Register(self.reg_type, Program.prog.curr_tape)
|
||||
|
||||
def same_type(self):
|
||||
return type(self)(size=self.size)
|
||||
|
||||
def link(self, other):
|
||||
if Program.prog.options.noreallocate:
|
||||
raise CompilerError("reallocation necessary for linking, "
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import itertools
|
||||
from Compiler import types, library, instructions
|
||||
from Compiler import comparison, util
|
||||
|
||||
def dest_comp(B):
|
||||
Bt = B.transpose()
|
||||
@@ -20,6 +21,7 @@ def reveal_sort(k, D, reverse=False):
|
||||
backward order
|
||||
|
||||
"""
|
||||
comparison.require_ring_size(util.log2(len(k)) + 1, 'sorting')
|
||||
assert len(k) == len(D)
|
||||
library.break_point()
|
||||
shuffle = types.sint.get_secure_shuffle(len(k))
|
||||
|
||||
@@ -109,7 +109,7 @@ class SqrtOram(Generic[T, B]):
|
||||
self.shuffle_used = cint.Array(self.n)
|
||||
# Random permutation on the data
|
||||
self.shufflei = Array.create_from(
|
||||
[self.index_type(i) for i in range(self.n)])
|
||||
self.index_type(regint.inc(self.n)))
|
||||
# Calculate the period if not given
|
||||
# upon recursion, the period should stay the same ("in sync"),
|
||||
# therefore it can be passed as a constructor parameter
|
||||
@@ -122,7 +122,7 @@ class SqrtOram(Generic[T, B]):
|
||||
# Note that self.shuffle_the_shuffle mutates this field
|
||||
# Why don't we pass it as an argument then? Well, this way we don't have to allocate memory while shuffling, which keeps open the possibility for multithreading
|
||||
self.permutation = Array.create_from(
|
||||
[self.index_type(i) for i in range(self.n)])
|
||||
self.index_type(regint.inc(self.n)))
|
||||
# We allow the caller to postpone the initialization of the shuffle
|
||||
# This is the most expensive operation, and can be done in a thread (only if you know what you're doing)
|
||||
# Note that if you do not initialize, the ORAM is insecure
|
||||
@@ -256,7 +256,6 @@ class SqrtOram(Generic[T, B]):
|
||||
|
||||
return result
|
||||
|
||||
@lib.method_block
|
||||
def write(self, index: T, *value: T):
|
||||
global trace, n_parallel
|
||||
if trace:
|
||||
@@ -271,7 +270,12 @@ class SqrtOram(Generic[T, B]):
|
||||
else:
|
||||
raise Exception("Cannot handle type of value passed")
|
||||
print(self.entry_length, value, type(value),len(value))
|
||||
value = MemValue(value)
|
||||
|
||||
self._write(index, *value)
|
||||
|
||||
@lib.method_block
|
||||
def _write(self, index: T, *value: T):
|
||||
value = MemValue(self.value_type(value))
|
||||
index = MemValue(index)
|
||||
|
||||
# Refresh if we have performed T (period) accesses
|
||||
@@ -513,14 +517,14 @@ class SqrtOram(Generic[T, B]):
|
||||
# Since the underlying memory of the position map is already aligned in
|
||||
# this packed structure, we can simply overwrite the memory while
|
||||
# maintaining the structure.
|
||||
self.position_map.reinitialize(*self.permutation)
|
||||
self.position_map.reinitialize(self.permutation)
|
||||
|
||||
def reinitialize(self, *data: T):
|
||||
def reinitialize(self, data: T):
|
||||
# Note that this method is only used during refresh, and as such is
|
||||
# only called with a permutation as data.
|
||||
|
||||
# The logical addresses of some previous permutation are irrelevant and must be reset
|
||||
self.shufflei.assign([self.index_type(i) for i in range(self.n)])
|
||||
self.shufflei.assign_vector(self.index_type(regint.inc(self.n)))
|
||||
# Reset the clock
|
||||
self.t.write(0)
|
||||
# Reset shuffle_used
|
||||
@@ -530,10 +534,10 @@ class SqrtOram(Generic[T, B]):
|
||||
# This structure is preserved while overwriting the values using
|
||||
# assign_vector
|
||||
self.shuffle.assign_vector(self.value_type(
|
||||
data, size=self.n * self.entry_length))
|
||||
data[:], size=self.n * self.entry_length))
|
||||
# Note that this updates self.permutation (see constructor for explanation)
|
||||
self.shuffle_the_shuffle()
|
||||
self.position_map.reinitialize(*self.permutation)
|
||||
self.position_map.reinitialize(self.permutation)
|
||||
|
||||
def _reset_shuffle_used(self):
|
||||
global allow_memory_allocation
|
||||
@@ -568,7 +572,7 @@ class PositionMap(Generic[T, B]):
|
||||
print_at_depth(self.depth, 'Scanning %s for logical address %s (fake=%s)',
|
||||
self.__class__.__name__, logical_address.reveal(), sintbit(fake).reveal())
|
||||
|
||||
def reinitialize(self, *permutation: T):
|
||||
def reinitialize(self, permutation: T):
|
||||
"""Reinitialize this PositionMap.
|
||||
|
||||
Since the reinitialization occurs at runtime (`on SqrtORAM.refresh()`),
|
||||
@@ -613,9 +617,10 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
|
||||
packed_size = int(math.ceil(self.n / pack))
|
||||
packed_structure = MultiArray(
|
||||
(packed_size, pack), value_type=value_type)
|
||||
for i in range(packed_size):
|
||||
@lib.for_range(packed_size)
|
||||
def _(i):
|
||||
packed_structure[i] = Array.create_from(
|
||||
permutation[i*pack:(i+1)*pack])
|
||||
permutation.get_vector(base=i * pack, size=pack))
|
||||
|
||||
SqrtOram.__init__(self, packed_structure, value_type=value_type,
|
||||
period=period, entry_length=pack, k=self.depth,
|
||||
@@ -720,8 +725,8 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
|
||||
|
||||
return p.reveal()
|
||||
|
||||
def reinitialize(self, *permutation: T):
|
||||
SqrtOram.reinitialize(self, *permutation)
|
||||
def reinitialize(self, permutation: T):
|
||||
SqrtOram.reinitialize(self, permutation)
|
||||
|
||||
|
||||
class LinearPositionMap(PositionMap):
|
||||
@@ -790,8 +795,8 @@ class LinearPositionMap(PositionMap):
|
||||
|
||||
return p.reveal()
|
||||
|
||||
def reinitialize(self, *data: T):
|
||||
self.physical.assign_vector(data)
|
||||
def reinitialize(self, data : T):
|
||||
self.physical.assign(data)
|
||||
|
||||
global allow_memory_allocation
|
||||
if allow_memory_allocation:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -28,11 +28,11 @@ def greater_than(a, b, bits):
|
||||
else:
|
||||
return a.greater_than(b, bits)
|
||||
|
||||
def pow2_value(a, bit_length=None, security=None):
|
||||
def pow2_value(a, bit_length=None):
|
||||
if is_constant_float(a):
|
||||
return 2**a
|
||||
else:
|
||||
return a.pow2(bit_length, security)
|
||||
return a.pow2(bit_length)
|
||||
|
||||
def mod2m(a, b, bits, signed):
|
||||
if isinstance(a, int):
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
int main()
|
||||
{
|
||||
P256Element::init();
|
||||
P256Element::Scalar key;
|
||||
KeySetup<Share<P256Element::Scalar>> key;
|
||||
string prefix = PREP_DIR "ECDSA/";
|
||||
mkdir_p(prefix.c_str());
|
||||
write_online_setup(prefix, P256Element::Scalar::pr());
|
||||
|
||||
@@ -44,6 +44,7 @@ int main(int argc, const char** argv)
|
||||
typedef Share<P256Element::Scalar> pShare;
|
||||
string prefix = get_prep_sub_dir<pShare>(PREP_DIR "ECDSA/", 2);
|
||||
read_mac_key(prefix, N, keyp);
|
||||
pShare::set_mac_key(keyp);
|
||||
|
||||
pShare::MAC_Check::setup(P);
|
||||
Share<P256Element>::MAC_Check::setup(P);
|
||||
|
||||
@@ -37,6 +37,9 @@ void preprocessing(vector<EcTuple<T>>& tuples, int buffer_size,
|
||||
EcdsaOptions opts)
|
||||
{
|
||||
bool prep_mul = opts.prep_mul;
|
||||
if (prep_mul)
|
||||
proc.protocol.init_mul();
|
||||
|
||||
Timer timer;
|
||||
timer.start();
|
||||
Player& P = proc.P;
|
||||
@@ -77,7 +80,6 @@ void preprocessing(vector<EcTuple<T>>& tuples, int buffer_size,
|
||||
MCc.POpen_Begin(opened_Rs, secret_Rs, extra_player);
|
||||
if (prep_mul)
|
||||
{
|
||||
protocol.init_mul();
|
||||
for (int i = 0; i < buffer_size; i++)
|
||||
protocol.prepare_mul(inv_ks[i], sk);
|
||||
protocol.start_exchange();
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
#include "Protocols/SpdzWisePrep.hpp"
|
||||
#include "Protocols/SpdzWiseInput.hpp"
|
||||
#include "Protocols/SpdzWiseShare.hpp"
|
||||
#include "Protocols/SpdzWiseRep3Shuffler.hpp"
|
||||
#include "Processor/Data_Files.hpp"
|
||||
#include "Processor/Instruction.hpp"
|
||||
#include "Processor/Machine.hpp"
|
||||
|
||||
@@ -24,6 +24,13 @@ Client::Client(const vector<string>& hostnames, int port_base,
|
||||
"P" + to_string(i), "C" + to_string(my_client_id), true);
|
||||
if (i == 0)
|
||||
specification.Receive(sockets[0]);
|
||||
else
|
||||
{
|
||||
octetStream spec;
|
||||
spec.Receive(sockets[i]);
|
||||
if (spec != specification)
|
||||
throw runtime_error("inconsistent specification");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -31,6 +31,15 @@ def set_keepalive_osx(sock, after_idle_sec=1, interval_sec=3, max_fails=5):
|
||||
sock.setsockopt(socket.IPPROTO_TCP, TCP_KEEPALIVE, interval_sec)
|
||||
|
||||
class Client:
|
||||
"""Client to servers running secure computation. Works both as a client
|
||||
to all parties or a trusted client to a single party.
|
||||
|
||||
:param hostnames: hostnames or IP addresses to connect to
|
||||
:param port_base: port number for first hostname,
|
||||
increases by one for every additional hostname
|
||||
:param my_client_id: number to identify client
|
||||
|
||||
"""
|
||||
def __init__(self, hostnames, port_base, my_client_id):
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
|
||||
name = 'C%d' % my_client_id
|
||||
@@ -62,6 +71,11 @@ class Client:
|
||||
|
||||
self.specification = octetStream()
|
||||
self.specification.Receive(self.sockets[0])
|
||||
for sock in self.sockets[1:]:
|
||||
specification = octetStream()
|
||||
specification.Receive(sock)
|
||||
if specification.buf != self.specification.buf:
|
||||
raise Exception('inconsistent specification')
|
||||
type = self.specification.get_int(4)
|
||||
if type == ord('R'):
|
||||
self.domain = Z2(self.specification.get_int(4))
|
||||
@@ -99,6 +113,12 @@ class Client:
|
||||
return triples
|
||||
|
||||
def send_private_inputs(self, values):
|
||||
""" Send inputs privately to the computation servers.
|
||||
This assumes that the client is connected to all servers.
|
||||
|
||||
:param values: list of input values
|
||||
|
||||
"""
|
||||
T = self.domain
|
||||
triples = self.receive_triples(T, len(values))
|
||||
os = octetStream()
|
||||
@@ -109,10 +129,46 @@ class Client:
|
||||
os.Send(socket)
|
||||
|
||||
def receive_outputs(self, n):
|
||||
""" Receive outputs privately from the computation servers.
|
||||
This assumes that the client is connected to all servers.
|
||||
|
||||
:param n: number of outputs
|
||||
|
||||
"""
|
||||
T = self.domain
|
||||
triples = self.receive_triples(T, n)
|
||||
return [int(self.clear_domain(triple[0].v)) for triple in triples]
|
||||
|
||||
def send_public_inputs(self, values):
|
||||
""" Send values in the clear. This works for public inputs
|
||||
to all servers or to send shares to a single server.
|
||||
|
||||
:param values: list of values
|
||||
|
||||
"""
|
||||
os = octetStream()
|
||||
for value in values:
|
||||
self.domain(value).pack(os)
|
||||
for socket in self.sockets:
|
||||
os.Send(socket)
|
||||
|
||||
def receive_plain_values(self, socket=None):
|
||||
""" Receive values in the clear. This works for public inputs
|
||||
to all servers or to send shares to a single server.
|
||||
|
||||
:param socket: socket to use (need to specify it there is more than one)
|
||||
|
||||
"""
|
||||
if socket is None:
|
||||
if len(self.sockets) != 1:
|
||||
raise Exception('need to specify socket')
|
||||
socket = self.sockets[0]
|
||||
os = octetStream()
|
||||
os.Receive(socket)
|
||||
assert len(os) % self.domain.size() == 0
|
||||
return [int(os.get(self.domain))
|
||||
for i in range(len(os) // self.domain.size())]
|
||||
|
||||
class octetStream:
|
||||
def __init__(self, value=None):
|
||||
self.buf = b''
|
||||
@@ -123,6 +179,8 @@ class octetStream:
|
||||
def get_length(self):
|
||||
return len(self.buf)
|
||||
|
||||
__len__ = get_length
|
||||
|
||||
def reset_write_head(self):
|
||||
self.buf = b''
|
||||
self.ptr = 0
|
||||
@@ -164,6 +222,11 @@ class octetStream:
|
||||
else:
|
||||
return 0
|
||||
|
||||
def get(self, type):
|
||||
res = type()
|
||||
res.unpack(self)
|
||||
return res
|
||||
|
||||
def consume(self, length):
|
||||
self.ptr += length
|
||||
assert self.ptr <= len(self.buf)
|
||||
|
||||
@@ -7,7 +7,7 @@ class Domain:
|
||||
|
||||
def __int__(self):
|
||||
res = self.v % self.modulus
|
||||
return res if res < self.modulus / 2 else res - self.modulus
|
||||
return int(res if res < self.modulus / 2 else res - self.modulus)
|
||||
|
||||
def __add__(self, other):
|
||||
try:
|
||||
|
||||
19
ExternalIO/personal-client-example.py
Executable file
19
ExternalIO/personal-client-example.py
Executable file
@@ -0,0 +1,19 @@
|
||||
#!/usr/bin/python3
|
||||
|
||||
import sys, random
|
||||
|
||||
sys.path.insert(0, 'ExternalIO')
|
||||
|
||||
from client import *
|
||||
|
||||
party = int(sys.argv[1])
|
||||
|
||||
client = Client(['localhost'], 15000 + party, 0)
|
||||
|
||||
n = 1000
|
||||
|
||||
if party < 2:
|
||||
client.send_public_inputs(random.gauss(0, 1) * 2 ** 16 for i in range(n))
|
||||
|
||||
x = [client.receive_plain_values() for i in range(2)]
|
||||
client.send_public_inputs(a + b for a, b in zip(*x))
|
||||
@@ -13,6 +13,8 @@
|
||||
#include "SimpleMachine.h"
|
||||
#include "Tools/mkpath.h"
|
||||
|
||||
#include "Protocols/Share.hpp"
|
||||
|
||||
template<class FD>
|
||||
Producer<FD>::Producer(int output_thread, bool write_output) :
|
||||
n_slots(0), output_thread(output_thread), write_output(write_output),
|
||||
|
||||
@@ -8,6 +8,8 @@
|
||||
|
||||
#include "BitAdder.h"
|
||||
|
||||
#include "Protocols/BufferScope.h"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
template<class T>
|
||||
@@ -69,6 +71,9 @@ void BitAdder::add(vector<vector<T> >& res,
|
||||
|
||||
size_t n_items = end - begin;
|
||||
|
||||
if (OnlineOptions::singleton.has_option("verbose_and"))
|
||||
fprintf(stderr, "%lu ANDs in bit adder\n", length * n_items * n_bits);
|
||||
|
||||
if (supply)
|
||||
{
|
||||
#ifdef VERBOSE_EDA
|
||||
@@ -85,6 +90,7 @@ void BitAdder::add(vector<vector<T> >& res,
|
||||
vector<T> carries(n_items);
|
||||
vector<T> a(n_items), b(n_items);
|
||||
auto& protocol = proc.protocol;
|
||||
BufferScope scope(proc.DataF, n_items * length * n_bits);
|
||||
for (int i = 0; i < n_bits; i++)
|
||||
{
|
||||
assert(summands[i].size() == 2);
|
||||
|
||||
@@ -24,14 +24,14 @@ void FakeSecret::load_clear(int n, const Integer& x)
|
||||
*this = x;
|
||||
}
|
||||
|
||||
void FakeSecret::bitcom(Memory<FakeSecret>& S, const vector<int>& regs)
|
||||
void FakeSecret::bitcom(StackedVector<FakeSecret>& S, const vector<int>& regs)
|
||||
{
|
||||
*this = 0;
|
||||
for (unsigned int i = 0; i < regs.size(); i++)
|
||||
*this ^= (S[regs[i]] << i);
|
||||
}
|
||||
|
||||
void FakeSecret::bitdec(Memory<FakeSecret>& S, const vector<int>& regs) const
|
||||
void FakeSecret::bitdec(StackedVector<FakeSecret>& S, const vector<int>& regs) const
|
||||
{
|
||||
for (unsigned int i = 0; i < regs.size(); i++)
|
||||
S[regs[i]] = (*this >> i) & 1;
|
||||
|
||||
@@ -142,8 +142,8 @@ public:
|
||||
template <class T>
|
||||
void store(Memory<T>& mem, size_t address) { mem[address] = *this; }
|
||||
|
||||
void bitcom(Memory<FakeSecret>& S, const vector<int>& regs);
|
||||
void bitdec(Memory<FakeSecret>& S, const vector<int>& regs) const;
|
||||
void bitcom(StackedVector<FakeSecret>& S, const vector<int>& regs);
|
||||
void bitdec(StackedVector<FakeSecret>& S, const vector<int>& regs) const;
|
||||
|
||||
template <class T>
|
||||
void xor_(int n, const FakeSecret& x, const T& y)
|
||||
|
||||
@@ -84,8 +84,8 @@ void Instruction::parse(istream& s, int pos)
|
||||
ostringstream os;
|
||||
os << "Code not defined for instruction " << showbase << hex << opcode << dec << endl;
|
||||
os << "This virtual machine executes binary circuits only." << endl;
|
||||
os << "Use 'compile.py -B'." << endl;
|
||||
throw Invalid_Instruction(os.str());
|
||||
os << "Use 'compile.py -B'.";
|
||||
exit_error(os.str());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
15
GC/NoShare.h
15
GC/NoShare.h
@@ -52,7 +52,7 @@ public:
|
||||
|
||||
static DataFieldType field_type()
|
||||
{
|
||||
throw not_implemented();
|
||||
return DATA_GF2;
|
||||
}
|
||||
|
||||
static void init_minimum(int)
|
||||
@@ -80,7 +80,8 @@ public:
|
||||
|
||||
bool operator!=(NoValue) const { return false; }
|
||||
|
||||
bool operator==(int) { fail(); return false; }
|
||||
bool operator==(int) const { fail(); return false; }
|
||||
bool operator==(NoValue) const { fail(); return false; }
|
||||
|
||||
bool get_bit(int) { fail(); return 0; }
|
||||
|
||||
@@ -92,6 +93,8 @@ public:
|
||||
|
||||
void input(istream&, bool) { fail(); }
|
||||
void output(ostream&, bool) {}
|
||||
|
||||
void pack(octetStream&) const { fail(); }
|
||||
};
|
||||
|
||||
inline ostream& operator<<(ostream& o, NoValue)
|
||||
@@ -169,8 +172,8 @@ public:
|
||||
|
||||
void load_clear(Integer, Integer) { fail(); }
|
||||
void random_bit() { fail(); }
|
||||
void bitdec(vector<NoShare>&, const vector<int>&) const { fail(); }
|
||||
void bitcom(vector<NoShare>&, const vector<int>&) const { fail(); }
|
||||
void bitdec(StackedVector<NoShare>&, const vector<int>&) const { fail(); }
|
||||
void bitcom(StackedVector<NoShare>&, const vector<int>&) const { fail(); }
|
||||
|
||||
void assign(const char*) { fail(); }
|
||||
|
||||
@@ -190,6 +193,8 @@ public:
|
||||
|
||||
NoShare& operator+=(const NoShare&) { fail(); return *this; }
|
||||
|
||||
bool operator==(NoShare) const { fail(); return false; }
|
||||
|
||||
NoShare get_bit(int) const { fail(); return {}; }
|
||||
|
||||
void xor_bit(int, NoShare) const { fail(); }
|
||||
@@ -201,6 +206,8 @@ public:
|
||||
|
||||
void input(istream&, bool) { fail(); }
|
||||
void output(ostream&, bool) { fail(); }
|
||||
|
||||
void pack(octetStream&) const { fail(); }
|
||||
};
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
@@ -79,10 +79,11 @@ template<class T>
|
||||
void PersonalPrep<T>::buffer_personal_triples(vector<array<T, 3>>& triples,
|
||||
size_t begin, size_t end)
|
||||
{
|
||||
#ifdef VERBOSE_EDA
|
||||
fprintf(stderr, "personal triples %zu to %zu\n", begin, end);
|
||||
RunningTimer timer;
|
||||
#endif
|
||||
bool verbose = OnlineOptions::singleton.has_option("verbose_eda");
|
||||
if (verbose)
|
||||
fprintf(stderr, "personal triples %zu to %zu\n", begin, end);
|
||||
|
||||
auto& party = ShareThread<typename T::whole_type>::s();
|
||||
auto& MC = party.MC->get_part_MC();
|
||||
auto& P = *party.P;
|
||||
@@ -102,9 +103,9 @@ void PersonalPrep<T>::buffer_personal_triples(vector<array<T, 3>>& triples,
|
||||
input.exchange();
|
||||
for (size_t i = begin; i < end; i++)
|
||||
triples[i][2] = input.finalize(input_player, T::default_length);
|
||||
#ifdef VERBOSE_EDA
|
||||
fprintf(stderr, "personal triples took %f seconds\n", timer.elapsed());
|
||||
#endif
|
||||
|
||||
if (verbose)
|
||||
fprintf(stderr, "personal triples took %f seconds\n", timer.elapsed());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ using namespace std;
|
||||
#include "Math/Integer.h"
|
||||
#include "Processor/ProcessorBase.h"
|
||||
#include "Processor/Instruction.h"
|
||||
#include "Tools/CheckVector.h"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
@@ -38,9 +39,9 @@ public:
|
||||
// rough measure for the memory usage
|
||||
size_t complexity;
|
||||
|
||||
Memory<T> S;
|
||||
Memory<Clear> C;
|
||||
Memory<Integer> I;
|
||||
StackedVector<T> S;
|
||||
StackedVector<Clear> C;
|
||||
StackedVector<Integer> I;
|
||||
|
||||
Timer xor_timer;
|
||||
|
||||
@@ -78,8 +79,8 @@ public:
|
||||
template<class U>
|
||||
void store_clear_in_dynamic(const vector<int>& args, U& dynamic_memory);
|
||||
|
||||
template<class U>
|
||||
void mem_op(int n, Memory<U>& dest, const Memory<U>& source,
|
||||
template<class U, class V>
|
||||
void mem_op(int n, U& dest, const V& source,
|
||||
Integer dest_address, Integer source_address);
|
||||
|
||||
void xors(const vector<int>& args);
|
||||
@@ -105,6 +106,9 @@ public:
|
||||
template<int = 0>
|
||||
void convcbit2s(const BaseInstruction& instruction);
|
||||
|
||||
void convcbitvec(const BaseInstruction& instruction, StackedVector<Integer>& Ci,
|
||||
Player* P);
|
||||
|
||||
void print_reg(int reg, int n, int size);
|
||||
void print_reg_plain(Clear& value);
|
||||
void print_reg_signed(unsigned n_bits, Integer value);
|
||||
@@ -114,6 +118,13 @@ public:
|
||||
void print_float_prec(int n);
|
||||
|
||||
void incint(const BaseInstruction& instruction);
|
||||
|
||||
void push_stack();
|
||||
void push_args(const vector<int>& args);
|
||||
void pop_stack(const vector<int>& results);
|
||||
|
||||
template<class U>
|
||||
void call_tape(const BaseInstruction& instruction, U& dynamic_memory);
|
||||
};
|
||||
|
||||
template <class T>
|
||||
|
||||
@@ -18,7 +18,7 @@ using namespace std;
|
||||
#include "Math/BitVec.h"
|
||||
|
||||
#include "GC/Machine.hpp"
|
||||
#include "Processor/ProcessorBase.hpp"
|
||||
#include "Processor/Processor.hpp"
|
||||
#include "Processor/IntInput.hpp"
|
||||
#include "Math/bigint.hpp"
|
||||
|
||||
@@ -53,9 +53,9 @@ template <class T>
|
||||
template <class U>
|
||||
void Processor<T>::reset(const U& program, int arg)
|
||||
{
|
||||
S.resize(program.num_reg(SBIT), "registers");
|
||||
C.resize(program.num_reg(CBIT), "registers");
|
||||
I.resize(program.num_reg(INT), "registers");
|
||||
S.resize(program.num_reg(SBIT));
|
||||
C.resize(program.num_reg(CBIT));
|
||||
I.resize(program.num_reg(INT));
|
||||
set_arg(arg);
|
||||
PC = 0;
|
||||
}
|
||||
@@ -202,14 +202,14 @@ void GC::Processor<T>::store_clear_in_dynamic(const vector<int>& args,
|
||||
}
|
||||
|
||||
template<class T>
|
||||
template<class U>
|
||||
void Processor<T>::mem_op(int n, Memory<U>& dest, const Memory<U>& source,
|
||||
template<class U, class V>
|
||||
void Processor<T>::mem_op(int n, U& dest, const V& source,
|
||||
Integer dest_address, Integer source_address)
|
||||
{
|
||||
dest.check_index(dest_address + n - 1);
|
||||
source.check_index(source_address + n - 1);
|
||||
auto d = &dest[dest_address];
|
||||
auto s = &source[source_address];
|
||||
auto d = &dest[dest_address.get()];
|
||||
auto s = &source[source_address.get()];
|
||||
for (int i = 0; i < n; i++)
|
||||
{
|
||||
*d++ = *s++;
|
||||
@@ -388,6 +388,30 @@ void Processor<T>::convcbit2s(const BaseInstruction& instruction)
|
||||
min(size_t(unit), instruction.get_n() - i * unit));
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Processor<T>::convcbitvec(const BaseInstruction& instruction,
|
||||
StackedVector<Integer>& Ci, Player* P)
|
||||
{
|
||||
vector<Integer> bits;
|
||||
auto n = instruction.get_n();
|
||||
bits.reserve(n);
|
||||
for (size_t i = 0; i < instruction.get_n(); i++)
|
||||
{
|
||||
int i1 = i / GC::Clear::N_BITS;
|
||||
int i2 = i % GC::Clear::N_BITS;
|
||||
auto bit = C[instruction.get_r(1) + i1].get_bit(i2);
|
||||
bits.push_back(bit);
|
||||
}
|
||||
|
||||
if (P)
|
||||
sync<T>(bits, *P);
|
||||
else if (not T::symmetric)
|
||||
sync<T>(bits, *Thread<T>::s().P);
|
||||
|
||||
for (size_t i = 0; i < n; i++)
|
||||
Ci[instruction.get_r(0) + i] = bits[i];
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void Processor<T>::print_reg(int reg, int n, int size)
|
||||
{
|
||||
@@ -417,7 +441,7 @@ void Processor<T>::print_reg_signed(unsigned n_bits, Integer reg)
|
||||
{
|
||||
if (n_bits <= Clear::N_BITS)
|
||||
{
|
||||
auto value = C[reg];
|
||||
auto value = C[reg.get()];
|
||||
unsigned n_shift = 0;
|
||||
if (n_bits > 1)
|
||||
n_shift = sizeof(value.get()) * 8 - n_bits;
|
||||
@@ -477,6 +501,56 @@ void Processor<T>::incint(const BaseInstruction& instruction)
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void GC::Processor<T>::push_stack()
|
||||
{
|
||||
S.push_stack();
|
||||
C.push_stack();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void GC::Processor<T>::push_args(const vector<int>& args)
|
||||
{
|
||||
S.push_args(args, SBIT);
|
||||
C.push_args(args, CBIT);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void GC::Processor<T>::pop_stack(const vector<int>& results)
|
||||
{
|
||||
S.pop_stack(results, SBIT);
|
||||
C.pop_stack(results, CBIT);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
template<class U>
|
||||
void Processor<T>::call_tape(const BaseInstruction& instruction, U& dynamic_memory)
|
||||
{
|
||||
auto new_arg = I.at(instruction.get_r(1)).get();
|
||||
|
||||
PC_stack.push_back(PC);
|
||||
arg_stack.push_back(this->arg);
|
||||
push_stack();
|
||||
I.push_stack();
|
||||
|
||||
auto& tape = machine->progs.at(instruction.get_r(0));
|
||||
reset(tape, new_arg);
|
||||
|
||||
auto& args = instruction.get_start();
|
||||
push_args(args);
|
||||
I.push_args(args, INT);
|
||||
|
||||
tape.execute(*this, dynamic_memory, PC);
|
||||
|
||||
pop_stack(args);
|
||||
I.pop_stack(args, INT);
|
||||
|
||||
PC = PC_stack.back();
|
||||
PC_stack.pop_back();
|
||||
this->arg = arg_stack.back();
|
||||
arg_stack.pop_back();
|
||||
}
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
#endif
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
|
||||
#include "Protocols/Rep4.hpp"
|
||||
#include "Protocols/Rep4Input.hpp"
|
||||
#include "Protocols/Replicated.hpp"
|
||||
#include "Protocols/ReplicatedPrep.hpp"
|
||||
|
||||
namespace GC
|
||||
|
||||
@@ -76,6 +76,10 @@ public:
|
||||
|
||||
static const bool actual_inputs = T::actual_inputs;
|
||||
|
||||
static const bool symmetric = true;
|
||||
|
||||
static bool real_shares(const Player&) { return true; }
|
||||
|
||||
static int threshold(int nplayers) { return T::threshold(nplayers); }
|
||||
|
||||
static Secret<T> input(party_id_t from, const int128& input, int n_bits = -1);
|
||||
@@ -148,9 +152,9 @@ public:
|
||||
Secret<T> operator>>(int i) const;
|
||||
|
||||
template<class U>
|
||||
void bitcom(Memory<U>& S, const vector<int>& regs);
|
||||
void bitcom(StackedVector<U>& S, const vector<int>& regs);
|
||||
template<class U>
|
||||
void bitdec(Memory<U>& S, const vector<int>& regs) const;
|
||||
void bitdec(StackedVector<U>& S, const vector<int>& regs) const;
|
||||
|
||||
Secret<T> operator+(const Secret<T>& x) const;
|
||||
Secret<T>& operator+=(const Secret<T>& x) { *this = *this + x; return *this; }
|
||||
|
||||
@@ -197,7 +197,7 @@ Secret<T> Secret<T>::operator>>(int i) const
|
||||
|
||||
template <class T>
|
||||
template <class U>
|
||||
void Secret<T>::bitcom(Memory<U>& S, const vector<int>& regs)
|
||||
void Secret<T>::bitcom(StackedVector<U>& S, const vector<int>& regs)
|
||||
{
|
||||
registers.clear();
|
||||
for (unsigned int i = 0; i < regs.size(); i++)
|
||||
@@ -210,7 +210,7 @@ void Secret<T>::bitcom(Memory<U>& S, const vector<int>& regs)
|
||||
|
||||
template <class T>
|
||||
template <class U>
|
||||
void Secret<T>::bitdec(Memory<U>& S, const vector<int>& regs) const
|
||||
void Secret<T>::bitdec(StackedVector<U>& S, const vector<int>& regs) const
|
||||
{
|
||||
if (regs.size() > registers.size())
|
||||
throw overflow("not enough bits for bit decomposition", regs.size(),
|
||||
|
||||
@@ -17,7 +17,7 @@ namespace GC
|
||||
void Semi::prepare_mult(const SemiSecret& x, const SemiSecret& y, int n,
|
||||
bool repeat)
|
||||
{
|
||||
if (repeat and OnlineOptions::singleton.live_prep)
|
||||
if (repeat and OnlineOptions::singleton.live_prep and (n < 0 or n > 1))
|
||||
{
|
||||
this->triples.push_back({{}});
|
||||
auto& triple = this->triples.back();
|
||||
@@ -35,6 +35,8 @@ void Semi::prepare_mult(const SemiSecret& x, const SemiSecret& y, int n,
|
||||
|
||||
void Semi::prepare_mul(const SemiSecret& x, const SemiSecret& y, int n)
|
||||
{
|
||||
if (n == -1)
|
||||
n = SemiSecret::default_length;
|
||||
super::prepare_mul(x.mask(n), y.mask(n), n);
|
||||
}
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ public:
|
||||
|
||||
void prepare_mult(const SemiSecret& x, const SemiSecret& y, int n,
|
||||
bool repeat);
|
||||
void prepare_mul(const SemiSecret& x, const SemiSecret& y, int n);
|
||||
void prepare_mul(const SemiSecret& x, const SemiSecret& y, int n = -1);
|
||||
};
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
@@ -79,6 +79,9 @@ array<SemiSecret, 3> SemiPrep::get_mixed_triple(int n)
|
||||
if (mixed_triples.empty())
|
||||
{
|
||||
assert(this->triple_generator);
|
||||
this->triple_generator->set_batch_size(
|
||||
BaseMachine::batch_size<SemiSecret>(DATA_MIXED,
|
||||
this->buffer_size));
|
||||
this->triple_generator->generateMixedTriples();
|
||||
for (auto& x : this->triple_generator->mixedTriples)
|
||||
{
|
||||
|
||||
@@ -38,6 +38,13 @@ public:
|
||||
|
||||
static const int default_length = sizeof(BitVec) * 8;
|
||||
|
||||
static const bool symmetric = V::symmetric;
|
||||
|
||||
static bool real_shares(const Player& P)
|
||||
{
|
||||
return V::real_shares(P);
|
||||
}
|
||||
|
||||
static string type_string() { return "binary secret"; }
|
||||
static string phase_name() { return "Binary computation"; }
|
||||
|
||||
@@ -64,8 +71,8 @@ public:
|
||||
|
||||
void load_clear(int n, const Integer& x);
|
||||
|
||||
void bitcom(Memory<T>& S, const vector<int>& regs);
|
||||
void bitdec(Memory<T>& S, const vector<int>& regs) const;
|
||||
void bitcom(StackedVector<T>& S, const vector<int>& regs);
|
||||
void bitdec(StackedVector<T>& S, const vector<int>& regs) const;
|
||||
|
||||
void xor_(int n, const T& x, const T& y)
|
||||
{ *this = BitVec(x ^ y).mask(n); }
|
||||
|
||||
@@ -70,30 +70,40 @@ void SemiSecret::andrsvec(Processor<SemiSecret>& processor,
|
||||
assert(protocol);
|
||||
protocol->init_mul();
|
||||
auto it = args.begin();
|
||||
int total_bits = 0, total_ops = 0;
|
||||
while (it < args.end())
|
||||
{
|
||||
int n_args = (*it++ - 3) / 2;
|
||||
int size = *it++;
|
||||
total_bits += n_args * size;
|
||||
it += n_args;
|
||||
int base = *it++;
|
||||
assert(n_args <= N_BITS);
|
||||
for (int i = 0; i < size; i += N_BITS)
|
||||
{
|
||||
square64 square;
|
||||
for (int j = 0; j < n_args; j++)
|
||||
square.rows[j] = processor.S.at(*(it + j) + i / N_BITS).get();
|
||||
int n_ops = min(N_BITS, size - i);
|
||||
square.transpose(n_args, n_ops);
|
||||
for (int j = 0; j < n_ops; j++)
|
||||
for (int k = 0; k < n_args; k += N_BITS)
|
||||
{
|
||||
long bit = processor.S.at(base + i / N_BITS).get_bit(j);
|
||||
auto y_ext = SemiSecret(bit).extend_bit();
|
||||
protocol->prepare_mult(square.rows[j], y_ext, n_args, true);
|
||||
int left = min(N_BITS, n_args - k);
|
||||
square64 square;
|
||||
for (int j = 0; j < left; j++)
|
||||
square.rows[j] = processor.S.at(
|
||||
*(it + k + j) + i / N_BITS).get();
|
||||
int n_ops = min(N_BITS, size - i);
|
||||
total_ops += n_ops;
|
||||
square.transpose(left, n_ops);
|
||||
for (int j = 0; j < n_ops; j++)
|
||||
{
|
||||
long bit = processor.S.at(base + i / N_BITS).get_bit(j);
|
||||
auto y_ext = SemiSecret(bit).extend_bit();
|
||||
protocol->prepare_mult(square.rows[j], y_ext, left, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
it += n_args;
|
||||
}
|
||||
|
||||
if (OnlineOptions::singleton.has_option("verbose_and"))
|
||||
fprintf(stderr, "%d/%d repeat ANDs\n", total_bits, total_ops);
|
||||
|
||||
protocol->exchange();
|
||||
|
||||
it = args.begin();
|
||||
@@ -103,13 +113,18 @@ void SemiSecret::andrsvec(Processor<SemiSecret>& processor,
|
||||
int size = *it++;
|
||||
for (int i = 0; i < size; i += N_BITS)
|
||||
{
|
||||
int n_ops = min(N_BITS, size - i);
|
||||
square64 square;
|
||||
for (int j = 0; j < n_ops; j++)
|
||||
square.rows[j] = protocol->finalize_mul(n_args).get();
|
||||
square.transpose(n_ops, n_args);
|
||||
for (int j = 0; j < n_args; j++)
|
||||
processor.S.at(*(it + j) + i / N_BITS) = square.rows[j];
|
||||
for (int base = 0; base < n_args; base += N_BITS)
|
||||
{
|
||||
int left = min(N_BITS, n_args - base);
|
||||
int n_ops = min(N_BITS, size - i);
|
||||
square64 square;
|
||||
for (int j = 0; j < n_ops; j++)
|
||||
square.rows[j] = protocol->finalize_mul(left).get();
|
||||
square.transpose(n_ops, left);
|
||||
for (int j = 0; j < left; j++)
|
||||
processor.S.at(*(it + base + j) + i / N_BITS) =
|
||||
square.rows[j];
|
||||
}
|
||||
}
|
||||
it += 2 * n_args + 1;
|
||||
}
|
||||
@@ -123,7 +138,7 @@ void SemiSecretBase<T, V>::load_clear(int n, const Integer& x)
|
||||
}
|
||||
|
||||
template<class T, class V>
|
||||
void SemiSecretBase<T, V>::bitcom(Memory<T>& S, const vector<int>& regs)
|
||||
void SemiSecretBase<T, V>::bitcom(StackedVector<T>& S, const vector<int>& regs)
|
||||
{
|
||||
*this = 0;
|
||||
for (unsigned int i = 0; i < regs.size(); i++)
|
||||
@@ -131,7 +146,7 @@ void SemiSecretBase<T, V>::bitcom(Memory<T>& S, const vector<int>& regs)
|
||||
}
|
||||
|
||||
template<class T, class V>
|
||||
void SemiSecretBase<T, V>::bitdec(Memory<T>& S,
|
||||
void SemiSecretBase<T, V>::bitdec(StackedVector<T>& S,
|
||||
const vector<int>& regs) const
|
||||
{
|
||||
for (unsigned int i = 0; i < regs.size(); i++)
|
||||
|
||||
@@ -108,18 +108,9 @@ ShareParty<T>::ShareParty(int argc, const char** argv, ez::ezOptionParser& opt,
|
||||
else
|
||||
P = new PlainPlayer(this->N, "shareparty");
|
||||
|
||||
try
|
||||
{
|
||||
read_mac_key(
|
||||
get_prep_sub_dir<typename T::part_type>(PREP_DIR, network_opts.nplayers),
|
||||
this->N,
|
||||
this->mac_key);
|
||||
}
|
||||
catch (exception& e)
|
||||
{
|
||||
SeededPRNG G;
|
||||
this->mac_key.randomize(G);
|
||||
}
|
||||
T::read_or_generate_mac_key(
|
||||
get_prep_sub_dir<typename T::part_type>(PREP_DIR, network_opts.nplayers),
|
||||
*P, this->mac_key);
|
||||
|
||||
T::MC::setup(*P);
|
||||
|
||||
|
||||
@@ -46,6 +46,9 @@ public:
|
||||
|
||||
static const bool is_real = true;
|
||||
static const bool actual_inputs = true;
|
||||
static const bool symmetric = true;
|
||||
|
||||
static bool real_shares(const Player&) { return true; }
|
||||
|
||||
static ShareThread<U>& get_party()
|
||||
{
|
||||
@@ -118,6 +121,7 @@ public:
|
||||
typedef BitVec open_type;
|
||||
typedef NoShare mac_type;
|
||||
typedef NoValue mac_key_type;
|
||||
typedef NoShare mac_share_type;
|
||||
|
||||
typedef NoShare bit_type;
|
||||
|
||||
@@ -151,6 +155,11 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
static GC::NoValue get_mac_key()
|
||||
{
|
||||
throw runtime_error("no MAC");
|
||||
}
|
||||
|
||||
template<class T>
|
||||
static string proto_fake_opts()
|
||||
{
|
||||
@@ -166,8 +175,8 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
void bitcom(Memory<U>& S, const vector<int>& regs);
|
||||
void bitdec(Memory<U>& S, const vector<int>& regs) const;
|
||||
void bitcom(StackedVector<U>& S, const vector<int>& regs);
|
||||
void bitdec(StackedVector<U>& S, const vector<int>& regs) const;
|
||||
|
||||
void xor_(int n, const This& x, const This& y)
|
||||
{ *this = (x ^ y).mask(n); }
|
||||
|
||||
@@ -54,7 +54,7 @@ void ReplicatedSecret<U>::load_clear(int n, const Integer& x)
|
||||
}
|
||||
|
||||
template<class U, int L>
|
||||
void RepSecretBase<U, L>::bitcom(Memory<U>& S, const vector<int>& regs)
|
||||
void RepSecretBase<U, L>::bitcom(StackedVector<U>& S, const vector<int>& regs)
|
||||
{
|
||||
*this = 0;
|
||||
for (unsigned int i = 0; i < regs.size(); i++)
|
||||
@@ -62,7 +62,7 @@ void RepSecretBase<U, L>::bitcom(Memory<U>& S, const vector<int>& regs)
|
||||
}
|
||||
|
||||
template<class U, int L>
|
||||
void RepSecretBase<U, L>::bitdec(Memory<U>& S, const vector<int>& regs) const
|
||||
void RepSecretBase<U, L>::bitdec(StackedVector<U>& S, const vector<int>& regs) const
|
||||
{
|
||||
for (unsigned int i = 0; i < regs.size(); i++)
|
||||
S[regs[i]] = (*this >> i) & 1;
|
||||
|
||||
@@ -94,23 +94,35 @@ void ShareThread<T>::and_(Processor<T>& processor,
|
||||
processor.check_args(args, 4);
|
||||
protocol->init_mul();
|
||||
T x_ext, y_ext;
|
||||
int total_bits = 0;
|
||||
for (size_t i = 0; i < args.size(); i += 4)
|
||||
{
|
||||
int n_bits = args[i];
|
||||
total_bits += n_bits;
|
||||
int left = args[i + 2];
|
||||
int right = args[i + 3];
|
||||
for (int j = 0; j < DIV_CEIL(n_bits, T::default_length); j++)
|
||||
{
|
||||
int n = min(T::default_length, n_bits - j * T::default_length);
|
||||
|
||||
if (not repeat and n == T::default_length)
|
||||
{
|
||||
protocol->prepare_mul(processor.S[left + j], processor.S[right + j]);
|
||||
continue;
|
||||
}
|
||||
|
||||
processor.S[left + j].mask(x_ext, n);
|
||||
if (repeat)
|
||||
processor.S[right].extend_bit(y_ext, n);
|
||||
else
|
||||
processor.S[right + j].mask(y_ext, n);
|
||||
processor.S[left + j].mask(x_ext, n);
|
||||
protocol->prepare_mult(x_ext, y_ext, n, repeat);
|
||||
}
|
||||
}
|
||||
|
||||
if (OnlineOptions::singleton.has_option("verbose_and"))
|
||||
fprintf(stderr, "%d%s ANDs\n", total_bits, repeat ? " repeat" : "");
|
||||
|
||||
protocol->exchange();
|
||||
|
||||
for (size_t i = 0; i < args.size(); i += 4)
|
||||
@@ -121,6 +133,13 @@ void ShareThread<T>::and_(Processor<T>& processor,
|
||||
{
|
||||
int n = min(T::default_length, n_bits - j * T::default_length);
|
||||
auto& res = processor.S[out + j];
|
||||
|
||||
if (not repeat and n == T::default_length)
|
||||
{
|
||||
res = protocol->finalize_mul();
|
||||
continue;
|
||||
}
|
||||
|
||||
protocol->finalize_mult(res, n);
|
||||
res.mask(res, n);
|
||||
}
|
||||
@@ -136,10 +155,12 @@ void ShareThread<T>::andrsvec(Processor<T>& processor, const vector<int>& args)
|
||||
protocol->init_mul();
|
||||
auto it = args.begin();
|
||||
T x_ext, y_ext;
|
||||
int total_bits = 0;
|
||||
while (it < args.end())
|
||||
{
|
||||
int n_args = (*it++ - 3) / 2;
|
||||
int size = *it++;
|
||||
total_bits += size * n_args;
|
||||
it += n_args;
|
||||
int base = *it++;
|
||||
for (int i = 0; i < size; i += N_BITS)
|
||||
@@ -155,6 +176,9 @@ void ShareThread<T>::andrsvec(Processor<T>& processor, const vector<int>& args)
|
||||
it += n_args;
|
||||
}
|
||||
|
||||
if (OnlineOptions::singleton.has_option("verbose_and"))
|
||||
fprintf(stderr, "%d repeat ANDs\n", total_bits);
|
||||
|
||||
protocol->exchange();
|
||||
|
||||
it = args.begin();
|
||||
|
||||
@@ -32,6 +32,8 @@ public:
|
||||
array<T, 3> get_triple_no_count(int n_bits)
|
||||
{
|
||||
int max_n_bits = T::default_length;
|
||||
if (n_bits == -1)
|
||||
n_bits = max_n_bits;
|
||||
assert(n_bits <= max_n_bits);
|
||||
assert(n_bits > 0);
|
||||
array<T, 3> res;
|
||||
|
||||
@@ -53,9 +53,15 @@ public:
|
||||
static const bool malicious = T::malicious;
|
||||
static const bool expensive_triples = T::expensive_triples;
|
||||
static const bool randoms_for_opens = false;
|
||||
static const bool symmetric = true;
|
||||
|
||||
static const int default_length = 64;
|
||||
|
||||
static bool real_shares(const Player&)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
static int size()
|
||||
{
|
||||
return part_type::size() * default_length;
|
||||
@@ -72,6 +78,11 @@ public:
|
||||
T::read_or_generate_mac_key(directory, P, key);
|
||||
}
|
||||
|
||||
static typename T::mac_type get_mac_key()
|
||||
{
|
||||
return T::get_mac_key();
|
||||
}
|
||||
|
||||
template<class U>
|
||||
static void reveal_inst(U& processor, const vector<int>& args)
|
||||
{
|
||||
|
||||
@@ -80,14 +80,15 @@
|
||||
X(INPUTBVEC, T::inputbvec(PROC, Proc, EXTRA)) \
|
||||
X(CONVSINT, S0.load_clear(IMM, Proc.read_Ci(REG1))) \
|
||||
X(CONVCINT, C0 = Proc.read_Ci(REG1)) \
|
||||
X(CONVCBIT, Proc.write_Ci(R0, PC1.get())) \
|
||||
X(CONVCBIT, Proc.write_Ci(R0, Proc.sync(PC1.get()))) \
|
||||
X(CONVCINTVEC, Proc.convcintvec(instruction)) \
|
||||
X(CONVCBITVEC, Proc.convcbitvec(instruction)) \
|
||||
X(CONVCBITVEC, Proc.Procb.convcbitvec(instruction, Proc.get_Ci(), &Proc.P)) \
|
||||
X(CONVCBIT2S, PROC.convcbit2s(instruction)) \
|
||||
X(DABIT, Proc.dabit(INST)) \
|
||||
X(EDABIT, Proc.edabit(INST)) \
|
||||
X(SEDABIT, Proc.edabit(INST, true)) \
|
||||
X(SPLIT, Proc.split(INST)) \
|
||||
X(CALL_ARG, ) \
|
||||
|
||||
#define GC_INSTRUCTIONS \
|
||||
X(INPUTB, T::inputb(PROC, EXTRA)) \
|
||||
@@ -101,6 +102,7 @@
|
||||
X(CONVCINT, C0 = PI1) \
|
||||
X(CONVCBIT, T::convcbit(I0, PC1, PROC)) \
|
||||
X(CONVCBIT2S, T::convcbit2s(PROC, instruction)) \
|
||||
X(CONVCBITVEC, PROC.convcbitvec(instruction, Ci, 0)) \
|
||||
X(PRINTCHR, PROC.print_chr(IMM)) \
|
||||
X(PRINTSTR, PROC.print_str(IMM)) \
|
||||
X(PRINTFLOATPREC, PROC.print_float_prec(IMM)) \
|
||||
@@ -146,8 +148,11 @@
|
||||
X(NPLAYERS, I0 = Thread<T>::s().P->num_players()) \
|
||||
X(THRESHOLD, I0 = T::threshold(Thread<T>::s().P->num_players())) \
|
||||
X(PLAYERID, I0 = Thread<T>::s().P->my_num()) \
|
||||
X(CRASH, if (I0.get()) throw crash_requested()) \
|
||||
X(CRASH, if (I0.get() and T::actual_inputs) throw crash_requested()) \
|
||||
X(ACTIVE, ) \
|
||||
X(LDTN, I0 = BaseMachine::thread_num) \
|
||||
X(CALL_TAPE, PROC.call_tape(INST, MD)) \
|
||||
X(CALL_ARG, ) \
|
||||
|
||||
#define INSTRUCTIONS BIT_INSTRUCTIONS GC_INSTRUCTIONS
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
The Software is copyright (c) 2023, Commonwealth Scientific and Industrial Research Organisation (CSIRO) ABN 41 687 119 230.
|
||||
The Software is copyright (c) 2024, Commonwealth Scientific and Industrial Research Organisation (CSIRO) ABN 41 687 119 230.
|
||||
|
||||
CSIRO grants you a licence to the Software on the terms of the BSD 3-Clause Licence.
|
||||
|
||||
|
||||
@@ -232,9 +232,9 @@ OTMachine::OTMachine(int argc, const char** argv)
|
||||
gettimeofday(&baseOTstart, NULL);
|
||||
// swap role for base OTs
|
||||
if (opt.isSet("-r"))
|
||||
bot_ = new BaseOT(nbase, 128, P, INV_ROLE(ot_role));
|
||||
bot_ = new BaseOT(nbase, P, INV_ROLE(ot_role));
|
||||
else
|
||||
bot_ = new FakeOT(nbase, 128, P, INV_ROLE(ot_role));
|
||||
bot_ = new FakeOT(nbase, P, INV_ROLE(ot_role));
|
||||
cout << "real mode " << opt.isSet("-r") << endl;
|
||||
BaseOT& bot = *bot_;
|
||||
bot.exec_base();
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "GC/TinierSharePrep.hpp"
|
||||
#include "GC/CcdPrep.hpp"
|
||||
#include "GC/PersonalPrep.hpp"
|
||||
#include "Protocols/Share.hpp"
|
||||
|
||||
//template class GC::ShareParty<GC::TinierSecret<gf2n_mac_key>>;
|
||||
template class GC::CcdPrep<GC::TinierSecret<gf2n_mac_key>>;
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "Math/BitVec.h"
|
||||
#include "GC/TinierSecret.h"
|
||||
|
||||
#include "Protocols/Share.hpp"
|
||||
#include "Protocols/fake-stuff.hpp"
|
||||
#include "Protocols/MascotPrep.hpp"
|
||||
#include "Math/Z2k.hpp"
|
||||
|
||||
@@ -39,5 +39,6 @@ int main(int argc, const char** argv)
|
||||
return run<2, 1>(machine);
|
||||
|
||||
cerr << "Not compiled for choice of parameters" << endl;
|
||||
cerr << "Try using '-lgp 128'" << endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
#include "Protocols/SpdzWisePrep.hpp"
|
||||
#include "Protocols/SpdzWiseInput.hpp"
|
||||
#include "Protocols/SpdzWiseShare.hpp"
|
||||
#include "Protocols/SpdzWiseRep3Shuffler.hpp"
|
||||
#include "Processor/Data_Files.hpp"
|
||||
#include "Processor/Instruction.hpp"
|
||||
#include "Processor/Machine.hpp"
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
#include "Protocols/SpdzWisePrep.hpp"
|
||||
#include "Protocols/SpdzWiseInput.hpp"
|
||||
#include "Protocols/SpdzWiseShare.hpp"
|
||||
#include "Protocols/SpdzWiseRep3Shuffler.hpp"
|
||||
#include "Protocols/PostSacrifice.hpp"
|
||||
#include "Protocols/MalRepRingPrep.hpp"
|
||||
#include "Protocols/MaliciousRepPrep.hpp"
|
||||
|
||||
@@ -26,6 +26,7 @@
|
||||
#include "Protocols/Beaver.hpp"
|
||||
#include "Protocols/MascotPrep.hpp"
|
||||
#include "Protocols/MalRepRingPrep.hpp"
|
||||
#include "Protocols/Share.hpp"
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
|
||||
@@ -26,6 +26,7 @@
|
||||
#include "Protocols/MAC_Check_Base.hpp"
|
||||
#include "Protocols/Beaver.hpp"
|
||||
#include "Protocols/MascotPrep.hpp"
|
||||
#include "Protocols/Share.hpp"
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
|
||||
10
Makefile
10
Makefile
@@ -52,7 +52,7 @@ endif
|
||||
endif
|
||||
|
||||
# used for dependency generation
|
||||
OBJS = $(patsubst %.cpp,%.o,$(wildcard */*.cpp)) $(STATIC_OTE)
|
||||
OBJS = $(patsubst %.cpp,%.o,$(wildcard */*.cpp */*/*.cpp)) $(STATIC_OTE)
|
||||
DEPS := $(wildcard */*.d */*/*.d)
|
||||
|
||||
# never delete
|
||||
@@ -150,13 +150,17 @@ static/%.x: Machines/%.o $(LIBRELEASE) $(LIBSIMPLEOT) local/lib/libcryptoTools.a
|
||||
$(MAKE) static-dir
|
||||
$(CXX) -o $@ $(CFLAGS) $^ -Wl,-Map=$<.map -Wl,-Bstatic -static-libgcc -static-libstdc++ $(LIBRELEASE) -llibOTe -lcryptoTools $(LIBSIMPLEOT) $(BOOST) $(LDLIBS) -Wl,-Bdynamic -ldl
|
||||
|
||||
static/%.x: Machines/BMR/%.o $(LIBRELEASE) $(LIBSIMPLEOT) local/lib/libcryptoTools.a local/lib/liblibOTe.a
|
||||
$(MAKE) static-dir
|
||||
$(CXX) -o $@ $(CFLAGS) $^ -Wl,-Map=$<.map -Wl,-Bstatic -static-libgcc -static-libstdc++ $(LIBRELEASE) -llibOTe -lcryptoTools $(LIBSIMPLEOT) $(BOOST) $(LDLIBS) -Wl,-Bdynamic -ldl
|
||||
|
||||
static/%.x: ECDSA/%.o ECDSA/P256Element.o $(VMOBJS) $(OT) $(LIBSIMPLEOT)
|
||||
$(CXX) $(CFLAGS) -o $@ $^ -Wl,-Map=$<.map -Wl,-Bstatic -static-libgcc -static-libstdc++ $(BOOST) $(LDLIBS) -Wl,-Bdynamic -ldl
|
||||
|
||||
static-dir:
|
||||
@ mkdir static 2> /dev/null; true
|
||||
|
||||
static-release: static-dir $(patsubst Machines/%.cpp, static/%.x, $(wildcard Machines/*-party.cpp)) static/emulate.x
|
||||
static-release: static-dir $(patsubst Machines/%.cpp, static/%.x, $(wildcard Machines/*-party.cpp)) $(patsubst Machines/BMR/%.cpp, static/%.x, $(wildcard Machines/BMR/*-party.cpp)) static/emulate.x
|
||||
|
||||
Fake-ECDSA.x: ECDSA/Fake-ECDSA.cpp ECDSA/P256Element.o $(COMMON) Processor/PrepBase.o
|
||||
$(CXX) -o $@ $^ $(CFLAGS) $(LDLIBS)
|
||||
@@ -352,7 +356,7 @@ cmake:
|
||||
wget https://github.com/Kitware/CMake/releases/download/v3.24.1/cmake-3.24.1.tar.gz
|
||||
tar xzvf cmake-3.24.1.tar.gz
|
||||
cd cmake-3.24.1; \
|
||||
./bootstrap --parallel=8 --prefix=../local && make && make install
|
||||
./bootstrap --parallel=8 --prefix=../local && make -j8 && make install
|
||||
|
||||
mac-setup: mac-machine-setup
|
||||
brew install openssl boost libsodium gmp yasm ntl cmake
|
||||
|
||||
@@ -67,8 +67,8 @@ public:
|
||||
{
|
||||
if (n == -1)
|
||||
pack(os);
|
||||
else if (n == 1)
|
||||
os.store_bit(this->a);
|
||||
else if (n < 8)
|
||||
os.store_bits(this->a, n);
|
||||
else
|
||||
os.store_int(super::mask(n).get(), DIV_CEIL(n, 8));
|
||||
}
|
||||
@@ -77,8 +77,8 @@ public:
|
||||
{
|
||||
if (n == -1)
|
||||
unpack(os);
|
||||
else if (n == 1)
|
||||
this->a = os.get_bit();
|
||||
else if (n < 8)
|
||||
this->a = os.get_bits(n);
|
||||
else
|
||||
this->a = os.get_int(DIV_CEIL(n, 8));
|
||||
}
|
||||
|
||||
@@ -161,6 +161,11 @@ public:
|
||||
return equal(1);
|
||||
}
|
||||
|
||||
bool operator==(const FixedVec<T, L>& other) const
|
||||
{
|
||||
return equal(other);
|
||||
}
|
||||
|
||||
bool operator!=(const FixedVec<T, L>& other) const
|
||||
{
|
||||
return not equal(other);
|
||||
|
||||
15
Math/Z2k.h
15
Math/Z2k.h
@@ -306,7 +306,7 @@ public:
|
||||
return operator*(SignedZ2<64>(other));
|
||||
}
|
||||
|
||||
void output(ostream& s, bool human = true) const;
|
||||
void output(ostream& s, bool human = true, bool signed_ = true) const;
|
||||
};
|
||||
|
||||
template<int K>
|
||||
@@ -479,12 +479,17 @@ SignedZ2<K> abs(const SignedZ2<K>& x)
|
||||
}
|
||||
|
||||
template<int K>
|
||||
void SignedZ2<K>::output(ostream& s, bool human) const
|
||||
void SignedZ2<K>::output(ostream& s, bool human, bool signed_) const
|
||||
{
|
||||
if (human)
|
||||
{
|
||||
bigint::tmp = *this;
|
||||
s << bigint::tmp;
|
||||
if (signed_)
|
||||
{
|
||||
bigint::tmp = *this;
|
||||
s << bigint::tmp;
|
||||
}
|
||||
else
|
||||
Z2<K>::output(s, human);
|
||||
}
|
||||
else
|
||||
Z2<K>::output(s, false);
|
||||
@@ -493,7 +498,7 @@ void SignedZ2<K>::output(ostream& s, bool human) const
|
||||
template<int K>
|
||||
ostream& operator<<(ostream& o, const SignedZ2<K>& x)
|
||||
{
|
||||
x.output(o, true);
|
||||
x.output(o, true, false);
|
||||
return o;
|
||||
}
|
||||
|
||||
|
||||
@@ -510,6 +510,7 @@ gf2n_short::gf2n_short(const int128& a)
|
||||
// Expansion is by x=y^5+1 (as we embed GF(256) into GF(2^40)
|
||||
void expand_byte(gf2n_short& a,int b)
|
||||
{
|
||||
gf2n_short::init_field(40);
|
||||
gf2n_short x,xp;
|
||||
x = (32+1);
|
||||
xp.assign_one();
|
||||
|
||||
17
Math/gfp.h
17
Math/gfp.h
@@ -107,7 +107,7 @@ class gfp_ : public ValueInterface
|
||||
static void write_setup(string dir)
|
||||
{ write_online_setup(dir, pr()); }
|
||||
static void check_setup(string dir);
|
||||
static string fake_opts() { return " -lgp " + to_string(length()); }
|
||||
static string fake_opts() { return " -P " + to_string(pr()); }
|
||||
|
||||
/**
|
||||
* Get the prime modulus
|
||||
@@ -229,18 +229,25 @@ class gfp_ : public ValueInterface
|
||||
// faster randomization, see implementation for explanation
|
||||
void almost_randomize(PRNG& G);
|
||||
|
||||
void output(ostream& s,bool human) const
|
||||
{ a.output(s,ZpD,human); }
|
||||
/**
|
||||
* Output.
|
||||
* @param s output stream
|
||||
* @param human human-readable or binary
|
||||
* @param signed_ signed representation (range `[-p/2,p/2]` instead of `[0,p]`)
|
||||
*/
|
||||
void output(ostream& s, bool human, bool signed_ = false) const
|
||||
{ a.output(s,ZpD, human, signed_); }
|
||||
void input(istream& s,bool human)
|
||||
{ a.input(s,ZpD,human); }
|
||||
|
||||
/**
|
||||
* Human-readable output in the range `[-p/2, p/2]`.
|
||||
* Human-readable output in the range `[0, p]`.
|
||||
* @param s output stream
|
||||
* @param x value
|
||||
*/
|
||||
friend ostream& operator<<(ostream& s,const gfp_& x)
|
||||
{ x.output(s,true);
|
||||
{
|
||||
x.output(s, true, false);
|
||||
return s;
|
||||
}
|
||||
/**
|
||||
|
||||
@@ -82,7 +82,7 @@ public:
|
||||
{
|
||||
write_setup(get_prep_sub_dir<T>(nplayers));
|
||||
}
|
||||
static string fake_opts() { return " -lgp " + to_string(length()); }
|
||||
static string fake_opts() { return " -P " + to_string(pr()); }
|
||||
|
||||
gfpvar_();
|
||||
gfpvar_(int other);
|
||||
|
||||
@@ -132,7 +132,7 @@ class modp_
|
||||
// - Can do in human or machine only format (later should be faster)
|
||||
// - If human output appends a space to help with reading
|
||||
// and also convert back/forth from Montgomery if needed
|
||||
void output(ostream& s,const Zp_Data& ZpD,bool human) const;
|
||||
void output(ostream& s, const Zp_Data& ZpD, bool human, bool signed_ = false) const;
|
||||
void input(istream& s,const Zp_Data& ZpD,bool human);
|
||||
|
||||
template<int X, int K>
|
||||
|
||||
@@ -327,12 +327,12 @@ void Power(modp_<L>& ans,const modp_<L>& x,const bigint& exp,const Zp_Data& ZpD)
|
||||
|
||||
|
||||
template<int L>
|
||||
void modp_<L>::output(ostream& s,const Zp_Data& ZpD,bool human) const
|
||||
void modp_<L>::output(ostream& s, const Zp_Data& ZpD, bool human, bool signed_) const
|
||||
{
|
||||
if (human)
|
||||
{ bigint te;
|
||||
to_bigint(te, ZpD);
|
||||
if (te < ZpD.pr / 2)
|
||||
if (te < ZpD.pr / 2 or not signed_)
|
||||
s << te;
|
||||
else
|
||||
s << (te - ZpD.pr);
|
||||
|
||||
@@ -10,12 +10,12 @@
|
||||
void check_ssl_file(string filename)
|
||||
{
|
||||
if (not ifstream(filename))
|
||||
throw runtime_error("Cannot access " + filename
|
||||
exit_error("Cannot access " + filename
|
||||
+ ". Have you set up SSL?\n"
|
||||
"You can use `Scripts/setup-ssl.sh <nparties>`.");
|
||||
}
|
||||
|
||||
void ssl_error(string side, string other, string me)
|
||||
void ssl_error(string side, string other, string me, exception& e)
|
||||
{
|
||||
cerr << side << "-side handshake with " << other
|
||||
<< " failed. Make sure both sides "
|
||||
@@ -48,6 +48,8 @@ void ssl_error(string side, string other, string me)
|
||||
cerr << "/";
|
||||
}
|
||||
cerr << endl;
|
||||
cerr << "SSL error: " << e.what() << endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
CryptoPlayer::CryptoPlayer(const Names& Nms, const string& id_base) :
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "Networking/Server.h"
|
||||
#include "Networking/ServerSocket.h"
|
||||
#include "Networking/Exchanger.h"
|
||||
#include "Processor/OnlineOptions.h"
|
||||
|
||||
#include <sys/select.h>
|
||||
#include <utility>
|
||||
@@ -78,7 +79,7 @@ void Names::init(int player, int pnb, const string& filename, int nplayers_wante
|
||||
}
|
||||
}
|
||||
if (nplayers_wanted > 0 and nplayers_wanted != nplayers)
|
||||
throw runtime_error("not enough hosts in " + filename);
|
||||
exit_error("not enough hosts in " + filename);
|
||||
#ifdef DEBUG_NETWORKING
|
||||
cerr << "Got list of " << nplayers << " players from file: " << endl;
|
||||
for (unsigned int i = 0; i < names.size(); i++)
|
||||
@@ -127,7 +128,17 @@ void Names::setup_names(const char *servername, int my_port)
|
||||
|
||||
int socket_num;
|
||||
int pn = portnum_base;
|
||||
set_up_client_socket(socket_num, servername, pn);
|
||||
|
||||
try
|
||||
{
|
||||
set_up_client_socket(socket_num, servername, pn);
|
||||
}
|
||||
catch (exception& e)
|
||||
{
|
||||
exit_error(
|
||||
string("cannot reach coordination server: ") + e.what());
|
||||
}
|
||||
|
||||
octetStream("P" + to_string(player_no)).Send(socket_num);
|
||||
#ifdef DEBUG_NETWORKING
|
||||
cerr << "Sent " << player_no << " to " << servername << ":" << pn << endl;
|
||||
@@ -155,11 +166,11 @@ void Names::setup_names(const char *servername, int my_port)
|
||||
}
|
||||
catch (exception& e)
|
||||
{
|
||||
throw runtime_error(string("error in network setup: ") + e.what());
|
||||
exit_error(string("error in network setup: ") + e.what());
|
||||
}
|
||||
|
||||
if (names.size() != ports.size())
|
||||
throw runtime_error("invalid network setup");
|
||||
exit_error("invalid network setup");
|
||||
nplayers = names.size();
|
||||
#ifdef VERBOSE
|
||||
for (int i = 0; i < nplayers; i++)
|
||||
@@ -288,7 +299,15 @@ void PlainPlayer::setup_sockets(const vector<string>& names,
|
||||
"Setting up send to self socket to %s:%d with id %s\n",
|
||||
localhost, ports[i], pn.c_str());
|
||||
#endif
|
||||
set_up_client_socket(sockets[i],localhost,ports[i]);
|
||||
try
|
||||
{
|
||||
set_up_client_socket(sockets[i],localhost,ports[i]);
|
||||
}
|
||||
catch (exception& e)
|
||||
{
|
||||
exit_error("cannot connect to myself, "
|
||||
"maybe check your firewall configuration");
|
||||
}
|
||||
} else {
|
||||
#ifdef DEBUG_NETWORKING
|
||||
fprintf(stderr, "Setting up client to %s:%d with id %s\n",
|
||||
@@ -762,7 +781,7 @@ NamedCommStats& NamedCommStats::operator +=(const NamedCommStats& other)
|
||||
{
|
||||
sent += other.sent;
|
||||
for (auto it = other.begin(); it != other.end(); it++)
|
||||
(*this)[it->first] += it->second;
|
||||
map<string, CommStats>::operator[](it->first) += it->second;
|
||||
return *this;
|
||||
}
|
||||
|
||||
@@ -786,7 +805,7 @@ NamedCommStats NamedCommStats::operator -(const NamedCommStats& other) const
|
||||
NamedCommStats res = *this;
|
||||
res.sent = sent - other.sent;
|
||||
for (auto it = other.begin(); it != other.end(); it++)
|
||||
res[it->first] -= it->second;
|
||||
res.map<string, CommStats>::operator[](it->first) -= it->second;
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -818,9 +837,25 @@ Timer& NamedCommStats::add_to_last_round(const string& name, size_t length)
|
||||
}
|
||||
}
|
||||
|
||||
void PlayerBase::reset_stats()
|
||||
Timer& CommStatsWithName::add_length_only(size_t length)
|
||||
{
|
||||
if (OnlineOptions::singleton.has_option("verbose_comm"))
|
||||
fprintf(stderr, "%s %zu bytes in same round\n", name.c_str(), length);
|
||||
return stats.add_length_only(length);
|
||||
}
|
||||
|
||||
Timer& CommStatsWithName::add(const octetStream& os)
|
||||
{
|
||||
if (OnlineOptions::singleton.has_option("verbose_comm"))
|
||||
fprintf(stderr, "%s %zu bytes\n", name.c_str(), os.get_length());
|
||||
return stats.add(os);
|
||||
}
|
||||
|
||||
void Player::reset_stats()
|
||||
{
|
||||
comm_stats.reset();
|
||||
for (auto& x : thread_stats)
|
||||
x.reset();
|
||||
}
|
||||
|
||||
NamedCommStats Player::total_comm() const
|
||||
|
||||
@@ -141,18 +141,28 @@ struct CommStats
|
||||
}
|
||||
Timer& add_length_only(size_t length)
|
||||
{
|
||||
#ifdef VERBOSE_COMM
|
||||
cout << "add " << length << endl;
|
||||
#endif
|
||||
data += length;
|
||||
return timer;
|
||||
}
|
||||
Timer& add(const octetStream& os) { return add(os.get_length()); }
|
||||
void add(const octetStream& os, const TimeScope& scope) { add(os) += scope; }
|
||||
CommStats& operator+=(const CommStats& other);
|
||||
CommStats& operator-=(const CommStats& other);
|
||||
};
|
||||
|
||||
class CommStatsWithName
|
||||
{
|
||||
const string& name;
|
||||
CommStats& stats;
|
||||
|
||||
public:
|
||||
CommStatsWithName(const string& name, CommStats& stats) :
|
||||
name(name), stats(stats) {}
|
||||
|
||||
Timer& add_length_only(size_t length);
|
||||
Timer& add(const octetStream& os);
|
||||
void add(const octetStream& os, const TimeScope& scope) { add(os) += scope; }
|
||||
};
|
||||
|
||||
class NamedCommStats : public map<string, CommStats>
|
||||
{
|
||||
public:
|
||||
@@ -167,14 +177,8 @@ public:
|
||||
void print(bool newline = false);
|
||||
void reset();
|
||||
Timer& add_to_last_round(const string& name, size_t length);
|
||||
#ifdef VERBOSE_COMM
|
||||
CommStats& operator[](const string& name)
|
||||
{
|
||||
auto& res = map<string, CommStats>::operator[](name);
|
||||
cout << name << " after " << res.data << endl;
|
||||
return res;
|
||||
}
|
||||
#endif
|
||||
CommStatsWithName operator[](const string& name)
|
||||
{ return {name, map<string, CommStats>::operator[](name)}; }
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -209,8 +213,6 @@ public:
|
||||
virtual void send_receive_all(const vector<octetStream>&,
|
||||
vector<octetStream>&) const
|
||||
{ throw not_implemented(); }
|
||||
|
||||
void reset_stats();
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -394,6 +396,7 @@ public:
|
||||
{ receive_player(i, o); }
|
||||
|
||||
NamedCommStats total_comm() const;
|
||||
void reset_stats();
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -168,8 +168,13 @@ Server* Server::start_networking(Names& N, int my_num, int nplayers,
|
||||
cerr << "Starting networking for " << my_num << "/" << nplayers
|
||||
<< " with server on " << hostname << ":" << (portnum) << endl;
|
||||
#endif
|
||||
assert(my_num >= 0);
|
||||
assert(my_num < nplayers);
|
||||
if (my_num < 0 or my_num >= nplayers)
|
||||
{
|
||||
cerr << "Player number " << my_num << " outside range: 0-"
|
||||
<< nplayers - 1 << endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
Server* server = 0;
|
||||
pthread_t thread;
|
||||
if (my_num == 0)
|
||||
|
||||
@@ -205,7 +205,7 @@ int ServerSocket::get_connection_socket(const string& id)
|
||||
while (clients.find(id) == clients.end())
|
||||
{
|
||||
if (data_signal.wait(CONNECTION_TIMEOUT) == ETIMEDOUT)
|
||||
throw runtime_error("Timed out waiting for peer. See "
|
||||
exit_error("Timed out waiting for peer. See "
|
||||
"https://mp-spdz.readthedocs.io/en/latest/networking.html "
|
||||
"for details on networking.");
|
||||
}
|
||||
@@ -230,7 +230,7 @@ void AnonymousServerSocket::init()
|
||||
void AnonymousServerSocket::process_client(const string& client_id)
|
||||
{
|
||||
if (clients.find(client_id) != clients.end())
|
||||
throw runtime_error("client " + client_id + " already connected");
|
||||
exit_error("client " + client_id + " already connected");
|
||||
client_connection_queue.push(client_id);
|
||||
}
|
||||
|
||||
@@ -242,7 +242,7 @@ int AnonymousServerSocket::get_connection_socket(string& client_id)
|
||||
{
|
||||
int res = data_signal.wait(CONNECTION_TIMEOUT);
|
||||
if (res == ETIMEDOUT)
|
||||
throw runtime_error("timed out while waiting for client");
|
||||
exit_error("timed out while waiting for client");
|
||||
else if (res)
|
||||
throw runtime_error("waiting error");
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ void error(const char *str)
|
||||
gethostname(err,1000);
|
||||
strcat(err," : ");
|
||||
strcat(err,str);
|
||||
throw runtime_error(string() + err + " : " + strerror(old_errno));
|
||||
exit_error(string() + err + " : " + strerror(old_errno));
|
||||
}
|
||||
|
||||
void set_up_client_socket(int& mysocket,const char* hostname,int Portnum)
|
||||
@@ -62,7 +62,7 @@ void set_up_client_socket(int& mysocket,const char* hostname,int Portnum)
|
||||
{
|
||||
for (rp = ai; rp != NULL; rp = rp->ai_next)
|
||||
cerr << "Family on offer: " << ai->ai_family << endl;
|
||||
runtime_error(string("No AF_INET for ") + (char*)hostname + " on " + (char*)my_name);
|
||||
exit_error(string("No AF_INET for ") + (char*)hostname + " on " + (char*)my_name);
|
||||
}
|
||||
|
||||
|
||||
@@ -106,10 +106,11 @@ void set_up_client_socket(int& mysocket,const char* hostname,int Portnum)
|
||||
|
||||
if (fl < 0)
|
||||
{
|
||||
throw runtime_error(
|
||||
exit_error(
|
||||
string() + "cannot connect from " + my_name + " to " + hostname + ":"
|
||||
+ to_string(Portnum) + " after " + to_string(attempts)
|
||||
+ " attempts in one minute because " + strerror(connect_errno) + ". "
|
||||
+ " attempts in " + to_string(CONNECTION_TIMEOUT)
|
||||
+ " seconds because " + strerror(connect_errno) + ". "
|
||||
"https://mp-spdz.readthedocs.io/en/latest/troubleshooting.html#"
|
||||
"connection-failures has more information on port requirements.");
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@
|
||||
typedef boost::asio::io_service ssl_service;
|
||||
|
||||
void check_ssl_file(string filename);
|
||||
void ssl_error(string side, string other, string server);
|
||||
void ssl_error(string side, string other, string server, exception& e);
|
||||
|
||||
class ssl_ctx : public boost::asio::ssl::context
|
||||
{
|
||||
@@ -62,9 +62,9 @@ public:
|
||||
try
|
||||
{
|
||||
handshake(ssl_socket::client);
|
||||
} catch (...)
|
||||
} catch (exception& e)
|
||||
{
|
||||
ssl_error("Client", other, me);
|
||||
ssl_error("Client", other, me, e);
|
||||
throw;
|
||||
}
|
||||
else
|
||||
@@ -72,9 +72,9 @@ public:
|
||||
try
|
||||
{
|
||||
handshake(ssl_socket::server);
|
||||
} catch (...)
|
||||
} catch (exception& e)
|
||||
{
|
||||
ssl_error("Server", other, me);
|
||||
ssl_error("Server", other, me, e);
|
||||
throw;
|
||||
}
|
||||
|
||||
|
||||
14
OT/BaseOT.h
14
OT/BaseOT.h
@@ -47,12 +47,12 @@ public:
|
||||
vector<BitVector> receiver_outputs;
|
||||
TwoPartyPlayer* P;
|
||||
/// Number of OTs
|
||||
int nOT, ot_length;
|
||||
int nOT;
|
||||
/// Which role(s) on this side
|
||||
OT_ROLE ot_role;
|
||||
|
||||
BaseOT(int nOT, int ot_length, TwoPartyPlayer* player, OT_ROLE role=BOTH)
|
||||
: P(player), nOT(nOT), ot_length(ot_length), ot_role(role)
|
||||
BaseOT(int nOT, TwoPartyPlayer* player, OT_ROLE role=BOTH)
|
||||
: P(player), nOT(nOT), ot_role(role)
|
||||
{
|
||||
receiver_inputs.resize(nOT);
|
||||
sender_inputs.resize(nOT);
|
||||
@@ -69,14 +69,12 @@ public:
|
||||
}
|
||||
|
||||
BaseOT(TwoPartyPlayer* player, OT_ROLE role) :
|
||||
BaseOT(128, 128, player, role)
|
||||
BaseOT(128, player, role)
|
||||
{
|
||||
}
|
||||
|
||||
virtual ~BaseOT() {}
|
||||
|
||||
int length() { return ot_length; }
|
||||
|
||||
/// Set choice bits
|
||||
void set_receiver_inputs(const BitVector& new_inputs)
|
||||
{
|
||||
@@ -126,8 +124,8 @@ protected:
|
||||
class FakeOT : public BaseOT
|
||||
{
|
||||
public:
|
||||
FakeOT(int nOT, int ot_length, TwoPartyPlayer* player, OT_ROLE role=BOTH) :
|
||||
BaseOT(nOT, ot_length, player, role) {}
|
||||
FakeOT(int nOT, TwoPartyPlayer* player, OT_ROLE role=BOTH) :
|
||||
BaseOT(nOT, player, role) {}
|
||||
void exec_base(bool new_receiver_inputs=true);
|
||||
};
|
||||
|
||||
|
||||
@@ -8,12 +8,13 @@
|
||||
void BitDiagonal::pack(octetStream& os) const
|
||||
{
|
||||
for (int i = 0; i < N_ROWS; i++)
|
||||
os.store_int(rows[i].get_bit(i), 1);
|
||||
os.store_bit(rows[i].get_bit(i));
|
||||
os.append(0);
|
||||
}
|
||||
|
||||
void BitDiagonal::unpack(octetStream& os)
|
||||
{
|
||||
*this = {};
|
||||
for (int i = 0; i < N_ROWS; i++)
|
||||
rows[i] = os.get_int(1) << i;
|
||||
rows[i] = RowType(os.get_bit()) << i;
|
||||
}
|
||||
|
||||
@@ -72,6 +72,11 @@ Spdz2kTripleGenerator<T>::Spdz2kTripleGenerator(const OTTripleSetup& setup,
|
||||
template<class T>
|
||||
void OTTripleGenerator<T>::set_batch_size(int batch_size)
|
||||
{
|
||||
// limit to ~1 GB
|
||||
batch_size = min(batch_size, int(1e7 / sizeof(T) / sizeof(T)));
|
||||
if (OnlineOptions::singleton.has_option("verbose_ot"))
|
||||
fprintf(stderr, "OT batch size %d (share size %d)\n", batch_size,
|
||||
int(sizeof(T)));
|
||||
nTriplesPerLoop = DIV_CEIL(batch_size, nloops);
|
||||
nTriples = nTriplesPerLoop * nloops;
|
||||
nPreampTriplesPerLoop = nTriplesPerLoop * nAmplify;
|
||||
@@ -198,9 +203,9 @@ void NPartyTripleGenerator<T>::generate()
|
||||
{
|
||||
outputFile.open(ss.str().c_str());
|
||||
if (machine.generateMACs or not T::clear::invertible)
|
||||
file_signature<T>().output(outputFile);
|
||||
file_signature<T>(this->mac_key).output(outputFile);
|
||||
else
|
||||
file_signature<typename T::clear>().output(outputFile);
|
||||
file_signature<SemiShare<typename T::clear>>().output(outputFile);
|
||||
}
|
||||
|
||||
if (machine.generateBits)
|
||||
@@ -250,6 +255,7 @@ void NPartyTripleGenerator<W>::generateInputs(int player)
|
||||
inputs.resize(toCheck);
|
||||
auto mac_key = this->get_mac_key();
|
||||
SemiInput<SemiShare<T>> input(0, globalPlayer);
|
||||
input.maybe_init(globalPlayer);
|
||||
input.reset_all(globalPlayer);
|
||||
vector<T> secrets(toCheck);
|
||||
if (mine)
|
||||
@@ -528,8 +534,10 @@ void OTTripleGenerator<T>::generateMixedTriples()
|
||||
machine.set_passive();
|
||||
machine.output = false;
|
||||
|
||||
int n = multiple_minimum(100 * nPreampTriplesPerLoop,
|
||||
T::open_type::size_in_bits());
|
||||
int n = multiple_minimum(nPreampTriplesPerLoop, 8);
|
||||
|
||||
if (OnlineOptions::singleton.has_option("verbose_mixed"))
|
||||
fprintf(stderr, "generating %d mixed triples\n", n);
|
||||
|
||||
valueBits.resize(2);
|
||||
valueBits[0].resize(n);
|
||||
@@ -556,6 +564,9 @@ void OTTripleGenerator<T>::generateMixedTriples()
|
||||
template<class U>
|
||||
void OTTripleGenerator<U>::plainTripleRound(int k)
|
||||
{
|
||||
if (OnlineOptions::singleton.has_option("verbose_triples"))
|
||||
fprintf(stderr, "generating %d triples\n", nPreampTriplesPerLoop);
|
||||
|
||||
typedef typename U::open_type T;
|
||||
|
||||
if (not (machine.amplify or machine.output))
|
||||
|
||||
@@ -78,6 +78,11 @@ void OTCorrelator<U>::correlate(int start, int slice,
|
||||
Slice<U> t1Slice(t1, start, slice);
|
||||
Slice<U> uSlice(u, start, slice);
|
||||
|
||||
if (OnlineOptions::singleton.has_option("verbose_correlate"))
|
||||
fprintf(stderr, "correlate %d matrices of size %d*%d, %u bits\n", slice,
|
||||
int(U::PartType::n_rows()), int(U::PartType::n_columns()),
|
||||
newReceiverInput.size());
|
||||
|
||||
// create correlation
|
||||
if (ot_role & RECEIVER)
|
||||
{
|
||||
|
||||
@@ -20,7 +20,7 @@ osuCrypto::IOService ot_extension_ios;
|
||||
OTExtensionWithMatrix OTExtensionWithMatrix::setup(TwoPartyPlayer& player,
|
||||
int128 delta, OT_ROLE role, bool passive)
|
||||
{
|
||||
BaseOT baseOT(128, 128, &player, INV_ROLE(role));
|
||||
BaseOT baseOT(128, &player, INV_ROLE(role));
|
||||
PRNG G;
|
||||
G.ReSeed();
|
||||
baseOT.set_receiver_inputs(delta);
|
||||
@@ -30,6 +30,11 @@ OTExtensionWithMatrix OTExtensionWithMatrix::setup(TwoPartyPlayer& player,
|
||||
|
||||
OTExtensionWithMatrix::OTExtensionWithMatrix(BaseOT& baseOT, TwoPartyPlayer* player,
|
||||
bool passive) : OTCorrelator(baseOT, player, passive)
|
||||
{
|
||||
init_me();
|
||||
}
|
||||
|
||||
void OTExtensionWithMatrix::init_me()
|
||||
{
|
||||
G.ReSeed();
|
||||
nsubloops = 1;
|
||||
@@ -37,6 +42,7 @@ OTExtensionWithMatrix::OTExtensionWithMatrix(BaseOT& baseOT, TwoPartyPlayer* pla
|
||||
#ifndef USE_KOS
|
||||
channel = 0;
|
||||
#endif
|
||||
softspoken_k = 2;
|
||||
}
|
||||
|
||||
OTExtensionWithMatrix::~OTExtensionWithMatrix()
|
||||
@@ -47,27 +53,43 @@ OTExtensionWithMatrix::~OTExtensionWithMatrix()
|
||||
#endif
|
||||
}
|
||||
|
||||
bool OTExtensionWithMatrix::use_kos()
|
||||
{
|
||||
#ifdef USE_KOS
|
||||
return true;
|
||||
#else
|
||||
return OnlineOptions::singleton.has_option("use_kos");
|
||||
#endif
|
||||
}
|
||||
|
||||
void OTExtensionWithMatrix::protocol_agreement()
|
||||
{
|
||||
if (agreed)
|
||||
return;
|
||||
|
||||
Bundle<octetStream> bundle(*player);
|
||||
#ifdef USE_KOS
|
||||
bundle.mine = string("KOS15");
|
||||
#else
|
||||
bundle.mine = string("SoftSpokenOT");
|
||||
#endif
|
||||
if (use_kos())
|
||||
bundle.mine = string("KOS15");
|
||||
else
|
||||
bundle.mine = string("SoftSpokenOT");
|
||||
|
||||
if (OnlineOptions::singleton.has_option("high_softspoken"))
|
||||
softspoken_k = 8;
|
||||
|
||||
bundle.mine.store(softspoken_k);
|
||||
|
||||
player->unchecked_broadcast(bundle);
|
||||
|
||||
try
|
||||
{
|
||||
bundle.compare(*player);
|
||||
agreed = true;
|
||||
}
|
||||
catch (mismatch_among_parties&)
|
||||
{
|
||||
cerr << "Parties compiled with different OT extensions" << endl;
|
||||
cerr << "Set \"USE_KOS\" to the same value on all parties" << endl;
|
||||
cerr << "and make sure that the SoftSpokenOT parameter is the same" << endl;
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
@@ -104,16 +126,27 @@ void OTExtensionWithMatrix::transfer(int nOTs,
|
||||
#endif
|
||||
}
|
||||
|
||||
void OTExtensionWithMatrix::extend(int nOTs_requested, const BitVector& newReceiverInput)
|
||||
void OTExtensionWithMatrix::extend(int nOTs_requested,
|
||||
const BitVector& newReceiverInput, bool hash)
|
||||
{
|
||||
protocol_agreement();
|
||||
|
||||
if (use_kos())
|
||||
{
|
||||
extend_correlated(nOTs_requested, newReceiverInput);
|
||||
if (hash)
|
||||
hash_outputs(nOTs_requested);
|
||||
return;
|
||||
}
|
||||
|
||||
#ifdef USE_KOS
|
||||
extend_correlated(nOTs_requested, newReceiverInput);
|
||||
hash_outputs(nOTs_requested);
|
||||
assert(use_kos());
|
||||
#else
|
||||
resize(nOTs_requested);
|
||||
|
||||
if (nOTs_requested == 0)
|
||||
return;
|
||||
|
||||
if (not channel)
|
||||
channel = new osuCrypto::Channel(ot_extension_ios, new PlayerCtSocket(*player));
|
||||
|
||||
@@ -141,14 +174,18 @@ void OTExtensionWithMatrix::soft_sender(size_t n)
|
||||
if (not (ot_role & SENDER))
|
||||
return;
|
||||
|
||||
if (OnlineOptions::singleton.has_option("verbose_ot"))
|
||||
fprintf(stderr, "%zu OTs as sender\n", n);
|
||||
|
||||
osuCrypto::PRNG prng(osuCrypto::sysRandomSeed());
|
||||
osuCrypto::SoftSpokenOT::TwoOneMaliciousSender sender(2);
|
||||
osuCrypto::SoftSpokenOT::TwoOneMaliciousSender sender(softspoken_k);
|
||||
|
||||
vector<osuCrypto::block> outputs;
|
||||
for (auto& x : G_receiver)
|
||||
{
|
||||
outputs.push_back(x.get_doubleword());
|
||||
}
|
||||
sender.malicious = not passive_only;
|
||||
sender.setBaseOts(outputs,
|
||||
{baseReceiverInput.get_ptr(), sender.baseOtCount()}, prng,
|
||||
*channel);
|
||||
@@ -171,8 +208,11 @@ void OTExtensionWithMatrix::soft_receiver(size_t n,
|
||||
if (not (ot_role & RECEIVER))
|
||||
return;
|
||||
|
||||
if (OnlineOptions::singleton.has_option("verbose_ot"))
|
||||
fprintf(stderr, "%zu OTs as receiver\n", n);
|
||||
|
||||
osuCrypto::PRNG prng(osuCrypto::sysRandomSeed());
|
||||
osuCrypto::SoftSpokenOT::TwoOneMaliciousReceiver recver(2);
|
||||
osuCrypto::SoftSpokenOT::TwoOneMaliciousReceiver recver(softspoken_k);
|
||||
|
||||
vector<array<osuCrypto::block, 2>> inputs;
|
||||
for (auto& x : G_sender)
|
||||
@@ -181,6 +221,7 @@ void OTExtensionWithMatrix::soft_receiver(size_t n,
|
||||
for (int i = 0; i < 2; i++)
|
||||
inputs.back()[i] = x[i].get_doubleword();
|
||||
}
|
||||
recver.malicious = not passive_only;
|
||||
recver.setBaseOts(inputs, prng, *channel);
|
||||
|
||||
// Choose which messages should be received.
|
||||
|
||||
@@ -63,6 +63,10 @@ class OTExtensionWithMatrix : public OTCorrelator<BitMatrix>
|
||||
|
||||
bool agreed;
|
||||
|
||||
int softspoken_k;
|
||||
|
||||
void init_me();
|
||||
|
||||
public:
|
||||
PRNG G;
|
||||
|
||||
@@ -76,21 +80,18 @@ public:
|
||||
: OTCorrelator<BitMatrix>(player, role, passive),
|
||||
nsubloops(nsubloops)
|
||||
{
|
||||
G.ReSeed();
|
||||
agreed = false;
|
||||
#ifndef USE_KOS
|
||||
channel = 0;
|
||||
#endif
|
||||
init_me();
|
||||
}
|
||||
|
||||
OTExtensionWithMatrix(BaseOT& baseOT, TwoPartyPlayer* player, bool passive);
|
||||
|
||||
~OTExtensionWithMatrix();
|
||||
|
||||
bool use_kos();
|
||||
void protocol_agreement();
|
||||
|
||||
void transfer(int nOTs, const BitVector& receiverInput, int nloops);
|
||||
void extend(int nOTs, const BitVector& newReceiverInput);
|
||||
void extend(int nOTs, const BitVector& newReceiverInput, bool hash = true);
|
||||
void extend_correlated(const BitVector& newReceiverInput);
|
||||
void extend_correlated(int nOTs, const BitVector& newReceiverInput);
|
||||
void transpose(int start = 0, int slice = -1);
|
||||
|
||||
@@ -77,6 +77,8 @@ public:
|
||||
|
||||
OTMultiplier(OTTripleGenerator<T>& generator, int thread_num);
|
||||
virtual ~OTMultiplier();
|
||||
|
||||
void init();
|
||||
void multiply();
|
||||
};
|
||||
|
||||
|
||||
@@ -88,11 +88,10 @@ OTMultiplier<T>::~OTMultiplier()
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void OTMultiplier<T>::multiply()
|
||||
void OTMultiplier<T>::init()
|
||||
{
|
||||
keyBits.set(generator.get_mac_key());
|
||||
rot_ext.extend(keyBits.size(), keyBits);
|
||||
this->outbox.push({});
|
||||
senderOutput.resize(keyBits.size());
|
||||
for (size_t j = 0; j < keyBits.size(); j++)
|
||||
{
|
||||
@@ -106,10 +105,18 @@ void OTMultiplier<T>::multiply()
|
||||
assert(receiverOutput.size() >= keyBits.size());
|
||||
receiverOutput.resize(keyBits.size());
|
||||
init_authenticator(keyBits, senderOutput, receiverOutput);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void OTMultiplier<T>::multiply()
|
||||
{
|
||||
this->outbox.push({});
|
||||
MultJob job;
|
||||
while (this->inbox.pop(job))
|
||||
{
|
||||
if (receiverOutput.empty())
|
||||
init();
|
||||
|
||||
if (job.input)
|
||||
{
|
||||
if (job.player == generator.my_num
|
||||
@@ -155,13 +162,14 @@ void SemiMultiplier<T>::multiplyForBits()
|
||||
otCorrelator.set_role(role);
|
||||
|
||||
BitVector aBits = this->generator.valueBits[0];
|
||||
rot_ext.extend_correlated(aBits);
|
||||
rot_ext.extend(aBits.size(), aBits, not rot_ext.use_kos());
|
||||
|
||||
typedef typename T::Rectangle X;
|
||||
vector<Matrix<X> >& baseSenderOutputs = otCorrelator.matrices;
|
||||
Matrix<X>& baseReceiverOutput = otCorrelator.senderOutputMatrices[0];
|
||||
|
||||
rot_ext.hash_outputs(aBits.size(), baseSenderOutputs, baseReceiverOutput);
|
||||
rot_ext.hash_outputs(aBits.size(), baseSenderOutputs, baseReceiverOutput,
|
||||
rot_ext.use_kos());
|
||||
|
||||
int n_squares = otCorrelator.receiverOutputMatrix.squares.size();
|
||||
otCorrelator.setup_for_correlation(aBits, baseSenderOutputs,
|
||||
@@ -201,12 +209,13 @@ void SemiMultiplier<T>::multiplyForMixed()
|
||||
this->generator.players[this->thread_num], BOTH, true);
|
||||
|
||||
BitVector aBits = this->generator.valueBits[0];
|
||||
rot_ext.extend_correlated(aBits);
|
||||
rot_ext.extend(aBits.size(), aBits, not rot_ext.use_kos());
|
||||
|
||||
auto& baseSenderOutputs = otCorrelator.matrices;
|
||||
auto& baseReceiverOutput = otCorrelator.senderOutputMatrices[0];
|
||||
|
||||
rot_ext.hash_outputs(aBits.size(), baseSenderOutputs, baseReceiverOutput);
|
||||
rot_ext.hash_outputs(aBits.size(), baseSenderOutputs, baseReceiverOutput,
|
||||
rot_ext.use_kos());
|
||||
|
||||
if (this->generator.get_player().num_players() == 2)
|
||||
{
|
||||
@@ -265,16 +274,17 @@ void OTMultiplier<W>::multiplyForTriples()
|
||||
//timers["Extension"].start();
|
||||
if (generator.machine.use_extension)
|
||||
{
|
||||
#ifdef USE_KOS
|
||||
rot_ext.extend_correlated(aBits);
|
||||
#else
|
||||
rot_ext.extend(aBits.size(), aBits);
|
||||
corr_hash = false;
|
||||
#endif
|
||||
if (rot_ext.use_kos())
|
||||
rot_ext.extend_correlated(aBits);
|
||||
else
|
||||
{
|
||||
rot_ext.extend(aBits.size(), aBits);
|
||||
corr_hash = false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
BaseOT bot(aBits.size(), -1, generator.players[thread_num]);
|
||||
BaseOT bot(aBits.size(), generator.players[thread_num]);
|
||||
bot.set_receiver_inputs(aBits);
|
||||
bot.exec_base(false);
|
||||
for (size_t i = 0; i < aBits.size(); i++)
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
#include "OTTripleSetup.h"
|
||||
|
||||
void* run_ot(void* job)
|
||||
{
|
||||
((OTTripleSetup::SetupJob*)job)->run();
|
||||
return 0;
|
||||
}
|
||||
|
||||
void OTTripleSetup::setup()
|
||||
{
|
||||
timeval baseOTstart, baseOTend;
|
||||
@@ -12,13 +18,13 @@ void OTTripleSetup::setup()
|
||||
}
|
||||
//baseReceiverInput.randomize(G);
|
||||
|
||||
vector<SetupJob> threads;
|
||||
for (int i = 0; i < nparties - 1; i++)
|
||||
{
|
||||
baseOTs[i]->set_receiver_inputs(base_receiver_inputs);
|
||||
baseOTs[i]->exec_base(false);
|
||||
baseSenderInputs[i] = baseOTs[i]->sender_inputs;
|
||||
baseReceiverOutputs[i] = baseOTs[i]->receiver_outputs;
|
||||
}
|
||||
threads.push_back({*this, i});
|
||||
for (int i = 0; i < nparties - 1; i++)
|
||||
pthread_create(&threads[i].thread, 0, run_ot, &threads[i]);
|
||||
for (int i = 0; i < nparties - 1; i++)
|
||||
pthread_join(threads[i].thread, 0);
|
||||
gettimeofday(&baseOTend, NULL);
|
||||
#ifdef VERBOSE_BASEOT
|
||||
double basetime = timeval_diff(&baseOTstart, &baseOTend);
|
||||
@@ -34,6 +40,14 @@ void OTTripleSetup::setup()
|
||||
// (since Sender finishes baseOTs before Receiver)
|
||||
}
|
||||
|
||||
void OTTripleSetup::run(int i)
|
||||
{
|
||||
baseOTs[i]->set_receiver_inputs(base_receiver_inputs);
|
||||
baseOTs[i]->exec_base(false);
|
||||
baseSenderInputs[i] = baseOTs[i]->sender_inputs;
|
||||
baseReceiverOutputs[i] = baseOTs[i]->receiver_outputs;
|
||||
}
|
||||
|
||||
void OTTripleSetup::close_connections()
|
||||
{
|
||||
for (size_t i = 0; i < players.size(); i++)
|
||||
@@ -47,7 +61,7 @@ OTTripleSetup OTTripleSetup::get_fresh()
|
||||
OTTripleSetup res = *this;
|
||||
for (int i = 0; i < nparties - 1; i++)
|
||||
{
|
||||
BaseOT bot(nbase, 128, 0);
|
||||
BaseOT bot(nbase, 0);
|
||||
bot.sender_inputs = baseSenderInputs[i];
|
||||
bot.receiver_outputs = baseReceiverOutputs[i];
|
||||
bot.set_seeds();
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
*/
|
||||
class OTTripleSetup
|
||||
{
|
||||
void run(int i);
|
||||
|
||||
BitVector base_receiver_inputs;
|
||||
vector<BaseOT*> baseOTs;
|
||||
|
||||
@@ -22,8 +24,27 @@ class OTTripleSetup
|
||||
int nbase;
|
||||
|
||||
public:
|
||||
class SetupJob
|
||||
{
|
||||
OTTripleSetup& setup;
|
||||
int i;
|
||||
|
||||
public:
|
||||
pthread_t thread;
|
||||
|
||||
SetupJob(OTTripleSetup& setup, int i) :
|
||||
setup(setup), i(i), thread(0)
|
||||
{
|
||||
}
|
||||
|
||||
void run()
|
||||
{
|
||||
setup.run(i);
|
||||
}
|
||||
};
|
||||
|
||||
map<string,Timer> timers;
|
||||
vector<OffsetPlayer*> players;
|
||||
vector<TwoPartyPlayer*> players;
|
||||
vector< vector< array<BitVector, 2> > > baseSenderInputs;
|
||||
vector< vector<BitVector> > baseReceiverOutputs;
|
||||
|
||||
@@ -56,16 +77,16 @@ public:
|
||||
else
|
||||
other_player = i;
|
||||
|
||||
players.push_back(new OffsetPlayer(N, N.get_offset(other_player)));
|
||||
players.push_back(new VirtualTwoPartyPlayer(N, other_player));
|
||||
|
||||
// sets up a pair of base OTs, playing both roles
|
||||
if (real_OTs)
|
||||
{
|
||||
baseOTs[i] = new BaseOT(nbase, 128, players[i]);
|
||||
baseOTs[i] = new BaseOT(nbase, players[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
baseOTs[i] = new FakeOT(nbase, 128, players[i]);
|
||||
baseOTs[i] = new FakeOT(nbase, players[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -70,6 +70,29 @@ int BaseMachine::bucket_size(size_t usage)
|
||||
return res;
|
||||
}
|
||||
|
||||
int BaseMachine::matrix_batch_size(int n_rows, int n_inner, int n_cols)
|
||||
{
|
||||
unsigned res = min(100, OnlineOptions::singleton.batch_size);
|
||||
if (has_program())
|
||||
res = min(res, (unsigned) matrix_requirement(n_rows, n_inner, n_cols));
|
||||
return res;
|
||||
}
|
||||
|
||||
int BaseMachine::matrix_requirement(int n_rows, int n_inner, int n_cols)
|
||||
{
|
||||
if (has_program())
|
||||
{
|
||||
auto res = s().progs[0].get_offline_data_used().matmuls[
|
||||
{n_rows, n_inner, n_cols}];
|
||||
if (res)
|
||||
return res;
|
||||
else
|
||||
return -1;
|
||||
}
|
||||
else
|
||||
return -1;
|
||||
}
|
||||
|
||||
BaseMachine::BaseMachine() : nthreads(0)
|
||||
{
|
||||
if (sodium_init() == -1)
|
||||
@@ -287,7 +310,7 @@ void BaseMachine::print_comm(Player& P, const NamedCommStats& comm_stats)
|
||||
rounds += x.second.rounds;
|
||||
cerr << "Data sent = " << comm_stats.sent / 1e6 << " MB in ~" << rounds
|
||||
<< " rounds (party " << P.my_num() << " only";
|
||||
if (nthreads > 1)
|
||||
if (multithread)
|
||||
cerr << "; rounds counted double due to multi-threading";
|
||||
if (not OnlineOptions::singleton.verbose)
|
||||
cerr << "; use '-v' for more details";
|
||||
|
||||
@@ -44,6 +44,7 @@ public:
|
||||
|
||||
string progname;
|
||||
int nthreads;
|
||||
bool multithread;
|
||||
|
||||
ThreadQueues queues;
|
||||
|
||||
@@ -66,10 +67,14 @@ public:
|
||||
template<class T>
|
||||
static int batch_size(Dtype type, int buffer_size = 0, int fallback = 0);
|
||||
template<class T>
|
||||
static int input_batch_size(int player, int buffer_size = 0);
|
||||
template<class T>
|
||||
static int edabit_batch_size(int n_bits, int buffer_size = 0);
|
||||
static int edabit_bucket_size(int n_bits);
|
||||
static int triple_bucket_size(DataFieldType type);
|
||||
static int bucket_size(size_t usage);
|
||||
static int matrix_batch_size(int n_rows, int n_inner, int n_cols);
|
||||
static int matrix_requirement(int n_rows, int n_inner, int n_cols);
|
||||
|
||||
BaseMachine();
|
||||
virtual ~BaseMachine() {}
|
||||
@@ -105,6 +110,10 @@ inline OTTripleSetup BaseMachine::fresh_ot_setup(Player& P)
|
||||
template<class T>
|
||||
int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback)
|
||||
{
|
||||
if (OnlineOptions::singleton.has_option("debug_batch_size"))
|
||||
fprintf(stderr, "batch_size buffer_size=%d fallback=%d\n", buffer_size,
|
||||
fallback);
|
||||
|
||||
int n_opts;
|
||||
int n = 0;
|
||||
int res = 0;
|
||||
@@ -114,7 +123,7 @@ int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback)
|
||||
else if (fallback > 0)
|
||||
n_opts = fallback;
|
||||
else
|
||||
n_opts = OnlineOptions::singleton.batch_size;
|
||||
n_opts = OnlineOptions::singleton.batch_size * T::default_length;
|
||||
|
||||
if (buffer_size <= 0 and has_program())
|
||||
{
|
||||
@@ -132,7 +141,6 @@ int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback)
|
||||
{
|
||||
n = buffer_size;
|
||||
buffer_size = 0;
|
||||
n_opts = OnlineOptions::singleton.batch_size;
|
||||
}
|
||||
|
||||
if (n > 0 and not (buffer_size > 0))
|
||||
@@ -161,16 +169,33 @@ int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback)
|
||||
else
|
||||
res = n_opts;
|
||||
|
||||
#ifdef DEBUG_BATCH_SIZE
|
||||
cerr << DataPositions::dtype_names[type] << " " << T::type_string()
|
||||
<< " res=" << res << " n="
|
||||
<< n << " n_opts=" << n_opts << " buffer_size=" << buffer_size << endl;
|
||||
#endif
|
||||
if (OnlineOptions::singleton.has_option("debug_batch_size"))
|
||||
cerr << DataPositions::dtype_names[type] << " " << T::type_string()
|
||||
<< " res=" << res << " n=" << n << " n_opts=" << n_opts
|
||||
<< " buffer_size=" << buffer_size << endl;
|
||||
|
||||
assert(res > 0);
|
||||
return res;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
int BaseMachine::input_batch_size(int player, int buffer_size)
|
||||
{
|
||||
if (buffer_size)
|
||||
return buffer_size;
|
||||
|
||||
if (has_program())
|
||||
{
|
||||
auto res =
|
||||
s().progs[0].get_offline_data_used(
|
||||
).inputs[player][T::clear::field_type()];
|
||||
if (res > 0)
|
||||
return res;
|
||||
}
|
||||
|
||||
return OnlineOptions::singleton.batch_size;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
int BaseMachine::edabit_batch_size(int n_bits, int buffer_size)
|
||||
{
|
||||
|
||||
@@ -29,10 +29,12 @@ public:
|
||||
|
||||
Conv2dTuple(const vector<int>& args, int start);
|
||||
|
||||
array<int, 3> matrix_dimensions();
|
||||
|
||||
template<class T>
|
||||
void pre(vector<T>& S, typename T::Protocol& protocol);
|
||||
void pre(StackedVector<T>& S, typename T::Protocol& protocol);
|
||||
template<class T>
|
||||
void post(vector<T>& S, typename T::Protocol& protocol);
|
||||
void post(StackedVector<T>& S, typename T::Protocol& protocol);
|
||||
|
||||
template<class T>
|
||||
void run_matrix(SubProcessor<T>& processor);
|
||||
|
||||
@@ -12,8 +12,10 @@
|
||||
#include "Networking/Player.h"
|
||||
#include "Protocols/edabit.h"
|
||||
#include "PrepBase.h"
|
||||
#include "PrepBuffer.h"
|
||||
#include "EdabitBuffer.h"
|
||||
#include "Tools/TimerWithComm.h"
|
||||
#include "Tools/CheckVector.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
@@ -104,8 +106,6 @@ class Preprocessing : public PrepBase
|
||||
protected:
|
||||
static const bool use_part = false;
|
||||
|
||||
DataPositions& usage;
|
||||
|
||||
bool do_count;
|
||||
|
||||
void count(Dtype dtype, int n = 1)
|
||||
@@ -115,9 +115,9 @@ protected:
|
||||
|
||||
template<int>
|
||||
void get_edabits(bool strict, size_t size, T* a,
|
||||
vector<typename T::bit_type>& Sb, const vector<int>& regs, false_type);
|
||||
StackedVector<typename T::bit_type>& Sb, const vector<int>& regs, false_type);
|
||||
template<int>
|
||||
void get_edabits(bool, size_t, T*, vector<typename T::bit_type>&,
|
||||
void get_edabits(bool, size_t, T*, StackedVector<typename T::bit_type>&,
|
||||
const vector<int>&, true_type)
|
||||
{ throw not_implemented(); }
|
||||
|
||||
@@ -126,6 +126,8 @@ protected:
|
||||
T get_random_from_inputs(int nplayers);
|
||||
|
||||
public:
|
||||
int buffer_size;
|
||||
|
||||
template<class U, class V>
|
||||
static Preprocessing<T>* get_new(Machine<U, V>& machine, DataPositions& usage,
|
||||
SubProcessor<T>* proc);
|
||||
@@ -135,7 +137,8 @@ public:
|
||||
static Preprocessing<T>* get_live_prep(SubProcessor<T>* proc,
|
||||
DataPositions& usage);
|
||||
|
||||
Preprocessing(DataPositions& usage) : usage(usage), do_count(true) {}
|
||||
Preprocessing(DataPositions& usage) :
|
||||
PrepBase(usage), do_count(true), buffer_size(0) {}
|
||||
virtual ~Preprocessing() {}
|
||||
|
||||
virtual void set_protocol(typename T::Protocol&) {};
|
||||
@@ -151,7 +154,7 @@ public:
|
||||
virtual void get_one_no_count(Dtype, T&) { throw not_implemented(); }
|
||||
virtual void get_input_no_count(T&, typename T::open_type&, int)
|
||||
{ throw not_implemented() ; }
|
||||
virtual void get_no_count(vector<T>&, DataTag, const vector<int>&, int)
|
||||
virtual void get_no_count(StackedVector<T>&, DataTag, const vector<int>&, int)
|
||||
{ throw not_implemented(); }
|
||||
|
||||
void get(Dtype dtype, T* a);
|
||||
@@ -159,7 +162,7 @@ public:
|
||||
void get_two(Dtype dtype, T& a, T& b);
|
||||
void get_one(Dtype dtype, T& a);
|
||||
void get_input(T& a, typename T::open_type& x, int i);
|
||||
void get(vector<T>& S, DataTag tag, const vector<int>& regs, int vector_size);
|
||||
void get(StackedVector<T>& S, DataTag tag, const vector<int>& regs, int vector_size);
|
||||
|
||||
/// Get fresh random multiplication triple
|
||||
virtual array<T, 3> get_triple(int n_bits);
|
||||
@@ -174,7 +177,7 @@ public:
|
||||
virtual void get_dabit(T& a, typename T::bit_type& b);
|
||||
virtual void get_dabit_no_count(T&, typename T::bit_type&) { throw runtime_error("no daBit"); }
|
||||
virtual void get_edabits(bool strict, size_t size, T* a,
|
||||
vector<typename T::bit_type>& Sb, const vector<int>& regs)
|
||||
StackedVector<typename T::bit_type>& Sb, const vector<int>& regs)
|
||||
{ get_edabits<0>(strict, size, a, Sb, regs, T::clear::characteristic_two); }
|
||||
virtual void get_edabit_no_count(bool, int, edabit<T>&)
|
||||
{ throw runtime_error("no edaBits"); }
|
||||
@@ -201,11 +204,11 @@ class Sub_Data_Files : public Preprocessing<T>
|
||||
|
||||
static int tuple_length(int dtype);
|
||||
|
||||
BufferOwner<T, T> buffers[N_DTYPE];
|
||||
vector<BufferOwner<T, T>> input_buffers;
|
||||
BufferOwner<InputTuple<T>, RefInputTuple<T>> my_input_buffers;
|
||||
map<DataTag, BufferOwner<T, T> > extended;
|
||||
BufferOwner<dabit<T>, dabit<T>> dabit_buffer;
|
||||
array<PrepBuffer<T>, N_DTYPE> buffers;
|
||||
vector<PrepBuffer<T>> input_buffers;
|
||||
PrepBuffer<InputTuple<T>, RefInputTuple<T>, T> my_input_buffers;
|
||||
map<DataTag, PrepBuffer<T> > extended;
|
||||
PrepBuffer<dabit<T>, dabit<T>, T> dabit_buffer;
|
||||
map<int, EdabitBuffer<T>> edabit_buffers;
|
||||
map<int, edabitvec<T>> my_edabits;
|
||||
|
||||
@@ -284,7 +287,7 @@ public:
|
||||
}
|
||||
|
||||
void setup_extended(const DataTag& tag, int tuple_size = 0);
|
||||
void get_no_count(vector<T>& S, DataTag tag, const vector<int>& regs, int vector_size);
|
||||
void get_no_count(StackedVector<T>& S, DataTag tag, const vector<int>& regs, int vector_size);
|
||||
void get_dabit_no_count(T& a, typename T::bit_type& b);
|
||||
|
||||
part_type& get_part();
|
||||
@@ -397,7 +400,7 @@ inline void Preprocessing<T>::get_input(T& a, typename T::open_type& x, int i)
|
||||
}
|
||||
|
||||
template<class T>
|
||||
inline void Preprocessing<T>::get(vector<T>& S, DataTag tag,
|
||||
inline void Preprocessing<T>::get(StackedVector<T>& S, DataTag tag,
|
||||
const vector<int>& regs, int vector_size)
|
||||
{
|
||||
usage.count(T::clear::field_type(), tag, vector_size);
|
||||
|
||||
@@ -143,14 +143,14 @@ Sub_Data_Files<T>::Sub_Data_Files(int my_num, int num_players,
|
||||
{
|
||||
if (T::clear::allows(Dtype(dtype)))
|
||||
{
|
||||
buffers[dtype].setup(
|
||||
buffers[dtype].setup(num_players,
|
||||
PrepBase::get_filename(prep_data_dir, Dtype(dtype), type_short,
|
||||
my_num, thread_num), tuple_length(dtype), type_string,
|
||||
DataPositions::dtype_names[dtype]);
|
||||
}
|
||||
}
|
||||
|
||||
dabit_buffer.setup(
|
||||
dabit_buffer.setup(num_players,
|
||||
PrepBase::get_filename(prep_data_dir, DATA_DABIT,
|
||||
type_short, my_num, thread_num), dabit<T>::size(), type_string,
|
||||
DataPositions::dtype_names[DATA_DABIT]);
|
||||
@@ -161,10 +161,10 @@ Sub_Data_Files<T>::Sub_Data_Files(int my_num, int num_players,
|
||||
string filename = PrepBase::get_input_filename(prep_data_dir,
|
||||
type_short, i, my_num, thread_num);
|
||||
if (i == my_num)
|
||||
my_input_buffers.setup(filename,
|
||||
my_input_buffers.setup(num_players, filename,
|
||||
InputTuple<T>::size(), type_string);
|
||||
else
|
||||
input_buffers[i].setup(filename,
|
||||
input_buffers[i].setup(num_players, filename,
|
||||
T::size(), type_string);
|
||||
}
|
||||
|
||||
@@ -344,14 +344,14 @@ void Sub_Data_Files<T>::setup_extended(const DataTag& tag, int tuple_size)
|
||||
{
|
||||
stringstream ss;
|
||||
ss << prep_data_dir << tag.get_string() << "-" << T::type_short() << "-P" << my_num;
|
||||
buffer.setup(ss.str(), tuple_length);
|
||||
buffer.setup(num_players, ss.str(), tuple_length);
|
||||
}
|
||||
|
||||
buffer.check_tuple_length(tuple_length);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Sub_Data_Files<T>::get_no_count(vector<T>& S, DataTag tag, const vector<int>& regs, int vector_size)
|
||||
void Sub_Data_Files<T>::get_no_count(StackedVector<T>& S, DataTag tag, const vector<int>& regs, int vector_size)
|
||||
{
|
||||
setup_extended(tag, regs.size());
|
||||
for (int j = 0; j < vector_size; j++)
|
||||
|
||||
@@ -12,6 +12,7 @@ using namespace std;
|
||||
#include "Math/BitVec.h"
|
||||
#include "Data_Files.h"
|
||||
#include "Protocols/Replicated.h"
|
||||
#include "Protocols/ReplicatedPrep.h"
|
||||
#include "Protocols/MAC_Check_Base.h"
|
||||
#include "Processor/Input.h"
|
||||
|
||||
@@ -109,7 +110,7 @@ public:
|
||||
};
|
||||
|
||||
template<class T>
|
||||
class DummyLivePrep : public Preprocessing<T>
|
||||
class DummyLivePrep : public BufferPrep<T>
|
||||
{
|
||||
public:
|
||||
static const bool homomorphic = true;
|
||||
@@ -133,16 +134,16 @@ public:
|
||||
}
|
||||
|
||||
DummyLivePrep(DataPositions& usage, GC::ShareThread<T>&) :
|
||||
Preprocessing<T>(usage)
|
||||
BufferPrep<T>(usage)
|
||||
{
|
||||
}
|
||||
DummyLivePrep(DataPositions& usage, bool = true) :
|
||||
Preprocessing<T>(usage)
|
||||
BufferPrep<T>(usage)
|
||||
{
|
||||
}
|
||||
|
||||
DummyLivePrep(SubProcessor<T>*, DataPositions& usage) :
|
||||
Preprocessing<T>(usage)
|
||||
BufferPrep<T>(usage)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -165,7 +166,7 @@ public:
|
||||
{
|
||||
fail();
|
||||
}
|
||||
void get_no_count(vector<T>&, DataTag, const vector<int>&, int)
|
||||
void get_no_count(StackedVector<T>&, DataTag, const vector<int>&, int)
|
||||
{
|
||||
fail();
|
||||
}
|
||||
|
||||
@@ -81,7 +81,6 @@ int ExternalClients::init_client_connection(const string& host, int portnum,
|
||||
auto socket = new client_socket(io_service, *peer_ctxs[my_client_id],
|
||||
plain_socket, "P" + to_string(party_num), "C" + to_string(my_client_id),
|
||||
true);
|
||||
if (party_num == 0)
|
||||
{
|
||||
octetStream specification;
|
||||
specification.Receive(socket);
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
#include <iomanip>
|
||||
|
||||
template<class cgf2n>
|
||||
void Instruction::execute_clear_gf2n(vector<cgf2n>& registers,
|
||||
void Instruction::execute_clear_gf2n(StackedVector<cgf2n>& registers,
|
||||
MemoryPart<cgf2n>& memory, ArithmeticProcessor& Proc) const
|
||||
{
|
||||
auto& C2 = registers;
|
||||
@@ -30,7 +30,7 @@ void Instruction::execute_clear_gf2n(vector<cgf2n>& registers,
|
||||
}
|
||||
|
||||
template<class cgf2n>
|
||||
void Instruction::gbitdec(vector<cgf2n>& registers) const
|
||||
void Instruction::gbitdec(StackedVector<cgf2n>& registers) const
|
||||
{
|
||||
for (int j = 0; j < size; j++)
|
||||
{
|
||||
@@ -44,7 +44,7 @@ void Instruction::gbitdec(vector<cgf2n>& registers) const
|
||||
}
|
||||
|
||||
template<class cgf2n>
|
||||
void Instruction::gbitcom(vector<cgf2n>& registers) const
|
||||
void Instruction::gbitcom(StackedVector<cgf2n>& registers) const
|
||||
{
|
||||
for (int j = 0; j < size; j++)
|
||||
{
|
||||
@@ -124,7 +124,7 @@ ostream& operator<<(ostream& s, const Instruction& instr)
|
||||
return s;
|
||||
}
|
||||
|
||||
template void Instruction::execute_clear_gf2n(vector<gf2n_short>& registers,
|
||||
template void Instruction::execute_clear_gf2n(StackedVector<gf2n_short>& registers,
|
||||
MemoryPart<gf2n_short>& memory, ArithmeticProcessor& Proc) const;
|
||||
template void Instruction::execute_clear_gf2n(vector<gf2n_long>& registers,
|
||||
template void Instruction::execute_clear_gf2n(StackedVector<gf2n_long>& registers,
|
||||
MemoryPart<gf2n_long>& memory, ArithmeticProcessor& Proc) const;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user