mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 13:37:58 -05:00
Maintenance.
This commit is contained in:
2
.gitmodules
vendored
2
.gitmodules
vendored
@@ -3,7 +3,7 @@
|
||||
url = https://github.com/mkskeller/SimpleOT
|
||||
[submodule "mpir"]
|
||||
path = mpir
|
||||
url = git://github.com/wbhart/mpir.git
|
||||
url = https://github.com/wbhart/mpir
|
||||
[submodule "Programs/Circuits"]
|
||||
path = Programs/Circuits
|
||||
url = https://github.com/mkskeller/bristol-fashion
|
||||
|
||||
@@ -259,7 +259,7 @@ ProgramParty::~ProgramParty()
|
||||
reset();
|
||||
if (P)
|
||||
{
|
||||
cerr << "Data sent: " << 1e-6 * P->comm_stats.total_data() << " MB" << endl;
|
||||
cerr << "Data sent: " << 1e-6 * P->total_comm().total_data() << " MB" << endl;
|
||||
delete P;
|
||||
}
|
||||
delete[] eval_threads;
|
||||
|
||||
@@ -175,7 +175,7 @@ void GarbleInputter<T>::exchange()
|
||||
assert(party.P != 0);
|
||||
assert(party.MC != 0);
|
||||
auto& protocol = party.shared_proc->protocol;
|
||||
protocol.init_mul(party.shared_proc);
|
||||
protocol.init_mul();
|
||||
for (auto& tuple : tuples)
|
||||
protocol.prepare_mul(tuple.first->mask,
|
||||
T::constant(1, party.P->my_num(), party.mac_key)
|
||||
|
||||
@@ -155,7 +155,7 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
|
||||
while (next != GC::DONE_BREAK);
|
||||
|
||||
MC->Check(*P);
|
||||
data_sent = P->comm_stats.total_data() + prep->data_sent();
|
||||
data_sent = P->total_comm().sent;
|
||||
|
||||
this->machine.write_memory(this->N.my_num());
|
||||
}
|
||||
@@ -173,7 +173,8 @@ void RealProgramParty<T>::garble()
|
||||
garble_jobs.clear();
|
||||
garble_inputter->reset_all(*P);
|
||||
auto& protocol = *garble_protocol;
|
||||
protocol.init_mul(shared_proc);
|
||||
protocol.init(*prep, shared_proc->MC);
|
||||
protocol.init_mul();
|
||||
|
||||
next = this->first_phase(program, garble_processor, this->garble_machine);
|
||||
|
||||
@@ -181,7 +182,8 @@ void RealProgramParty<T>::garble()
|
||||
protocol.exchange();
|
||||
|
||||
typename T::Protocol second_protocol(*P);
|
||||
second_protocol.init_mul(shared_proc);
|
||||
second_protocol.init(*prep, shared_proc->MC);
|
||||
second_protocol.init_mul();
|
||||
for (auto& job : garble_jobs)
|
||||
job.middle_round(*this, second_protocol);
|
||||
|
||||
|
||||
@@ -293,6 +293,9 @@ public:
|
||||
template<class U>
|
||||
static void convcbit2s(GC::Processor<U>&, const BaseInstruction&)
|
||||
{ throw runtime_error("convcbit2s not implemented"); }
|
||||
template<class U>
|
||||
static void andm(GC::Processor<U>&, const BaseInstruction&)
|
||||
{ throw runtime_error("andm not implemented"); }
|
||||
|
||||
// most BMR phases don't need actual input
|
||||
template<class T>
|
||||
|
||||
@@ -42,6 +42,12 @@ BaseTrustedParty::BaseTrustedParty()
|
||||
_received_gc_received = 0;
|
||||
n_received = 0;
|
||||
randomfd = open("/dev/urandom", O_RDONLY);
|
||||
done_filling = false;
|
||||
}
|
||||
|
||||
BaseTrustedParty::~BaseTrustedParty()
|
||||
{
|
||||
close(randomfd);
|
||||
}
|
||||
|
||||
TrustedProgramParty::TrustedProgramParty(int argc, char** argv) :
|
||||
|
||||
@@ -20,7 +20,7 @@ public:
|
||||
vector<SendBuffer> msg_input_masks;
|
||||
|
||||
BaseTrustedParty();
|
||||
virtual ~BaseTrustedParty() {}
|
||||
virtual ~BaseTrustedParty();
|
||||
|
||||
/* From NodeUpdatable class */
|
||||
virtual void NodeReady();
|
||||
@@ -104,7 +104,6 @@ private:
|
||||
void add_all_keys(const Register& reg, bool external);
|
||||
};
|
||||
|
||||
|
||||
inline void BaseTrustedParty::add_keys(const Register& reg)
|
||||
{
|
||||
for(int p = 0; p < get_n_parties(); p++)
|
||||
|
||||
13
CHANGELOG.md
13
CHANGELOG.md
@@ -1,5 +1,18 @@
|
||||
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
|
||||
|
||||
## 0.2.9 (Jan 11, 2021)
|
||||
|
||||
- Disassembler
|
||||
- Run-time parameter for probabilistic truncation error
|
||||
- Probabilistic truncation for some protocols computing modulo a prime
|
||||
- Simplified C++ interface
|
||||
- Comparison as in [ACCO](https://dl.acm.org/doi/10.1145/3474123.3486757)
|
||||
- More general scalar-vector multiplication
|
||||
- Complete memory support for clear bits
|
||||
- Extended clear bit functionality with Yao's garbled circuits
|
||||
- Allow preprocessing information to be supplied via named pipes
|
||||
- In-place operations for containers
|
||||
|
||||
## 0.2.8 (Nov 4, 2021)
|
||||
|
||||
- Tested on Apple laptop with ARM chip
|
||||
|
||||
@@ -112,10 +112,16 @@ class bits(Tape.Register, _structure, _bit):
|
||||
return cls.load_dynamic_mem(address)
|
||||
else:
|
||||
for i in range(res.size):
|
||||
cls.load_inst[util.is_constant(address)](res[i], address + i)
|
||||
cls.mem_op(cls.load_inst, res[i], address + i)
|
||||
return res
|
||||
def store_in_mem(self, address):
|
||||
self.store_inst[isinstance(address, int)](self, address)
|
||||
self.mem_op(self.store_inst, self, address)
|
||||
@staticmethod
|
||||
def mem_op(inst, reg, address):
|
||||
direct = isinstance(address, int)
|
||||
if not direct:
|
||||
address = regint.conv(address)
|
||||
inst[direct](reg, address)
|
||||
@classmethod
|
||||
def new(cls, value=None, n=None):
|
||||
if util.is_constant(value):
|
||||
|
||||
@@ -77,13 +77,16 @@ def LTZ(s, a, k, kappa):
|
||||
|
||||
k: bit length of a
|
||||
"""
|
||||
movs(s, program.non_linear.ltz(a, k, kappa))
|
||||
|
||||
def LtzRing(a, k):
|
||||
from .types import sint, _bitint
|
||||
from .GC.types import sbitvec
|
||||
if program.use_split():
|
||||
summands = a.split_to_two_summands(k)
|
||||
carry = CarryOutRawLE(*reversed(list(x[:-1] for x in summands)))
|
||||
msb = carry ^ summands[0][-1] ^ summands[1][-1]
|
||||
movs(s, sint.conv(msb))
|
||||
return sint.conv(msb)
|
||||
return
|
||||
elif program.options.ring:
|
||||
from . import floatingpoint
|
||||
@@ -96,11 +99,7 @@ def LTZ(s, a, k, kappa):
|
||||
a = r_bin[0].bit_decompose_clear(c_prime, m)
|
||||
b = r_bin[:m]
|
||||
u = CarryOutRaw(a[::-1], b[::-1])
|
||||
movs(s, sint.conv(r_bin[m].bit_xor(c_prime >> m).bit_xor(u)))
|
||||
return
|
||||
t = sint()
|
||||
Trunc(t, a, k, k - 1, kappa, True)
|
||||
subsfi(s, t, 0)
|
||||
return sint.conv(r_bin[m].bit_xor(c_prime >> m).bit_xor(u))
|
||||
|
||||
def LessThanZero(a, k, kappa):
|
||||
from . import types
|
||||
|
||||
@@ -82,7 +82,7 @@ def run(args, options):
|
||||
prog.finalize()
|
||||
|
||||
if prog.req_num:
|
||||
print('Program requires:')
|
||||
print('Program requires at most:')
|
||||
for x in prog.req_num.pretty():
|
||||
print(x)
|
||||
|
||||
|
||||
@@ -12,4 +12,7 @@ class ArgumentError(CompilerError):
|
||||
""" Exception raised for errors in instruction argument parsing. """
|
||||
def __init__(self, arg, msg):
|
||||
self.arg = arg
|
||||
self.msg = msg
|
||||
self.msg = msg
|
||||
|
||||
class VectorMismatch(CompilerError):
|
||||
pass
|
||||
|
||||
@@ -392,7 +392,7 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False):
|
||||
for i in range(1,l):
|
||||
ci[i] = c % two_power(i)
|
||||
c_dprime = sum(ci[i]*(x[i-1] - x[i]) for i in range(1,l))
|
||||
lts(d, c_dprime, r_prime, l, kappa)
|
||||
d = program.Program.prog.non_linear.ltz(c_dprime - r_prime, l, kappa)
|
||||
if compute_modulo:
|
||||
b = c_dprime - r_prime + pow2m * d
|
||||
return b, pow2m
|
||||
@@ -629,12 +629,14 @@ def BITLT(a, b, bit_length):
|
||||
# - From the paper
|
||||
# Multiparty Computation for Interval, Equality, and Comparison without
|
||||
# Bit-Decomposition Protocol
|
||||
def BitDecFull(a, maybe_mixed=False):
|
||||
def BitDecFull(a, n_bits=None, maybe_mixed=False):
|
||||
from .library import get_program, do_while, if_, break_point
|
||||
from .types import sint, regint, longint, cint
|
||||
p = get_program().prime
|
||||
assert p
|
||||
bit_length = p.bit_length()
|
||||
n_bits = n_bits or bit_length
|
||||
assert n_bits <= bit_length
|
||||
logp = int(round(math.log(p, 2)))
|
||||
if abs(p - 2 ** logp) / p < 2 ** -get_program().security:
|
||||
# inspired by Rabbit (https://eprint.iacr.org/2021/119)
|
||||
@@ -677,12 +679,12 @@ def BitDecFull(a, maybe_mixed=False):
|
||||
czero = (c==0)
|
||||
q = bbits[0].long_one() - comparison.BitLTL_raw(bbits, t)
|
||||
fbar = [bbits[0].clear_type.conv(cint(x))
|
||||
for x in ((1<<bit_length)+c-p).bit_decompose(bit_length)]
|
||||
fbard = bbits[0].bit_decompose_clear(cmodp, bit_length)
|
||||
g = [q.if_else(fbar[i], fbard[i]) for i in range(bit_length)]
|
||||
for x in ((1<<bit_length)+c-p).bit_decompose(n_bits)]
|
||||
fbard = bbits[0].bit_decompose_clear(cmodp, n_bits)
|
||||
g = [q.if_else(fbar[i], fbard[i]) for i in range(n_bits)]
|
||||
h = bbits[0].bit_adder(bbits, g)
|
||||
abits = [bbits[0].clear_type(cint(czero)).if_else(bbits[i], h[i])
|
||||
for i in range(bit_length)]
|
||||
for i in range(n_bits)]
|
||||
if maybe_mixed:
|
||||
return abits
|
||||
else:
|
||||
|
||||
@@ -442,7 +442,7 @@ class join_tape(base.Instruction):
|
||||
arg_format = ['int']
|
||||
|
||||
class crash(base.IOInstruction):
|
||||
""" Crash runtime if the register's value is > 0.
|
||||
""" Crash runtime if the value in the register is not zero.
|
||||
|
||||
:param: Crash condition (regint)"""
|
||||
code = base.opcodes['CRASH']
|
||||
@@ -1275,7 +1275,7 @@ class prep(base.Instruction):
|
||||
field_type = 'modp'
|
||||
|
||||
def add_usage(self, req_node):
|
||||
req_node.increment((self.field_type, self.args[0]), 1)
|
||||
req_node.increment((self.field_type, self.args[0]), self.get_size())
|
||||
|
||||
def has_var_args(self):
|
||||
return True
|
||||
@@ -2407,19 +2407,6 @@ class sqrs(base.CISC):
|
||||
subml(self.args[0], s[5], c[1])
|
||||
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
class lts(base.CISC):
|
||||
""" Secret comparison $s_i = (s_j < s_k)$. """
|
||||
__slots__ = []
|
||||
arg_format = ['sw', 's', 's', 'int', 'int']
|
||||
|
||||
def expand(self):
|
||||
from .types import sint
|
||||
a = sint()
|
||||
subs(a, self.args[1], self.args[2])
|
||||
comparison.LTZ(self.args[0], a, self.args[3], self.args[4])
|
||||
|
||||
# placeholder for documentation
|
||||
class cisc:
|
||||
""" Meta instruction for emulation. This instruction is only generated
|
||||
|
||||
@@ -4,6 +4,8 @@ import time
|
||||
import inspect
|
||||
import functools
|
||||
import copy
|
||||
import sys
|
||||
import struct
|
||||
from Compiler.exceptions import *
|
||||
from Compiler.config import *
|
||||
from Compiler import util
|
||||
@@ -299,11 +301,12 @@ def vectorize(instruction, global_dict=None):
|
||||
vectorized_name = 'v' + instruction.__name__
|
||||
Vectorized_Instruction.__name__ = vectorized_name
|
||||
global_dict[vectorized_name] = Vectorized_Instruction
|
||||
|
||||
if 'sphinx.extension' in sys.modules:
|
||||
return instruction
|
||||
|
||||
global_dict[instruction.__name__ + '_class'] = instruction
|
||||
instruction.__doc__ = ''
|
||||
# exclude GF(2^n) instructions from documentation
|
||||
if instruction.code and instruction.code >> 8 == 1:
|
||||
maybe_vectorized_instruction.__doc__ = ''
|
||||
maybe_vectorized_instruction.arg_format = instruction.arg_format
|
||||
return maybe_vectorized_instruction
|
||||
|
||||
|
||||
@@ -389,8 +392,11 @@ def gf2n(instruction):
|
||||
else:
|
||||
global_dict[GF2N_Instruction.__name__] = GF2N_Instruction
|
||||
|
||||
if 'sphinx.extension' in sys.modules:
|
||||
return instruction
|
||||
|
||||
global_dict[instruction.__name__ + '_class'] = instruction_cls
|
||||
instruction_cls.__doc__ = ''
|
||||
maybe_gf2n_instruction.arg_format = instruction.arg_format
|
||||
return maybe_gf2n_instruction
|
||||
#return instruction
|
||||
|
||||
@@ -661,6 +667,12 @@ class RegisterArgFormat(ArgFormat):
|
||||
assert arg.i >= 0
|
||||
return int_to_bytes(arg.i)
|
||||
|
||||
def __init__(self, f):
|
||||
self.i = struct.unpack('>I', f.read(4))[0]
|
||||
|
||||
def __str__(self):
|
||||
return self.reg_type + str(self.i)
|
||||
|
||||
class ClearModpAF(RegisterArgFormat):
|
||||
reg_type = RegType.ClearModp
|
||||
|
||||
@@ -686,6 +698,12 @@ class IntArgFormat(ArgFormat):
|
||||
def encode(cls, arg):
|
||||
return int_to_bytes(arg)
|
||||
|
||||
def __init__(self, f):
|
||||
self.i = struct.unpack('>i', f.read(4))[0]
|
||||
|
||||
def __str__(self):
|
||||
return str(self.i)
|
||||
|
||||
class ImmediateModpAF(IntArgFormat):
|
||||
@classmethod
|
||||
def check(cls, arg):
|
||||
@@ -722,6 +740,13 @@ class String(ArgFormat):
|
||||
def encode(cls, arg):
|
||||
return bytearray(arg, 'ascii') + b'\0' * (cls.length - len(arg))
|
||||
|
||||
def __init__(self, f):
|
||||
tmp = f.read(16)
|
||||
self.str = str(tmp[0:tmp.find(b'\0')], 'ascii')
|
||||
|
||||
def __str__(self):
|
||||
return self.str
|
||||
|
||||
ArgFormats = {
|
||||
'c': ClearModpAF,
|
||||
's': SecretModpAF,
|
||||
@@ -890,6 +915,54 @@ class Instruction(object):
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(' + self.get_pre_arg() + ','.join(str(a) for a in self.args) + ')'
|
||||
|
||||
class ParsedInstruction:
|
||||
reverse_opcodes = {}
|
||||
|
||||
def __init__(self, f):
|
||||
cls = type(self)
|
||||
from Compiler import instructions
|
||||
from Compiler.GC import instructions as gc_inst
|
||||
if not cls.reverse_opcodes:
|
||||
for module in instructions, gc_inst:
|
||||
for x, y in inspect.getmodule(module).__dict__.items():
|
||||
if inspect.isclass(y) and y.__name__[0] != 'v':
|
||||
try:
|
||||
cls.reverse_opcodes[y.code] = y
|
||||
except AttributeError:
|
||||
pass
|
||||
read = lambda: struct.unpack('>I', f.read(4))[0]
|
||||
full_code = read()
|
||||
code = full_code % (1 << Instruction.code_length)
|
||||
self.size = full_code >> Instruction.code_length
|
||||
self.type = cls.reverse_opcodes[code]
|
||||
t = self.type
|
||||
name = t.__name__
|
||||
try:
|
||||
n_args = len(t.arg_format)
|
||||
self.var_args = False
|
||||
except:
|
||||
n_args = read()
|
||||
self.var_args = True
|
||||
try:
|
||||
arg_format = iter(t.arg_format)
|
||||
except:
|
||||
if name == 'cisc':
|
||||
arg_format = itertools.chain(['str'], itertools.repeat('int'))
|
||||
else:
|
||||
arg_format = itertools.repeat('int')
|
||||
self.args = [ArgFormats[next(arg_format)](f)
|
||||
for i in range(n_args)]
|
||||
|
||||
def __str__(self):
|
||||
name = self.type.__name__
|
||||
res = name + ' '
|
||||
if self.size > 1:
|
||||
res = 'v' + res + str(self.size) + ', '
|
||||
if self.var_args:
|
||||
res += str(len(self.args)) + ', '
|
||||
res += ', '.join(str(arg) for arg in self.args)
|
||||
return res
|
||||
|
||||
class VarArgsInstruction(Instruction):
|
||||
def has_var_args(self):
|
||||
return True
|
||||
|
||||
@@ -219,6 +219,9 @@ def crash(condition=None):
|
||||
:param condition: crash if true (default: true)
|
||||
|
||||
"""
|
||||
if isinstance(condition, localint):
|
||||
# allow crash on local values
|
||||
condition = condition._v
|
||||
if condition == None:
|
||||
condition = regint(1)
|
||||
instructions.crash(regint.conv(condition))
|
||||
@@ -1347,6 +1350,8 @@ def while_loop(loop_body, condition, arg, g=None):
|
||||
arg = regint(arg)
|
||||
def loop_fn():
|
||||
result = loop_body(arg)
|
||||
if isinstance(result, MemValue):
|
||||
result = result.read()
|
||||
result.link(arg)
|
||||
cont = condition(result)
|
||||
return cont
|
||||
@@ -1531,6 +1536,8 @@ def if_(condition):
|
||||
def if_e(condition):
|
||||
"""
|
||||
Conditional execution with else block.
|
||||
Use :py:class:`~Compiler.types.MemValue` to assign values that
|
||||
live beyond.
|
||||
|
||||
:param condition: regint/cint/int
|
||||
|
||||
@@ -1538,12 +1545,13 @@ def if_e(condition):
|
||||
|
||||
.. code::
|
||||
|
||||
y = MemValue(0)
|
||||
@if_e(x > 0)
|
||||
def _():
|
||||
...
|
||||
y.write(1)
|
||||
@else_
|
||||
def _():
|
||||
...
|
||||
y.write(0)
|
||||
"""
|
||||
try:
|
||||
condition = bool(condition)
|
||||
@@ -1647,11 +1655,18 @@ def get_player_id():
|
||||
return res
|
||||
|
||||
def listen_for_clients(port):
|
||||
""" Listen for clients on specific port. """
|
||||
""" Listen for clients on specific port base.
|
||||
|
||||
:param port: port base (int/regint/cint)
|
||||
"""
|
||||
instructions.listen(regint.conv(port))
|
||||
|
||||
def accept_client_connection(port):
|
||||
""" Listen for clients on specific port. """
|
||||
""" Accept client connection on specific port base.
|
||||
|
||||
:param port: port base (int/regint/cint)
|
||||
:returns: client id
|
||||
"""
|
||||
res = regint()
|
||||
instructions.acceptclientconnection(res, regint.conv(port))
|
||||
return res
|
||||
|
||||
@@ -1810,6 +1810,7 @@ class Optimizer:
|
||||
self.print_loss_reduction = False
|
||||
self.i_epoch = MemValue(0)
|
||||
self.stopped_on_loss = MemValue(0)
|
||||
self.stopped_on_low_loss = MemValue(0)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
@@ -1932,6 +1933,7 @@ class Optimizer:
|
||||
""" Run training.
|
||||
|
||||
:param batch_size: batch size (defaults to example size of first layer)
|
||||
:param stop_on_loss: stop when loss falls below this (default: 0)
|
||||
"""
|
||||
if self.n_epochs == 0:
|
||||
return
|
||||
@@ -2013,6 +2015,7 @@ class Optimizer:
|
||||
if self.tol > 0:
|
||||
res *= (1 - (loss_sum >= 0) * \
|
||||
(loss_sum < self.tol * n_per_epoch)).reveal()
|
||||
self.stopped_on_low_loss.write(1 - res)
|
||||
return res
|
||||
|
||||
def reveal_correctness(self, data, truth, batch_size):
|
||||
@@ -2138,6 +2141,7 @@ class Optimizer:
|
||||
if depreciation:
|
||||
self.gamma.imul(depreciation)
|
||||
print_ln('reducing learning rate to %s', self.gamma)
|
||||
return 1 - self.stopped_on_low_loss
|
||||
if 'model_output' in program.args:
|
||||
self.output_weights()
|
||||
|
||||
@@ -2386,6 +2390,7 @@ class keras:
|
||||
return list(self.opt.thetas)
|
||||
|
||||
def build(self, input_shape, batch_size=128):
|
||||
data_input_shape = input_shape
|
||||
if self.opt != None and \
|
||||
input_shape == self.opt.layers[0].X.sizes and \
|
||||
batch_size <= self.batch_size and \
|
||||
@@ -2458,9 +2463,10 @@ class keras:
|
||||
else:
|
||||
raise Exception(layer[0] + ' not supported')
|
||||
if layers[-1].d_out == 1:
|
||||
layers.append(Output(input_shape[0]))
|
||||
layers.append(Output(data_input_shape[0]))
|
||||
else:
|
||||
layers.append(MultiOutput(input_shape[0], layers[-1].d_out))
|
||||
layers.append(
|
||||
MultiOutput(data_input_shape[0], layers[-1].d_out))
|
||||
if self.optimizer[1]:
|
||||
raise Exception('use keyword arguments for optimizer')
|
||||
opt = self.optimizer[0]
|
||||
@@ -2504,7 +2510,7 @@ class keras:
|
||||
if x.total_size() != self.opt.layers[0].X.total_size():
|
||||
raise Exception('sample data size mismatch')
|
||||
if y.total_size() != self.opt.layers[-1].Y.total_size():
|
||||
print (y, layers[-1].Y)
|
||||
print (y, self.opt.layers[-1].Y)
|
||||
raise Exception('label size mismatch')
|
||||
if validation_data == None:
|
||||
validation_data = None, None
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from .comparison import *
|
||||
from .floatingpoint import *
|
||||
from .types import *
|
||||
from . import comparison
|
||||
from . import comparison, program
|
||||
|
||||
class NonLinear:
|
||||
kappa = None
|
||||
@@ -30,6 +30,15 @@ class NonLinear:
|
||||
def trunc_pr(self, a, k, m, signed=True):
|
||||
if isinstance(a, types.cint):
|
||||
return shift_two(a, m)
|
||||
prog = program.Program.prog
|
||||
if prog.use_trunc_pr:
|
||||
if signed and prog.use_trunc_pr != -1:
|
||||
a += (1 << (k - 1))
|
||||
res = sint()
|
||||
trunc_pr(res, a, k, m)
|
||||
if signed and prog.use_trunc_pr != -1:
|
||||
res -= (1 << (k - m - 1))
|
||||
return res
|
||||
return self._trunc_pr(a, k, m, signed)
|
||||
|
||||
def trunc_round_nearest(self, a, k, m, signed):
|
||||
@@ -44,6 +53,9 @@ class NonLinear:
|
||||
return a
|
||||
return self._trunc(a, k, m, signed)
|
||||
|
||||
def ltz(self, a, k, kappa=None):
|
||||
return -self.trunc(a, k, k - 1, kappa, True)
|
||||
|
||||
class Masking(NonLinear):
|
||||
def eqz(self, a, k):
|
||||
c, r = self._mask(a, k)
|
||||
@@ -100,42 +112,44 @@ class KnownPrime(NonLinear):
|
||||
def _mod2m(self, a, k, m, signed):
|
||||
if signed:
|
||||
a += cint(1) << (k - 1)
|
||||
return sint.bit_compose(self.bit_dec(a, k, k, True)[:m])
|
||||
return sint.bit_compose(self.bit_dec(a, k, m, True))
|
||||
|
||||
def _trunc_pr(self, a, k, m, signed):
|
||||
# nearest truncation
|
||||
return self.trunc_round_nearest(a, k, m, signed)
|
||||
|
||||
def _trunc(self, a, k, m, signed=None):
|
||||
if signed:
|
||||
a += cint(1) << (k - 1)
|
||||
res = sint.bit_compose(self.bit_dec(a, k, k, True)[m:])
|
||||
if signed:
|
||||
res -= cint(1) << (k - 1 - m)
|
||||
return res
|
||||
return TruncZeros(a - self._mod2m(a, k, m, signed), k, m, signed)
|
||||
|
||||
def trunc_round_nearest(self, a, k, m, signed):
|
||||
a += cint(1) << (m - 1)
|
||||
if signed:
|
||||
a += cint(1) << (k - 1)
|
||||
k += 1
|
||||
res = sint.bit_compose(self.bit_dec(a, k, k, True)[m:])
|
||||
res = self._trunc(a, k, m, False)
|
||||
if signed:
|
||||
res -= cint(1) << (k - m - 2)
|
||||
return res
|
||||
|
||||
def bit_dec(self, a, k, m, maybe_mixed=False):
|
||||
assert k < self.prime.bit_length()
|
||||
bits = BitDecFull(a, maybe_mixed=maybe_mixed)
|
||||
if len(bits) < m:
|
||||
raise CompilerError('%d has fewer than %d bits' % (self.prime, m))
|
||||
return bits[:m]
|
||||
bits = BitDecFull(a, m, maybe_mixed=maybe_mixed)
|
||||
assert len(bits) == m
|
||||
return bits
|
||||
|
||||
def eqz(self, a, k):
|
||||
# always signed
|
||||
a += two_power(k)
|
||||
return 1 - types.sintbit.conv(KORL(self.bit_dec(a, k, k, True)))
|
||||
|
||||
def ltz(self, a, k, kappa=None):
|
||||
if k + 1 < self.prime.bit_length():
|
||||
# https://dl.acm.org/doi/10.1145/3474123.3486757
|
||||
# "negative" values wrap around when doubling, thus becoming odd
|
||||
return self.mod2m(2 * a, k + 1, 1, False)
|
||||
else:
|
||||
return super(KnownPrime, self).ltz(a, k, kappa)
|
||||
|
||||
class Ring(Masking):
|
||||
""" Non-linear functionality modulo a power of two known at compile time.
|
||||
"""
|
||||
@@ -172,3 +186,6 @@ class Ring(Masking):
|
||||
return TruncRing(None, tmp + 1, k - m + 1, 1, signed)
|
||||
else:
|
||||
return super(Ring, self).trunc_round_nearest(a, k, m, signed)
|
||||
|
||||
def ltz(self, a, k, kappa=None):
|
||||
return LtzRing(a, k)
|
||||
|
||||
@@ -578,6 +578,15 @@ class Program(object):
|
||||
self.warn_about_mem.append(False)
|
||||
self.curr_block.warn_about_mem = False
|
||||
|
||||
@staticmethod
|
||||
def read_tapes(schedule):
|
||||
if not os.path.exists(schedule):
|
||||
schedule = 'Programs/Schedules/%s.sch' % schedule
|
||||
|
||||
lines = open(schedule).readlines()
|
||||
for tapename in lines[2].split(' '):
|
||||
yield tapename.strip()
|
||||
|
||||
class Tape:
|
||||
""" A tape contains a list of basic blocks, onto which instructions are added. """
|
||||
def __init__(self, name, program):
|
||||
@@ -1109,7 +1118,20 @@ class Tape:
|
||||
else:
|
||||
self.req_bit_length[t] = max(bit_length, self.req_bit_length)
|
||||
|
||||
class Register(object):
|
||||
@staticmethod
|
||||
def read_instructions(tapename):
|
||||
tape = open('Programs/Bytecode/%s.bc' % tapename, 'rb')
|
||||
while tape.peek():
|
||||
yield inst_base.ParsedInstruction(tape)
|
||||
|
||||
class _no_truth(object):
|
||||
__slots__ = []
|
||||
|
||||
def __bool__(self):
|
||||
raise CompilerError('Cannot derive truth value from register, '
|
||||
"consider using 'compile.py -l'")
|
||||
|
||||
class Register(_no_truth):
|
||||
"""
|
||||
Class for creating new registers. The register's index is automatically assigned
|
||||
based on the block's reg_counter dictionary.
|
||||
@@ -1233,10 +1255,6 @@ class Tape:
|
||||
self.reg_type == RegType.ClearGF2N or \
|
||||
self.reg_type == RegType.ClearInt
|
||||
|
||||
def __bool__(self):
|
||||
raise CompilerError('Cannot derive truth value from register, '
|
||||
"consider using 'compile.py -l'")
|
||||
|
||||
def __str__(self):
|
||||
return self.reg_type + str(self.i)
|
||||
|
||||
|
||||
@@ -127,7 +127,7 @@ def vectorize(operation):
|
||||
if (isinstance(args[0], Tape.Register) or isinstance(args[0], sfloat)) \
|
||||
and not isinstance(args[0], bits) \
|
||||
and args[0].size != self.size:
|
||||
raise CompilerError('Different vector sizes of operands: %d/%d'
|
||||
raise VectorMismatch('Different vector sizes of operands: %d/%d'
|
||||
% (self.size, args[0].size))
|
||||
set_global_vector_size(self.size)
|
||||
try:
|
||||
@@ -221,7 +221,7 @@ def inputmixed(*args):
|
||||
else:
|
||||
instructions.inputmixedreg(*(args[:-1] + (regint.conv(args[-1]),)))
|
||||
|
||||
class _number(object):
|
||||
class _number(Tape._no_truth):
|
||||
""" Number functionality. """
|
||||
|
||||
def square(self):
|
||||
@@ -246,7 +246,11 @@ class _number(object):
|
||||
elif is_one(other):
|
||||
return self
|
||||
else:
|
||||
return self.mul(other)
|
||||
try:
|
||||
return self.mul(other)
|
||||
except VectorMismatch:
|
||||
# try reverse multiplication
|
||||
return NotImplemented
|
||||
|
||||
__radd__ = __add__
|
||||
__rmul__ = __mul__
|
||||
@@ -320,7 +324,7 @@ class _number(object):
|
||||
def popcnt_bits(bits):
|
||||
return sum(bits)
|
||||
|
||||
class _int(object):
|
||||
class _int(Tape._no_truth):
|
||||
""" Integer functionality. """
|
||||
|
||||
@staticmethod
|
||||
@@ -408,7 +412,7 @@ class _int(object):
|
||||
def long_one():
|
||||
return 1
|
||||
|
||||
class _bit(object):
|
||||
class _bit(Tape._no_truth):
|
||||
""" Binary functionality. """
|
||||
|
||||
def bit_xor(self, other):
|
||||
@@ -474,7 +478,7 @@ class _gf2n(_bit):
|
||||
def bit_not(self):
|
||||
return self ^ 1
|
||||
|
||||
class _structure(object):
|
||||
class _structure(Tape._no_truth):
|
||||
""" Interface for type-dependent container types. """
|
||||
|
||||
MemValue = classmethod(lambda cls, value: MemValue(cls.conv(value)))
|
||||
@@ -591,7 +595,7 @@ class _secret_structure(_structure):
|
||||
res.input_from(player)
|
||||
return res
|
||||
|
||||
class _vec(object):
|
||||
class _vec(Tape._no_truth):
|
||||
def link(self, other):
|
||||
assert len(self.v) == len(other.v)
|
||||
for x, y in zip(self.v, other.v):
|
||||
@@ -726,7 +730,7 @@ class _register(Tape.Register, _number, _structure):
|
||||
assert self.size == 1
|
||||
res = type(self)(size=size)
|
||||
for i in range(size):
|
||||
movs(res[i], self)
|
||||
self.mov(res[i], self)
|
||||
return res
|
||||
|
||||
class _clear(_register):
|
||||
@@ -1010,9 +1014,10 @@ class cint(_clear, _int):
|
||||
if bit_length <= 64:
|
||||
return regint(self) < regint(other)
|
||||
else:
|
||||
sint.require_bit_length(bit_length + 1)
|
||||
diff = self - other
|
||||
diff += (1 << (bit_length - 1))
|
||||
shifted = diff >> (bit_length - 1)
|
||||
diff += 1 << bit_length
|
||||
shifted = diff >> bit_length
|
||||
res = 1 - regint(shifted & 1)
|
||||
return res
|
||||
|
||||
@@ -1646,7 +1651,7 @@ class regint(_register, _int):
|
||||
player = -1
|
||||
intoutput(player, self)
|
||||
|
||||
class localint(object):
|
||||
class localint(Tape._no_truth):
|
||||
""" Local integer that must prevented from leaking into the secure
|
||||
computation. Uses regint internally.
|
||||
|
||||
@@ -1669,7 +1674,7 @@ class localint(object):
|
||||
__eq__ = lambda self, other: localint(self._v == other)
|
||||
__ne__ = lambda self, other: localint(self._v != other)
|
||||
|
||||
class personal(object):
|
||||
class personal(Tape._no_truth):
|
||||
def __init__(self, player, value):
|
||||
assert value is not NotImplemented
|
||||
assert not isinstance(value, _secret)
|
||||
@@ -2003,9 +2008,11 @@ class _secret(_register, _secret_structure):
|
||||
size or one size 1 for a value-vector multiplication.
|
||||
|
||||
:param other: any compatible type """
|
||||
if isinstance(other, _secret) and (1 in (self.size, other.size)) \
|
||||
if isinstance(other, _register) and (1 in (self.size, other.size)) \
|
||||
and (self.size, other.size) != (1, 1):
|
||||
x, y = (other, self) if self.size < other.size else (self, other)
|
||||
if not isinstance(other, _secret):
|
||||
return y.expand_to_vector(x.size) * x
|
||||
res = type(self)(size=x.size)
|
||||
mulrs(res, x, y)
|
||||
return res
|
||||
@@ -2221,11 +2228,13 @@ class sint(_secret, _int):
|
||||
@vectorized_classmethod
|
||||
def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType):
|
||||
""" Securely obtain shares of values input by a client.
|
||||
This uses the triple-based input protocol introduced by
|
||||
`Damgård et al. <http://eprint.iacr.org/2015/1006>`_
|
||||
|
||||
:param n: number of inputs (int)
|
||||
:param client_id: regint
|
||||
:param size: vector size (default 1)
|
||||
|
||||
:returns: list of sint
|
||||
"""
|
||||
# send shares of a triple to client
|
||||
triples = list(itertools.chain(*(sint.get_random_triple() for i in range(n))))
|
||||
@@ -2910,7 +2919,7 @@ for t in (sint, sgf2n):
|
||||
sint.bit_type = sintbit
|
||||
sgf2n.bit_type = sgf2n
|
||||
|
||||
class _bitint(object):
|
||||
class _bitint(Tape._no_truth):
|
||||
bits = None
|
||||
log_rounds = False
|
||||
linear_rounds = False
|
||||
@@ -3521,6 +3530,7 @@ class cfix(_number, _structure):
|
||||
|
||||
@classmethod
|
||||
def _new(cls, other, k=None, f=None):
|
||||
assert not isinstance(other, (list, tuple))
|
||||
res = cls(k=k, f=f)
|
||||
res.v = cint.conv(other)
|
||||
return res
|
||||
@@ -3567,6 +3577,8 @@ class cfix(_number, _structure):
|
||||
return len(self.v)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if isinstance(index, slice):
|
||||
return [self._new(x, k=self.k, f=self.f) for x in self.v[index]]
|
||||
return self._new(self.v[index], k=self.k, f=self.f)
|
||||
|
||||
@vectorize
|
||||
@@ -3608,7 +3620,6 @@ class cfix(_number, _structure):
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
@vectorize
|
||||
def mul(self, other):
|
||||
""" Clear fixed-point multiplication.
|
||||
|
||||
@@ -4045,7 +4056,8 @@ class _fix(_single):
|
||||
'for fixed-point computation')
|
||||
cls.round_nearest = True
|
||||
if adapt_ring and program.options.ring \
|
||||
and 'fix_ring' not in program.args:
|
||||
and 'fix_ring' not in program.args \
|
||||
and 2 * cls.k > int(program.options.ring):
|
||||
need = 2 ** int(math.ceil(math.log(2 * cls.k, 2)))
|
||||
if need != int(program.options.ring):
|
||||
print('Changing computation modulus to 2^%d' % need)
|
||||
@@ -4489,7 +4501,7 @@ class squant(_single):
|
||||
def __neg__(self):
|
||||
return self._new(-self.v + 2 * util.expand(self.Z, self.v.size))
|
||||
|
||||
class _unreduced_squant(object):
|
||||
class _unreduced_squant(Tape._no_truth):
|
||||
def __init__(self, v, params, res_params=None, n_summands=1):
|
||||
self.v = v
|
||||
self.params = params
|
||||
@@ -5011,7 +5023,7 @@ class sfloat(_number, _secret_structure):
|
||||
:return: cfloat """
|
||||
return cfloat(self.v.reveal(), self.p.reveal(), self.z.reveal(), self.s.reveal())
|
||||
|
||||
class cfloat(object):
|
||||
class cfloat(Tape._no_truth):
|
||||
""" Helper class for printing revealed sfloats. """
|
||||
__slots__ = ['v', 'p', 'z', 's', 'nan']
|
||||
|
||||
|
||||
@@ -52,10 +52,10 @@ void run(int argc, const char** argv)
|
||||
P.unchecked_broadcast(bundle);
|
||||
Timer timer;
|
||||
timer.start();
|
||||
auto stats = P.comm_stats;
|
||||
auto stats = P.total_comm();
|
||||
pShare sk = typename T<P256Element::Scalar>::Honest::Protocol(P).get_random();
|
||||
cout << "Secret key generation took " << timer.elapsed() * 1e3 << " ms" << endl;
|
||||
(P.comm_stats - stats).print(true);
|
||||
(P.total_comm() - stats).print(true);
|
||||
|
||||
OnlineOptions::singleton.batch_size = (1 + pShare::Protocol::uses_triples) * n_tuples;
|
||||
DataPositions usage;
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
|
||||
#define NO_MIXED_CIRCUITS
|
||||
|
||||
#define NO_SECURITY_CHECK
|
||||
|
||||
#include "GC/TinierSecret.h"
|
||||
#include "GC/TinyMC.h"
|
||||
#include "GC/VectorInput.h"
|
||||
|
||||
@@ -113,10 +113,10 @@ void run(int argc, const char** argv)
|
||||
P.unchecked_broadcast(bundle);
|
||||
Timer timer;
|
||||
timer.start();
|
||||
auto stats = P.comm_stats;
|
||||
auto stats = P.total_comm();
|
||||
sk_prep.get_two(DATA_INVERSE, sk, __);
|
||||
cout << "Secret key generation took " << timer.elapsed() * 1e3 << " ms" << endl;
|
||||
(P.comm_stats - stats).print(true);
|
||||
(P.total_comm() - stats).print(true);
|
||||
|
||||
OnlineOptions::singleton.batch_size = (1 + pShare::Protocol::uses_triples) * n_tuples;
|
||||
typename pShare::TriplePrep prep(0, usage);
|
||||
|
||||
@@ -41,8 +41,8 @@ void preprocessing(vector<EcTuple<T>>& tuples, int buffer_size,
|
||||
timer.start();
|
||||
Player& P = proc.P;
|
||||
auto& prep = proc.DataF;
|
||||
size_t start = P.sent + prep.data_sent();
|
||||
auto stats = P.comm_stats + prep.comm_stats();
|
||||
size_t start = P.total_comm().sent;
|
||||
auto stats = P.total_comm();
|
||||
auto& extra_player = P;
|
||||
|
||||
auto& protocol = proc.protocol;
|
||||
@@ -77,7 +77,7 @@ void preprocessing(vector<EcTuple<T>>& tuples, int buffer_size,
|
||||
MCc.POpen_Begin(opened_Rs, secret_Rs, extra_player);
|
||||
if (prep_mul)
|
||||
{
|
||||
protocol.init_mul(&proc);
|
||||
protocol.init_mul();
|
||||
for (int i = 0; i < buffer_size; i++)
|
||||
protocol.prepare_mul(inv_ks[i], sk);
|
||||
protocol.start_exchange();
|
||||
@@ -106,9 +106,9 @@ void preprocessing(vector<EcTuple<T>>& tuples, int buffer_size,
|
||||
timer.stop();
|
||||
cout << "Generated " << buffer_size << " tuples in " << timer.elapsed()
|
||||
<< " seconds, throughput " << buffer_size / timer.elapsed() << ", "
|
||||
<< 1e-3 * (P.sent + prep.data_sent() - start) / buffer_size
|
||||
<< 1e-3 * (P.total_comm().sent - start) / buffer_size
|
||||
<< " kbytes per tuple" << endl;
|
||||
(P.comm_stats + prep.comm_stats() - stats).print(true);
|
||||
(P.total_comm() - stats).print(true);
|
||||
}
|
||||
|
||||
template<template<class U> class T>
|
||||
|
||||
@@ -61,8 +61,7 @@ EcSignature sign(const unsigned char* message, size_t length,
|
||||
(void) pk;
|
||||
Timer timer;
|
||||
timer.start();
|
||||
size_t start = P.sent;
|
||||
auto stats = P.comm_stats;
|
||||
auto stats = P.total_comm();
|
||||
EcSignature signature;
|
||||
vector<P256Element> opened_R;
|
||||
if (opts.R_after_msg)
|
||||
@@ -71,7 +70,7 @@ EcSignature sign(const unsigned char* message, size_t length,
|
||||
auto& protocol = proc->protocol;
|
||||
if (proc)
|
||||
{
|
||||
protocol.init_mul(proc);
|
||||
protocol.init_mul();
|
||||
protocol.prepare_mul(sk, tuple.a);
|
||||
protocol.start_exchange();
|
||||
}
|
||||
@@ -91,9 +90,9 @@ EcSignature sign(const unsigned char* message, size_t length,
|
||||
auto rx = tuple.R.x();
|
||||
signature.s = MC.open(
|
||||
tuple.a * hash_to_scalar(message, length) + prod * rx, P);
|
||||
auto diff = (P.total_comm() - stats);
|
||||
cout << "Minimal signing took " << timer.elapsed() * 1e3 << " ms and sending "
|
||||
<< (P.sent - start) << " bytes" << endl;
|
||||
auto diff = (P.comm_stats - stats);
|
||||
<< diff.sent << " bytes" << endl;
|
||||
diff.print(true);
|
||||
return signature;
|
||||
}
|
||||
@@ -139,11 +138,11 @@ void sign_benchmark(vector<EcTuple<T>>& tuples, T<P256Element::Scalar> sk,
|
||||
P.unchecked_broadcast(bundle);
|
||||
Timer timer;
|
||||
timer.start();
|
||||
auto stats = P.comm_stats;
|
||||
auto stats = P.total_comm();
|
||||
P256Element pk = MCc.open(sk, P);
|
||||
MCc.Check(P);
|
||||
cout << "Public key generation took " << timer.elapsed() * 1e3 << " ms" << endl;
|
||||
(P.comm_stats - stats).print(true);
|
||||
(P.total_comm() - stats).print(true);
|
||||
|
||||
for (size_t i = 0; i < min(10lu, tuples.size()); i++)
|
||||
{
|
||||
@@ -154,13 +153,12 @@ void sign_benchmark(vector<EcTuple<T>>& tuples, T<P256Element::Scalar> sk,
|
||||
Timer timer;
|
||||
timer.start();
|
||||
auto& check_player = MCp.get_check_player(P);
|
||||
auto stats = check_player.comm_stats;
|
||||
auto start = check_player.sent;
|
||||
auto stats = check_player.total_comm();
|
||||
MCp.Check(P);
|
||||
MCc.Check(P);
|
||||
auto diff = (check_player.total_comm() - stats);
|
||||
cout << "Online checking took " << timer.elapsed() * 1e3 << " ms and sending "
|
||||
<< (check_player.sent - start) << " bytes" << endl;
|
||||
auto diff = (check_player.comm_stats - stats);
|
||||
<< diff.sent << " bytes" << endl;
|
||||
diff.print();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,9 @@
|
||||
|
||||
#include "Networking/ssl_sockets.h"
|
||||
|
||||
/**
|
||||
* Client-side interface
|
||||
*/
|
||||
class Client
|
||||
{
|
||||
vector<int> plain_sockets;
|
||||
@@ -15,15 +18,37 @@ class Client
|
||||
ssl_service io_service;
|
||||
|
||||
public:
|
||||
/**
|
||||
* Sockets for cleartext communication
|
||||
*/
|
||||
vector<ssl_socket*> sockets;
|
||||
|
||||
/**
|
||||
* Specification of computation domain
|
||||
*/
|
||||
octetStream specification;
|
||||
|
||||
/**
|
||||
* Start a new set of connections to computing parties.
|
||||
* @param hostnames location of computing parties
|
||||
* @param port_base port base
|
||||
* @param my_client_id client identifier
|
||||
*/
|
||||
Client(const vector<string>& hostnames, int port_base, int my_client_id);
|
||||
~Client();
|
||||
|
||||
/**
|
||||
* Securely input private values.
|
||||
* @param values vector of integer-like values
|
||||
*/
|
||||
template<class T>
|
||||
void send_private_inputs(const vector<T>& values);
|
||||
|
||||
/**
|
||||
* Securely receive output values.
|
||||
* @param n number of values
|
||||
* @returns vector of integer-like values
|
||||
*/
|
||||
template<class T>
|
||||
vector<T> receive_outputs(int n);
|
||||
};
|
||||
|
||||
@@ -19,6 +19,8 @@ Scripts/<protocol>.sh bankers_bonus-1 &
|
||||
./bankers-bonus-client.x 1 <nparties> 200 0 &
|
||||
./bankers-bonus-client.x 2 <nparties> 50 1
|
||||
```
|
||||
`<protocol>` can be any arithmetic protocol (e.g., `mascot`) but not a
|
||||
binary protocol (e.g., `yao`).
|
||||
This should output that the winning id is 1. Note that the ids have to
|
||||
be incremental, and the client with the highest id has to input 1 as
|
||||
the last argument while the others have to input 0 there. Furthermore,
|
||||
@@ -32,54 +34,21 @@ different hosts, you will have to distribute the `*.pem` files.
|
||||
|
||||
### Connection Setup
|
||||
|
||||
**listen**(*int port_num*)
|
||||
|
||||
Setup a socket server to listen for client connections. Runs in own thread so once created clients will be able to connect in the background.
|
||||
|
||||
*port_num* - the port number to listen on.
|
||||
|
||||
**acceptclientconnection**(*regint client_socket_id*, *int port_num*)
|
||||
|
||||
Picks the first available client socket connection. Blocks if none available.
|
||||
|
||||
*client_socket_id* - an identifier used to refer to the client socket.
|
||||
|
||||
*port_num* - the port number identifies the socket server to accept connections on.
|
||||
1. [Listen for clients](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.library.listen_for_clients)
|
||||
2. [Accept client connections](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.library.accept_client_connection)
|
||||
3. [Close client connections](https://mp-spdz.readthedocs.io/en/latest/instructions.html#Compiler.instructions.closeclientconnection)
|
||||
|
||||
### Data Exchange
|
||||
|
||||
Only the sint methods are documented here, equivalent methods are available for the other data types **cfix**, **cint** and **regint**. See implementation details in [types.py](../Compiler/types.py).
|
||||
Only the `sint` methods used in the example are documented here, equivalent methods are available for other data types. See [the reference](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#module-Compiler.types).
|
||||
|
||||
*[sint inputs]* **sint.read_from_socket**(*regint client_socket_id*, *int number_of_inputs*)
|
||||
1. [Public value from client](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.regint.read_from_socket)
|
||||
2. [Secret value from client](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sint.receive_from_client)
|
||||
3. [Reveal secret value to clients](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sint.reveal_to_clients)
|
||||
|
||||
Read a share of an input from a client, blocking on the client send.
|
||||
## Client-Side Interface
|
||||
|
||||
*client_socket_id* - an identifier used to refer to the client socket.
|
||||
|
||||
*number_of_inputs* - the number of inputs expected
|
||||
|
||||
*[inputs]* - returned list of shares of private input.
|
||||
|
||||
**sint.write_to_socket**(*regint client_socket_id*, *[sint values]*, *int message_type*)
|
||||
|
||||
Write shares of values including macs to an external client.
|
||||
|
||||
*client_socket_id* - an identifier used to refer to the client socket.
|
||||
|
||||
*[values]* - list of shares of values to send to client.
|
||||
|
||||
*message_type* - optional integer which will be sent in first 4 bytes of message, to indicate message type to client.
|
||||
|
||||
See also sint.write_shares_to_socket where macs can be explicitly included or excluded from the message.
|
||||
|
||||
*[sint inputs]* **sint.receive_from_client**(*int number_of_inputs*, *regint client_socket_id*, *int message_type*)
|
||||
|
||||
Receive shares of private inputs from a client, blocking on client send. This is an abstraction which first sends shares of random values to the client and then receives masked input from the client, using the input protocol introduced in [Confidential Benchmarking based on Multiparty Computation. Damgard et al.](http://eprint.iacr.org/2015/1006.pdf)
|
||||
|
||||
*number_of_inputs* - the number of inputs expected
|
||||
|
||||
*client_socket_id* - an identifier used to refer to the client socket.
|
||||
|
||||
*message_type* - optional integer which will be sent in first 4 bytes of message, to indicate message type to client.
|
||||
|
||||
*[inputs]* - returned list of shares of private input.
|
||||
The example uses the `Client` class implemented in
|
||||
`ExternalIO/Client.hpp` to handle the communication, see
|
||||
https://mp-spdz.readthedocs.io/en/latest/io.html#reference for
|
||||
documentation.
|
||||
|
||||
@@ -33,9 +33,6 @@ class FHE_Params
|
||||
|
||||
int n_mults() const { return FFTData.size() - 1; }
|
||||
|
||||
// Rely on default copy assignment/constructor (not that they should
|
||||
// ever be needed)
|
||||
|
||||
void set(const Ring& R,const vector<bigint>& primes);
|
||||
void set(const vector<bigint>& primes);
|
||||
void set_sec(int sec);
|
||||
|
||||
@@ -178,12 +178,6 @@ int finalize_lengths(int& lg2p0, int& lg2p1, int n, int m, int* lg2pi,
|
||||
|
||||
return extra_slack;
|
||||
}
|
||||
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* Here onwards needs NTL
|
||||
******************************************************************************/
|
||||
|
||||
|
||||
|
||||
@@ -345,6 +339,7 @@ ZZX Cyclotomic(int N)
|
||||
return F;
|
||||
}
|
||||
#else
|
||||
// simplified version powers of two
|
||||
int phi_N(int N)
|
||||
{
|
||||
if (((N - 1) & N) != 0)
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
#ifndef _NTL_Subs
|
||||
#define _NTL_Subs
|
||||
|
||||
/* All these routines use NTL on the inside */
|
||||
|
||||
#include "FHE/Ring.h"
|
||||
#include "FHE/FFT_Data.h"
|
||||
#include "FHE/P2Data.h"
|
||||
@@ -47,7 +45,7 @@ public:
|
||||
|
||||
};
|
||||
|
||||
// Main setup routine (need NTL if online_only is false)
|
||||
// Main setup routine
|
||||
void generate_setup(int nparties, int lgp, int lg2,
|
||||
int sec, bool skip_2 = false, int slack = 0, bool round_up = false);
|
||||
|
||||
@@ -60,7 +58,6 @@ int generate_semi_setup(int plaintext_length, int sec,
|
||||
int common_semi_setup(FHE_Params& params, int m, bigint p, int lgp0, int lgp1,
|
||||
bool round_up);
|
||||
|
||||
// Everything else needs NTL
|
||||
void init(Ring& Rg, int m, bool generate_poly);
|
||||
void init(P2Data& P2D,const Ring& Rg);
|
||||
|
||||
|
||||
@@ -114,7 +114,6 @@ NoiseBounds::NoiseBounds(const bigint& p, int phi_m, int n, int sec, int slack,
|
||||
cout << "n: " << n << endl;
|
||||
cout << "sec: " << sec << endl;
|
||||
cout << "sigma: " << this->sigma << endl;
|
||||
cout << "h: " << h << endl;
|
||||
cout << "B_clean size: " << numBits(B_clean) << endl;
|
||||
cout << "B_scale size: " << numBits(B_scale) << endl;
|
||||
cout << "B_KS size: " << numBits(B_KS) << endl;
|
||||
|
||||
@@ -401,19 +401,29 @@ void Ring_Element::change_rep(RepType r)
|
||||
|
||||
bool Ring_Element::equals(const Ring_Element& a) const
|
||||
{
|
||||
if (element.empty() and a.element.empty())
|
||||
return true;
|
||||
else if (element.empty() or a.element.empty())
|
||||
throw not_implemented();
|
||||
|
||||
if (rep!=a.rep) { throw rep_mismatch(); }
|
||||
if (*FFTD!=*a.FFTD) { throw pr_mismatch(); }
|
||||
|
||||
if (is_zero() or a.is_zero())
|
||||
return is_zero() and a.is_zero();
|
||||
|
||||
for (int i=0; i<(*FFTD).phi_m(); i++)
|
||||
{ if (!areEqual(element[i],a.element[i],(*FFTD).get_prD())) { return false; } }
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
bool Ring_Element::is_zero() const
|
||||
{
|
||||
if (element.empty())
|
||||
return true;
|
||||
for (auto& x : element)
|
||||
if (not ::isZero(x, FFTD->get_prD()))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
ConversionIterator Ring_Element::get_iterator() const
|
||||
{
|
||||
if (rep != polynomial)
|
||||
@@ -560,6 +570,8 @@ void Ring_Element::check(const FFT_Data& FFTD) const
|
||||
{
|
||||
if (&FFTD != this->FFTD)
|
||||
throw params_mismatch();
|
||||
if (is_zero())
|
||||
throw runtime_error("element is zero");
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -95,6 +95,7 @@ class Ring_Element
|
||||
void randomize(PRNG& G,bool Diag=false);
|
||||
|
||||
bool equals(const Ring_Element& a) const;
|
||||
bool is_zero() const;
|
||||
|
||||
// This is a NOP in cases where we cannot do a FFT
|
||||
void change_rep(RepType r);
|
||||
|
||||
@@ -175,7 +175,7 @@ size_t PairwiseGenerator<FD>::report_size(ReportType type)
|
||||
template <class FD>
|
||||
size_t PairwiseGenerator<FD>::report_sent()
|
||||
{
|
||||
return P.sent;
|
||||
return P.total_comm().sent;
|
||||
}
|
||||
|
||||
template <class FD>
|
||||
|
||||
@@ -71,7 +71,7 @@ public:
|
||||
void run(bool exhaust);
|
||||
size_t report_size(ReportType type);
|
||||
void report_size(ReportType type, MemoryUsage& res);
|
||||
size_t report_sent() { return P.sent; }
|
||||
size_t report_sent() { return P.total_comm().sent; }
|
||||
};
|
||||
|
||||
#endif /* FHEOFFLINE_SIMPLEGENERATOR_H_ */
|
||||
|
||||
@@ -96,7 +96,7 @@ void BitAdder::add(vector<vector<T> >& res,
|
||||
b[j] = summands[i][1][input_begin + j];
|
||||
}
|
||||
|
||||
protocol.init_mul(&proc);
|
||||
protocol.init_mul();
|
||||
for (size_t j = 0; j < n_items; j++)
|
||||
{
|
||||
res[begin + j][i] = a[j] + b[j] + carries[j];
|
||||
|
||||
@@ -91,11 +91,6 @@ public:
|
||||
(typename T::clear(tmp.get_bit(0)) << i);
|
||||
}
|
||||
}
|
||||
|
||||
NamedCommStats comm_stats()
|
||||
{
|
||||
return part_prep.comm_stats();
|
||||
}
|
||||
};
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
@@ -25,6 +25,14 @@ void CcdPrep<T>::set_protocol(typename T::Protocol& protocol)
|
||||
{
|
||||
auto& thread = ShareThread<T>::s();
|
||||
assert(thread.MC);
|
||||
|
||||
if (part_proc)
|
||||
{
|
||||
assert(&part_proc->MC == &thread.MC->get_part_MC());
|
||||
assert(&part_proc->P == &protocol.get_part().P);
|
||||
return;
|
||||
}
|
||||
|
||||
part_proc = new SubProcessor<typename T::part_type>(
|
||||
thread.MC->get_part_MC(), part_prep, protocol.get_part().P);
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ public:
|
||||
typedef ShamirInput<This> Input;
|
||||
|
||||
typedef ShamirMC<This> MAC_Check;
|
||||
typedef Shamir<This> Protocol;
|
||||
|
||||
typedef This small_type;
|
||||
|
||||
|
||||
@@ -108,6 +108,9 @@ public:
|
||||
template<class U>
|
||||
static void convcbit2s(GC::Processor<U>&, const BaseInstruction&)
|
||||
{ throw runtime_error("convcbit2s not implemented"); }
|
||||
template<class U>
|
||||
static void andm(GC::Processor<U>&, const BaseInstruction&)
|
||||
{ throw runtime_error("andm not implemented"); }
|
||||
|
||||
static FakeSecret input(GC::Processor<FakeSecret>& processor, const InputArgs& args);
|
||||
static FakeSecret input(int from, word input, int n_bits);
|
||||
|
||||
@@ -84,7 +84,7 @@ void Instruction::parse(istream& s, int pos)
|
||||
ostringstream os;
|
||||
os << "Code not defined for instruction " << showbase << hex << opcode << dec << endl;
|
||||
os << "This virtual machine executes binary circuits only." << endl;
|
||||
os << "Try compiling with '-B' or use only sbit* types." << endl;
|
||||
os << "Use 'compile.py -B'." << endl;
|
||||
throw Invalid_Instruction(os.str());
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#define GC_NOSHARE_H_
|
||||
|
||||
#include "Processor/DummyProtocol.h"
|
||||
#include "Processor/Instruction.h"
|
||||
#include "Protocols/ShareInterface.h"
|
||||
|
||||
class InputArgs;
|
||||
@@ -148,11 +149,14 @@ public:
|
||||
|
||||
static void trans(Processor<NoShare>&, Integer, const vector<int>&) { fail(); }
|
||||
|
||||
static void andm(GC::Processor<NoShare>&, const BaseInstruction&) { fail(); }
|
||||
|
||||
static NoShare constant(const GC::Clear&, int, mac_key_type, int = -1) { fail(); return {}; }
|
||||
|
||||
NoShare() {}
|
||||
|
||||
NoShare(int) { fail(); }
|
||||
template<class T>
|
||||
NoShare(T) { fail(); }
|
||||
|
||||
void load_clear(Integer, Integer) { fail(); }
|
||||
void random_bit() { fail(); }
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
|
||||
#include "Protocols/Replicated.hpp"
|
||||
#include "Protocols/MaliciousRepMC.hpp"
|
||||
#include "Protocols/MalRepRingPrep.hpp"
|
||||
#include "ShareSecret.hpp"
|
||||
|
||||
namespace GC
|
||||
@@ -28,24 +29,19 @@ PostSacriBin::~PostSacriBin()
|
||||
}
|
||||
}
|
||||
|
||||
void PostSacriBin::init_mul(SubProcessor<T>* proc)
|
||||
{
|
||||
assert(proc != 0);
|
||||
init_mul(proc->DataF, proc->MC);
|
||||
}
|
||||
|
||||
void PostSacriBin::init_mul(Preprocessing<T>&, T::MC&)
|
||||
void PostSacriBin::init_mul()
|
||||
{
|
||||
if ((int) inputs.size() >= OnlineOptions::singleton.batch_size)
|
||||
check();
|
||||
honest.init_mul();
|
||||
}
|
||||
|
||||
PostSacriBin::T::clear PostSacriBin::prepare_mul(const T& x, const T& y, int n)
|
||||
void PostSacriBin::prepare_mul(const T& x, const T& y, int n)
|
||||
{
|
||||
if (n == -1)
|
||||
n = T::default_length;
|
||||
honest.prepare_mul(x, y, n);
|
||||
inputs.push_back({{x.mask(n), y.mask(n)}});
|
||||
return {};
|
||||
}
|
||||
|
||||
void PostSacriBin::exchange()
|
||||
@@ -55,6 +51,8 @@ void PostSacriBin::exchange()
|
||||
|
||||
PostSacriBin::T PostSacriBin::finalize_mul(int n)
|
||||
{
|
||||
if (n == -1)
|
||||
n = T::default_length;
|
||||
auto res = honest.finalize_mul(n);
|
||||
outputs.push_back({res, n});
|
||||
return res;
|
||||
|
||||
@@ -38,9 +38,8 @@ public:
|
||||
PostSacriBin(Player& P);
|
||||
~PostSacriBin();
|
||||
|
||||
void init_mul(Preprocessing<T>&, T::MC&);
|
||||
void init_mul(SubProcessor<T>* proc);
|
||||
T::clear prepare_mul(const T& x, const T& y, int n = -1);
|
||||
void init_mul();
|
||||
void prepare_mul(const T& x, const T& y, int n = -1);
|
||||
void exchange();
|
||||
T finalize_mul(int n = -1);
|
||||
|
||||
|
||||
@@ -3,6 +3,9 @@
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_REPPREP_HPP_
|
||||
#define GC_REPPREP_HPP_
|
||||
|
||||
#include "RepPrep.h"
|
||||
#include "ShareThread.h"
|
||||
#include "Processor/OnlineOptions.h"
|
||||
@@ -98,3 +101,5 @@ void RepPrep<T>::buffer_inputs(int player)
|
||||
}
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
#endif
|
||||
|
||||
@@ -126,6 +126,9 @@ public:
|
||||
template<class U>
|
||||
static void convcbit2s(Processor<U>& processor, const BaseInstruction& instruction)
|
||||
{ T::convcbit2s(processor, instruction); }
|
||||
template<class U>
|
||||
static void andm(Processor<U>& processor, const BaseInstruction& instruction)
|
||||
{ T::andm(processor, instruction); }
|
||||
|
||||
Secret();
|
||||
Secret(const Integer& x) { *this = x; }
|
||||
|
||||
@@ -24,12 +24,15 @@ SemiPrep::SemiPrep(DataPositions& usage, bool) :
|
||||
void SemiPrep::set_protocol(Beaver<SemiSecret>& protocol)
|
||||
{
|
||||
if (triple_generator)
|
||||
{
|
||||
assert(&triple_generator->get_player() == &protocol.P);
|
||||
return;
|
||||
}
|
||||
|
||||
(void) protocol;
|
||||
params.set_passive();
|
||||
triple_generator = new SemiSecret::TripleGenerator(
|
||||
BaseMachine::s().fresh_ot_setup(),
|
||||
BaseMachine::fresh_ot_setup(protocol.P),
|
||||
protocol.P.N, -1, OnlineOptions::singleton.batch_size,
|
||||
1, params, {}, &protocol.P);
|
||||
triple_generator->multi_threaded = false;
|
||||
@@ -61,12 +64,4 @@ void SemiPrep::buffer_bits()
|
||||
}
|
||||
}
|
||||
|
||||
NamedCommStats SemiPrep::comm_stats()
|
||||
{
|
||||
if (triple_generator)
|
||||
return triple_generator->comm_stats();
|
||||
else
|
||||
return {};
|
||||
}
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
@@ -44,6 +44,8 @@ public:
|
||||
|
||||
array<SemiSecret, 3> get_triple_no_count(int n_bits)
|
||||
{
|
||||
if (n_bits == -1)
|
||||
n_bits = SemiSecret::default_length;
|
||||
return ShiftableTripleBuffer<SemiSecret>::get_triple_no_count(n_bits);
|
||||
}
|
||||
|
||||
@@ -51,8 +53,6 @@ public:
|
||||
{
|
||||
throw not_implemented();
|
||||
}
|
||||
|
||||
NamedCommStats comm_stats();
|
||||
};
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
@@ -78,6 +78,8 @@ public:
|
||||
template<class T>
|
||||
static void convcbit2s(Processor<T>& processor, const BaseInstruction& instruction)
|
||||
{ processor.convcbit2s(instruction); }
|
||||
static void andm(Processor<U>& processor, const BaseInstruction& instruction)
|
||||
{ processor.andm(instruction); }
|
||||
|
||||
static BitVec get_mask(int n) { return n >= 64 ? -1 : ((1L << n) - 1); }
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ void ShareSecret<U>::invert(int n, const U& x)
|
||||
{
|
||||
U ones;
|
||||
ones.load_clear(64, -1);
|
||||
static_cast<U&>(*this) = U(x ^ ones) & get_mask(n);
|
||||
reinterpret_cast<U&>(*this) = U(x + ones) & get_mask(n);
|
||||
}
|
||||
|
||||
template<class U>
|
||||
@@ -92,8 +92,12 @@ template<class U>
|
||||
void ShareSecret<U>::store_clear_in_dynamic(Memory<U>& mem,
|
||||
const vector<ClearWriteAccess>& accesses)
|
||||
{
|
||||
auto& thread = ShareThread<U>::s();
|
||||
assert(thread.P);
|
||||
assert(thread.MC);
|
||||
for (auto access : accesses)
|
||||
mem[access.address] = access.value;
|
||||
mem[access.address] = U::constant(access.value, thread.P->my_num(),
|
||||
thread.MC->get_alphai());
|
||||
}
|
||||
|
||||
template<class U>
|
||||
@@ -330,7 +334,7 @@ void ShareSecret<U>::random_bit()
|
||||
template<class U>
|
||||
U& GC::ShareSecret<U>::operator=(const U& other)
|
||||
{
|
||||
U& real_this = static_cast<U&>(*this);
|
||||
U& real_this = reinterpret_cast<U&>(*this);
|
||||
real_this = other;
|
||||
return real_this;
|
||||
}
|
||||
|
||||
@@ -58,9 +58,6 @@ public:
|
||||
|
||||
void pre_run();
|
||||
void post_run() { ShareThread<T>::post_run(); }
|
||||
|
||||
NamedCommStats comm_stats()
|
||||
{ return Thread<T>::comm_stats() + this->DataF.comm_stats(); }
|
||||
};
|
||||
|
||||
template<class T>
|
||||
|
||||
@@ -63,6 +63,7 @@ void ShareThread<T>::pre_run(Player& P, typename T::mac_key_type mac_key)
|
||||
protocol = new typename T::Protocol(*this->P);
|
||||
MC = this->new_mc(mac_key);
|
||||
DataF.set_protocol(*this->protocol);
|
||||
this->protocol->init(DataF, *MC);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -85,7 +86,7 @@ void ShareThread<T>::and_(Processor<T>& processor,
|
||||
{
|
||||
auto& protocol = this->protocol;
|
||||
processor.check_args(args, 4);
|
||||
protocol->init_mul(DataF, *this->MC);
|
||||
protocol->init_mul();
|
||||
T x_ext, y_ext;
|
||||
for (size_t i = 0; i < args.size(); i += 4)
|
||||
{
|
||||
|
||||
@@ -55,8 +55,6 @@ public:
|
||||
|
||||
void join_tape();
|
||||
void finish();
|
||||
|
||||
virtual NamedCommStats comm_stats();
|
||||
};
|
||||
|
||||
template<class T>
|
||||
|
||||
@@ -96,13 +96,6 @@ void Thread<T>::finish()
|
||||
pthread_join(thread, 0);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
NamedCommStats Thread<T>::comm_stats()
|
||||
{
|
||||
assert(P);
|
||||
return P->comm_stats;
|
||||
}
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
|
||||
|
||||
@@ -95,11 +95,11 @@ void ThreadMaster<T>::run()
|
||||
|
||||
post_run();
|
||||
|
||||
NamedCommStats stats = P->comm_stats;
|
||||
NamedCommStats stats = P->total_comm();
|
||||
ExecutionStats exe_stats;
|
||||
for (auto thread : threads)
|
||||
{
|
||||
stats += thread->P->comm_stats;
|
||||
stats += thread->P->total_comm();
|
||||
exe_stats += thread->processor.stats;
|
||||
delete thread;
|
||||
}
|
||||
|
||||
@@ -44,8 +44,6 @@ public:
|
||||
~TinierSharePrep();
|
||||
|
||||
void set_protocol(typename T::Protocol& protocol);
|
||||
|
||||
NamedCommStats comm_stats();
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
|
||||
#include "TinierSharePrep.h"
|
||||
|
||||
#include "PersonalPrep.hpp"
|
||||
#include "PersonalPrep.h"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
@@ -39,14 +39,17 @@ template<class T>
|
||||
void TinierSharePrep<T>::set_protocol(typename T::Protocol& protocol)
|
||||
{
|
||||
if (triple_generator)
|
||||
{
|
||||
assert(&triple_generator->get_player() == &protocol.P);
|
||||
return;
|
||||
}
|
||||
|
||||
params.generateMACs = true;
|
||||
params.amplify = false;
|
||||
params.check = false;
|
||||
auto& thread = ShareThread<typename T::whole_type>::s();
|
||||
triple_generator = new typename T::TripleGenerator(
|
||||
BaseMachine::s().fresh_ot_setup(), protocol.P.N, -1,
|
||||
BaseMachine::fresh_ot_setup(protocol.P), protocol.P.N, -1,
|
||||
OnlineOptions::singleton.batch_size, 1,
|
||||
params, thread.MC->get_alphai(), &protocol.P);
|
||||
triple_generator->multi_threaded = false;
|
||||
@@ -84,17 +87,6 @@ void GC::TinierSharePrep<T>::buffer_bits()
|
||||
BufferPrep<T>::get_random_from_inputs(thread.P->num_players()));
|
||||
}
|
||||
|
||||
template<class T>
|
||||
NamedCommStats TinierSharePrep<T>::comm_stats()
|
||||
{
|
||||
NamedCommStats res;
|
||||
if (triple_generator)
|
||||
res += triple_generator->comm_stats();
|
||||
if (real_triple_generator)
|
||||
res += real_triple_generator->comm_stats();
|
||||
return res;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -16,7 +16,7 @@ void TinierSharePrep<T>::init_real(Player& P)
|
||||
assert(real_triple_generator == 0);
|
||||
auto& thread = ShareThread<secret_type>::s();
|
||||
real_triple_generator = new typename T::whole_type::TripleGenerator(
|
||||
BaseMachine::s().fresh_ot_setup(), P.N, -1,
|
||||
BaseMachine::fresh_ot_setup(P), P.N, -1,
|
||||
OnlineOptions::singleton.batch_size, 1, params,
|
||||
thread.MC->get_alphai(), &P);
|
||||
real_triple_generator->multi_threaded = false;
|
||||
|
||||
@@ -36,6 +36,8 @@ public:
|
||||
|
||||
void add_mine(const typename T::open_type& input, int n_bits)
|
||||
{
|
||||
if (n_bits == -1)
|
||||
n_bits = T::default_length;
|
||||
for (int i = 0; i < n_bits; i++)
|
||||
part_input.add_mine(input.get_bit(i));
|
||||
input_lengths.push_back(n_bits);
|
||||
@@ -43,6 +45,8 @@ public:
|
||||
|
||||
void add_other(int player, int n_bits)
|
||||
{
|
||||
if (n_bits == -1)
|
||||
n_bits = T::default_length;
|
||||
for (int i = 0; i < n_bits; i++)
|
||||
part_input.add_other(player);
|
||||
}
|
||||
@@ -69,6 +73,8 @@ public:
|
||||
|
||||
void finalize_other(int player, T& target, octetStream&, int n_bits)
|
||||
{
|
||||
if (n_bits == -1)
|
||||
n_bits = T::default_length;
|
||||
target.resize_regs(n_bits);
|
||||
for (int i = 0; i < n_bits; i++)
|
||||
part_input.finalize_other(player, target.get_reg(i),
|
||||
|
||||
@@ -21,9 +21,10 @@ public:
|
||||
|
||||
VectorProtocol(Player& P);
|
||||
|
||||
void init_mul(SubProcessor<T>* proc);
|
||||
void init_mul(Preprocessing<T>& prep, typename T::MAC_Check& MC);
|
||||
typename T::clear prepare_mul(const T& x, const T& y, int n = -1);
|
||||
void init(Preprocessing<T>& prep, typename T::MAC_Check& MC);
|
||||
|
||||
void init_mul();
|
||||
void prepare_mul(const T& x, const T& y, int n = -1);
|
||||
void exchange();
|
||||
void finalize_mult(T& res, int n = -1);
|
||||
T finalize_mul(int n = -1);
|
||||
|
||||
@@ -18,26 +18,26 @@ VectorProtocol<T>::VectorProtocol(Player& P) :
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void VectorProtocol<T>::init_mul(SubProcessor<T>* proc)
|
||||
{
|
||||
assert(proc);
|
||||
init_mul(proc->DataF, proc->MC);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void VectorProtocol<T>::init_mul(Preprocessing<T>& prep,
|
||||
void VectorProtocol<T>::init(Preprocessing<T>& prep,
|
||||
typename T::MAC_Check& MC)
|
||||
{
|
||||
part_protocol.init_mul(prep.get_part(), MC.get_part_MC());
|
||||
part_protocol.init(prep.get_part(), MC.get_part_MC());
|
||||
}
|
||||
|
||||
template<class T>
|
||||
typename T::clear VectorProtocol<T>::prepare_mul(const T& x,
|
||||
void VectorProtocol<T>::init_mul()
|
||||
{
|
||||
part_protocol.init_mul();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void VectorProtocol<T>::prepare_mul(const T& x,
|
||||
const T& y, int n)
|
||||
{
|
||||
if (n == -1)
|
||||
n = T::default_length;
|
||||
for (int i = 0; i < n; i++)
|
||||
part_protocol.prepare_mul(x.get_reg(i), y.get_reg(i), 1);
|
||||
return {};
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -57,6 +57,8 @@ T VectorProtocol<T>::finalize_mul(int n)
|
||||
template<class T>
|
||||
void VectorProtocol<T>::finalize_mult(T& res, int n)
|
||||
{
|
||||
if (n == -1)
|
||||
n = T::default_length;
|
||||
res.resize_regs(n);
|
||||
for (int i = 0; i < n; i++)
|
||||
res.get_reg(i) = part_protocol.finalize_mul(1);
|
||||
|
||||
@@ -46,6 +46,7 @@
|
||||
X(NOTCB, processor.notcb(INST)) \
|
||||
X(ANDRS, T::andrs(PROC, EXTRA)) \
|
||||
X(ANDS, T::ands(PROC, EXTRA)) \
|
||||
X(ANDM, T::andm(PROC, instruction)) \
|
||||
X(ADDCB, C0 = PC1 + PC2) \
|
||||
X(ADDCBI, C0 = PC1 + int(IMM)) \
|
||||
X(MULCBI, C0 = PC1 * int(IMM)) \
|
||||
@@ -76,7 +77,6 @@
|
||||
#define COMBI_INSTRUCTIONS BIT_INSTRUCTIONS \
|
||||
X(INPUTB, T::inputb(PROC, Proc, EXTRA)) \
|
||||
X(INPUTBVEC, T::inputbvec(PROC, Proc, EXTRA)) \
|
||||
X(ANDM, processor.andm(instruction)) \
|
||||
X(CONVSINT, S0.load_clear(IMM, Proc.read_Ci(REG1))) \
|
||||
X(CONVCINT, C0 = Proc.read_Ci(REG1)) \
|
||||
X(CONVCBIT, Proc.write_Ci(R0, PC1.get())) \
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
CSIRO Open Source Software Licence Agreement (variation of the BSD / MIT License)
|
||||
Copyright (c) 2021, Commonwealth Scientific and Industrial Research Organisation (CSIRO) ABN 41 687 119 230.
|
||||
Copyright (c) 2022, Commonwealth Scientific and Industrial Research Organisation (CSIRO) ABN 41 687 119 230.
|
||||
All rights reserved. CSIRO is willing to grant you a licence to this MP-SPDZ sofware on the following terms, except where otherwise indicated for third party material.
|
||||
Redistribution and use of this software in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
||||
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
||||
|
||||
16
Machines/Atlas.hpp
Normal file
16
Machines/Atlas.hpp
Normal file
@@ -0,0 +1,16 @@
|
||||
/*
|
||||
* Atlas.hpp
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef MACHINES_ATLAS_HPP_
|
||||
#define MACHINES_ATLAS_HPP_
|
||||
|
||||
#include "Protocols/AtlasShare.h"
|
||||
#include "Protocols/AtlasPrep.h"
|
||||
#include "GC/AtlasSecret.h"
|
||||
|
||||
#include "ShamirMachine.hpp"
|
||||
#include "Protocols/Atlas.hpp"
|
||||
|
||||
#endif /* MACHINES_ATLAS_HPP_ */
|
||||
@@ -4,6 +4,7 @@
|
||||
*/
|
||||
|
||||
#include "Protocols/MalRepRingPrep.h"
|
||||
#include "Protocols/ReplicatedPrep2k.h"
|
||||
|
||||
#include "Processor/Data_Files.hpp"
|
||||
#include "Processor/Instruction.hpp"
|
||||
|
||||
17
Machines/Rep4.hpp
Normal file
17
Machines/Rep4.hpp
Normal file
@@ -0,0 +1,17 @@
|
||||
/*
|
||||
* Rep4.hpp
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef MACHINES_REP4_HPP_
|
||||
#define MACHINES_REP4_HPP_
|
||||
|
||||
#include "GC/Rep4Secret.h"
|
||||
#include "Protocols/Rep4Share2k.h"
|
||||
#include "Protocols/Rep4Prep.h"
|
||||
#include "Protocols/Rep4.hpp"
|
||||
#include "Protocols/Rep4MC.hpp"
|
||||
#include "Protocols/Rep4Input.hpp"
|
||||
#include "Protocols/Rep4Prep.hpp"
|
||||
|
||||
#endif /* MACHINES_REP4_HPP_ */
|
||||
@@ -21,13 +21,15 @@
|
||||
#include "GC/TinierSecret.h"
|
||||
#include "GC/TinyMC.h"
|
||||
#include "GC/VectorInput.h"
|
||||
#include "GC/VectorProtocol.h"
|
||||
|
||||
#include "GC/ShareParty.hpp"
|
||||
#include "GC/ShareParty.h"
|
||||
#include "GC/Secret.hpp"
|
||||
#include "GC/TinyPrep.hpp"
|
||||
#include "GC/ShareSecret.hpp"
|
||||
#include "GC/TinierSharePrep.hpp"
|
||||
#include "GC/CcdPrep.hpp"
|
||||
#include "GC/ShareSecret.h"
|
||||
#include "GC/TinierSharePrep.h"
|
||||
#include "GC/CcdPrep.h"
|
||||
|
||||
#include "GC/VectorProtocol.hpp"
|
||||
|
||||
#include "Math/gfp.hpp"
|
||||
|
||||
|
||||
@@ -23,9 +23,10 @@
|
||||
#include "Protocols/MascotPrep.hpp"
|
||||
#include "Protocols/Spdz2kPrep.hpp"
|
||||
|
||||
#include "GC/ShareParty.hpp"
|
||||
#include "GC/ShareSecret.hpp"
|
||||
#include "GC/ShareParty.h"
|
||||
#include "GC/ShareSecret.h"
|
||||
#include "GC/Secret.hpp"
|
||||
#include "GC/TinyPrep.hpp"
|
||||
#include "GC/TinierSharePrep.hpp"
|
||||
#include "GC/CcdPrep.hpp"
|
||||
#include "GC/TinierSharePrep.h"
|
||||
#include "GC/CcdPrep.h"
|
||||
|
||||
#include "GC/VectorProtocol.hpp"
|
||||
|
||||
@@ -18,3 +18,4 @@
|
||||
#include "Protocols/MAC_Check.hpp"
|
||||
#include "Protocols/SemiMC.hpp"
|
||||
#include "Protocols/Beaver.hpp"
|
||||
#include "Protocols/MalRepRingPrep.hpp"
|
||||
|
||||
15
Machines/Semi2k.hpp
Normal file
15
Machines/Semi2k.hpp
Normal file
@@ -0,0 +1,15 @@
|
||||
/*
|
||||
* Semi2.hpp
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef MACHINES_SEMI2K_HPP_
|
||||
#define MACHINES_SEMI2K_HPP_
|
||||
|
||||
#include "Protocols/Semi2kShare.h"
|
||||
#include "Protocols/SemiPrep2k.h"
|
||||
|
||||
#include "Semi.hpp"
|
||||
#include "Protocols/RepRingOnlyEdabitPrep.hpp"
|
||||
|
||||
#endif /* MACHINES_SEMI2K_HPP_ */
|
||||
@@ -27,6 +27,7 @@
|
||||
#include "Protocols/Beaver.hpp"
|
||||
#include "Protocols/Spdz2kPrep.hpp"
|
||||
#include "Protocols/ReplicatedPrep.hpp"
|
||||
#include "Protocols/MalRepRingPrep.hpp"
|
||||
#include "GC/ShareSecret.hpp"
|
||||
#include "GC/VectorProtocol.hpp"
|
||||
#include "GC/Secret.hpp"
|
||||
|
||||
23
Machines/Tinier.cpp
Normal file
23
Machines/Tinier.cpp
Normal file
@@ -0,0 +1,23 @@
|
||||
/*
|
||||
* Tinier.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "GC/TinyMC.h"
|
||||
#include "GC/TinierSecret.h"
|
||||
#include "GC/VectorInput.h"
|
||||
|
||||
#include "GC/ShareParty.hpp"
|
||||
#include "GC/Secret.hpp"
|
||||
#include "GC/TinyPrep.hpp"
|
||||
#include "GC/ShareSecret.hpp"
|
||||
#include "GC/TinierSharePrep.hpp"
|
||||
#include "GC/CcdPrep.hpp"
|
||||
#include "GC/PersonalPrep.hpp"
|
||||
|
||||
//template class GC::ShareParty<GC::TinierSecret<gf2n_short>>;
|
||||
template class GC::CcdPrep<GC::TinierSecret<gf2n_short>>;
|
||||
template class Preprocessing<GC::TinierSecret<gf2n_short>>;
|
||||
template class GC::TinierSharePrep<GC::TinierShare<gf2n_short>>;
|
||||
template class GC::ShareSecret<GC::TinierSecret<gf2n_short>>;
|
||||
template class TripleShuffleSacrifice<GC::TinierSecret<gf2n_short>>;
|
||||
@@ -3,12 +3,7 @@
|
||||
*
|
||||
*/
|
||||
|
||||
#include "Protocols/AtlasShare.h"
|
||||
#include "Protocols/AtlasPrep.h"
|
||||
#include "GC/AtlasSecret.h"
|
||||
|
||||
#include "ShamirMachine.hpp"
|
||||
#include "Protocols/Atlas.hpp"
|
||||
#include "Atlas.hpp"
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
|
||||
@@ -10,11 +10,13 @@
|
||||
#include "Processor/RingOptions.h"
|
||||
|
||||
#include "Processor/Machine.hpp"
|
||||
#include "Processor/OnlineOptions.hpp"
|
||||
#include "Math/Z2k.hpp"
|
||||
#include "Protocols/Replicated.hpp"
|
||||
#include "Protocols/ShuffleSacrifice.hpp"
|
||||
#include "Protocols/ReplicatedPrep.hpp"
|
||||
#include "Protocols/FakeShare.hpp"
|
||||
#include "Protocols/MalRepRingPrep.hpp"
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
@@ -22,7 +24,7 @@ int main(int argc, const char** argv)
|
||||
Names N;
|
||||
ez::ezOptionParser opt;
|
||||
RingOptions ring_opts(opt, argc, argv);
|
||||
online_opts = {opt, argc, argv};
|
||||
online_opts = {opt, argc, argv, FakeShare<SignedZ2<64>>()};
|
||||
opt.parse(argc, argv);
|
||||
opt.syntax = string(argv[0]) + " <progname>";
|
||||
|
||||
@@ -44,9 +46,7 @@ int main(int argc, const char** argv)
|
||||
#ifdef ROUND_NEAREST_IN_EMULATION
|
||||
cerr << "Using nearest rounding instead of probabilistic truncation" << endl;
|
||||
#else
|
||||
#ifdef RISKY_TRUNCATION_IN_EMULATION
|
||||
cerr << "Using risky truncation" << endl;
|
||||
#endif
|
||||
online_opts.set_trunc_error(opt);
|
||||
#endif
|
||||
|
||||
int R = ring_opts.ring_size_from_opts_or_schedule(progname);
|
||||
|
||||
@@ -24,6 +24,7 @@
|
||||
#include "Protocols/SemiMC.hpp"
|
||||
#include "Protocols/Beaver.hpp"
|
||||
#include "Protocols/Hemi.hpp"
|
||||
#include "Protocols/MalRepRingPrep.hpp"
|
||||
#include "GC/ShareSecret.hpp"
|
||||
#include "GC/SemiHonestRepPrep.h"
|
||||
#include "Math/gfp.hpp"
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "Processor/OnlineMachine.hpp"
|
||||
#include "Processor/Machine.hpp"
|
||||
#include "Protocols/Replicated.hpp"
|
||||
#include "Protocols/MalRepRingPrep.hpp"
|
||||
#include "Math/gfp.hpp"
|
||||
#include "Math/Z2k.hpp"
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
#include "Protocols/MAC_Check.hpp"
|
||||
#include "Protocols/SemiMC.hpp"
|
||||
#include "Protocols/Beaver.hpp"
|
||||
#include "Protocols/MalRepRingPrep.hpp"
|
||||
#include "GC/ShareSecret.hpp"
|
||||
#include "GC/SemiHonestRepPrep.h"
|
||||
#include "Math/gfp.hpp"
|
||||
|
||||
28
Makefile
28
Makefile
@@ -25,6 +25,8 @@ MINI_OT = OT/OTTripleSetup.o OT/BaseOT.o $(LIBSIMPLEOT)
|
||||
VMOBJS = $(PROCESSOR) $(COMMONOBJS) GC/square64.o GC/Instruction.o OT/OTTripleSetup.o OT/BaseOT.o $(LIBSIMPLEOT)
|
||||
VM = $(MINI_OT) $(SHAREDLIB)
|
||||
COMMON = $(SHAREDLIB)
|
||||
TINIER = Machines/Tinier.o $(OT)
|
||||
SPDZ = Machines/SPDZ.o $(TINIER)
|
||||
|
||||
|
||||
LIB = libSPDZ.a
|
||||
@@ -117,7 +119,7 @@ sy: sy-rep-field-party.x sy-rep-ring-party.x sy-shamir-party.x
|
||||
ecdsa: $(patsubst ECDSA/%.cpp,%.x,$(wildcard ECDSA/*-ecdsa-party.cpp)) Fake-ECDSA.x
|
||||
ecdsa-static: static-dir $(patsubst ECDSA/%.cpp,static/%.x,$(wildcard ECDSA/*-ecdsa-party.cpp))
|
||||
|
||||
$(LIBRELEASE): Protocols/MalRepRingOptions.o $(PROCESSOR) $(COMMONOBJS) $(OT) $(GC)
|
||||
$(LIBRELEASE): Protocols/MalRepRingOptions.o $(PROCESSOR) $(COMMONOBJS) $(TINIER) $(GC)
|
||||
$(AR) -csr $@ $^
|
||||
|
||||
CFLAGS += -fPIC
|
||||
@@ -203,16 +205,16 @@ ps-rep-bin-party.x: GC/PostSacriBin.o
|
||||
semi-bin-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o
|
||||
tiny-party.x: $(OT)
|
||||
tinier-party.x: $(OT)
|
||||
spdz2k-party.x: $(OT) $(patsubst %.cpp,%.o,$(wildcard Machines/SPDZ2*.cpp))
|
||||
spdz2k-party.x: $(TINIER) $(patsubst %.cpp,%.o,$(wildcard Machines/SPDZ2*.cpp))
|
||||
static/spdz2k-party.x: $(patsubst %.cpp,%.o,$(wildcard Machines/SPDZ2*.cpp))
|
||||
semi-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o
|
||||
semi2k-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o
|
||||
hemi-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT)
|
||||
soho-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT)
|
||||
cowgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(OT)
|
||||
chaigear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(OT)
|
||||
lowgear-party.x: $(FHEOFFLINE) $(OT) Protocols/CowGearOptions.o Protocols/LowGearKeyGen.o
|
||||
highgear-party.x: $(FHEOFFLINE) $(OT) Protocols/CowGearOptions.o Protocols/HighGearKeyGen.o
|
||||
cowgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(TINIER)
|
||||
chaigear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(TINIER)
|
||||
lowgear-party.x: $(FHEOFFLINE) $(TINIER) Protocols/CowGearOptions.o Protocols/LowGearKeyGen.o
|
||||
highgear-party.x: $(FHEOFFLINE) $(TINIER) Protocols/CowGearOptions.o Protocols/HighGearKeyGen.o
|
||||
atlas-party.x: GC/AtlasSecret.o
|
||||
static/hemi-party.x: $(FHEOBJS)
|
||||
static/soho-party.x: $(FHEOBJS)
|
||||
@@ -220,10 +222,10 @@ static/cowgear-party.x: $(FHEOBJS)
|
||||
static/chaigear-party.x: $(FHEOBJS)
|
||||
static/lowgear-party.x: $(FHEOBJS) Protocols/CowGearOptions.o Protocols/LowGearKeyGen.o
|
||||
static/highgear-party.x: $(FHEOBJS) Protocols/CowGearOptions.o Protocols/HighGearKeyGen.o
|
||||
mascot-party.x: Machines/SPDZ.o $(OT)
|
||||
static/mascot-party.x: Machines/SPDZ.o
|
||||
Player-Online.x: Machines/SPDZ.o $(OT)
|
||||
mama-party.x: $(OT)
|
||||
mascot-party.x: $(SPDZ)
|
||||
static/mascot-party.x: $(SPDZ)
|
||||
Player-Online.x: $(SPDZ)
|
||||
mama-party.x: $(TINIER)
|
||||
ps-rep-ring-party.x: Protocols/MalRepRingOptions.o
|
||||
malicious-rep-ring-party.x: Protocols/MalRepRingOptions.o
|
||||
sy-rep-ring-party.x: Protocols/MalRepRingOptions.o
|
||||
@@ -236,8 +238,10 @@ emulate.x: GC/FakeSecret.o
|
||||
semi-bmr-party.x: GC/SemiPrep.o GC/SemiSecret.o $(OT)
|
||||
real-bmr-party.x: $(OT)
|
||||
paper-example.x: $(VM) $(OT) $(FHEOFFLINE)
|
||||
mascot-offline.x: $(VM) $(OT)
|
||||
cowgear-offline.x: $(OT) $(FHEOFFLINE)
|
||||
binary-example.x: $(VM) $(OT) GC/PostSacriBin.o GC/SemiPrep.o GC/SemiSecret.o GC/AtlasSecret.o
|
||||
mixed-example.x: $(VM) $(OT) GC/PostSacriBin.o GC/SemiPrep.o GC/SemiSecret.o GC/AtlasSecret.o Machines/Tinier.o
|
||||
mascot-offline.x: $(VM) $(TINIER)
|
||||
cowgear-offline.x: $(TINIER) $(FHEOFFLINE)
|
||||
static/rep-bmr-party.x: $(BMR)
|
||||
static/mal-rep-bmr-party.x: $(BMR)
|
||||
static/shamir-bmr-party.x: $(BMR)
|
||||
|
||||
@@ -26,6 +26,7 @@ public:
|
||||
|
||||
static const false_type invertible;
|
||||
static const true_type characteristic_two;
|
||||
static const true_type binary;
|
||||
|
||||
static char type_char() { return 'B'; }
|
||||
static string type_short() { return "B"; }
|
||||
@@ -64,8 +65,21 @@ public:
|
||||
void pack(octetStream& os) const { os.store_int<sizeof(T)>(this->a); }
|
||||
void unpack(octetStream& os) { this->a = os.get_int<sizeof(T)>(); }
|
||||
|
||||
void pack(octetStream& os, int n) const { os.store_int(super::mask(n).get(), DIV_CEIL(n, 8)); }
|
||||
void unpack(octetStream& os, int n) { this->a = os.get_int(DIV_CEIL(n, 8)); }
|
||||
void pack(octetStream& os, int n) const
|
||||
{
|
||||
if (n == -1)
|
||||
pack(os);
|
||||
else
|
||||
os.store_int(super::mask(n).get(), DIV_CEIL(n, 8));
|
||||
}
|
||||
|
||||
void unpack(octetStream& os, int n)
|
||||
{
|
||||
if (n == -1)
|
||||
unpack(os);
|
||||
else
|
||||
this->a = os.get_int(DIV_CEIL(n, 8));
|
||||
}
|
||||
|
||||
static BitVec_ unpack_new(octetStream& os, int n = n_bits)
|
||||
{
|
||||
@@ -81,5 +95,7 @@ template<class T>
|
||||
const false_type BitVec_<T>::invertible;
|
||||
template<class T>
|
||||
const true_type BitVec_<T>::characteristic_two;
|
||||
template<class T>
|
||||
const true_type BitVec_<T>::binary;
|
||||
|
||||
#endif /* MATH_BITVEC_H_ */
|
||||
|
||||
@@ -36,8 +36,9 @@ void read_setup(const string& dir_prefix, int lgp = -1)
|
||||
{
|
||||
if (lgp > 0)
|
||||
{
|
||||
cerr << "No modulus found in " << filename << ", generating " << lgp
|
||||
<< "-bit prime" << endl;
|
||||
if (OnlineOptions::singleton.verbose)
|
||||
cerr << "No modulus found in " << filename << ", generating "
|
||||
<< lgp << "-bit prime" << endl;
|
||||
T::init_default(lgp);
|
||||
}
|
||||
else
|
||||
|
||||
@@ -20,6 +20,7 @@ public:
|
||||
static const false_type characteristic_two;
|
||||
static const false_type prime_field;
|
||||
static const false_type invertible;
|
||||
static const false_type binary;
|
||||
|
||||
template<class T>
|
||||
static void init(bool mont = true) { (void) mont; }
|
||||
|
||||
@@ -47,6 +47,7 @@ public:
|
||||
static int size_in_limbs() { return N_WORDS; }
|
||||
static int size_in_bits() { return size() * 8; }
|
||||
static int length() { return size_in_bits(); }
|
||||
static int n_bits() { return N_BITS; }
|
||||
static int t() { return 0; }
|
||||
|
||||
static char type_char() { return 'R'; }
|
||||
@@ -100,6 +101,8 @@ public:
|
||||
|
||||
int bit_length() const;
|
||||
|
||||
Z2 mask(int) const { return *this; }
|
||||
|
||||
Z2<K> operator+(const Z2<K>& other) const;
|
||||
Z2<K> operator-(const Z2<K>& other) const;
|
||||
|
||||
|
||||
@@ -86,6 +86,42 @@ void Zp_Data::Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y,int t
|
||||
{ inline_mpn_copyi(z,ans+t,t); }
|
||||
}
|
||||
|
||||
void Zp_Data::Mont_Mult_switch(mp_limb_t* z, const mp_limb_t* x,
|
||||
const mp_limb_t* y) const
|
||||
{
|
||||
switch (t)
|
||||
{
|
||||
#ifdef __BMI2__
|
||||
#define CASE(N) \
|
||||
case N: \
|
||||
Mont_Mult_<N>(z, x, y); \
|
||||
break;
|
||||
CASE(1)
|
||||
CASE(2)
|
||||
#if MAX_MOD_SZ >= 4
|
||||
CASE(3)
|
||||
CASE(4)
|
||||
#endif
|
||||
#if MAX_MOD_SZ >= 5
|
||||
CASE(5)
|
||||
#endif
|
||||
#if MAX_MOD_SZ >= 6
|
||||
CASE(6)
|
||||
#endif
|
||||
#if MAX_MOD_SZ >= 10
|
||||
CASE(7)
|
||||
CASE(8)
|
||||
CASE(9)
|
||||
CASE(10)
|
||||
#endif
|
||||
#undef CASE
|
||||
#endif
|
||||
default:
|
||||
Mont_Mult_variable(z, x, y);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
ostream& operator<<(ostream& s,const Zp_Data& ZpD)
|
||||
|
||||
@@ -40,6 +40,7 @@ class Zp_Data
|
||||
template <int T>
|
||||
void Mont_Mult_(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const;
|
||||
void Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const;
|
||||
void Mont_Mult_switch(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const;
|
||||
void Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y, int t) const;
|
||||
void Mont_Mult_variable(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const
|
||||
{ Mont_Mult(z, x, y, t); }
|
||||
@@ -242,37 +243,11 @@ inline void Zp_Data::Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t*
|
||||
{
|
||||
if (not cpu_has_bmi2())
|
||||
return Mont_Mult_variable(z, x, y);
|
||||
switch (t)
|
||||
{
|
||||
#ifdef __BMI2__
|
||||
#define CASE(N) \
|
||||
case N: \
|
||||
Mont_Mult_<N>(z, x, y); \
|
||||
break;
|
||||
CASE(1)
|
||||
CASE(2)
|
||||
#if MAX_MOD_SZ >= 4
|
||||
CASE(3)
|
||||
CASE(4)
|
||||
return Mont_Mult_switch(z, x, y);
|
||||
#else
|
||||
return Mont_Mult_variable(z, x, y);
|
||||
#endif
|
||||
#if MAX_MOD_SZ >= 5
|
||||
CASE(5)
|
||||
#endif
|
||||
#if MAX_MOD_SZ >= 6
|
||||
CASE(6)
|
||||
#endif
|
||||
#if MAX_MOD_SZ >= 10
|
||||
CASE(7)
|
||||
CASE(8)
|
||||
CASE(9)
|
||||
CASE(10)
|
||||
#endif
|
||||
#undef CASE
|
||||
#endif
|
||||
default:
|
||||
Mont_Mult_variable(z, x, y);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
inline void Zp_Data::Mont_Mult_max(mp_limb_t* z, const mp_limb_t* x,
|
||||
|
||||
@@ -11,7 +11,6 @@ using namespace std;
|
||||
#include "Math/Bit.h"
|
||||
#include "Math/Setup.h"
|
||||
#include "Tools/random.h"
|
||||
#include "GC/NoShare.h"
|
||||
#include "Processor/OnlineOptions.h"
|
||||
|
||||
#include "Math/modp.hpp"
|
||||
@@ -101,6 +100,7 @@ class gfp_ : public ValueInterface
|
||||
static int size() { return t() * sizeof(mp_limb_t); }
|
||||
static int size_in_bits() { return 8 * size(); }
|
||||
static int length() { return ZpD.pr_bit_length; }
|
||||
static int n_bits() { return length() - 1; }
|
||||
|
||||
static void reqbl(int n);
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "CryptoPlayer.h"
|
||||
#include "Math/Setup.h"
|
||||
#include "Tools/Bundle.h"
|
||||
|
||||
void check_ssl_file(string filename)
|
||||
{
|
||||
@@ -124,12 +125,14 @@ CryptoPlayer::~CryptoPlayer()
|
||||
|
||||
void CryptoPlayer::send_to_no_stats(int other, const octetStream& o) const
|
||||
{
|
||||
assert(other != my_num());
|
||||
senders[other]->request(o);
|
||||
senders[other]->wait(o);
|
||||
}
|
||||
|
||||
void CryptoPlayer::receive_player_no_stats(int other, octetStream& o) const
|
||||
{
|
||||
assert(other != my_num());
|
||||
receivers[other]->request(o);
|
||||
receivers[other]->wait(o);
|
||||
}
|
||||
@@ -137,6 +140,7 @@ void CryptoPlayer::receive_player_no_stats(int other, octetStream& o) const
|
||||
void CryptoPlayer::exchange_no_stats(int other, const octetStream& to_send,
|
||||
octetStream& to_receive) const
|
||||
{
|
||||
assert(other != my_num());
|
||||
if (&to_send == &to_receive)
|
||||
{
|
||||
MultiPlayer<ssl_socket*>::exchange_no_stats(other, to_send, to_receive);
|
||||
@@ -153,6 +157,7 @@ void CryptoPlayer::exchange_no_stats(int other, const octetStream& to_send,
|
||||
void CryptoPlayer::pass_around_no_stats(const octetStream& to_send,
|
||||
octetStream& to_receive, int offset) const
|
||||
{
|
||||
assert(get_player(offset) != my_num());
|
||||
if (&to_send == &to_receive)
|
||||
{
|
||||
MultiPlayer<ssl_socket*>::pass_around_no_stats(to_send, to_receive, offset);
|
||||
|
||||
@@ -14,12 +14,14 @@
|
||||
|
||||
using namespace std;
|
||||
|
||||
void Names::init(int player,int pnb,int my_port,const char* servername)
|
||||
void Names::init(int player, int pnb, int my_port, const char* servername,
|
||||
bool setup_socket)
|
||||
{
|
||||
player_no=player;
|
||||
portnum_base=pnb;
|
||||
setup_names(servername, my_port);
|
||||
setup_server();
|
||||
if (setup_socket)
|
||||
setup_server();
|
||||
}
|
||||
|
||||
Names::Names(int player, int nplayers, const string& servername, int pnb,
|
||||
@@ -124,7 +126,7 @@ void Names::setup_names(const char *servername, int my_port)
|
||||
my_port = default_port(player_no);
|
||||
|
||||
int socket_num;
|
||||
int pn = portnum_base - 1;
|
||||
int pn = portnum_base;
|
||||
set_up_client_socket(socket_num, servername, pn);
|
||||
octetStream("P" + to_string(player_no)).Send(socket_num);
|
||||
#ifdef DEBUG_NETWORKING
|
||||
@@ -132,15 +134,11 @@ void Names::setup_names(const char *servername, int my_port)
|
||||
#endif
|
||||
|
||||
// Send my name
|
||||
octet my_name[512];
|
||||
memset(my_name,0,512*sizeof(octet));
|
||||
sockaddr_in address;
|
||||
socklen_t size = sizeof address;
|
||||
getsockname(socket_num, (sockaddr*)&address, &size);
|
||||
char* name = inet_ntoa(address.sin_addr);
|
||||
// max length of IP address with ending 0
|
||||
strncpy((char*)my_name, name, 16);
|
||||
send(socket_num,my_name,512);
|
||||
char* my_name = inet_ntoa(address.sin_addr);
|
||||
octetStream(my_name).Send(socket_num);
|
||||
send(socket_num,(octet*)&my_port,4);
|
||||
#ifdef DEBUG_NETWORKING
|
||||
fprintf(stderr, "My Name = %s\n",my_name);
|
||||
@@ -158,9 +156,10 @@ void Names::setup_names(const char *servername, int my_port)
|
||||
names.resize(nplayers);
|
||||
ports.resize(nplayers);
|
||||
for (i=0; i<nplayers; i++)
|
||||
{ octet tmp[512];
|
||||
receive(socket_num,tmp,512);
|
||||
names[i]=(char*)tmp;
|
||||
{
|
||||
octetStream os;
|
||||
os.Receive(socket_num);
|
||||
names[i] = os.str();
|
||||
receive(socket_num, (octet*)&ports[i], 4);
|
||||
#ifdef VERBOSE
|
||||
cerr << "Player " << i << " is running on machine " << names[i] << endl;
|
||||
@@ -176,6 +175,12 @@ void Names::setup_server()
|
||||
server->init();
|
||||
}
|
||||
|
||||
void Names::set_server(ServerSocket* socket)
|
||||
{
|
||||
assert(not server);
|
||||
server = socket;
|
||||
}
|
||||
|
||||
|
||||
Names::Names(const Names& other)
|
||||
{
|
||||
@@ -201,6 +206,7 @@ Player::Player(const Names& Nms) :
|
||||
{
|
||||
nplayers=Nms.nplayers;
|
||||
player_no=Nms.player_no;
|
||||
thread_stats.resize(nplayers);
|
||||
}
|
||||
|
||||
|
||||
@@ -243,6 +249,10 @@ MultiPlayer<T>::~MultiPlayer()
|
||||
|
||||
Player::~Player()
|
||||
{
|
||||
#ifdef VERBOSE
|
||||
for (auto& x : thread_stats)
|
||||
x.print();
|
||||
#endif
|
||||
}
|
||||
|
||||
PlayerBase::~PlayerBase()
|
||||
@@ -685,7 +695,7 @@ void VirtualTwoPartyPlayer::send(octetStream& o) const
|
||||
{
|
||||
TimeScope ts(comm_stats["Sending one-to-one"].add(o));
|
||||
P.send_to_no_stats(other_player, o);
|
||||
sent += o.get_length();
|
||||
comm_stats.sent += o.get_length();
|
||||
}
|
||||
|
||||
void RealTwoPartyPlayer::receive(octetStream& o) const
|
||||
@@ -729,12 +739,13 @@ void RealTwoPartyPlayer::exchange(octetStream& o) const
|
||||
void VirtualTwoPartyPlayer::send_receive_player(vector<octetStream>& o) const
|
||||
{
|
||||
TimeScope ts(comm_stats["Exchanging one-to-one"].add(o[0]));
|
||||
sent += o[0].get_length();
|
||||
comm_stats.sent += o[0].get_length();
|
||||
P.exchange_no_stats(other_player, o[0], o[1]);
|
||||
}
|
||||
|
||||
VirtualTwoPartyPlayer::VirtualTwoPartyPlayer(Player& P, int other_player) :
|
||||
TwoPartyPlayer(P.my_num()), P(P), other_player(other_player)
|
||||
TwoPartyPlayer(P.my_num()), P(P), other_player(other_player), comm_stats(
|
||||
P.thread_stats.at(other_player))
|
||||
{
|
||||
}
|
||||
|
||||
@@ -814,5 +825,13 @@ void NamedCommStats::print(bool newline)
|
||||
cerr << endl;
|
||||
}
|
||||
|
||||
NamedCommStats Player::total_comm() const
|
||||
{
|
||||
auto res = comm_stats;
|
||||
for (auto& x : thread_stats)
|
||||
res += x;
|
||||
return res;
|
||||
}
|
||||
|
||||
template class MultiPlayer<int>;
|
||||
template class MultiPlayer<ssl_socket*> ;
|
||||
|
||||
@@ -35,6 +35,7 @@ class Names
|
||||
friend class Player;
|
||||
friend class PlainPlayer;
|
||||
friend class RealTwoPartyPlayer;
|
||||
friend class Server;
|
||||
|
||||
vector<string> names;
|
||||
vector<int> ports;
|
||||
@@ -51,6 +52,8 @@ class Names
|
||||
|
||||
void setup_server();
|
||||
|
||||
void set_server(ServerSocket* socket);
|
||||
|
||||
public:
|
||||
|
||||
static const int DEFAULT_PORT = -1;
|
||||
@@ -62,8 +65,10 @@ class Names
|
||||
* @param my_port my port number (`DEFAULT_PORT` for default,
|
||||
* which is base port number plus player number)
|
||||
* @param servername location of server
|
||||
* @param setup_socket whether to start listening
|
||||
*/
|
||||
void init(int player,int pnb,int my_port,const char* servername);
|
||||
void init(int player, int pnb, int my_port, const char* servername,
|
||||
bool setup_socket = true);
|
||||
Names(int player,int pnb,int my_port,const char* servername) : Names()
|
||||
{ init(player,pnb,my_port,servername); }
|
||||
|
||||
@@ -172,11 +177,12 @@ class PlayerBase
|
||||
protected:
|
||||
int player_no;
|
||||
|
||||
public:
|
||||
size_t& sent;
|
||||
mutable Timer timer;
|
||||
mutable NamedCommStats comm_stats;
|
||||
|
||||
public:
|
||||
mutable Timer timer;
|
||||
|
||||
PlayerBase(int player_no) : player_no(player_no), sent(comm_stats.sent) {}
|
||||
virtual ~PlayerBase();
|
||||
|
||||
@@ -205,6 +211,8 @@ protected:
|
||||
public:
|
||||
const Names& N;
|
||||
|
||||
mutable vector<NamedCommStats> thread_stats;
|
||||
|
||||
Player(const Names& Nms);
|
||||
virtual ~Player();
|
||||
|
||||
@@ -358,6 +366,8 @@ public:
|
||||
virtual void request_receive(int i, octetStream& o) const { (void)i; (void)o; }
|
||||
virtual void wait_receive(int i, octetStream& o) const
|
||||
{ receive_player(i, o); }
|
||||
|
||||
NamedCommStats total_comm() const;
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -500,6 +510,7 @@ class VirtualTwoPartyPlayer : public TwoPartyPlayer
|
||||
{
|
||||
Player& P;
|
||||
int other_player;
|
||||
NamedCommStats& comm_stats;
|
||||
|
||||
public:
|
||||
VirtualTwoPartyPlayer(Player& P, int other_player);
|
||||
|
||||
@@ -51,9 +51,17 @@ void Receiver<T>::run()
|
||||
while (in.pop(os))
|
||||
{
|
||||
os->reset_write_head();
|
||||
#ifdef VERBOSE_SSL
|
||||
timer.start();
|
||||
RunningTimer mytimer;
|
||||
#endif
|
||||
os->Receive(socket);
|
||||
#ifdef VERBOSE_SSL
|
||||
cout << "receiving " << os->get_length() * 1e-6 << " MB on " << socket
|
||||
<< " took " << mytimer.elapsed() << ", total "
|
||||
<< timer.elapsed() << endl;
|
||||
timer.stop();
|
||||
#endif
|
||||
out.push(os);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,9 +47,17 @@ void Sender<T>::run()
|
||||
const octetStream* os = 0;
|
||||
while (in.pop(os))
|
||||
{
|
||||
// timer.start();
|
||||
#ifdef VERBOSE_SSL
|
||||
timer.start();
|
||||
RunningTimer mytimer;
|
||||
#endif
|
||||
os->Send(socket);
|
||||
// timer.stop();
|
||||
#ifdef VERBOSE_SSL
|
||||
cout << "sending " << os->get_length() * 1e-6 << " MB on " << socket
|
||||
<< " took " << mytimer.elapsed() << ", total "
|
||||
<< timer.elapsed() << endl;
|
||||
timer.stop();
|
||||
#endif
|
||||
out.push(os);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,9 +28,7 @@ void Server::get_ip(int num)
|
||||
inet_ntop(AF_INET6, &s->sin6_addr, ipstr, sizeof ipstr);
|
||||
}
|
||||
|
||||
names[num]=new octet[512];
|
||||
memset(names[num], 0, 512);
|
||||
strncpy((char*)names[num], ipstr, INET6_ADDRSTRLEN);
|
||||
names[num] = ipstr;
|
||||
|
||||
#ifdef DEBUG_NETWORKING
|
||||
cerr << "Client IP address: " << names[num] << endl;
|
||||
@@ -45,11 +43,11 @@ void Server::get_name(int num)
|
||||
#endif
|
||||
|
||||
// Receive name sent by client (legacy) - not used here
|
||||
octet my_name[512];
|
||||
receive(socket_num[num],my_name,512);
|
||||
octetStream os;
|
||||
os.Receive(socket_num[num]);
|
||||
receive(socket_num[num],(octet*)&ports[num],4);
|
||||
#ifdef DEBUG_NETWORKING
|
||||
cerr << "Player " << num << " sent (IP for info only) " << my_name << ":"
|
||||
cerr << "Player " << num << " sent (IP for info only) " << os.str() << ":"
|
||||
<< ports[num] << endl;
|
||||
#endif
|
||||
|
||||
@@ -66,7 +64,7 @@ void Server::send_names(int num)
|
||||
send(socket_num[num],nmachines,4);
|
||||
for (int i=0; i<nmachines; i++)
|
||||
{
|
||||
send(socket_num[num],names[i],512);
|
||||
octetStream(names[i]).Send(socket_num[num]);
|
||||
send(socket_num[num],(octet*)&ports[i],4);
|
||||
}
|
||||
}
|
||||
@@ -88,13 +86,20 @@ Server::Server(int argc,char **argv)
|
||||
}
|
||||
nmachines=atoi(argv[1]);
|
||||
PortnumBase=atoi(argv[2]);
|
||||
server_socket = 0;
|
||||
}
|
||||
|
||||
Server::Server(int nmachines, int PortnumBase) :
|
||||
nmachines(nmachines), PortnumBase(PortnumBase)
|
||||
nmachines(nmachines), PortnumBase(PortnumBase), server_socket(0)
|
||||
{
|
||||
}
|
||||
|
||||
Server::~Server()
|
||||
{
|
||||
if (server_socket)
|
||||
delete server_socket;
|
||||
}
|
||||
|
||||
void Server::start()
|
||||
{
|
||||
int i;
|
||||
@@ -107,7 +112,8 @@ void Server::start()
|
||||
for (i=0; i<nmachines; i++) { socket_num[i]=-1; }
|
||||
|
||||
// port number one lower to avoid conflict with players
|
||||
ServerSocket server(PortnumBase - 1);
|
||||
server_socket = new ServerSocket(PortnumBase);
|
||||
auto& server = *server_socket;
|
||||
server.init();
|
||||
|
||||
// set up connections
|
||||
@@ -130,7 +136,7 @@ void Server::start()
|
||||
bool all_on_local = true, none_on_local = true;
|
||||
for (i = 1; i < nmachines; i++)
|
||||
{
|
||||
bool on_local = string((char*)names[i]).compare("127.0.0.1");
|
||||
bool on_local = names[i].compare("127.0.0.1");
|
||||
all_on_local &= on_local;
|
||||
none_on_local &= not on_local;
|
||||
}
|
||||
@@ -144,9 +150,6 @@ void Server::start()
|
||||
for (i=0; i<nmachines; i++)
|
||||
send_names(i);
|
||||
|
||||
for (i=0; i<nmachines; i++)
|
||||
{ delete[] names[i]; }
|
||||
|
||||
for (int i = 0; i < nmachines; i++)
|
||||
close(socket_num[i]);
|
||||
}
|
||||
@@ -162,7 +165,7 @@ Server* Server::start_networking(Names& N, int my_num, int nplayers,
|
||||
{
|
||||
#ifdef DEBUG_NETWORKING
|
||||
cerr << "Starting networking for " << my_num << "/" << nplayers
|
||||
<< " with server on " << hostname << ":" << (portnum - 1) << endl;
|
||||
<< " with server on " << hostname << ":" << (portnum) << endl;
|
||||
#endif
|
||||
assert(my_num >= 0);
|
||||
assert(my_num < nplayers);
|
||||
@@ -172,12 +175,19 @@ Server* Server::start_networking(Names& N, int my_num, int nplayers,
|
||||
{
|
||||
pthread_create(&thread, 0, Server::start_in_thread,
|
||||
server = new Server(nplayers, portnum));
|
||||
}
|
||||
N.init(my_num, portnum, my_port, hostname.c_str());
|
||||
if (my_num == 0)
|
||||
{
|
||||
N.init(my_num, portnum, my_port, hostname.c_str(), false);
|
||||
pthread_join(thread, 0);
|
||||
N.set_server(server->get_socket());
|
||||
delete server;
|
||||
}
|
||||
else
|
||||
N.init(my_num, portnum, my_port, hostname.c_str());
|
||||
return 0;
|
||||
}
|
||||
|
||||
ServerSocket* Server::get_socket()
|
||||
{
|
||||
auto res = server_socket;
|
||||
server_socket = 0;
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -14,10 +14,11 @@ using namespace std;
|
||||
class Server
|
||||
{
|
||||
vector<int> socket_num;
|
||||
vector<octet*> names;
|
||||
vector<string> names;
|
||||
vector<int> ports;
|
||||
int nmachines;
|
||||
int PortnumBase;
|
||||
ServerSocket* server_socket;
|
||||
|
||||
void get_ip(int num);
|
||||
void get_name(int num);
|
||||
@@ -31,7 +32,11 @@ public:
|
||||
|
||||
Server(int argc, char** argv);
|
||||
Server(int nmachines, int PortnumBase);
|
||||
~Server();
|
||||
|
||||
void start();
|
||||
|
||||
ServerSocket* get_socket();
|
||||
};
|
||||
|
||||
#endif /* NETWORKING_SERVER_H_ */
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#define CRYPTO_SSL_SOCKETS_H_
|
||||
|
||||
#include "Tools/int.h"
|
||||
#include "Tools/time-func.h"
|
||||
#include "sockets.h"
|
||||
#include "Math/Setup.h"
|
||||
|
||||
@@ -46,6 +47,10 @@ public:
|
||||
string me, bool client) :
|
||||
parent(io_service, ctx)
|
||||
{
|
||||
#ifdef DEBUG_NETWORKING
|
||||
cerr << me << " setting up SSL to " << other << " as " <<
|
||||
(client ? "client" : "server") << endl;
|
||||
#endif
|
||||
lowest_layer().assign(boost::asio::ip::tcp::v4(), plaintext_socket);
|
||||
set_verify_mode(boost::asio::ssl::verify_peer);
|
||||
set_verify_callback(boost::asio::ssl::rfc2818_verification(other));
|
||||
@@ -82,8 +87,16 @@ template<>
|
||||
inline void send(ssl_socket* socket, octet* data, size_t length)
|
||||
{
|
||||
size_t sent = 0;
|
||||
#ifdef VERBOSE_SSL
|
||||
RunningTimer timer;
|
||||
#endif
|
||||
while (sent < length)
|
||||
{
|
||||
sent += send_non_blocking(socket, data + sent, length - sent);
|
||||
#ifdef VERBOSE_SSL
|
||||
cout << "sent " << sent * 1e-6 << " MB at " << timer.elapsed() << endl;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
template<>
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "OT/BaseOT.h"
|
||||
#include "Tools/random.h"
|
||||
#include "Tools/benchmarking.h"
|
||||
#include "Tools/Bundle.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <iostream>
|
||||
@@ -78,6 +79,23 @@ void send_if_ot_receiver(TwoPartyPlayer* P, vector<octetStream>& os, OT_ROLE rol
|
||||
|
||||
void BaseOT::exec_base(bool new_receiver_inputs)
|
||||
{
|
||||
Bundle<octetStream> bundle(*P);
|
||||
#ifdef NO_AVX_OT
|
||||
bundle.mine = string("OT without AVX");
|
||||
#else
|
||||
bundle.mine = string("OT with AVX");
|
||||
#endif
|
||||
try
|
||||
{
|
||||
bundle.compare(*P);
|
||||
}
|
||||
catch (mismatch_among_parties&)
|
||||
{
|
||||
cerr << "Parties compiled with different base OT algorithms" << endl;
|
||||
cerr << "Set \"AVX_OT\" to the same value on all parties" << endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
#ifdef NO_AVX_OT
|
||||
#ifdef USE_RISTRETTO
|
||||
typedef CurveElement Element;
|
||||
|
||||
@@ -116,7 +116,7 @@ public:
|
||||
|
||||
mac_key_type get_mac_key() const { return mac_key; }
|
||||
|
||||
NamedCommStats comm_stats();
|
||||
Player& get_player() { return globalPlayer; }
|
||||
};
|
||||
|
||||
template<class T>
|
||||
@@ -209,15 +209,4 @@ public:
|
||||
void generateTriples();
|
||||
};
|
||||
|
||||
template<class T>
|
||||
NamedCommStats OTTripleGenerator<T>::comm_stats()
|
||||
{
|
||||
NamedCommStats res;
|
||||
if (parentPlayer != &globalPlayer)
|
||||
res = globalPlayer.comm_stats;
|
||||
for (auto& player : players)
|
||||
res += player->comm_stats;
|
||||
return res;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -110,22 +110,31 @@ void BaseMachine::time()
|
||||
void BaseMachine::start(int n)
|
||||
{
|
||||
cout << "Starting timer " << n << " at " << timer[n].elapsed()
|
||||
<< " (" << timer[n].mb_sent() << " MB)"
|
||||
<< " after " << timer[n].idle() << endl;
|
||||
timer[n].start();
|
||||
timer[n].start(total_comm());
|
||||
}
|
||||
|
||||
void BaseMachine::stop(int n)
|
||||
{
|
||||
timer[n].stop();
|
||||
cout << "Stopped timer " << n << " at " << timer[n].elapsed() << endl;
|
||||
timer[n].stop(total_comm());
|
||||
cout << "Stopped timer " << n << " at " << timer[n].elapsed() << " ("
|
||||
<< timer[n].mb_sent() << " MB)" << endl;
|
||||
}
|
||||
|
||||
void BaseMachine::print_timers()
|
||||
{
|
||||
cerr << "The following timing is ";
|
||||
if (OnlineOptions::singleton.live_prep)
|
||||
cerr << "in";
|
||||
else
|
||||
cerr << "ex";
|
||||
cerr << "clusive preprocessing." << endl;
|
||||
cerr << "Time = " << timer[0].elapsed() << " seconds " << endl;
|
||||
timer.erase(0);
|
||||
for (map<int,Timer>::iterator it = timer.begin(); it != timer.end(); it++)
|
||||
cerr << "Time" << it->first << " = " << it->second.elapsed() << " seconds " << endl;
|
||||
for (auto it = timer.begin(); it != timer.end(); it++)
|
||||
cerr << "Time" << it->first << " = " << it->second.elapsed() << " seconds ("
|
||||
<< it->second.mb_sent() << " MB)" << endl;
|
||||
}
|
||||
|
||||
string BaseMachine::memory_filename(const string& type_short, int my_number)
|
||||
@@ -170,3 +179,18 @@ bigint BaseMachine::prime_from_schedule(string progname)
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
|
||||
NamedCommStats BaseMachine::total_comm()
|
||||
{
|
||||
NamedCommStats res;
|
||||
for (auto& queue : queues)
|
||||
res += queue->get_comm_stats();
|
||||
return res;
|
||||
}
|
||||
|
||||
void BaseMachine::set_thread_comm(const NamedCommStats& stats)
|
||||
{
|
||||
auto queue = queues.at(BaseMachine::thread_num);
|
||||
assert(queue);
|
||||
queue->set_comm_stats(stats);
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#define PROCESSOR_BASEMACHINE_H_
|
||||
|
||||
#include "Tools/time-func.h"
|
||||
#include "Tools/TimerWithComm.h"
|
||||
#include "OT/OTTripleSetup.h"
|
||||
#include "ThreadJob.h"
|
||||
#include "ThreadQueues.h"
|
||||
@@ -22,7 +23,7 @@ class BaseMachine
|
||||
protected:
|
||||
static BaseMachine* singleton;
|
||||
|
||||
std::map<int,Timer> timer;
|
||||
std::map<int,TimerWithComm> timer;
|
||||
|
||||
string compiler;
|
||||
string domain;
|
||||
@@ -66,12 +67,18 @@ public:
|
||||
|
||||
virtual void reqbl(int) {}
|
||||
|
||||
OTTripleSetup fresh_ot_setup();
|
||||
static OTTripleSetup fresh_ot_setup(Player& P);
|
||||
|
||||
NamedCommStats total_comm();
|
||||
void set_thread_comm(const NamedCommStats& stats);
|
||||
};
|
||||
|
||||
inline OTTripleSetup BaseMachine::fresh_ot_setup()
|
||||
inline OTTripleSetup BaseMachine::fresh_ot_setup(Player& P)
|
||||
{
|
||||
return ot_setups.at(thread_num).get_fresh();
|
||||
if (singleton and size_t(thread_num) < s().ot_setups.size())
|
||||
return s().ot_setups.at(thread_num).get_fresh();
|
||||
else
|
||||
return OTTripleSetup(P, true);
|
||||
}
|
||||
|
||||
#endif /* PROCESSOR_BASEMACHINE_H_ */
|
||||
|
||||
@@ -38,7 +38,7 @@ void Binary_File_IO::read_from_file(const string filename, vector< T >& buffer,
|
||||
|
||||
int size_in_bytes = T::size() * buffer.size();
|
||||
int n_read = 0;
|
||||
char * read_buffer = new char[size_in_bytes];
|
||||
char read_buffer[size_in_bytes];
|
||||
inf.seekg(start_posn * T::size());
|
||||
do
|
||||
{
|
||||
|
||||
@@ -89,6 +89,7 @@ template<class sint, class sgf2n> class Processor;
|
||||
template<class sint, class sgf2n> class Data_Files;
|
||||
template<class sint, class sgf2n> class Machine;
|
||||
template<class T> class SubProcessor;
|
||||
template<class T> class NoFilePrep;
|
||||
|
||||
/**
|
||||
* Abstract base class for preprocessing
|
||||
@@ -125,6 +126,7 @@ public:
|
||||
template<class U, class V>
|
||||
static Preprocessing<T>* get_new(Machine<U, V>& machine, DataPositions& usage,
|
||||
SubProcessor<T>* proc);
|
||||
template<int = 0>
|
||||
static Preprocessing<T>* get_new(bool live_prep, const Names& N,
|
||||
DataPositions& usage);
|
||||
static Preprocessing<T>* get_live_prep(SubProcessor<T>* proc,
|
||||
@@ -133,22 +135,21 @@ public:
|
||||
Preprocessing(DataPositions& usage) : usage(usage), do_count(true) {}
|
||||
virtual ~Preprocessing() {}
|
||||
|
||||
virtual void set_protocol(typename T::Protocol& protocol) = 0;
|
||||
virtual void set_protocol(typename T::Protocol&) {};
|
||||
virtual void set_proc(SubProcessor<T>* proc) { (void) proc; }
|
||||
|
||||
virtual void seekg(DataPositions& pos) { (void) pos; }
|
||||
virtual void prune() {}
|
||||
virtual void purge() {}
|
||||
|
||||
virtual size_t data_sent() { return comm_stats().sent; }
|
||||
virtual NamedCommStats comm_stats() { return {}; }
|
||||
|
||||
virtual void get_three_no_count(Dtype dtype, T& a, T& b, T& c) = 0;
|
||||
virtual void get_two_no_count(Dtype dtype, T& a, T& b) = 0;
|
||||
virtual void get_one_no_count(Dtype dtype, T& a) = 0;
|
||||
virtual void get_input_no_count(T& a, typename T::open_type& x, int i) = 0;
|
||||
virtual void get_no_count(vector<T>& S, DataTag tag, const vector<int>& regs,
|
||||
int vector_size) = 0;
|
||||
virtual void get_three_no_count(Dtype, T&, T&, T&)
|
||||
{ throw not_implemented(); }
|
||||
virtual void get_two_no_count(Dtype, T&, T&) { throw not_implemented(); }
|
||||
virtual void get_one_no_count(Dtype, T&) { throw not_implemented(); }
|
||||
virtual void get_input_no_count(T&, typename T::open_type&, int)
|
||||
{ throw not_implemented() ; }
|
||||
virtual void get_no_count(vector<T>&, DataTag, const vector<int>&, int)
|
||||
{ throw not_implemented(); }
|
||||
|
||||
void get(Dtype dtype, T* a);
|
||||
void get_three(Dtype dtype, T& a, T& b, T& c);
|
||||
@@ -191,6 +192,9 @@ class Sub_Data_Files : public Preprocessing<T>
|
||||
{
|
||||
template<class U> friend class Sub_Data_Files;
|
||||
|
||||
typedef typename conditional<T::LivePrep::use_part,
|
||||
Sub_Data_Files<typename T::part_type>, NoFilePrep<typename T::part_type>>::type part_type;
|
||||
|
||||
static int tuple_length(int dtype);
|
||||
|
||||
BufferOwner<T, T> buffers[N_DTYPE];
|
||||
@@ -205,7 +209,7 @@ class Sub_Data_Files : public Preprocessing<T>
|
||||
const string prep_data_dir;
|
||||
int thread_num;
|
||||
|
||||
Sub_Data_Files<typename T::part_type>* part;
|
||||
part_type* part;
|
||||
|
||||
void buffer_edabits_with_queues(bool strict, int n_bits)
|
||||
{ buffer_edabits_with_queues<0>(strict, n_bits, T::clear::characteristic_two); }
|
||||
@@ -274,7 +278,7 @@ public:
|
||||
void get_no_count(vector<T>& S, DataTag tag, const vector<int>& regs, int vector_size);
|
||||
void get_dabit_no_count(T& a, typename T::bit_type& b);
|
||||
|
||||
Preprocessing<typename T::part_type>& get_part();
|
||||
part_type& get_part();
|
||||
};
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
@@ -307,8 +311,6 @@ class Data_Files
|
||||
}
|
||||
|
||||
void reset_usage() { usage.reset(); skipped.reset(); }
|
||||
|
||||
NamedCommStats comm_stats();
|
||||
};
|
||||
|
||||
template<class T> inline
|
||||
@@ -418,6 +420,7 @@ T Preprocessing<T>::get_bit()
|
||||
template<class T>
|
||||
T Preprocessing<T>::get_random()
|
||||
{
|
||||
assert(not usage.inputs.empty());
|
||||
return get_random_from_inputs(usage.inputs.size());
|
||||
}
|
||||
|
||||
@@ -429,10 +432,4 @@ inline void Data_Files<sint, sgf2n>::purge()
|
||||
DataFb.purge();
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
NamedCommStats Data_Files<sint, sgf2n>::comm_stats()
|
||||
{
|
||||
return DataFp.comm_stats() + DataF2.comm_stats() + DataFb.comm_stats();
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "Processor/Data_Files.h"
|
||||
#include "Processor/Processor.h"
|
||||
#include "Processor/NoFilePrep.h"
|
||||
#include "Protocols/dabit.h"
|
||||
#include "Math/Setup.h"
|
||||
#include "GC/BitPrepFiles.h"
|
||||
@@ -30,6 +31,7 @@ Preprocessing<T>* Preprocessing<T>::get_new(
|
||||
}
|
||||
|
||||
template<class T>
|
||||
template<int>
|
||||
Preprocessing<T>* Preprocessing<T>::get_new(
|
||||
bool live_prep, const Names& N,
|
||||
DataPositions& usage)
|
||||
@@ -156,17 +158,7 @@ Data_Files<sint, sgf2n>::Data_Files(const Names& N) :
|
||||
template<class sint, class sgf2n>
|
||||
Data_Files<sint, sgf2n>::~Data_Files()
|
||||
{
|
||||
#ifdef VERBOSE
|
||||
if (DataFp.data_sent())
|
||||
cerr << "Sent for " << sint::type_string() << " preprocessing threads: " <<
|
||||
DataFp.data_sent() * 1e-6 << " MB" << endl;
|
||||
#endif
|
||||
delete &DataFp;
|
||||
#ifdef VERBOSE
|
||||
if (DataF2.data_sent())
|
||||
cerr << "Sent for " << sgf2n::type_string() << " preprocessing threads: " <<
|
||||
DataF2.data_sent() * 1e-6 << " MB" << endl;
|
||||
#endif
|
||||
delete &DataF2;
|
||||
delete &DataFb;
|
||||
}
|
||||
@@ -264,6 +256,8 @@ void Sub_Data_Files<T>::purge()
|
||||
for (auto it : extended)
|
||||
it.second.purge();
|
||||
dabit_buffer.purge();
|
||||
if (part != 0)
|
||||
part->purge();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -329,10 +323,10 @@ void Sub_Data_Files<T>::buffer_edabits_with_queues(bool strict, int n_bits,
|
||||
}
|
||||
|
||||
template<class T>
|
||||
Preprocessing<typename T::part_type>& Sub_Data_Files<T>::get_part()
|
||||
typename Sub_Data_Files<T>::part_type& Sub_Data_Files<T>::get_part()
|
||||
{
|
||||
if (part == 0)
|
||||
part = new Sub_Data_Files<typename T::part_type>(my_num, num_players,
|
||||
part = new part_type(my_num, num_players,
|
||||
get_prep_sub_dir<typename T::part_type>(num_players), this->usage,
|
||||
thread_num);
|
||||
return *part;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user