mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-08 21:18:03 -05:00
Optimized matrix multiplication in Hemi.
This commit is contained in:
@@ -79,7 +79,7 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
|
||||
auto& MC = this->MC;
|
||||
|
||||
this->_id = online_opts.playerno + 1;
|
||||
Server* server = Server::start_networking(N, online_opts.playerno, nparties,
|
||||
Server::start_networking(N, online_opts.playerno, nparties,
|
||||
network_opts.hostname, network_opts.portnum_base);
|
||||
if (T::dishonest_majority)
|
||||
P = new PlainPlayer(N, 0);
|
||||
@@ -159,8 +159,7 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
|
||||
MC->Check(*P);
|
||||
data_sent = P->comm_stats.total_data() + prep->data_sent();
|
||||
|
||||
if (server)
|
||||
delete server;
|
||||
this->machine.write_memory(this->N.my_num());
|
||||
}
|
||||
|
||||
template<class T>
|
||||
|
||||
@@ -25,6 +25,9 @@ static void throw_bad_ip(const char* ip) {
|
||||
throw std::invalid_argument( "bad ip" );
|
||||
}
|
||||
|
||||
namespace BIU
|
||||
{
|
||||
|
||||
Client::Client(endpoint_t* endpoints, int numservers, ClientUpdatable* updatable, unsigned int max_message_size)
|
||||
:_max_msg_sz(max_message_size),
|
||||
_numservers(numservers),
|
||||
@@ -205,3 +208,5 @@ void Client::_send_blocking(SendBuffer& msg, int id) {
|
||||
fflush(0);
|
||||
#endif
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -28,6 +28,8 @@ public:
|
||||
|
||||
|
||||
|
||||
namespace BIU
|
||||
{
|
||||
|
||||
class Client {
|
||||
public:
|
||||
@@ -61,4 +63,6 @@ private:
|
||||
boost::thread_group threads;
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif /* NETWORK_INC_CLIENT_H_ */
|
||||
|
||||
@@ -35,7 +35,7 @@ Node::Node(const char* netmap_file, int my_id, NodeUpdatable* updatable, int num
|
||||
_ready_nodes = new bool[_numparties](); //initialized to false
|
||||
_clients_connected = new bool[_numparties]();
|
||||
_server = new BIU::Server(_port, _numparties-1, this, max_message_size);
|
||||
_client = new Client(_endpoints, _numparties-1, this, max_message_size);
|
||||
_client = new BIU::Client(_endpoints, _numparties-1, this, max_message_size);
|
||||
}
|
||||
|
||||
Node::~Node() {
|
||||
|
||||
@@ -69,7 +69,7 @@ private:
|
||||
int _numparties;
|
||||
|
||||
endpoint_t* _endpoints;
|
||||
Client* _client;
|
||||
BIU::Client* _client;
|
||||
BIU::Server* _server;
|
||||
bool* _ready_nodes;
|
||||
volatile bool _connected_to_servers;
|
||||
|
||||
12
CHANGELOG.md
12
CHANGELOG.md
@@ -1,5 +1,17 @@
|
||||
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
|
||||
|
||||
## 0.2.7 (Sep 17, 2021)
|
||||
|
||||
- Optimized matrix multiplication in Hemi
|
||||
- Improved client communication
|
||||
- Private integer division as per `Veugen and Abspoel
|
||||
<https://doi.org/10.2478/popets-2021-0073>`
|
||||
- Compiler option to translate some Python control flow instructions
|
||||
to run-time instructions
|
||||
- Functionality to break out of run-time loops
|
||||
- Run-time range check of data structure accesses
|
||||
- Improved documentation of network infrastructure
|
||||
|
||||
## 0.2.6 (Aug 6, 2021)
|
||||
|
||||
- [ATLAS](https://eprint.iacr.org/2021/833)
|
||||
|
||||
@@ -52,6 +52,7 @@ opcodes = dict(
|
||||
CONVCBIT2S = 0x249,
|
||||
XORCBI = 0x210,
|
||||
BITDECC = 0x211,
|
||||
NOTCB = 0x212,
|
||||
CONVCINT = 0x213,
|
||||
REVEAL = 0x214,
|
||||
STMSDCI = 0x215,
|
||||
@@ -190,6 +191,16 @@ class nots(BinaryVectorInstruction):
|
||||
code = opcodes['NOTS']
|
||||
arg_format = ['int','sbw','sb']
|
||||
|
||||
class notcb(BinaryVectorInstruction):
|
||||
""" Bitwise NOT of secret register vector.
|
||||
|
||||
:param: number of bits
|
||||
:param: result (cbit)
|
||||
:param: operand (cbit)
|
||||
"""
|
||||
code = opcodes['NOTCB']
|
||||
arg_format = ['int','cbw','cb']
|
||||
|
||||
class addcb(NonVectorInstruction):
|
||||
""" Integer addition two single clear bit registers.
|
||||
|
||||
@@ -617,4 +628,4 @@ class cond_print_strb(base.IOInstruction):
|
||||
arg_format = ['cb', 'int']
|
||||
|
||||
def __init__(self, cond, val):
|
||||
super(cond_print_str, self).__init__(cond, self.str_to_int(val))
|
||||
super(cond_print_strb, self).__init__(cond, self.str_to_int(val))
|
||||
|
||||
@@ -169,6 +169,33 @@ class bits(Tape.Register, _structure, _bit):
|
||||
(str(other), repr(other), type(other), type(self)))
|
||||
def long_one(self):
|
||||
return 2**self.n - 1 if self.n != None else None
|
||||
def is_long_one(self, other):
|
||||
return util.is_all_ones(other, self.n) or \
|
||||
(other is None and self.n == None)
|
||||
def res_type(self, other):
|
||||
if self.n == None and other.n == None:
|
||||
n = None
|
||||
else:
|
||||
n = max(self.n, other.n)
|
||||
return self.get_type(n)
|
||||
@read_mem_value
|
||||
def __and__(self, other):
|
||||
if util.is_zero(other):
|
||||
return 0
|
||||
elif self.is_long_one(other):
|
||||
return self
|
||||
else:
|
||||
return self._and(other)
|
||||
@read_mem_value
|
||||
def __xor__(self, other):
|
||||
if util.is_zero(other):
|
||||
return self
|
||||
elif self.is_long_one(other):
|
||||
return ~self
|
||||
else:
|
||||
return self._xor(other)
|
||||
__rand__ = __and__
|
||||
__rxor__ = __xor__
|
||||
def __repr__(self):
|
||||
if self.n != None:
|
||||
suffix = '%d' % self.n
|
||||
@@ -245,19 +272,20 @@ class cbits(bits):
|
||||
self.clear_op(other, inst.addcb, inst.addcbi, operator.add)
|
||||
__sub__ = lambda self, other: \
|
||||
self.clear_op(-other, inst.addcb, inst.addcbi, operator.add)
|
||||
def __xor__(self, other):
|
||||
def _xor(self, other):
|
||||
if isinstance(other, (sbits, sbitvec)):
|
||||
return NotImplemented
|
||||
elif isinstance(other, cbits):
|
||||
res = cbits.get_type(max(self.n, other.n))()
|
||||
res = self.res_type(other)()
|
||||
assert res.size == self.size
|
||||
assert res.size == other.size
|
||||
inst.xorcb(res.n, res, self, other)
|
||||
return res
|
||||
else:
|
||||
return self.clear_op(other, None, inst.xorcbi, operator.xor)
|
||||
def _and(self, other):
|
||||
return NotImplemented
|
||||
__radd__ = __add__
|
||||
__rxor__ = __xor__
|
||||
def __mul__(self, other):
|
||||
if isinstance(other, cbits):
|
||||
return NotImplemented
|
||||
@@ -278,7 +306,9 @@ class cbits(bits):
|
||||
inst.shlcbi(res, self, other)
|
||||
return res
|
||||
def __invert__(self):
|
||||
return self ^ self.long_one()
|
||||
res = type(self)()
|
||||
inst.notcb(self.n, res, self)
|
||||
return res
|
||||
def print_reg(self, desc=''):
|
||||
inst.print_regb(self, desc)
|
||||
def print_reg_plain(self):
|
||||
@@ -287,6 +317,8 @@ class cbits(bits):
|
||||
def print_if(self, string):
|
||||
inst.cond_print_strb(self, string)
|
||||
def output_if(self, cond):
|
||||
if Program.prog.options.binary:
|
||||
raise CompilerError('conditional output not supported')
|
||||
cint(self).output_if(cond)
|
||||
def reveal(self):
|
||||
return self
|
||||
@@ -423,8 +455,7 @@ class sbits(bits):
|
||||
__radd__ = __add__
|
||||
__sub__ = __add__
|
||||
__rsub__ = __add__
|
||||
__xor__ = __add__
|
||||
__rxor__ = __add__
|
||||
_xor = __add__
|
||||
@read_mem_value
|
||||
def __mul__(self, other):
|
||||
if isinstance(other, int):
|
||||
@@ -440,13 +471,7 @@ class sbits(bits):
|
||||
except AttributeError:
|
||||
return NotImplemented
|
||||
__rmul__ = __mul__
|
||||
@read_mem_value
|
||||
def __and__(self, other):
|
||||
if util.is_zero(other):
|
||||
return 0
|
||||
elif util.is_all_ones(other, self.n) or \
|
||||
(other is None and self.n == None):
|
||||
return self
|
||||
def _and(self, other):
|
||||
res = self.new(n=self.n)
|
||||
if not isinstance(other, sbits):
|
||||
other = cbits.get_type(self.n).conv(other)
|
||||
@@ -456,7 +481,6 @@ class sbits(bits):
|
||||
assert(self.n == other.n)
|
||||
inst.ands(self.n, res, self, other)
|
||||
return res
|
||||
__rand__ = __and__
|
||||
def xor_int(self, other):
|
||||
if other == 0:
|
||||
return self
|
||||
@@ -551,12 +575,6 @@ class sbits(bits):
|
||||
@staticmethod
|
||||
def ripple_carry_adder(*args, **kwargs):
|
||||
return sbitint.ripple_carry_adder(*args, **kwargs)
|
||||
def to_sint(self, n_bits):
|
||||
""" Convert the :py:obj:`n_bits` least significant bits to
|
||||
:py:obj:`~Compiler.types.sint`. """
|
||||
bits = sbitvec.from_vec(sbitvec([self]).v[:n_bits]).elements()[0]
|
||||
bits = sint(bits, size=n_bits)
|
||||
return sint.bit_compose(bits)
|
||||
|
||||
class sbitvec(_vec):
|
||||
""" Vector of registers of secret bits, effectively a matrix of secret bits.
|
||||
|
||||
@@ -591,7 +591,7 @@ class RegintOptimizer:
|
||||
elif isinstance(inst, IndirectMemoryInstruction):
|
||||
if inst.args[1] in self.cache:
|
||||
instructions[i] = inst.get_direct(self.cache[inst.args[1]])
|
||||
elif isinstance(inst, convint_class):
|
||||
elif type(inst) == convint_class:
|
||||
if inst.args[1] in self.cache:
|
||||
res = self.cache[inst.args[1]]
|
||||
self.cache[inst.args[0]] = res
|
||||
|
||||
@@ -2,6 +2,7 @@ from Compiler.program import Program
|
||||
from .GC import types as GC_types
|
||||
|
||||
import sys
|
||||
import re, tempfile, os
|
||||
|
||||
|
||||
def run(args, options):
|
||||
@@ -21,11 +22,62 @@ def run(args, options):
|
||||
del VARS[i]
|
||||
|
||||
print('Compiling file', prog.infile)
|
||||
|
||||
f = open(prog.infile, 'rb')
|
||||
|
||||
changed = False
|
||||
if options.flow_optimization:
|
||||
output = []
|
||||
if_stack = []
|
||||
for line in open(prog.infile):
|
||||
if if_stack and not re.match(if_stack[-1][0], line):
|
||||
if_stack.pop()
|
||||
m = re.match(
|
||||
'(\s*)for +([a-zA-Z_]+) +in +range\(([0-9a-zA-Z_]+)\):',
|
||||
line)
|
||||
if m:
|
||||
output.append('%s@for_range_opt(%s)\n' % (m.group(1),
|
||||
m.group(3)))
|
||||
output.append('%sdef _(%s):\n' % (m.group(1), m.group(2)))
|
||||
changed = True
|
||||
continue
|
||||
m = re.match('(\s*)if(\W.*):', line)
|
||||
if m:
|
||||
if_stack.append((m.group(1), len(output)))
|
||||
output.append('%s@if_(%s)\n' % (m.group(1), m.group(2)))
|
||||
output.append('%sdef _():\n' % (m.group(1)))
|
||||
changed = True
|
||||
continue
|
||||
m = re.match('(\s*)elif\s+', line)
|
||||
if m:
|
||||
raise CompilerError('elif not supported')
|
||||
if if_stack:
|
||||
m = re.match('%selse:' % if_stack[-1][0], line)
|
||||
if m:
|
||||
start = if_stack[-1][1]
|
||||
ws = if_stack[-1][0]
|
||||
output[start] = re.sub(r'^%s@if_\(' % ws, r'%s@if_e(' % ws,
|
||||
output[start])
|
||||
output.append('%s@else_\n' % ws)
|
||||
output.append('%sdef _():\n' % ws)
|
||||
continue
|
||||
output.append(line)
|
||||
if changed:
|
||||
infile = tempfile.NamedTemporaryFile('w+', delete=False)
|
||||
for line in output:
|
||||
infile.write(line)
|
||||
infile.seek(0)
|
||||
else:
|
||||
infile = open(prog.infile)
|
||||
else:
|
||||
infile = open(prog.infile)
|
||||
|
||||
# make compiler modules directly accessible
|
||||
sys.path.insert(0, 'Compiler')
|
||||
# create the tapes
|
||||
exec(compile(open(prog.infile).read(), prog.infile, 'exec'), VARS)
|
||||
exec(compile(infile.read(), infile.name, 'exec'), VARS)
|
||||
|
||||
if changed and not options.debug:
|
||||
os.unlink(infile.name)
|
||||
|
||||
prog.finalize()
|
||||
|
||||
|
||||
@@ -562,7 +562,7 @@ def SDiv(a, b, l, kappa, round_nearest=False):
|
||||
y = a * w
|
||||
y = y.round(2 * l + 1, l, kappa, round_nearest, signed=False)
|
||||
x2 = types.sint()
|
||||
comparison.Mod2m(x2, x, 2 * l + 1, l, kappa, False)
|
||||
comparison.Mod2m(x2, x, 2 * l + 1, l, kappa, True)
|
||||
x1 = comparison.TruncZeros(x - x2, 2 * l + 1, l, True)
|
||||
for i in range(theta-1):
|
||||
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa,
|
||||
@@ -642,7 +642,7 @@ def BitDecFull(a, maybe_mixed=False):
|
||||
b, bbits = sint.get_edabit(logp, True, size=a.size)
|
||||
if logp != bit_length:
|
||||
from .GC.types import sbits
|
||||
bbits += [sbits.get_type(a.size)(0)]
|
||||
bbits += [0]
|
||||
else:
|
||||
bbits = [sint.get_random_bit(size=a.size) for i in range(logp)]
|
||||
b = sint.bit_compose(bbits)
|
||||
|
||||
@@ -409,6 +409,18 @@ class use_edabit(base.Instruction):
|
||||
code = base.opcodes['USE_EDABIT']
|
||||
arg_format = ['int','int','int']
|
||||
|
||||
class use_matmul(base.Instruction):
|
||||
""" Matrix multiplication usage. Used for multithreading of
|
||||
preprocessing.
|
||||
|
||||
:param: number of left-hand rows (int)
|
||||
:param: number of left-hand columns/right-hand rows (int)
|
||||
:param: number of right-hand columns (int)
|
||||
:param: number (int, -1 for unknown)
|
||||
"""
|
||||
code = base.opcodes['USE_MATMUL']
|
||||
arg_format = ['int','int','int','int']
|
||||
|
||||
class run_tape(base.Instruction):
|
||||
""" Start tape/bytecode file in another thread.
|
||||
|
||||
@@ -432,7 +444,7 @@ class join_tape(base.Instruction):
|
||||
class crash(base.IOInstruction):
|
||||
""" Crash runtime. """
|
||||
code = base.opcodes['CRASH']
|
||||
arg_format = []
|
||||
arg_format = ['ci']
|
||||
|
||||
class start_grind(base.IOInstruction):
|
||||
code = base.opcodes['STARTGRIND']
|
||||
@@ -1559,51 +1571,49 @@ class pubinput(base.PublicFileIOInstruction):
|
||||
code = base.opcodes['PUBINPUT']
|
||||
arg_format = ['cw']
|
||||
|
||||
@base.vectorize
|
||||
class readsocketc(base.IOInstruction):
|
||||
""" Read a variable number of clear values in internal representation
|
||||
from socket for a specified client id and store them in clear registers.
|
||||
|
||||
:param: number of arguments to follow / number of inputs minus one (int)
|
||||
:param: client id (regint)
|
||||
:param: vector size (int)
|
||||
:param: destination (cint)
|
||||
:param: (repeat destination)...
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['READSOCKETC']
|
||||
arg_format = tools.chain(['ci'], itertools.repeat('cw'))
|
||||
arg_format = tools.chain(['ci','int'], itertools.repeat('cw'))
|
||||
|
||||
def has_var_args(self):
|
||||
return True
|
||||
|
||||
@base.vectorize
|
||||
class readsockets(base.IOInstruction):
|
||||
"""Read a variable number of secret shares + MACs from socket for a client id and store in registers"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['READSOCKETS']
|
||||
arg_format = tools.chain(['ci'], itertools.repeat('sw'))
|
||||
arg_format = tools.chain(['ci','int'], itertools.repeat('sw'))
|
||||
|
||||
def has_var_args(self):
|
||||
return True
|
||||
|
||||
@base.vectorize
|
||||
class readsocketint(base.IOInstruction):
|
||||
""" Read a variable number of 32-bit integers from socket for a
|
||||
specified client id and store them in clear integer registers.
|
||||
|
||||
:param: number of arguments to follow / number of inputs minus one (int)
|
||||
:param: client id (regint)
|
||||
:param: vector size (int)
|
||||
:param: destination (regint)
|
||||
:param: (repeat destination)...
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['READSOCKETINT']
|
||||
arg_format = tools.chain(['ci'], itertools.repeat('ciw'))
|
||||
arg_format = tools.chain(['ci','int'], itertools.repeat('ciw'))
|
||||
|
||||
def has_var_args(self):
|
||||
return True
|
||||
|
||||
@base.vectorize
|
||||
class writesocketc(base.IOInstruction):
|
||||
"""
|
||||
Write a variable number of clear GF(p) values from registers into socket
|
||||
@@ -1611,29 +1621,28 @@ class writesocketc(base.IOInstruction):
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['WRITESOCKETC']
|
||||
arg_format = tools.chain(['ci', 'int'], itertools.repeat('c'))
|
||||
arg_format = tools.chain(['ci', 'int', 'int'], itertools.repeat('c'))
|
||||
|
||||
def has_var_args(self):
|
||||
return True
|
||||
|
||||
@base.vectorize
|
||||
class writesocketshare(base.IOInstruction):
|
||||
""" Write a variable number of shares (without MACs) from secret
|
||||
registers into socket for a specified client id.
|
||||
|
||||
:param: client id (regint)
|
||||
:param: message type (must be 0)
|
||||
:param: vector size (int)
|
||||
:param: source (sint)
|
||||
:param: (repeat source)...
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['WRITESOCKETSHARE']
|
||||
arg_format = tools.chain(['ci', 'int'], itertools.repeat('s'))
|
||||
arg_format = tools.chain(['ci', 'int', 'int'], itertools.repeat('s'))
|
||||
|
||||
def has_var_args(self):
|
||||
return True
|
||||
|
||||
@base.vectorize
|
||||
class writesocketint(base.IOInstruction):
|
||||
"""
|
||||
Write a variable number of 32-bit ints from registers into socket
|
||||
@@ -1641,7 +1650,7 @@ class writesocketint(base.IOInstruction):
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['WRITESOCKETINT']
|
||||
arg_format = tools.chain(['ci', 'int'], itertools.repeat('ci'))
|
||||
arg_format = tools.chain(['ci', 'int', 'int'], itertools.repeat('ci'))
|
||||
|
||||
def has_var_args(self):
|
||||
return True
|
||||
@@ -2266,6 +2275,10 @@ class matmulsm(matmul_base):
|
||||
for i in range(2):
|
||||
assert args[8 + i].size == args[4 + i]
|
||||
|
||||
def add_usage(self, req_node):
|
||||
super(matmulsm, self).add_usage(req_node)
|
||||
req_node.increment(('matmul', tuple(self.args[3:6])), 1)
|
||||
|
||||
class conv2ds(base.DataInstruction):
|
||||
""" Secret 2D convolution.
|
||||
|
||||
@@ -2301,6 +2314,12 @@ class conv2ds(base.DataInstruction):
|
||||
return self.args[3] * self.args[4] * self.args[7] * self.args[8] * \
|
||||
self.args[11] * self.args[14]
|
||||
|
||||
def add_usage(self, req_node):
|
||||
super(conv2ds, self).add_usage(req_node)
|
||||
args = self.args
|
||||
req_node.increment(('matmul', (1, args[7] * args[8] * args[11],
|
||||
args[14] * args[3] * args[4])), 1)
|
||||
|
||||
@base.vectorize
|
||||
class trunc_pr(base.VarArgsInstruction):
|
||||
""" Probabilistic truncation if supported by the protocol.
|
||||
|
||||
@@ -63,6 +63,7 @@ opcodes = dict(
|
||||
THRESHOLD = 0xE3,
|
||||
PLAYERID = 0xE4,
|
||||
USE_EDABIT = 0xE5,
|
||||
USE_MATMUL = 0x1F,
|
||||
# Addition
|
||||
ADDC = 0x20,
|
||||
ADDS = 0x21,
|
||||
@@ -456,7 +457,7 @@ def cisc(function):
|
||||
reset_global_vector_size()
|
||||
program.curr_tape = old_tape
|
||||
for x, bl in tape.req_bit_length.items():
|
||||
old_tape.require_bit_length(bl, x)
|
||||
old_tape.require_bit_length(bl - 1, x)
|
||||
from Compiler.allocator import Merger
|
||||
merger = Merger(block, program.options,
|
||||
tuple(program.to_merge))
|
||||
@@ -542,7 +543,13 @@ def cisc(function):
|
||||
MergeCISC.__name__ = function.__name__
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
if program.options.cisc:
|
||||
same_sizes = True
|
||||
for arg in args:
|
||||
try:
|
||||
same_sizes &= arg.size == args[0].size
|
||||
except:
|
||||
pass
|
||||
if program.options.cisc and same_sizes:
|
||||
return MergeCISC(*args, **kwargs)
|
||||
else:
|
||||
return function(*args, **kwargs)
|
||||
|
||||
@@ -143,7 +143,6 @@ def print_str_if(cond, ss, *args):
|
||||
assert len(subs) == len(args) + 1
|
||||
if isinstance(cond, localint):
|
||||
cond = cond._v
|
||||
cond = cint.conv(cond)
|
||||
for i, s in enumerate(subs):
|
||||
if i != 0:
|
||||
val = args[i - 1]
|
||||
@@ -203,6 +202,27 @@ def runtime_error(msg='', *args):
|
||||
print_ln(msg, *args)
|
||||
crash()
|
||||
|
||||
def runtime_error_if(condition, msg='', *args):
|
||||
""" Conditionally print an error message and abort the runtime.
|
||||
|
||||
:param condition: regint/cint/int/cbit
|
||||
:param msg: message
|
||||
:param args: list of public values to fit ``%s`` in the message
|
||||
|
||||
"""
|
||||
print_ln_if(condition, msg, *args)
|
||||
crash(condition)
|
||||
|
||||
def crash(condition=None):
|
||||
""" Crash virtual machine.
|
||||
|
||||
:param condition: crash if true (default: true)
|
||||
|
||||
"""
|
||||
if condition == None:
|
||||
condition = regint(1)
|
||||
instructions.crash(regint.conv(condition))
|
||||
|
||||
def public_input():
|
||||
""" Public input read from ``Programs/Public-Input/<progname>``. """
|
||||
res = cint()
|
||||
@@ -1209,6 +1229,8 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
|
||||
return loop_body(base + i)
|
||||
prog = get_program()
|
||||
thread_args = []
|
||||
if prog.curr_tape == prog.tapes[0]:
|
||||
prog.n_running_threads = n_threads
|
||||
if not util.is_zero(thread_rounds):
|
||||
tape = prog.new_tape(f, (0,), 'multithread')
|
||||
for i in range(n_threads - remainder):
|
||||
@@ -1225,6 +1247,7 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
|
||||
if len(mem_state):
|
||||
args[i][1] = mem_state.address
|
||||
thread_args.append((tape1, i))
|
||||
prog.n_running_threads = None
|
||||
threads = prog.run_tapes(thread_args)
|
||||
for thread in threads:
|
||||
prog.join_tape(thread)
|
||||
@@ -1397,6 +1420,7 @@ def do_while(loop_fn, g=None):
|
||||
# possibly unknown loop count
|
||||
get_tape().open_scope(lambda x: x[0].set_all(float('Inf')), \
|
||||
name='begin-loop')
|
||||
get_tape().loop_breaks.append([])
|
||||
loop_block = instructions.program.curr_block
|
||||
condition = _run_and_link(loop_fn, g)
|
||||
if callable(condition):
|
||||
@@ -1404,8 +1428,15 @@ def do_while(loop_fn, g=None):
|
||||
branch = instructions.jmpnz(regint.conv(condition), 0, add_to_prog=False)
|
||||
instructions.program.curr_block.set_exit(branch, loop_block)
|
||||
get_tape().close_scope(scope, parent_node, 'end-loop')
|
||||
for loop_break in get_tape().loop_breaks.pop():
|
||||
loop_break.set_exit(jmp(0, add_to_prog=False), get_block())
|
||||
return loop_fn
|
||||
|
||||
def break_loop():
|
||||
""" Break out of loop. """
|
||||
get_tape().loop_breaks[-1].append(get_block())
|
||||
break_point('break')
|
||||
|
||||
def if_then(condition):
|
||||
class State: pass
|
||||
state = State()
|
||||
@@ -1483,10 +1514,18 @@ def if_(condition):
|
||||
def _():
|
||||
...
|
||||
"""
|
||||
try:
|
||||
condition = bool(condition)
|
||||
except:
|
||||
pass
|
||||
def decorator(body):
|
||||
if_then(condition)
|
||||
_run_and_link(body)
|
||||
end_if()
|
||||
if isinstance(condition, bool):
|
||||
if condition:
|
||||
_run_and_link(body)
|
||||
else:
|
||||
if_then(condition)
|
||||
_run_and_link(body)
|
||||
end_if()
|
||||
return decorator
|
||||
|
||||
def if_e(condition):
|
||||
@@ -1506,15 +1545,30 @@ def if_e(condition):
|
||||
def _():
|
||||
...
|
||||
"""
|
||||
try:
|
||||
condition = bool(condition)
|
||||
except:
|
||||
pass
|
||||
def decorator(body):
|
||||
if_then(condition)
|
||||
_run_and_link(body)
|
||||
if isinstance(condition, bool):
|
||||
get_tape().if_states.append(condition)
|
||||
if condition:
|
||||
_run_and_link(body)
|
||||
else:
|
||||
if_then(condition)
|
||||
_run_and_link(body)
|
||||
return decorator
|
||||
|
||||
def else_(body):
|
||||
else_then()
|
||||
_run_and_link(body)
|
||||
end_if()
|
||||
if_states = get_tape().if_states
|
||||
if isinstance(if_states[-1], bool):
|
||||
if not if_states[-1]:
|
||||
_run_and_link(body)
|
||||
if_states.pop()
|
||||
else:
|
||||
else_then()
|
||||
_run_and_link(body)
|
||||
end_if()
|
||||
|
||||
def and_(*terms):
|
||||
res = regint(0)
|
||||
|
||||
@@ -292,7 +292,8 @@ class Output(NoVariableLayer):
|
||||
self.l.write(sum(lse) * \
|
||||
self.divisor(N, 1))
|
||||
|
||||
def eval(self, size, base=0):
|
||||
def eval(self, size, base=0, top=False):
|
||||
assert not top
|
||||
if self.approx:
|
||||
return approx_sigmoid(self.X.get_vector(base, size), self.approx)
|
||||
else:
|
||||
@@ -474,8 +475,14 @@ class MultiOutput(MultiOutputBase):
|
||||
self.true_X[i] = true_X
|
||||
self.l.write(sum(tmp.get_vector(0, N)) / N)
|
||||
|
||||
def eval(self, N):
|
||||
def eval(self, N, top=False):
|
||||
d_out = self.X.sizes[1]
|
||||
if top:
|
||||
res = sint.Array(N)
|
||||
@for_range_opt_multithread(self.n_threads, N)
|
||||
def _(i):
|
||||
res[i] = argmax(self.X[i])
|
||||
return res
|
||||
res = sfix.Matrix(N, d_out)
|
||||
if self.approx:
|
||||
@for_range_opt_multithread(self.n_threads, N)
|
||||
@@ -1832,7 +1839,7 @@ class Optimizer:
|
||||
|
||||
@_no_mem_warnings
|
||||
def forward(self, N=None, batch=None, keep_intermediate=True,
|
||||
model_from=None, training=False):
|
||||
model_from=None, training=False, run_last=True):
|
||||
""" Compute graph.
|
||||
|
||||
:param N: batch size (used if batch not given)
|
||||
@@ -1851,7 +1858,9 @@ class Optimizer:
|
||||
break_point()
|
||||
if self.time_layers:
|
||||
start_timer(100 + i)
|
||||
layer.forward(batch=self.batch_for(layer, batch), training=training)
|
||||
if i != len(self.layers) - 1 or run_last:
|
||||
layer.forward(batch=self.batch_for(layer, batch),
|
||||
training=training)
|
||||
if self.time_layers:
|
||||
stop_timer(100 + i)
|
||||
break_point()
|
||||
@@ -1862,19 +1871,23 @@ class Optimizer:
|
||||
theta.delete()
|
||||
|
||||
@_no_mem_warnings
|
||||
def eval(self, data, batch_size=None):
|
||||
def eval(self, data, batch_size=None, top=False):
|
||||
""" Compute evaluation after training.
|
||||
|
||||
:param data: sample data (:py:class:`Compiler.types.Matrix` with one row per sample)
|
||||
:param top: return top prediction instead of probability distribution
|
||||
"""
|
||||
if isinstance(self.layers[-1].Y, Array):
|
||||
res = sfix.Array(len(data))
|
||||
if isinstance(self.layers[-1].Y, Array) or top:
|
||||
if top:
|
||||
res = sint.Array(len(data))
|
||||
else:
|
||||
res = sfix.Array(len(data))
|
||||
else:
|
||||
res = sfix.Matrix(len(data), self.layers[-1].d_out)
|
||||
def f(start, batch_size, batch):
|
||||
batch.assign_vector(regint.inc(batch_size, start))
|
||||
self.forward(batch=batch)
|
||||
part = self.layers[-1].eval(batch_size)
|
||||
self.forward(batch=batch, run_last=not top)
|
||||
part = self.layers[-1].eval(batch_size, top=top)
|
||||
res.assign_part_vector(part.get_vector(), start)
|
||||
self.run_in_batches(f, data, batch_size or len(self.layers[1].X))
|
||||
return res
|
||||
@@ -2487,24 +2500,29 @@ class keras:
|
||||
batch_size = min(batch_size, self.batch_size)
|
||||
return self.opt.eval(x, batch_size=batch_size)
|
||||
|
||||
def solve_linear(A, b, n_iterations, debug=False):
|
||||
def solve_linear(A, b, n_iterations, progress=False):
|
||||
""" Iterative linear solution approximation. """
|
||||
assert len(b) == A.sizes[0]
|
||||
x = sfix.Array(A.sizes[1])
|
||||
x.assign_vector(sfix.get_random(-1, 1, size=len(x)))
|
||||
At = A.transpose()
|
||||
AtA = sfix.Matrix(len(x), len(x))
|
||||
AtA[:] = A.direct_trans_mul(A)
|
||||
v = sfix.Array(A.sizes[1])
|
||||
v.assign_all(0)
|
||||
r = Array.create_from(A.transpose() * b - AtA * x)
|
||||
Av = sfix.Array(len(x))
|
||||
@for_range(n_iterations)
|
||||
def _(i):
|
||||
r = At * (b - A * x)
|
||||
tmp = A * r
|
||||
tmp = sfix.dot_product(tmp, tmp)
|
||||
alpha = (tmp == 0).if_else(0, sfix.dot_product(r, r) / tmp)
|
||||
x.assign(x + alpha * r)
|
||||
if debug:
|
||||
print_ln('%s r=%s tmp=%s r*r=%s tmp*tmp=%s alpha=%s x=%s alpha*r=%s', i,
|
||||
list(r.reveal()), list(tmp.reveal()),
|
||||
sfix.dot_product(r, r).reveal(), sfix.dot_product(tmp, tmp).reveal(),
|
||||
alpha.reveal(), x.reveal_list(), list((alpha * r).reveal()))
|
||||
v[:] = r - sfix.dot_product(r, Av) / sfix.dot_product(v, Av) * v
|
||||
Av[:] = AtA * v
|
||||
v_norm = sfix.dot_product(v, Av)
|
||||
vr = sfix.dot_product(v, r)
|
||||
alpha = (v_norm == 0).if_else(0, vr / v_norm)
|
||||
x[:] = x + alpha * v
|
||||
r[:] = r - alpha * Av
|
||||
if progress:
|
||||
print_ln('%s alpha=%s vr=%s v_norm=%s', i, alpha.reveal(),
|
||||
vr.reveal(), v_norm.reveal())
|
||||
return x
|
||||
|
||||
def mr(A, n_iterations):
|
||||
|
||||
@@ -266,7 +266,7 @@ def tan(x):
|
||||
|
||||
@types.vectorize
|
||||
@instructions_base.sfix_cisc
|
||||
def exp2_fx(a, zero_output=False):
|
||||
def exp2_fx(a, zero_output=False, as19=False):
|
||||
"""
|
||||
Power of two for fixed-point numbers.
|
||||
|
||||
@@ -287,7 +287,7 @@ def exp2_fx(a, zero_output=False):
|
||||
g = a._new(whole_exp.TruncMul(e.v, 2 * a.k, n_shift,
|
||||
nearest=a.round_nearest), a.k, a.f)
|
||||
return g
|
||||
if types.program.options.ring:
|
||||
if types.program.options.ring and not as19:
|
||||
sint = types.sint
|
||||
intbitint = types.intbitint
|
||||
# how many bits to use from integer part
|
||||
|
||||
@@ -947,7 +947,7 @@ class RefBucket(object):
|
||||
child.output()
|
||||
|
||||
def random_block(length, value_type):
|
||||
return sum(value_type.bit_type.get_random_bit() << i for i in range(length))
|
||||
return bit_compose(value_type.bit_type.get_random_bit() for i in range(length))
|
||||
|
||||
class List(EndRecursiveEviction):
|
||||
""" Debugging only. List which accepts secret values as indices
|
||||
|
||||
@@ -150,6 +150,8 @@ class Program(object):
|
||||
self._linear_rounds = False
|
||||
self.warn_about_mem = [True]
|
||||
self.relevant_opts = set()
|
||||
self.n_running_threads = None
|
||||
self.input_files = {}
|
||||
Program.prog = self
|
||||
from . import instructions_base, instructions, types, comparison
|
||||
instructions.program = self
|
||||
@@ -370,8 +372,6 @@ class Program(object):
|
||||
""" Allocate memory from the top """
|
||||
if not isinstance(size, int):
|
||||
raise CompilerError('size must be known at compile time')
|
||||
if not (creator_tape or self.curr_tape).singular:
|
||||
raise CompilerError('cannot allocate memory outside main thread')
|
||||
if size == 0:
|
||||
return
|
||||
if isinstance(mem_type, type):
|
||||
@@ -383,6 +383,14 @@ class Program(object):
|
||||
mem_type = mem_type.reg_type
|
||||
elif reg_type is not None:
|
||||
self.types[mem_type] = reg_type
|
||||
single_size = None
|
||||
if not (creator_tape or self.curr_tape).singular:
|
||||
if self.n_running_threads:
|
||||
single_size = size
|
||||
size *= self.n_running_threads
|
||||
else:
|
||||
raise CompilerError('cannot allocate memory '
|
||||
'outside main thread')
|
||||
blocks = self.free_mem_blocks[mem_type]
|
||||
addr = blocks.pop(size)
|
||||
if addr is not None:
|
||||
@@ -393,7 +401,13 @@ class Program(object):
|
||||
if len(str(addr)) != len(str(addr + size)) and self.verbose:
|
||||
print("Memory of type '%s' now of size %d" % (mem_type, addr + size))
|
||||
self.allocated_mem_blocks[addr,mem_type] = size
|
||||
return addr
|
||||
if single_size:
|
||||
from .library import get_thread_number, runtime_error_if
|
||||
tn = get_thread_number()
|
||||
runtime_error_if(tn > self.n_running_threads, 'malloc')
|
||||
return addr + single_size * (tn - 1)
|
||||
else:
|
||||
return addr
|
||||
|
||||
def free(self, addr, mem_type):
|
||||
""" Free memory """
|
||||
@@ -424,7 +438,8 @@ class Program(object):
|
||||
from . import library
|
||||
self.curr_tape.start_new_basicblock(None, 'memory-usage')
|
||||
# reset register counter to 0
|
||||
self.curr_tape.init_registers()
|
||||
if not self.options.noreallocate:
|
||||
self.curr_tape.init_registers()
|
||||
for mem_type,size in sorted(self.allocated_mem.items()):
|
||||
if size:
|
||||
#print "Memory of type '%s' of size %d" % (mem_type, size)
|
||||
@@ -586,6 +601,7 @@ class Tape:
|
||||
self.functions = []
|
||||
self.singular = True
|
||||
self.free_threads = set()
|
||||
self.loop_breaks = []
|
||||
|
||||
class BasicBlock(object):
|
||||
def __init__(self, parent, name, scope, exit_condition=None):
|
||||
@@ -827,8 +843,7 @@ class Tape:
|
||||
block = left.popleft()
|
||||
alloc(block)
|
||||
for child in block.children:
|
||||
if child.instructions:
|
||||
left.append(child)
|
||||
left.append(child)
|
||||
for i,block in enumerate(reversed(self.basicblocks)):
|
||||
if len(block.instructions) > 1000000:
|
||||
print('Allocating %s, %d/%d' % \
|
||||
@@ -878,6 +893,10 @@ class Tape:
|
||||
self.basicblocks[-1].instructions.append(
|
||||
Compiler.instructions.use_edabit(True, req[1], num, \
|
||||
add_to_prog=False))
|
||||
elif req[0] == 'matmul':
|
||||
self.basicblocks[-1].instructions.append(
|
||||
Compiler.instructions.use_matmul(*req[1], num, \
|
||||
add_to_prog=False))
|
||||
|
||||
if not self.is_empty():
|
||||
# bit length requirement
|
||||
@@ -1012,6 +1031,9 @@ class Tape:
|
||||
else:
|
||||
eda = 'loose edabits'
|
||||
res += ['%s %s of length %d' % (n, eda, req[1])]
|
||||
elif domain == 'matmul':
|
||||
res += ['%s matrix multiplications (%dx%d * %dx%d)' %
|
||||
(n, req[1][0], req[1][1], req[1][1], req[1][2])]
|
||||
elif req[0] != 'all':
|
||||
res += ['%s %s %ss' % (n, domain, req[1])]
|
||||
if self['all','round']:
|
||||
@@ -1077,7 +1099,10 @@ class Tape:
|
||||
def require_bit_length(self, bit_length, t='p'):
|
||||
if t == 'p':
|
||||
if self.program.prime:
|
||||
assert bit_length < self.program.prime.bit_length() - 1
|
||||
if (bit_length >= self.program.prime.bit_length() - 1):
|
||||
raise CompilerError(
|
||||
'required bit length %d too much for %d' % \
|
||||
(bit_length, self.program.prime))
|
||||
self.req_bit_length[t] = max(bit_length + 1, \
|
||||
self.req_bit_length[t])
|
||||
else:
|
||||
@@ -1150,7 +1175,9 @@ class Tape:
|
||||
def _new_by_number(self, i, size=1):
|
||||
return Tape.Register(self.reg_type, self.program, size=size, i=i)
|
||||
|
||||
def get_vector(self, base, size):
|
||||
def get_vector(self, base=0, size=None):
|
||||
if size == None:
|
||||
size = self.size
|
||||
if base == 0 and size == self.size:
|
||||
return self
|
||||
if size == 1:
|
||||
@@ -1206,7 +1233,8 @@ class Tape:
|
||||
self.reg_type == RegType.ClearInt
|
||||
|
||||
def __bool__(self):
|
||||
raise CompilerError('cannot derive truth value from register')
|
||||
raise CompilerError('Cannot derive truth value from register, '
|
||||
"consider using 'compile.py -l'")
|
||||
|
||||
def __str__(self):
|
||||
return self.reg_type + str(self.i)
|
||||
|
||||
@@ -508,7 +508,7 @@ class _structure(object):
|
||||
a = sfix.Tensor([10, 10])
|
||||
"""
|
||||
if len(shape) == 1:
|
||||
return Array(size[0], cls)
|
||||
return Array(shape[0], cls)
|
||||
else:
|
||||
return MultiArray(shape, cls)
|
||||
|
||||
@@ -522,6 +522,75 @@ class _structure(object):
|
||||
def mem_size():
|
||||
return 1
|
||||
|
||||
class _secret_structure(_structure):
|
||||
@classmethod
|
||||
def input_tensor_from(cls, player, shape):
|
||||
""" Input tensor secretly from player.
|
||||
|
||||
:param player: int/regint/cint
|
||||
:param shape: tensor shape
|
||||
|
||||
"""
|
||||
res = cls.Tensor(shape)
|
||||
res.input_from(player)
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def input_tensor_from_client(cls, client_id, shape):
|
||||
""" Input tensor secretly from client.
|
||||
|
||||
:param client_id: client identifier (public)
|
||||
:param shape: tensor shape
|
||||
|
||||
"""
|
||||
res = cls.Tensor(shape)
|
||||
res.assign_vector(cls.receive_from_client(1, client_id,
|
||||
size=res.total_size())[0])
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def input_tensor_via(cls, player, content):
|
||||
"""
|
||||
Input tensor-like data via a player. This overwrites the input
|
||||
file for the relevant player. The following returns an
|
||||
:py:class:`sint` matrix of dimension 2 by 2::
|
||||
|
||||
M = [[1, 2], [3, 4]]
|
||||
sint.input_tensor_via(0, M)
|
||||
|
||||
Make sure to copy ``Player-Data/Input-P<player>-0`` if running
|
||||
on another host.
|
||||
|
||||
"""
|
||||
if program.curr_tape != program.tapes[0]:
|
||||
raise CompilerError('only available in main thread')
|
||||
shape = []
|
||||
tmp = content
|
||||
while True:
|
||||
try:
|
||||
shape.append(len(tmp))
|
||||
tmp = tmp[0]
|
||||
except:
|
||||
break
|
||||
if not program.input_files.get(player, None):
|
||||
program.input_files[player] = open(
|
||||
'Player-Data/Input-P%d-0' % player, 'w')
|
||||
f = program.input_files[player]
|
||||
def traverse(content, level):
|
||||
assert len(content) == shape[level]
|
||||
if level == len(shape) - 1:
|
||||
for x in content:
|
||||
f.write(' ')
|
||||
f.write(str(x))
|
||||
else:
|
||||
for x in content:
|
||||
traverse(x, level + 1)
|
||||
traverse(content, 0)
|
||||
f.write('\n')
|
||||
res = cls.Tensor(shape)
|
||||
res.input_from(player)
|
||||
return res
|
||||
|
||||
class _vec(object):
|
||||
def link(self, other):
|
||||
assert len(self.v) == len(other.v)
|
||||
@@ -835,23 +904,26 @@ class cint(_clear, _int):
|
||||
|
||||
:param client_id: Client id (regint)
|
||||
:param n: number of values (default 1)
|
||||
:param size: vector size (default 1)
|
||||
:returns: cint (if n=1) or list of cint
|
||||
"""
|
||||
res = [cls() for i in range(n)]
|
||||
readsocketc(client_id, *res)
|
||||
readsocketc(client_id, get_global_vector_size(), *res)
|
||||
if n == 1:
|
||||
return res[0]
|
||||
else:
|
||||
return res
|
||||
|
||||
@vectorized_classmethod
|
||||
@classmethod
|
||||
def write_to_socket(self, client_id, values, message_type=ClientMessageType.NoType):
|
||||
""" Send a list of clear values to a client.
|
||||
|
||||
:param client_id: Client id (regint)
|
||||
:param values: list of cint
|
||||
"""
|
||||
writesocketc(client_id, message_type, *values)
|
||||
for value in values:
|
||||
assert(value.size == values[0].size)
|
||||
writesocketc(client_id, message_type, values[0].size, *values)
|
||||
|
||||
@vectorized_classmethod
|
||||
def load_mem(cls, address, mem_type=None):
|
||||
@@ -1089,7 +1161,7 @@ class cint(_clear, _int):
|
||||
cond_print_str(self, string)
|
||||
|
||||
def output_if(self, cond):
|
||||
cond_print_plain(cond, self, cint(0))
|
||||
cond_print_plain(self.conv(cond), self, cint(0))
|
||||
|
||||
|
||||
class cgf2n(_clear, _gf2n):
|
||||
@@ -1298,23 +1370,26 @@ class regint(_register, _int):
|
||||
|
||||
:param client_id: Client id (regint)
|
||||
:param n: number of values (default 1)
|
||||
:param size: vector size (default 1)
|
||||
:returns: regint (if n=1) or list of regint
|
||||
"""
|
||||
res = [cls() for i in range(n)]
|
||||
readsocketint(client_id, *res)
|
||||
readsocketint(client_id, get_global_vector_size(), *res)
|
||||
if n == 1:
|
||||
return res[0]
|
||||
else:
|
||||
return res
|
||||
|
||||
@vectorized_classmethod
|
||||
@classmethod
|
||||
def write_to_socket(self, client_id, values, message_type=ClientMessageType.NoType):
|
||||
""" Send a list of clear integers to a client.
|
||||
|
||||
:param client_id: Client id (regint)
|
||||
:param values: list of regint
|
||||
"""
|
||||
writesocketint(client_id, message_type, *values)
|
||||
for value in values:
|
||||
assert(value.size == values[0].size)
|
||||
writesocketint(client_id, message_type, values[0].size, *values)
|
||||
|
||||
@vectorize_init
|
||||
def __init__(self, val=None, size=None):
|
||||
@@ -1388,6 +1463,8 @@ class regint(_register, _int):
|
||||
""" Clear integer division (rounding to floor).
|
||||
|
||||
:param other: regint/cint/int """
|
||||
if util.is_constant(other) and other >= 2 ** 64:
|
||||
return 0
|
||||
return self.int_op(other, divint)
|
||||
|
||||
def __rfloordiv__(self, other):
|
||||
@@ -1546,10 +1623,17 @@ class regint(_register, _int):
|
||||
""" Output string if value is non-zero.
|
||||
|
||||
:param string: Python string """
|
||||
cint(self).print_if(string)
|
||||
self._condition().print_if(string)
|
||||
|
||||
def output_if(self, cond):
|
||||
cint(self).output_if(cond)
|
||||
self._condition().output_if(cond)
|
||||
|
||||
def _condition(self):
|
||||
if program.options.binary:
|
||||
from GC.types import cbits
|
||||
return cbits.get_type(64)(self)
|
||||
else:
|
||||
return cint(self)
|
||||
|
||||
def binary_output(self, player=None):
|
||||
""" Write 64-bit signed integer to
|
||||
@@ -1597,6 +1681,9 @@ class personal(object):
|
||||
def binary_output(self):
|
||||
self._v.binary_output(self.player)
|
||||
|
||||
def bit_decompose(self, length):
|
||||
return [personal(self.player, x) for x in self._v.bit_decompose(length)]
|
||||
|
||||
def _san(self, other):
|
||||
if isinstance(other, personal):
|
||||
assert self.player == other.player
|
||||
@@ -1689,7 +1776,7 @@ class longint:
|
||||
res += x.bit_decompose(64)
|
||||
return res[:bit_length]
|
||||
|
||||
class _secret(_register):
|
||||
class _secret(_register, _secret_structure):
|
||||
__slots__ = []
|
||||
|
||||
mov = staticmethod(set_instruction_type(movs))
|
||||
@@ -2022,7 +2109,7 @@ class sint(_secret, _int):
|
||||
the bit length.
|
||||
|
||||
:param val: initialization (sint/cint/regint/int/cgf2n or list
|
||||
thereof or sbits/sbitvec)
|
||||
thereof or sbits/sbitvec/sfix)
|
||||
:param size: vector size (int), defaults to 1 or size of list
|
||||
|
||||
"""
|
||||
@@ -2130,33 +2217,62 @@ class sint(_secret, _int):
|
||||
rawinput(player, res)
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@vectorized_classmethod
|
||||
def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType):
|
||||
""" Securely obtain shares of values input by a client.
|
||||
|
||||
:param n: number of inputs (int)
|
||||
:param client_id: regint
|
||||
:param size: vector size (default 1)
|
||||
|
||||
"""
|
||||
# send shares of a triple to client
|
||||
triples = list(itertools.chain(*(sint.get_random_triple() for i in range(n))))
|
||||
sint.write_shares_to_socket(client_id, triples, message_type)
|
||||
|
||||
received = cint.read_from_socket(client_id, n)
|
||||
received = util.tuplify(cint.read_from_socket(client_id, n))
|
||||
y = [0] * n
|
||||
for i in range(n):
|
||||
y[i] = received[i] - triples[i * 3]
|
||||
return y
|
||||
|
||||
@classmethod
|
||||
def reveal_to_clients(cls, clients, values):
|
||||
""" Reveal securely to clients.
|
||||
|
||||
:param clients: client ids (list or array)
|
||||
:param values: list of sint to reveal
|
||||
|
||||
"""
|
||||
set_global_vector_size(values[0].size)
|
||||
to_send = []
|
||||
|
||||
for value in values:
|
||||
assert(value.size == values[0].size)
|
||||
r = sint.get_random()
|
||||
to_send += [value, r, value * r]
|
||||
|
||||
if isinstance(clients, Array):
|
||||
n_clients = clients.length
|
||||
else:
|
||||
n_clients = len(clients)
|
||||
|
||||
@library.for_range(n_clients)
|
||||
def loop_body(i):
|
||||
sint.write_shares_to_socket(clients[i], to_send)
|
||||
reset_global_vector_size()
|
||||
|
||||
@vectorized_classmethod
|
||||
def read_from_socket(cls, client_id, n=1):
|
||||
""" Receive secret-shared value(s) from client.
|
||||
|
||||
:param client_id: Client id (regint)
|
||||
:param n: number of values (default 1)
|
||||
:param size: vector size of values (default 1)
|
||||
:returns: sint (if n=1) or list of sint
|
||||
"""
|
||||
res = [cls() for i in range(n)]
|
||||
readsockets(client_id, *res)
|
||||
readsockets(client_id, get_global_vector_size(), *res)
|
||||
if n == 1:
|
||||
return res[0]
|
||||
else:
|
||||
@@ -2165,9 +2281,9 @@ class sint(_secret, _int):
|
||||
@vectorize
|
||||
def write_share_to_socket(self, client_id, message_type=ClientMessageType.NoType):
|
||||
""" Send only share to socket """
|
||||
writesocketshare(client_id, message_type, self)
|
||||
writesocketshare(client_id, message_type, self.size, self)
|
||||
|
||||
@vectorized_classmethod
|
||||
@classmethod
|
||||
def write_shares_to_socket(cls, client_id, values,
|
||||
message_type=ClientMessageType.NoType):
|
||||
""" Send shares of a list of values to a specified client socket.
|
||||
@@ -2175,7 +2291,7 @@ class sint(_secret, _int):
|
||||
:param client_id: regint
|
||||
:param values: list of sint
|
||||
"""
|
||||
writesocketshare(client_id, message_type, *values)
|
||||
writesocketshare(client_id, message_type, values[0].size, *values)
|
||||
|
||||
@classmethod
|
||||
def read_from_file(cls, start, n_items):
|
||||
@@ -2227,6 +2343,9 @@ class sint(_secret, _int):
|
||||
size = val._v.size
|
||||
super(sint, self).__init__('s', size=size)
|
||||
inputpersonal(size, val.player, self, self.clear_type.conv(val._v))
|
||||
elif isinstance(val, _fix):
|
||||
super(sint, self).__init__('s', size=val.v.size)
|
||||
self.load_other(val.v.round(val.k, val.f))
|
||||
else:
|
||||
super(sint, self).__init__('s', val=val, size=size)
|
||||
|
||||
@@ -2498,8 +2617,18 @@ class sint(_secret, _int):
|
||||
|
||||
def private_division(self, divisor, active=True, dividend_length=None,
|
||||
divisor_length=None):
|
||||
assert active == False
|
||||
""" Private integer division as per `Veugen and Abspoel
|
||||
<https://doi.org/10.2478/popets-2021-0073>`_
|
||||
|
||||
:param divisor: public (cint/regint) or personal value thereof
|
||||
:param active: whether to check on the party knowing the
|
||||
divisor (active security)
|
||||
:param dividend_length: bit length of the dividend (default:
|
||||
global bit length)
|
||||
:param dividend_length: bit length of the divisor (default:
|
||||
global bit length)
|
||||
|
||||
"""
|
||||
d = divisor
|
||||
l = divisor_length or program.bit_length
|
||||
m = dividend_length or program.bit_length
|
||||
@@ -2515,11 +2644,28 @@ class sint(_secret, _int):
|
||||
r_prime = sint.get_random_int(m + sigma)
|
||||
r_pprime = sint.get_random_int(l + sigma)
|
||||
|
||||
h = (r + (r_prime << (l + sigma))) * sint(d)
|
||||
z = ((self << (l + sigma)) + h + r_pprime).reveal_to(0)
|
||||
d_shared = sint(d)
|
||||
h = (r + (r_prime << (l + sigma))) * d_shared
|
||||
z_shared = ((self << (l + sigma)) + h + r_pprime)
|
||||
z = z_shared.reveal_to(0)
|
||||
|
||||
y = sint(z // (d << (l + sigma)))
|
||||
y_prime = sint((z // d) % (2 ** (l + sigma)))
|
||||
if active:
|
||||
z_prime = [sint(x) for x in (z // d).bit_decompose(min_length)]
|
||||
check = [(x * (1 - x)).reveal() == 0 for x in z_prime]
|
||||
z_pp = [sint(x) for x in (z % d).bit_decompose(l)]
|
||||
check += [(x * (1 - x)).reveal() == 0 for x in z_pp]
|
||||
library.runtime_error_if(sum(check) != len(check),
|
||||
'private division')
|
||||
z_pp = sint.bit_compose(z_pp)
|
||||
beta1 = z_pp.less_than(d_shared, l)
|
||||
beta2 = z_shared - sint.bit_compose(z_prime) * d_shared - z_pp
|
||||
library.runtime_error_if(beta1.reveal() != 1, 'private div')
|
||||
library.runtime_error_if(beta2.reveal() != 0, 'private div')
|
||||
y_prime = sint.bit_compose(z_prime[:l + sigma])
|
||||
y = sint.bit_compose(z_prime[l + sigma:])
|
||||
else:
|
||||
y = sint(z // (d << (l + sigma)))
|
||||
y_prime = sint((z // d) % (2 ** (l + sigma)))
|
||||
|
||||
b = r.greater_than(y_prime, l + sigma)
|
||||
w = y - b - r_prime
|
||||
@@ -3320,15 +3466,16 @@ class cfix(_number, _structure):
|
||||
|
||||
:param client_id: Client id (regint)
|
||||
:param n: number of values (default 1)
|
||||
:param: vector size (int)
|
||||
:returns: cfix (if n=1) or list of cfix
|
||||
"""
|
||||
cint_input = cint.read_from_socket(client_id, n)
|
||||
cint_inputs = cint.read_from_socket(client_id, n)
|
||||
if n == 1:
|
||||
return cfix._new(cint_inputs)
|
||||
else:
|
||||
return list(map(cfix, cint_inputs))
|
||||
|
||||
@vectorized_classmethod
|
||||
return list(map(cfix._new, cint_inputs))
|
||||
|
||||
@classmethod
|
||||
def write_to_socket(self, client_id, values, message_type=ClientMessageType.NoType):
|
||||
""" Send a list of clear fixed-point values to a client
|
||||
(represented as clear integers).
|
||||
@@ -3336,10 +3483,12 @@ class cfix(_number, _structure):
|
||||
:param client_id: Client id (regint)
|
||||
:param values: list of cint
|
||||
"""
|
||||
for value in values:
|
||||
assert(value.size == values[0].size)
|
||||
def cfix_to_cint(fix_val):
|
||||
return cint(fix_val.v)
|
||||
cint_values = list(map(cfix_to_cint, values))
|
||||
writesocketc(client_id, message_type, *cint_values)
|
||||
writesocketc(client_id, message_type, values[0].size, *cint_values)
|
||||
|
||||
@staticmethod
|
||||
def malloc(size, creator_tape=None):
|
||||
@@ -3402,6 +3551,9 @@ class cfix(_number, _structure):
|
||||
def __len__(self):
|
||||
return len(self.v)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self._new(self.v[index], k=self.k, f=self.f)
|
||||
|
||||
@vectorize
|
||||
def load_int(self, v):
|
||||
self.v = cint(v) * (2 ** self.f)
|
||||
@@ -3601,7 +3753,7 @@ class cfix(_number, _structure):
|
||||
cint(0), cint(0), cint(0))
|
||||
|
||||
def output_if(self, cond):
|
||||
cond_print_plain(cond, self.v, cint(-self.f))
|
||||
cond_print_plain(cint.conv(cond), self.v, cint(-self.f))
|
||||
|
||||
@vectorize
|
||||
def binary_output(self, player=None):
|
||||
@@ -3614,7 +3766,7 @@ class cfix(_number, _structure):
|
||||
player = -1
|
||||
floatoutput(player, self.v, cint(-self.f), cint(0), cint(0))
|
||||
|
||||
class _single(_number, _structure):
|
||||
class _single(_number, _secret_structure):
|
||||
""" Representation as single integer preserving the order """
|
||||
""" E.g. fixed-point numbers """
|
||||
__slots__ = ['v']
|
||||
@@ -3623,7 +3775,7 @@ class _single(_number, _structure):
|
||||
""" Whether to round deterministically to nearest instead of
|
||||
probabilistically, e.g. after fixed-point multiplication. """
|
||||
|
||||
@classmethod
|
||||
@vectorized_classmethod
|
||||
def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType):
|
||||
"""
|
||||
Securely obtain shares of values input by a client. Assumes client
|
||||
@@ -3631,12 +3783,22 @@ class _single(_number, _structure):
|
||||
|
||||
:param n: number of inputs (int)
|
||||
:param client_id: regint
|
||||
|
||||
:param size: vector size (default 1)
|
||||
"""
|
||||
sint_inputs = cls.int_type.receive_from_client(n, client_id,
|
||||
message_type)
|
||||
return list(map(cls._new, sint_inputs))
|
||||
|
||||
@classmethod
|
||||
def reveal_to_clients(cls, clients, values):
|
||||
""" Reveal securely to clients.
|
||||
|
||||
:param clients: client ids (list or array)
|
||||
:param values: list of values of this class
|
||||
|
||||
"""
|
||||
cls.int_type.reveal_to_clients(clients, [x.v for x in values])
|
||||
|
||||
@vectorized_classmethod
|
||||
def write_shares_to_socket(cls, client_id, values,
|
||||
message_type=ClientMessageType.NoType):
|
||||
@@ -3926,6 +4088,9 @@ class _fix(_single):
|
||||
self.v = (1-2*_v.s)*a
|
||||
elif isinstance(_v, type(self)):
|
||||
self.v = _v.v
|
||||
elif isinstance(_v, cfix):
|
||||
assert _v.f <= self.f
|
||||
self.v = self.int_type(_v.v << (self.f - _v.f))
|
||||
elif isinstance(_v, (MemValue, MemFix)):
|
||||
#this is a memvalue object
|
||||
self.v = type(self)(_v.read()).v
|
||||
@@ -4049,9 +4214,9 @@ class _fix(_single):
|
||||
return revealed_fix._new(val)
|
||||
|
||||
class sfix(_fix):
|
||||
""" Secret fixed-point number represented as secret integer.
|
||||
This uses integer operations internally, see :py:class:`sint` for security
|
||||
considerations.
|
||||
""" Secret fixed-point number represented as secret integer, by
|
||||
multiplying with ``2^f`` and then rounding. See :py:class:`sint`
|
||||
for security considerations of the underlying integer operations.
|
||||
|
||||
It supports basic arithmetic (``+, -, *, /``), returning
|
||||
:py:class:`sfix`, and comparisons (``==, !=, <, <=, >, >=``),
|
||||
@@ -4391,7 +4556,7 @@ class squant_params(object):
|
||||
shifted = under.if_else(0, shifted)
|
||||
return squant._new(shifted, params=self)
|
||||
|
||||
class sfloat(_number, _structure):
|
||||
class sfloat(_number, _secret_structure):
|
||||
"""
|
||||
Secret floating-point number.
|
||||
Represents :math:`(1 - 2s) \cdot (1 - z)\cdot v \cdot 2^p`.
|
||||
@@ -4705,6 +4870,7 @@ class sfloat(_number, _structure):
|
||||
return -self + other
|
||||
__rsub__.__doc__ = __sub__.__doc__
|
||||
|
||||
@vectorize
|
||||
def __truediv__(self, other):
|
||||
""" Secret floating-point division.
|
||||
|
||||
@@ -4857,21 +5023,36 @@ def _get_type(t):
|
||||
else:
|
||||
return t
|
||||
|
||||
class Array(object):
|
||||
class _vectorizable:
|
||||
def reveal_to_clients(self, clients):
|
||||
""" Reveal contents to list of clients.
|
||||
|
||||
:param clients: list or array of client identifiers
|
||||
|
||||
"""
|
||||
self.value_type.reveal_to_clients(clients, [self.get_vector()])
|
||||
|
||||
class Array(_vectorizable):
|
||||
"""
|
||||
Array accessible by public index. That is, ``a[i]`` works for an
|
||||
array ``a`` and ``i`` being a :py:class:`regint`,
|
||||
:py:class:`cint`, or a Python integer. ``a[start:stop:step]``
|
||||
works as well, and so does iteration over an array.
|
||||
|
||||
Arrays support a number of element-wise operations if the
|
||||
underlying basic type does so. These are ``+, -, *, **, /``. The
|
||||
return type of these is a vector, which can be assigned to an
|
||||
array of a compatible type using :py:func:`assign`.
|
||||
:py:class:`cint`, or a Python integer.
|
||||
|
||||
:param length: compile-time integer (int) or :py:obj:`None` for unknown length
|
||||
:param value_type: basic type
|
||||
:param address: if given (regint/int), the array will not be allocated
|
||||
|
||||
You can convert between arrays and register vectors by using slice
|
||||
indexing. This allows for element-wise operations as long as
|
||||
supported by the basic type. The following adds 10 secret integers
|
||||
from the first two parties::
|
||||
|
||||
a = sint.Array(10)
|
||||
a.input_from(0)
|
||||
b = sint.Array(10)
|
||||
b.input_from(1)
|
||||
a[:] += b[:]
|
||||
|
||||
"""
|
||||
@classmethod
|
||||
def create_from(cls, l):
|
||||
@@ -4900,6 +5081,7 @@ class Array(object):
|
||||
self.debug = debug
|
||||
self.creator_tape = program.curr_tape
|
||||
self.sink = None
|
||||
self.check_indices = True
|
||||
if alloc:
|
||||
self.alloc()
|
||||
|
||||
@@ -4914,11 +5096,17 @@ class Array(object):
|
||||
|
||||
def get_address(self, index):
|
||||
key = str(index)
|
||||
if isinstance(index, int) and self.length is not None:
|
||||
index += self.length * (index < 0)
|
||||
if index >= self.length or index < 0:
|
||||
raise IndexError('index %s, length %s' % \
|
||||
(str(index), str(self.length)))
|
||||
if self.length is not None:
|
||||
from .GC.types import cbits
|
||||
if isinstance(index, int):
|
||||
index += self.length * (index < 0)
|
||||
if index >= self.length or index < 0:
|
||||
raise IndexError('index %s, length %s' % \
|
||||
(str(index), str(self.length)))
|
||||
elif self.check_indices and not isinstance(index, cbits):
|
||||
library.runtime_error_if(regint.conv(index) >= self.length,
|
||||
'overflow: %s/%s',
|
||||
index, self.length)
|
||||
if (program.curr_block, key) not in self.address_cache:
|
||||
n = self.value_type.n_elements()
|
||||
length = self.length
|
||||
@@ -4937,21 +5125,24 @@ class Array(object):
|
||||
def get_slice(self, index):
|
||||
if index.stop is None and self.length is None:
|
||||
raise CompilerError('Cannot slice array of unknown length')
|
||||
return index.start or 0, index.stop or self.length, index.step or 1
|
||||
if index.step == 0:
|
||||
raise CompilerError('slice step cannot be zero')
|
||||
return index.start or 0, \
|
||||
min(index.stop or self.length, self.length), index.step or 1
|
||||
|
||||
def __getitem__(self, index):
|
||||
""" Reading from array.
|
||||
|
||||
:param index: public (regint/cint/int/slice)
|
||||
:return: array if slice is given, basic type otherwise"""
|
||||
:return: vector if slice is given, basic type otherwise"""
|
||||
if isinstance(index, slice):
|
||||
start, stop, step = self.get_slice(index)
|
||||
res_length = (stop - start - 1) // step + 1
|
||||
res = Array(res_length, self.value_type)
|
||||
@library.for_range(res_length)
|
||||
def f(i):
|
||||
res[i] = self[start+i*step]
|
||||
return res
|
||||
if step == 1:
|
||||
return self.get_vector(start, stop - start)
|
||||
else:
|
||||
res_length = (stop - start - 1) // step + 1
|
||||
addresses = regint.inc(res_length, start, step)
|
||||
return self.get_vector(addresses, res_length)
|
||||
return self._load(self.get_address(index))
|
||||
|
||||
def __setitem__(self, index, value):
|
||||
@@ -4961,15 +5152,21 @@ class Array(object):
|
||||
:param value: convertible for relevant basic type """
|
||||
if isinstance(index, slice):
|
||||
start, stop, step = self.get_slice(index)
|
||||
value = Array.create_from(value)
|
||||
source_index = MemValue(0)
|
||||
@library.for_range(start, stop, step)
|
||||
def f(i):
|
||||
self[i] = value[source_index]
|
||||
source_index.iadd(1)
|
||||
return
|
||||
if step == 1:
|
||||
return self.assign(value, start)
|
||||
else:
|
||||
res_length = (stop - start - 1) // step + 1
|
||||
addresses = regint.inc(res_length, start, step)
|
||||
return self.assign(value, addresses)
|
||||
self._store(value, self.get_address(index))
|
||||
|
||||
def get_sub(self, start, stop=None):
|
||||
if stop is None:
|
||||
stop = start
|
||||
start = 0
|
||||
return Array(stop - start, self.value_type,
|
||||
address=self.address + start)
|
||||
|
||||
def maybe_get(self, condition, index):
|
||||
""" Return entry if condition is true.
|
||||
|
||||
@@ -4989,7 +5186,7 @@ class Array(object):
|
||||
self.sink = self.value_type.Array(
|
||||
1, address=self.value_type.malloc(1, creator_tape=program.tapes[0]))
|
||||
addresses = (condition.if_else(x, y) for x, y in
|
||||
zip(util.tuplify(self.get_address(index)),
|
||||
zip(util.tuplify(self.get_address(condition * index)),
|
||||
util.tuplify(self.sink.get_address(0))))
|
||||
self._store(value, util.untuplify(tuple(addresses)))
|
||||
|
||||
@@ -5034,10 +5231,10 @@ class Array(object):
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
other.store_in_mem(self.get_address(base))
|
||||
self.value_type.conv(other).store_in_mem(self.get_address(base))
|
||||
if len(self) != None and util.is_constant(base):
|
||||
assert len(self) >= other.size + base
|
||||
except AttributeError:
|
||||
except (AttributeError, CompilerError):
|
||||
if isinstance(other, Array):
|
||||
@library.for_range_opt(len(other))
|
||||
def _(i):
|
||||
@@ -5071,7 +5268,7 @@ class Array(object):
|
||||
|
||||
:param base: starting point (regint/cint/int)
|
||||
:param size: length (compile-time int) """
|
||||
size = size or self.length
|
||||
size = size or self.length - base
|
||||
return self.value_type.load_mem(self.get_address(base), size=size)
|
||||
|
||||
get_part_vector = get_vector
|
||||
@@ -5249,7 +5446,7 @@ sint.dynamic_array = Array
|
||||
sgf2n.dynamic_array = Array
|
||||
|
||||
|
||||
class SubMultiArray(object):
|
||||
class SubMultiArray(_vectorizable):
|
||||
""" Multidimensional array functionality. Don't construct this
|
||||
directly, use :py:class:`MultiArray` instead. """
|
||||
def __init__(self, sizes, value_type, address, index, debug=None):
|
||||
@@ -5261,6 +5458,7 @@ class SubMultiArray(object):
|
||||
self.address = None
|
||||
self.sub_cache = {}
|
||||
self.debug = debug
|
||||
self.check_indices = True
|
||||
if debug:
|
||||
library.print_ln_if(self.address + reduce(operator.mul, self.sizes) * self.value_type.n_elements() > program.allocated_mem[self.value_type.reg_type], 'AOF%d:' % len(self.sizes) + self.debug)
|
||||
|
||||
@@ -5271,11 +5469,17 @@ class SubMultiArray(object):
|
||||
:return: :py:class:`Array` if one-dimensional, :py:class:`SubMultiArray` otherwise"""
|
||||
if util.is_constant(index) and index >= self.sizes[0]:
|
||||
raise StopIteration
|
||||
if isinstance(index, slice) and index == slice(None):
|
||||
return self.get_vector()
|
||||
key = program.curr_block, str(index)
|
||||
if key not in self.sub_cache:
|
||||
if self.debug:
|
||||
library.print_ln_if(index >= self.sizes[0], \
|
||||
'OF%d:' % len(self.sizes) + self.debug)
|
||||
if util.is_constant(index) and \
|
||||
(index >= self.sizes[0] or index < 0):
|
||||
raise CompilerError('index out of range')
|
||||
elif self.check_indices:
|
||||
library.runtime_error_if(index >= self.sizes[0],
|
||||
'overflow: %s/%s',
|
||||
index, self.sizes)
|
||||
if len(self.sizes) == 2:
|
||||
self.sub_cache[key] = \
|
||||
Array(self.sizes[1], self.value_type, \
|
||||
@@ -5293,6 +5497,8 @@ class SubMultiArray(object):
|
||||
|
||||
:param index: public (regint/cint/int)
|
||||
:param other: container of matching size and type """
|
||||
if isinstance(index, slice) and index == slice(None):
|
||||
return self.assign(other)
|
||||
self[index].assign(other)
|
||||
|
||||
def __len__(self):
|
||||
@@ -5526,13 +5732,18 @@ class SubMultiArray(object):
|
||||
self.assign_vector(self.get_vector() + other.get_vector())
|
||||
|
||||
def __mul__(self, other):
|
||||
# legacy function
|
||||
return self.mul(other)
|
||||
|
||||
def mul(self, other, res_params=None):
|
||||
# legacy function
|
||||
return self.dot(other, res_params)
|
||||
|
||||
def dot(self, other, res_params=None):
|
||||
""" Matrix-matrix and matrix-vector multiplication.
|
||||
|
||||
:param self: two-dimensional
|
||||
:param other: Matrix or Array of matching size and type """
|
||||
return self.mul(other)
|
||||
|
||||
def mul(self, other, res_params=None):
|
||||
assert len(self.sizes) == 2
|
||||
if isinstance(other, Array):
|
||||
assert len(other) == self.sizes[1]
|
||||
@@ -5762,11 +5973,16 @@ class SubMultiArray(object):
|
||||
assert len(self.sizes) == 2
|
||||
res = Matrix(self.sizes[1], self.sizes[0], self.value_type)
|
||||
library.break_point()
|
||||
@library.for_range_opt(self.sizes[1])
|
||||
def _(i):
|
||||
if self.value_type.n_elements() == 1:
|
||||
@library.for_range_opt(self.sizes[0])
|
||||
def _(j):
|
||||
res[i][j] = self[j][i]
|
||||
res.set_column(j, self[j][:])
|
||||
else:
|
||||
@library.for_range_opt(self.sizes[1])
|
||||
def _(i):
|
||||
@library.for_range_opt(self.sizes[0])
|
||||
def _(j):
|
||||
res[i][j] = self[j][i]
|
||||
library.break_point()
|
||||
return res
|
||||
|
||||
@@ -5809,14 +6025,22 @@ class MultiArray(SubMultiArray):
|
||||
"""
|
||||
Multidimensional array. The access operator (``a[i]``) allows to a
|
||||
multi-dimensional array of dimension one less or a simple array
|
||||
for a two-dimensional array. Element-wise addition and subtraction
|
||||
is supported, returning a vector, which can be assigned using
|
||||
:py:func:`assign`. Matrix-vector and matrix-matrix multiplication
|
||||
is supported as well.
|
||||
for a two-dimensional array.
|
||||
|
||||
:param sizes: shape (compile-time list of integers)
|
||||
:param value_type: basic type of entries
|
||||
|
||||
You can convert between arrays and register vectors by using slice
|
||||
indexing. This allows for element-wise operations as long as
|
||||
supported by the basic type. The following has the first two parties
|
||||
input a 10x10 secret integer matrix followed by storing the
|
||||
element-wise multiplications in the same data structure::
|
||||
|
||||
a = sint.Tensor([3, 10, 10])
|
||||
a[0].input_from(0)
|
||||
a[1].input_from(1)
|
||||
a[2][:] = a[0][:] * a[1][:]
|
||||
|
||||
"""
|
||||
def __init__(self, sizes, value_type, debug=None, address=None, alloc=True):
|
||||
if isinstance(address, Array):
|
||||
|
||||
@@ -35,7 +35,7 @@ int main(int argc, const char** argv)
|
||||
int n_tuples = 1000;
|
||||
if (not opt.lastArgs.empty())
|
||||
n_tuples = atoi(opt.lastArgs[0]->c_str());
|
||||
PlainPlayer P(N);
|
||||
PlainPlayer P(N, "ecdsa");
|
||||
P256Element::init();
|
||||
|
||||
P256Element::Scalar keyp;
|
||||
|
||||
@@ -43,7 +43,7 @@ void run(int argc, const char** argv)
|
||||
int n_tuples = 1000;
|
||||
if (not opt.lastArgs.empty())
|
||||
n_tuples = atoi(opt.lastArgs[0]->c_str());
|
||||
CryptoPlayer P(N);
|
||||
CryptoPlayer P(N, "ecdsa");
|
||||
P256Element::init();
|
||||
typedef T<P256Element::Scalar> pShare;
|
||||
OnlineOptions::singleton.batch_size = 1;
|
||||
|
||||
@@ -88,7 +88,7 @@ void run(int argc, const char** argv)
|
||||
int n_tuples = 1000;
|
||||
if (not opt.lastArgs.empty())
|
||||
n_tuples = atoi(opt.lastArgs[0]->c_str());
|
||||
PlainPlayer P(N);
|
||||
PlainPlayer P(N, "ecdsa");
|
||||
P256Element::init();
|
||||
P256Element::Scalar::next::init_field(P256Element::Scalar::pr(), false);
|
||||
|
||||
|
||||
31
ExternalIO/Client.h
Normal file
31
ExternalIO/Client.h
Normal file
@@ -0,0 +1,31 @@
|
||||
/*
|
||||
* Client.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef EXTERNALIO_CLIENT_H_
|
||||
#define EXTERNALIO_CLIENT_H_
|
||||
|
||||
#include "Networking/ssl_sockets.h"
|
||||
|
||||
class Client
|
||||
{
|
||||
vector<int> plain_sockets;
|
||||
ssl_ctx ctx;
|
||||
ssl_service io_service;
|
||||
|
||||
public:
|
||||
vector<ssl_socket*> sockets;
|
||||
octetStream specification;
|
||||
|
||||
Client(const vector<string>& hostnames, int port_base, int my_client_id);
|
||||
~Client();
|
||||
|
||||
template<class T>
|
||||
void send_private_inputs(const vector<T>& values);
|
||||
|
||||
template<class T>
|
||||
vector<T> receive_outputs(int n);
|
||||
};
|
||||
|
||||
#endif /* EXTERNALIO_CLIENT_H_ */
|
||||
126
ExternalIO/Client.hpp
Normal file
126
ExternalIO/Client.hpp
Normal file
@@ -0,0 +1,126 @@
|
||||
/*
|
||||
* Client.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "Client.h"
|
||||
|
||||
inline
|
||||
Client::Client(const vector<string>& hostnames, int port_base,
|
||||
int my_client_id) :
|
||||
ctx("C" + to_string(my_client_id))
|
||||
{
|
||||
bigint::init_thread();
|
||||
|
||||
// Setup connections from this client to each party socket
|
||||
int nparties = hostnames.size();
|
||||
plain_sockets.resize(nparties);
|
||||
sockets.resize(nparties);
|
||||
for (int i = 0; i < nparties; i++)
|
||||
{
|
||||
set_up_client_socket(plain_sockets[i], hostnames[i].c_str(), port_base + i);
|
||||
octetStream(to_string(my_client_id)).Send(plain_sockets[i]);
|
||||
sockets[i] = new ssl_socket(io_service, ctx, plain_sockets[i],
|
||||
"P" + to_string(i), "C" + to_string(my_client_id), true);
|
||||
if (i == 0)
|
||||
specification.Receive(sockets[0]);
|
||||
}
|
||||
}
|
||||
|
||||
inline
|
||||
Client::~Client()
|
||||
{
|
||||
for (auto& socket : sockets)
|
||||
{
|
||||
delete socket;
|
||||
}
|
||||
}
|
||||
|
||||
// Send the private inputs masked with a random value.
|
||||
// Receive shares of a preprocessed triple from each SPDZ engine, combine and check the triples are valid.
|
||||
// Add the private input value to triple[0] and send to each spdz engine.
|
||||
template<class T>
|
||||
void Client::send_private_inputs(const vector<T>& values)
|
||||
{
|
||||
int num_inputs = values.size();
|
||||
octetStream os;
|
||||
vector< vector<T> > triples(num_inputs, vector<T>(3));
|
||||
vector<T> triple_shares(3);
|
||||
|
||||
// Receive num_inputs triples from SPDZ
|
||||
for (size_t j = 0; j < sockets.size(); j++)
|
||||
{
|
||||
os.reset_write_head();
|
||||
os.Receive(sockets[j]);
|
||||
|
||||
#ifdef VERBOSE_COMM
|
||||
cerr << "received " << os.get_length() << " from " << j << endl;
|
||||
#endif
|
||||
|
||||
for (int j = 0; j < num_inputs; j++)
|
||||
{
|
||||
for (int k = 0; k < 3; k++)
|
||||
{
|
||||
triple_shares[k].unpack(os);
|
||||
triples[j][k] += triple_shares[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check triple relations (is a party cheating?)
|
||||
for (int i = 0; i < num_inputs; i++)
|
||||
{
|
||||
if (T(triples[i][0] * triples[i][1]) != triples[i][2])
|
||||
{
|
||||
cerr << triples[i][2] << " != " << triples[i][0] << " * " << triples[i][1] << endl;
|
||||
cerr << "Incorrect triple at " << i << ", aborting\n";
|
||||
throw mac_fail();
|
||||
}
|
||||
}
|
||||
// Send inputs + triple[0], so SPDZ can compute shares of each value
|
||||
os.reset_write_head();
|
||||
for (int i = 0; i < num_inputs; i++)
|
||||
{
|
||||
T y = values[i] + triples[i][0];
|
||||
y.pack(os);
|
||||
}
|
||||
|
||||
for (auto& socket : sockets)
|
||||
os.Send(socket);
|
||||
}
|
||||
|
||||
// Receive shares of the result and sum together.
|
||||
// Also receive authenticating values.
|
||||
template<class T>
|
||||
vector<T> Client::receive_outputs(int n)
|
||||
{
|
||||
vector<T> triples(3 * n);
|
||||
octetStream os;
|
||||
for (auto& socket : sockets)
|
||||
{
|
||||
os.reset_write_head();
|
||||
os.Receive(socket);
|
||||
#ifdef VERBOSE_COMM
|
||||
cout << "received " << os.get_length() << endl;
|
||||
#endif
|
||||
for (int j = 0; j < 3 * n; j++)
|
||||
{
|
||||
T value;
|
||||
value.unpack(os);
|
||||
triples[j] += value;
|
||||
}
|
||||
}
|
||||
|
||||
vector<T> output_values;
|
||||
for (int i = 0; i < 3 * n; i += 3)
|
||||
{
|
||||
if (T(triples[i] * triples[i + 1]) != triples[i + 2])
|
||||
{
|
||||
cerr << "Unable to authenticate output value as correct, aborting." << endl;
|
||||
throw mac_fail();
|
||||
}
|
||||
output_values.push_back(triples[i]);
|
||||
}
|
||||
|
||||
return output_values;
|
||||
}
|
||||
@@ -39,111 +39,33 @@
|
||||
#include "Protocols/fake-stuff.h"
|
||||
|
||||
#include "Math/gfp.hpp"
|
||||
#include "Client.hpp"
|
||||
|
||||
#include <sodium.h>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
|
||||
// Send the private inputs masked with a random value.
|
||||
// Receive shares of a preprocessed triple from each SPDZ engine, combine and check the triples are valid.
|
||||
// Add the private input value to triple[0] and send to each spdz engine.
|
||||
template<class T>
|
||||
void send_private_inputs(const vector<T>& values, vector<ssl_socket*>& sockets, int nparties)
|
||||
{
|
||||
int num_inputs = values.size();
|
||||
octetStream os;
|
||||
vector< vector<T> > triples(num_inputs, vector<T>(3));
|
||||
vector<T> triple_shares(3);
|
||||
|
||||
// Receive num_inputs triples from SPDZ
|
||||
for (int j = 0; j < nparties; j++)
|
||||
{
|
||||
os.reset_write_head();
|
||||
os.Receive(sockets[j]);
|
||||
|
||||
#ifdef VERBOSE_COMM
|
||||
cerr << "received " << os.get_length() << " from " << j << endl;
|
||||
#endif
|
||||
|
||||
for (int j = 0; j < num_inputs; j++)
|
||||
{
|
||||
for (int k = 0; k < 3; k++)
|
||||
{
|
||||
triple_shares[k].unpack(os);
|
||||
triples[j][k] += triple_shares[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check triple relations (is a party cheating?)
|
||||
for (int i = 0; i < num_inputs; i++)
|
||||
{
|
||||
if (T(triples[i][0] * triples[i][1]) != triples[i][2])
|
||||
{
|
||||
cerr << triples[i][2] << " != " << triples[i][0] << " * " << triples[i][1] << endl;
|
||||
cerr << "Incorrect triple at " << i << ", aborting\n";
|
||||
throw mac_fail();
|
||||
}
|
||||
}
|
||||
// Send inputs + triple[0], so SPDZ can compute shares of each value
|
||||
os.reset_write_head();
|
||||
for (int i = 0; i < num_inputs; i++)
|
||||
{
|
||||
T y = values[i] + triples[i][0];
|
||||
y.pack(os);
|
||||
}
|
||||
for (int j = 0; j < nparties; j++)
|
||||
os.Send(sockets[j]);
|
||||
}
|
||||
|
||||
// Receive shares of the result and sum together.
|
||||
// Also receive authenticating values.
|
||||
template<class T>
|
||||
T receive_result(vector<ssl_socket*>& sockets, int nparties)
|
||||
{
|
||||
vector<T> output_values(3);
|
||||
octetStream os;
|
||||
for (int i = 0; i < nparties; i++)
|
||||
{
|
||||
os.reset_write_head();
|
||||
os.Receive(sockets[i]);
|
||||
for (unsigned int j = 0; j < 3; j++)
|
||||
{
|
||||
T value;
|
||||
value.unpack(os);
|
||||
output_values[j] += value;
|
||||
}
|
||||
}
|
||||
|
||||
if (T(output_values[0] * output_values[1]) != output_values[2])
|
||||
{
|
||||
cerr << "Unable to authenticate output value as correct, aborting." << endl;
|
||||
throw mac_fail();
|
||||
}
|
||||
return output_values[0];
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void one_run(T salary_value, vector<ssl_socket*>& sockets, int nparties)
|
||||
void one_run(T salary_value, Client& client)
|
||||
{
|
||||
// Run the computation
|
||||
send_private_inputs<T>({salary_value}, sockets, nparties);
|
||||
client.send_private_inputs<T>({salary_value});
|
||||
cout << "Sent private inputs to each SPDZ engine, waiting for result..." << endl;
|
||||
|
||||
// Get the result back (client_id of winning client)
|
||||
T result = receive_result<T>(sockets, nparties);
|
||||
T result = client.receive_outputs<T>(1)[0];
|
||||
|
||||
cout << "Winning client id is : " << result << endl;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void run(double salary_value, vector<ssl_socket*>& sockets, int nparties)
|
||||
void run(double salary_value, Client& client)
|
||||
{
|
||||
// sint
|
||||
one_run<T>(long(round(salary_value)), sockets, nparties);
|
||||
one_run<T>(long(round(salary_value)), client);
|
||||
// sfix with f = 16
|
||||
one_run<T>(long(round(salary_value * exp2(16))), sockets, nparties);
|
||||
one_run<T>(long(round(salary_value * exp2(16))), client);
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
@@ -165,7 +87,7 @@ int main(int argc, char** argv)
|
||||
nparties = atoi(argv[2]);
|
||||
salary_value = atof(argv[3]);
|
||||
finish = atoi(argv[4]);
|
||||
vector<const char*> hostnames(nparties, "localhost");
|
||||
vector<string> hostnames(nparties, "localhost");
|
||||
|
||||
if (argc > 5)
|
||||
{
|
||||
@@ -185,19 +107,11 @@ int main(int argc, char** argv)
|
||||
bigint::init_thread();
|
||||
|
||||
// Setup connections from this client to each party socket
|
||||
vector<int> plain_sockets(nparties);
|
||||
vector<ssl_socket*> sockets(nparties);
|
||||
ssl_ctx ctx("C" + to_string(my_client_id));
|
||||
ssl_service io_service;
|
||||
octetStream specification;
|
||||
Client client(hostnames, port_base, my_client_id);
|
||||
auto& specification = client.specification;
|
||||
auto& sockets = client.sockets;
|
||||
for (int i = 0; i < nparties; i++)
|
||||
{
|
||||
set_up_client_socket(plain_sockets[i], hostnames[i], port_base + i);
|
||||
send(plain_sockets[i], (octet*) &my_client_id, sizeof(int));
|
||||
sockets[i] = new ssl_socket(io_service, ctx, plain_sockets[i],
|
||||
"P" + to_string(i), "C" + to_string(my_client_id), true);
|
||||
if (i == 0)
|
||||
specification.Receive(sockets[0]);
|
||||
octetStream os;
|
||||
os.store(finish);
|
||||
os.Send(sockets[i]);
|
||||
@@ -211,7 +125,7 @@ int main(int argc, char** argv)
|
||||
{
|
||||
gfp::init_field(specification.get<bigint>());
|
||||
cerr << "using prime " << gfp::pr() << endl;
|
||||
run<gfp>(salary_value, sockets, nparties);
|
||||
run<gfp>(salary_value, client);
|
||||
break;
|
||||
}
|
||||
case 'R':
|
||||
@@ -220,13 +134,13 @@ int main(int argc, char** argv)
|
||||
switch (R)
|
||||
{
|
||||
case 64:
|
||||
run<Z2<64>>(salary_value, sockets, nparties);
|
||||
run<Z2<64>>(salary_value, client);
|
||||
break;
|
||||
case 104:
|
||||
run<Z2<104>>(salary_value, sockets, nparties);
|
||||
run<Z2<104>>(salary_value, client);
|
||||
break;
|
||||
case 128:
|
||||
run<Z2<128>>(salary_value, sockets, nparties);
|
||||
run<Z2<128>>(salary_value, client);
|
||||
break;
|
||||
default:
|
||||
cerr << R << "-bit ring not implemented";
|
||||
@@ -239,8 +153,5 @@ int main(int argc, char** argv)
|
||||
exit(1);
|
||||
}
|
||||
|
||||
for (int i = 0; i < nparties; i++)
|
||||
delete sockets[i];
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -63,6 +63,16 @@ public:
|
||||
return *this;
|
||||
}
|
||||
|
||||
AddableVector<T> operator-(const AddableVector<T>& y) const
|
||||
{
|
||||
if (this->size() != y.size())
|
||||
throw out_of_range("vector length mismatch");
|
||||
AddableVector<T> res(y.size());
|
||||
for (unsigned int i = 0; i < this->size(); i++)
|
||||
res[i] = (*this)[i] - y[i];
|
||||
return res;
|
||||
}
|
||||
|
||||
void mul(const AddableVector<T>& x, const AddableVector<T>& y)
|
||||
{
|
||||
if (x.size() != y.size())
|
||||
|
||||
71
FHE/Diagonalizer.cpp
Normal file
71
FHE/Diagonalizer.cpp
Normal file
@@ -0,0 +1,71 @@
|
||||
/*
|
||||
* Diagonalizer.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "Diagonalizer.h"
|
||||
|
||||
Diagonalizer::Diagonalizer(const MatrixVector& matrices,
|
||||
const FFT_Data& FTD, const FHE_PK& pk) :
|
||||
FTD(FTD)
|
||||
{
|
||||
assert(not matrices.empty());
|
||||
for (auto& matrix : matrices)
|
||||
{
|
||||
assert(matrix.n_cols == matrices[0].n_cols);
|
||||
assert(matrix.n_rows == matrices[0].n_rows);
|
||||
}
|
||||
|
||||
n_rows = matrices[0].n_rows;
|
||||
n_cols = matrices[0].n_cols;
|
||||
assert(n_rows * matrices.size() <= size_t(FTD.num_slots()));
|
||||
for (size_t i = 0; i < n_cols; i++)
|
||||
{
|
||||
Plaintext_<FFT_Data> plaintext(FTD, Evaluation);
|
||||
for (size_t k = 0; k < matrices.size(); k++)
|
||||
{
|
||||
for (size_t j = 0; j < n_rows; j++)
|
||||
{
|
||||
auto entry = matrices.at(k)[{j, (j + i) % n_cols}];
|
||||
plaintext.set_element(k * n_rows + j, entry);
|
||||
}
|
||||
}
|
||||
ciphertexts.push_back(pk.encrypt(plaintext));
|
||||
}
|
||||
}
|
||||
|
||||
Plaintext_<FFT_Data> Diagonalizer::get_plaintext(
|
||||
const MatrixVector& matrices, int left_col,
|
||||
int right_col)
|
||||
{
|
||||
Plaintext_<FFT_Data> plaintext(FTD, Evaluation);
|
||||
for (size_t k = 0; k < matrices.size(); k++)
|
||||
for (size_t j = 0; j < n_rows; j++)
|
||||
plaintext.set_element(k * n_rows + j,
|
||||
matrices.at(k)[{(left_col + j) % n_cols, right_col}]);
|
||||
return plaintext;
|
||||
}
|
||||
|
||||
Diagonalizer::MatrixVector Diagonalizer::decrypt(
|
||||
const vector<Ciphertext>& products, int n_matrices, FHE_SK& sk)
|
||||
{
|
||||
vector<Plaintext_<FFT_Data>> plaintexts;
|
||||
for (auto& x : products)
|
||||
plaintexts.push_back(sk.decrypt(x, FTD));
|
||||
return dediag(plaintexts, n_matrices);
|
||||
}
|
||||
|
||||
Diagonalizer::MatrixVector Diagonalizer::dediag(
|
||||
const vector<Plaintext_<FFT_Data>>& products, int n_matrices)
|
||||
{
|
||||
int n_cols_out = products.size();
|
||||
MatrixVector res(n_matrices, {int(n_rows), n_cols_out});
|
||||
for (int i = 0; i < n_cols_out; i++)
|
||||
{
|
||||
auto& c = products.at(i);
|
||||
for (int j = 0; j < n_matrices; j++)
|
||||
for (size_t k = 0; k < n_rows; k++)
|
||||
res.at(j)[{k, i}] = c.element(j * n_rows + k);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
36
FHE/Diagonalizer.h
Normal file
36
FHE/Diagonalizer.h
Normal file
@@ -0,0 +1,36 @@
|
||||
/*
|
||||
* Diagonalizer.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef FHE_DIAGONALIZER_H_
|
||||
#define FHE_DIAGONALIZER_H_
|
||||
|
||||
#include "Math/gfpvar.h"
|
||||
#include "Ciphertext.h"
|
||||
#include "Protocols/ShareMatrix.h"
|
||||
|
||||
class Diagonalizer
|
||||
{
|
||||
const FFT_Data& FTD;
|
||||
|
||||
size_t n_rows, n_cols;
|
||||
|
||||
public:
|
||||
typedef AddableVector<ValueMatrix<gfpvar>> MatrixVector;
|
||||
|
||||
vector<Ciphertext> ciphertexts;
|
||||
|
||||
Diagonalizer(const MatrixVector& matrices,
|
||||
const FFT_Data& FTD, const FHE_PK& pk);
|
||||
|
||||
Plaintext_<FFT_Data> get_plaintext(const MatrixVector& matrices,
|
||||
int left_col, int right_col);
|
||||
|
||||
MatrixVector decrypt(const vector<Ciphertext>&, int n_matrices, FHE_SK& sk);
|
||||
|
||||
MatrixVector dediag(const vector<Plaintext_<FFT_Data>>& plaintexts,
|
||||
int n_matrices);
|
||||
};
|
||||
|
||||
#endif /* FHE_DIAGONALIZER_H_ */
|
||||
@@ -185,7 +185,8 @@ void FFT_Iter(vector<modp>& ioput, int n, const vector<modp>& roots,
|
||||
int start = queues.distribute(job, n / 2);
|
||||
for (int i = start; i < n / 2; i++)
|
||||
FFT_Iter2_body(ioput, alpha2, i, m, PrD);
|
||||
queues.wrap_up(job);
|
||||
if (start > 0)
|
||||
queues.wrap_up(job);
|
||||
}
|
||||
else
|
||||
for (int i = 0; i < n / 2; i++)
|
||||
|
||||
@@ -359,11 +359,19 @@ void FHE_PK::check(const FHE_Params& params, const bigint& pr) const
|
||||
|
||||
|
||||
|
||||
template void FHE_PK::encrypt(Ciphertext&, const Plaintext_<FFT_Data>& mess,
|
||||
const Random_Coins& rc) const;
|
||||
template void FHE_PK::encrypt(Ciphertext&, const Plaintext_<P2Data>& mess,
|
||||
const Random_Coins& rc) const;
|
||||
|
||||
template Ciphertext FHE_PK::encrypt(const Plaintext_<FFT_Data>& mess,
|
||||
const Random_Coins& rc) const;
|
||||
template Ciphertext FHE_PK::encrypt(const Plaintext_<FFT_Data>& mess) const;
|
||||
template Ciphertext FHE_PK::encrypt(const Plaintext_<P2Data>& mess) const;
|
||||
|
||||
template void FHE_SK::decrypt(Plaintext_<FFT_Data>&, const Ciphertext& c) const;
|
||||
template void FHE_SK::decrypt(Plaintext_<P2Data>&, const Ciphertext& c) const;
|
||||
|
||||
template Plaintext_<FFT_Data> FHE_SK::decrypt(const Ciphertext& c,
|
||||
const FFT_Data& FieldD);
|
||||
template Plaintext_<P2Data> FHE_SK::decrypt(const Ciphertext& c,
|
||||
|
||||
@@ -44,6 +44,20 @@ void Multiplier<FD>::multiply_and_add(Plaintext_<FD>& res,
|
||||
template <class FD>
|
||||
void Multiplier<FD>::multiply_and_add(Plaintext_<FD>& res,
|
||||
const Ciphertext& enc_a, const Rq_Element& b, OT_ROLE role)
|
||||
{
|
||||
if (role & SENDER)
|
||||
{
|
||||
timers["Ciphertext multiplication"].start();
|
||||
C.mul(enc_a, b);
|
||||
timers["Ciphertext multiplication"].stop();
|
||||
}
|
||||
|
||||
add(res, C, role);
|
||||
}
|
||||
|
||||
template <class FD>
|
||||
void Multiplier<FD>::add(Plaintext_<FD>& res, const Ciphertext& c,
|
||||
OT_ROLE role, int n_summands)
|
||||
{
|
||||
o.reset_write_head();
|
||||
|
||||
@@ -51,9 +65,6 @@ void Multiplier<FD>::multiply_and_add(Plaintext_<FD>& res,
|
||||
{
|
||||
PRNG G;
|
||||
G.ReSeed();
|
||||
timers["Ciphertext multiplication"].start();
|
||||
C.mul(enc_a, b);
|
||||
timers["Ciphertext multiplication"].stop();
|
||||
timers["Mask randomization"].start();
|
||||
product_share.randomize(G);
|
||||
bigint B = 6 * machine.setup<FD>().params.get_R();
|
||||
@@ -63,12 +74,13 @@ void Multiplier<FD>::multiply_and_add(Plaintext_<FD>& res,
|
||||
B *= NonInteractiveProof::slack(machine.sec,
|
||||
machine.setup<FD>().params.phi_m());
|
||||
B <<= machine.extra_slack;
|
||||
B *= n_summands;
|
||||
rc.generateUniform(G, 0, B, B);
|
||||
timers["Mask randomization"].stop();
|
||||
timers["Encryption"].start();
|
||||
other_pk.encrypt(mask, product_share, rc);
|
||||
timers["Encryption"].stop();
|
||||
mask += C;
|
||||
mask += c;
|
||||
mask.pack(o);
|
||||
res -= product_share;
|
||||
}
|
||||
|
||||
@@ -47,6 +47,8 @@ public:
|
||||
const Plaintext_<FD>& b);
|
||||
void multiply_and_add(Plaintext_<FD>& res, const Ciphertext& C,
|
||||
const Rq_Element& b, OT_ROLE role = BOTH);
|
||||
void add(Plaintext_<FD>& res, const Ciphertext& C, OT_ROLE role = BOTH,
|
||||
int n_summands = 1);
|
||||
void multiply_alpha_and_add(Plaintext_<FD>& res, const Rq_Element& b,
|
||||
OT_ROLE role = BOTH);
|
||||
int get_offset() { return P.get_offset(); }
|
||||
|
||||
@@ -17,7 +17,7 @@ PairwiseMachine::PairwiseMachine(Player& P) :
|
||||
}
|
||||
|
||||
PairwiseMachine::PairwiseMachine(int argc, const char** argv) :
|
||||
MachineBase(argc, argv), P(*new PlainPlayer(N, 0xffff << 16)),
|
||||
MachineBase(argc, argv), P(*new PlainPlayer(N, "pairwise")),
|
||||
other_pks(N.num_players(), {setup_p.params, 0}),
|
||||
pk(other_pks[N.my_num()]), sk(pk)
|
||||
{
|
||||
|
||||
@@ -31,7 +31,7 @@ public:
|
||||
map<string, Timer> timers;
|
||||
|
||||
GeneratorBase(int thread_num, const Names& N, Player* player = 0) :
|
||||
player(player ? 0 : new PlainPlayer(N, thread_num << 16)),
|
||||
player(player ? 0 : new PlainPlayer(N, to_string(thread_num))),
|
||||
thread_num(thread_num),
|
||||
P(player ? *player : *this->player), thread(0), total(0)
|
||||
{
|
||||
|
||||
@@ -216,7 +216,7 @@ void MultiplicativeMachine::generate_setup(int slack)
|
||||
template <class FD>
|
||||
void MultiplicativeMachine::fake_keys(int slack)
|
||||
{
|
||||
PlainPlayer P(N, -1 * N.num_players() * N.num_players());
|
||||
PlainPlayer P(N, "fake");
|
||||
octetStream os;
|
||||
PartSetup<FD>& part_setup = setup.part<FD>();
|
||||
if (P.my_num() == 0)
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "AtlasSecret.h"
|
||||
#include "TinyMC.h"
|
||||
|
||||
#include "Protocols/Shamir.hpp"
|
||||
#include "Protocols/ShamirMC.hpp"
|
||||
#include "Protocols/MAC_Check_Base.hpp"
|
||||
#include "Secret.hpp"
|
||||
|
||||
@@ -92,9 +92,9 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
size_t data_sent()
|
||||
NamedCommStats comm_stats()
|
||||
{
|
||||
return part_prep.data_sent();
|
||||
return part_prep.comm_stats();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -108,7 +108,8 @@ public:
|
||||
static FakeSecret input(GC::Processor<FakeSecret>& processor, const InputArgs& args);
|
||||
static FakeSecret input(int from, word input, int n_bits);
|
||||
|
||||
static FakeSecret constant(clear value, int = 0, mac_key_type = {}) { return value; }
|
||||
static FakeSecret constant(clear value, int = 0, mac_key_type = {}, int = -1)
|
||||
{ return value; }
|
||||
|
||||
FakeSecret() {}
|
||||
template <class T>
|
||||
|
||||
@@ -68,6 +68,7 @@ enum
|
||||
CLEAR_WRITE = 0x210,
|
||||
XORCBI = 0x210,
|
||||
BITDECC = 0x211,
|
||||
NOTCB = 0x212,
|
||||
CONVCINT = 0x213,
|
||||
REVEAL = 0x214,
|
||||
STMSDCI = 0x215,
|
||||
|
||||
@@ -84,12 +84,6 @@ public:
|
||||
return new HashMaliciousRepMC<U>;
|
||||
}
|
||||
|
||||
static U constant(const BitVec& other, int my_num, const BitVec& alphai)
|
||||
{
|
||||
(void) my_num, (void) alphai;
|
||||
return other;
|
||||
}
|
||||
|
||||
MalRepSecretBase() {}
|
||||
template<class T>
|
||||
MalRepSecretBase(const T& other) : super(other) {}
|
||||
|
||||
@@ -143,7 +143,7 @@ public:
|
||||
|
||||
static void trans(Processor<NoShare>&, Integer, const vector<int>&) { fail(); }
|
||||
|
||||
static NoShare constant(const GC::Clear&, int, mac_key_type) { fail(); return {}; }
|
||||
static NoShare constant(const GC::Clear&, int, mac_key_type, int = -1) { fail(); return {}; }
|
||||
|
||||
NoShare() {}
|
||||
|
||||
|
||||
@@ -86,6 +86,7 @@ public:
|
||||
void xors(const vector<int>& args, size_t start, size_t end);
|
||||
void xorc(const ::BaseInstruction& instruction);
|
||||
void nots(const ::BaseInstruction& instruction);
|
||||
void notcb(const ::BaseInstruction& instruction);
|
||||
void andm(const ::BaseInstruction& instruction);
|
||||
void and_(const vector<int>& args, bool repeat);
|
||||
void andrs(const vector<int>& args) { and_(args, true); }
|
||||
|
||||
@@ -257,6 +257,19 @@ void Processor<T>::nots(const ::BaseInstruction& instruction)
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Processor<T>::notcb(const ::BaseInstruction& instruction)
|
||||
{
|
||||
int total = instruction.get_n();
|
||||
int unit = Clear::N_BITS;
|
||||
for (int i = 0; i < DIV_CEIL(total, unit); i++)
|
||||
{
|
||||
int n = min(unit, total - i * unit);
|
||||
C[instruction.get_r(0) + i] =
|
||||
Clear(~C[instruction.get_r(1) + i].get()).mask(n);
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Processor<T>::andm(const ::BaseInstruction& instruction)
|
||||
{
|
||||
|
||||
@@ -30,7 +30,7 @@ public:
|
||||
static MC* new_mc(typename super::mac_key_type) { return new MC; }
|
||||
|
||||
static This constant(const typename super::clear& constant, int my_num,
|
||||
typename super::mac_key_type = {})
|
||||
typename super::mac_key_type = {}, int = -1)
|
||||
{
|
||||
return Rep4Share<typename super::clear>::constant(constant, my_num);
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
|
||||
#include "Protocols/ReplicatedPrep.hpp"
|
||||
#include "Protocols/MAC_Check_Base.hpp"
|
||||
#include "Protocols/Replicated.hpp"
|
||||
#include "OT/NPartyTripleGenerator.hpp"
|
||||
|
||||
namespace GC
|
||||
@@ -65,12 +66,12 @@ void SemiPrep::buffer_bits()
|
||||
}
|
||||
}
|
||||
|
||||
size_t SemiPrep::data_sent()
|
||||
NamedCommStats SemiPrep::comm_stats()
|
||||
{
|
||||
if (triple_generator)
|
||||
return triple_generator->data_sent();
|
||||
return triple_generator->comm_stats();
|
||||
else
|
||||
return 0;
|
||||
return {};
|
||||
}
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
@@ -53,7 +53,7 @@ public:
|
||||
throw not_implemented();
|
||||
}
|
||||
|
||||
size_t data_sent();
|
||||
NamedCommStats comm_stats();
|
||||
};
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
@@ -102,16 +102,16 @@ ShareParty<T>::ShareParty(int argc, const char** argv, ez::ezOptionParser& opt,
|
||||
if (not this->machine.use_encryption and not T::dishonest_majority)
|
||||
insecure("unencrypted communication");
|
||||
|
||||
Server* server = network_opts.start_networking(this->N, my_num);
|
||||
network_opts.start_networking(this->N, my_num);
|
||||
|
||||
if (online_opts.live_prep)
|
||||
if (T::needs_ot)
|
||||
{
|
||||
Player* P;
|
||||
if (this->machine.use_encryption)
|
||||
P = new CryptoPlayer(this->N, 0xFFFF);
|
||||
P = new CryptoPlayer(this->N, "shareparty");
|
||||
else
|
||||
P = new PlainPlayer(this->N, 0xFFFF);
|
||||
P = new PlainPlayer(this->N, "shareparty");
|
||||
for (int i = 0; i < this->machine.nthreads; i++)
|
||||
this->machine.ot_setups.push_back({*P, true});
|
||||
delete P;
|
||||
@@ -133,9 +133,6 @@ ShareParty<T>::ShareParty(int argc, const char** argv, ez::ezOptionParser& opt,
|
||||
this->run();
|
||||
|
||||
this->machine.write_memory(this->N.my_num());
|
||||
|
||||
if (server)
|
||||
delete server;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
|
||||
@@ -171,8 +171,8 @@ class ReplicatedSecret : public RepSecretBase<U, 2>
|
||||
public:
|
||||
typedef ReplicatedBase Protocol;
|
||||
|
||||
static ReplicatedSecret constant(const typename super::clear& value, int my_num,
|
||||
typename super::mac_key_type)
|
||||
static ReplicatedSecret constant(const typename super::clear& value,
|
||||
int my_num, typename super::mac_key_type, int = -1)
|
||||
{
|
||||
ReplicatedSecret res;
|
||||
if (my_num < 2)
|
||||
|
||||
@@ -58,8 +58,8 @@ public:
|
||||
void pre_run();
|
||||
void post_run() { ShareThread<T>::post_run(); }
|
||||
|
||||
size_t data_sent()
|
||||
{ return Thread<T>::data_sent() + this->DataF.data_sent(); }
|
||||
NamedCommStats comm_stats()
|
||||
{ return Thread<T>::comm_stats() + this->DataF.comm_stats(); }
|
||||
};
|
||||
|
||||
template<class T>
|
||||
|
||||
@@ -56,7 +56,7 @@ public:
|
||||
void join_tape();
|
||||
void finish();
|
||||
|
||||
virtual size_t data_sent();
|
||||
virtual NamedCommStats comm_stats();
|
||||
};
|
||||
|
||||
template<class T>
|
||||
|
||||
@@ -51,10 +51,11 @@ void Thread<T>::run()
|
||||
singleton = this;
|
||||
BaseMachine::s().thread_num = thread_num;
|
||||
secure_prng.ReSeed();
|
||||
string id = "T" + to_string(thread_num);
|
||||
if (machine.use_encryption)
|
||||
P = new CryptoPlayer(N, thread_num << 16);
|
||||
P = new CryptoPlayer(N, id);
|
||||
else
|
||||
P = new PlainPlayer(N, thread_num << 16);
|
||||
P = new PlainPlayer(N, id);
|
||||
processor.open_input_file(N.my_num(), thread_num,
|
||||
master.opts.cmd_private_input_file);
|
||||
processor.out.activate(N.my_num() == 0 or master.opts.interactive);
|
||||
@@ -98,10 +99,10 @@ void Thread<T>::finish()
|
||||
}
|
||||
|
||||
template<class T>
|
||||
size_t GC::Thread<T>::data_sent()
|
||||
NamedCommStats Thread<T>::comm_stats()
|
||||
{
|
||||
assert(P);
|
||||
return P->comm_stats.total_data();
|
||||
return P->comm_stats;
|
||||
}
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
@@ -58,7 +58,7 @@ Thread<T>* ThreadMaster<T>::new_thread(int i)
|
||||
template<class T>
|
||||
void ThreadMaster<T>::run()
|
||||
{
|
||||
P = new PlainPlayer(N, 0xff << 24);
|
||||
P = new PlainPlayer(N, "main");
|
||||
|
||||
machine.load_schedule(progname);
|
||||
|
||||
@@ -87,12 +87,10 @@ void ThreadMaster<T>::run()
|
||||
|
||||
NamedCommStats stats = P->comm_stats;
|
||||
ExecutionStats exe_stats;
|
||||
size_t data_sent = P->comm_stats.total_data();
|
||||
for (auto thread : threads)
|
||||
{
|
||||
stats += thread->P->comm_stats;
|
||||
exe_stats += thread->processor.stats;
|
||||
data_sent += thread->data_sent();
|
||||
delete thread;
|
||||
}
|
||||
|
||||
@@ -102,7 +100,7 @@ void ThreadMaster<T>::run()
|
||||
stats.print();
|
||||
|
||||
cerr << "Time = " << timer.elapsed() << " seconds" << endl;
|
||||
cerr << "Data sent = " << data_sent * 1e-6 << " MB" << endl;
|
||||
cerr << "Data sent = " << stats.sent * 1e-6 << " MB" << endl;
|
||||
}
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
@@ -48,7 +48,7 @@ public:
|
||||
|
||||
void set_protocol(typename T::Protocol& protocol);
|
||||
|
||||
size_t data_sent();
|
||||
NamedCommStats comm_stats();
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
@@ -92,13 +92,13 @@ void GC::TinierSharePrep<T>::buffer_bits()
|
||||
}
|
||||
|
||||
template<class T>
|
||||
size_t TinierSharePrep<T>::data_sent()
|
||||
NamedCommStats TinierSharePrep<T>::comm_stats()
|
||||
{
|
||||
size_t res = 0;
|
||||
NamedCommStats res;
|
||||
if (triple_generator)
|
||||
res += triple_generator->data_sent();
|
||||
res += triple_generator->comm_stats();
|
||||
if (real_triple_generator)
|
||||
res += real_triple_generator->data_sent();
|
||||
res += real_triple_generator->comm_stats();
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
@@ -70,11 +70,14 @@ public:
|
||||
T::reveal_inst(processor, args);
|
||||
}
|
||||
|
||||
static This constant(BitVec other, int my_num, mac_key_type alphai)
|
||||
static This constant(BitVec other, int my_num, mac_key_type alphai,
|
||||
int n_bits = -1)
|
||||
{
|
||||
if (n_bits < 0)
|
||||
n_bits = other.length();
|
||||
This res;
|
||||
res.resize_regs(other.length());
|
||||
for (int i = 0; i < other.length(); i++)
|
||||
res.resize_regs(n_bits);
|
||||
for (int i = 0; i < n_bits; i++)
|
||||
res.get_reg(i) = part_type::constant(other.get_bit(i), my_num, alphai);
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -43,6 +43,7 @@
|
||||
X(XORCB, processor.xorc(instruction)) \
|
||||
X(XORCBI, C0.xor_(PC1, IMM)) \
|
||||
X(NOTS, processor.nots(INST)) \
|
||||
X(NOTCB, processor.notcb(INST)) \
|
||||
X(ANDRS, T::andrs(PROC, EXTRA)) \
|
||||
X(ANDS, T::ands(PROC, EXTRA)) \
|
||||
X(ADDCB, C0 = PC1 + PC2) \
|
||||
@@ -140,6 +141,7 @@
|
||||
X(NPLAYERS, I0 = Thread<T>::s().P->num_players()) \
|
||||
X(THRESHOLD, I0 = T::threshold(Thread<T>::s().P->num_players())) \
|
||||
X(PLAYERID, I0 = Thread<T>::s().P->my_num()) \
|
||||
X(CRASH, if (I0.get()) throw crash_requested()) \
|
||||
|
||||
#define INSTRUCTIONS BIT_INSTRUCTIONS GC_INSTRUCTIONS
|
||||
|
||||
|
||||
@@ -226,7 +226,7 @@ OTMachine::OTMachine(int argc, const char** argv)
|
||||
N.push_back(new Names(my_num, portnum_base + 1000 * N.size(), names));
|
||||
}
|
||||
|
||||
P = new RealTwoPartyPlayer(*N[0], 1 - my_num, 500);
|
||||
P = new RealTwoPartyPlayer(*N[0], 1 - my_num, "machine");
|
||||
|
||||
timeval baseOTstart, baseOTend;
|
||||
gettimeofday(&baseOTstart, NULL);
|
||||
@@ -319,7 +319,8 @@ void OTMachine::run()
|
||||
}
|
||||
// now setup resources for each thread
|
||||
// round robin with the names
|
||||
players[i] = new RealTwoPartyPlayer(*N[i%N.size()], 1 - my_num, (i+1) * 1000);
|
||||
players[i] = new RealTwoPartyPlayer(*N[i % N.size()], 1 - my_num,
|
||||
"thread" + to_string(i));
|
||||
tinfos[i].thread_num = i+1;
|
||||
tinfos[i].other_player_num = 1 - my_num;
|
||||
tinfos[i].nOTs = nOTs;
|
||||
|
||||
@@ -138,7 +138,10 @@ TripleMachine::TripleMachine(int argc, const char** argv) :
|
||||
opt.get("-S")->getInt(z2s);
|
||||
|
||||
// doesn't work with Montgomery multiplication
|
||||
gfpvar1::init_field(prime, false);
|
||||
if (prime)
|
||||
gfpvar1::init_field(prime, false);
|
||||
else
|
||||
gfpvar1::init_default(128, false);
|
||||
gf2n_long::init_field(128);
|
||||
gf2n_short::init_field(40);
|
||||
|
||||
@@ -175,7 +178,7 @@ void TripleMachine::run()
|
||||
nConnections = 2;
|
||||
}
|
||||
// do the base OTs
|
||||
PlainPlayer P(N[0], 0xF000);
|
||||
PlainPlayer P(N[0], "base");
|
||||
OTTripleSetup setup(P, true);
|
||||
|
||||
vector<GeneratorThread*> generators(nthreads);
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
*/
|
||||
|
||||
#include "Protocols/HemiShare.h"
|
||||
#include "Protocols/HemiOptions.h"
|
||||
#include "Math/gfp.h"
|
||||
#include "Math/gf2n.h"
|
||||
#include "FHE/P2Data.h"
|
||||
@@ -22,6 +23,7 @@
|
||||
#include "Protocols/MAC_Check.hpp"
|
||||
#include "Protocols/SemiMC.hpp"
|
||||
#include "Protocols/Beaver.hpp"
|
||||
#include "Protocols/Hemi.hpp"
|
||||
#include "GC/ShareSecret.hpp"
|
||||
#include "GC/SemiHonestRepPrep.h"
|
||||
#include "Math/gfp.hpp"
|
||||
@@ -29,6 +31,7 @@
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
ez::ezOptionParser opt;
|
||||
HemiOptions::singleton = {opt, argc, argv};
|
||||
DishonestMajorityFieldMachine<HemiShare, HemiShare, gf2n_short>(argc, argv,
|
||||
opt);
|
||||
}
|
||||
|
||||
60
Makefile
60
Makefile
@@ -9,7 +9,7 @@ NETWORK = $(patsubst %.cpp,%.o,$(wildcard Networking/*.cpp))
|
||||
|
||||
PROCESSOR = $(patsubst %.cpp,%.o,$(wildcard Processor/*.cpp))
|
||||
|
||||
FHEOFFLINE = $(patsubst %.cpp,%.o,$(wildcard FHEOffline/*.cpp FHE/*.cpp)) Protocols/CowGearOptions.o
|
||||
FHEOBJS = $(patsubst %.cpp,%.o,$(wildcard FHEOffline/*.cpp FHE/*.cpp)) Protocols/CowGearOptions.o
|
||||
|
||||
GC = $(patsubst %.cpp,%.o,$(wildcard GC/*.cpp)) $(PROCESSOR)
|
||||
GC_SEMI = GC/SemiSecret.o GC/SemiPrep.o GC/square64.o
|
||||
@@ -17,32 +17,38 @@ GC_SEMI = GC/SemiSecret.o GC/SemiPrep.o GC/square64.o
|
||||
OT = $(patsubst %.cpp,%.o,$(wildcard OT/*.cpp))
|
||||
OT_EXE = ot.x ot-offline.x
|
||||
|
||||
COMMON = $(MATH) $(TOOLS) $(NETWORK) GC/square64.o Processor/OnlineOptions.o Processor/BaseMachine.o Processor/DataPositions.o Processor/ThreadQueues.o Processor/ThreadQueue.o
|
||||
COMMONOBJS = $(MATH) $(TOOLS) $(NETWORK) GC/square64.o Processor/OnlineOptions.o Processor/BaseMachine.o Processor/DataPositions.o Processor/ThreadQueues.o Processor/ThreadQueue.o
|
||||
COMPLETE = $(COMMON) $(PROCESSOR) $(FHEOFFLINE) $(TINYOTOFFLINE) $(GC) $(OT)
|
||||
YAO = $(patsubst %.cpp,%.o,$(wildcard Yao/*.cpp)) $(OT) BMR/Key.o
|
||||
BMR = $(patsubst %.cpp,%.o,$(wildcard BMR/*.cpp BMR/network/*.cpp))
|
||||
VM = $(PROCESSOR) $(COMMON) GC/square64.o GC/Instruction.o OT/OTTripleSetup.o OT/BaseOT.o $(LIBSIMPLEOT)
|
||||
MINI_OT = OT/OTTripleSetup.o OT/BaseOT.o $(LIBSIMPLEOT)
|
||||
VMOBJS = $(PROCESSOR) $(COMMONOBJS) GC/square64.o GC/Instruction.o OT/OTTripleSetup.o OT/BaseOT.o $(LIBSIMPLEOT)
|
||||
VM = $(MINI_OT) $(SHAREDLIB)
|
||||
COMMON = $(SHAREDLIB)
|
||||
|
||||
|
||||
LIB = libSPDZ.a
|
||||
SHAREDLIB = libSPDZ.so
|
||||
FHEOFFLINE = libFHE.so
|
||||
LIBRELEASE = librelease.a
|
||||
|
||||
ifeq ($(AVX_OT), 0)
|
||||
VM += ECDSA/P256Element.o
|
||||
OT += ECDSA/P256Element.o
|
||||
MINI_OT += ECDSA/P256Element.o
|
||||
else
|
||||
LIBSIMPLEOT = SimpleOT/libsimpleot.a
|
||||
endif
|
||||
|
||||
# used for dependency generation
|
||||
OBJS = $(BMR) $(FHEOFFLINE) $(TINYOTOFFLINE) $(YAO) $(COMPLETE) $(patsubst %.cpp,%.o,$(wildcard Machines/*.cpp Utils/*.cpp))
|
||||
OBJS = $(BMR) $(FHEOBJS) $(TINYOTOFFLINE) $(YAO) $(COMPLETE) $(patsubst %.cpp,%.o,$(wildcard Machines/*.cpp Utils/*.cpp))
|
||||
DEPS := $(wildcard */*.d */*/*.d)
|
||||
|
||||
# never delete
|
||||
.SECONDARY: $(OBJS)
|
||||
.SECONDARY: $(OBJS) $(patsubst %.cpp,%.o,$(wildcard */*.cpp))
|
||||
|
||||
|
||||
all: arithmetic binary gen_input online offline externalIO bmr ecdsa doc
|
||||
all: arithmetic binary gen_input online offline externalIO bmr ecdsa
|
||||
vm: arithmetic binary
|
||||
|
||||
.PHONY: doc
|
||||
@@ -112,13 +118,22 @@ sy: sy-rep-field-party.x sy-rep-ring-party.x sy-shamir-party.x
|
||||
ecdsa: $(patsubst ECDSA/%.cpp,%.x,$(wildcard ECDSA/*-ecdsa-party.cpp)) Fake-ECDSA.x
|
||||
ecdsa-static: static-dir $(patsubst ECDSA/%.cpp,static/%.x,$(wildcard ECDSA/*-ecdsa-party.cpp))
|
||||
|
||||
$(LIBRELEASE): Protocols/MalRepRingOptions.o $(PROCESSOR) $(COMMON) $(OT) $(GC)
|
||||
$(LIBRELEASE): Protocols/MalRepRingOptions.o $(PROCESSOR) $(COMMONOBJS) $(OT) $(GC)
|
||||
$(AR) -csr $@ $^
|
||||
|
||||
CFLAGS += -fPIC
|
||||
LDLIBS += -Wl,-rpath -Wl,$(CURDIR)
|
||||
|
||||
$(SHAREDLIB): $(PROCESSOR) $(COMMONOBJS) GC/square64.o GC/Instruction.o
|
||||
$(CXX) $(CFLAGS) -shared -o $@ $^ $(LDLIBS)
|
||||
|
||||
$(FHEOFFLINE): $(FHEOBJS) $(SHAREDLIB)
|
||||
$(CXX) $(CFLAGS) -shared -o $@ $^ $(LDLIBS)
|
||||
|
||||
static/%.x: Machines/%.o $(LIBRELEASE) $(LIBSIMPLEOT)
|
||||
$(CXX) $(CFLAGS) -o $@ $^ -Wl,-Map=$<.map -Wl,-Bstatic -static-libgcc -static-libstdc++ $(LIBRELEASE) $(LIBSIMPLEOT) $(BOOST) $(LDLIBS) -Wl,-Bdynamic -ldl
|
||||
|
||||
static/%.x: ECDSA/%.o ECDSA/P256Element.o $(VM) $(OT) $(LIBSIMPLEOT)
|
||||
static/%.x: ECDSA/%.o ECDSA/P256Element.o $(VMOBJS) $(OT) $(LIBSIMPLEOT)
|
||||
$(CXX) $(CFLAGS) -o $@ $^ -Wl,-Map=$<.map -Wl,-Bstatic -static-libgcc -static-libstdc++ $(BOOST) $(LDLIBS) -Wl,-Bdynamic -ldl
|
||||
|
||||
static-dir:
|
||||
@@ -141,7 +156,7 @@ gc-emulate.x: $(VM) GC/FakeSecret.o GC/square64.o
|
||||
bmr-%.x: $(BMR) $(VM) Machines/bmr-%.cpp $(LIBSIMPLEOT)
|
||||
$(CXX) -o $@ $(CFLAGS) $^ $(BOOST) $(LDLIBS)
|
||||
|
||||
%-bmr-party.x: Machines/%-bmr-party.o $(BMR) $(VM) $(LIBSIMPLEOT)
|
||||
%-bmr-party.x: Machines/%-bmr-party.o $(BMR) $(SHAREDLIB) $(MINI_OT)
|
||||
$(CXX) -o $@ $(CFLAGS) $^ $(BOOST) $(LDLIBS)
|
||||
|
||||
bmr-clean:
|
||||
@@ -170,15 +185,15 @@ default-prime-length.x: Utils/default-prime-length.o
|
||||
secure.x: Utils/secure.o
|
||||
$(CXX) -o $@ $(CFLAGS) $^
|
||||
|
||||
Fake-Offline.x: Utils/Fake-Offline.o $(VM)
|
||||
$(CXX) -o $@ $(CFLAGS) $^ $(LDLIBS)
|
||||
|
||||
%.x: Utils/%.o $(COMMON)
|
||||
$(CXX) -o $@ $(CFLAGS) $^ $(LDLIBS)
|
||||
|
||||
%.x: Machines/%.o $(VM) OT/OTTripleSetup.o OT/BaseOT.o $(LIBSIMPLEOT)
|
||||
%.x: Machines/%.o $(MINI_OT) $(SHAREDLIB)
|
||||
$(CXX) -o $@ $(CFLAGS) $^ $(LDLIBS)
|
||||
|
||||
%gear-party.x: Machines/%gear-party.o $(VM) OT/OTTripleSetup.o OT/BaseOT.o $(LIBSIMPLEOT)
|
||||
$(CXX) -o $@ $(CFLAGS) $^ $(LDLIBS) -lntl
|
||||
|
||||
%-ecdsa-party.x: ECDSA/%-ecdsa-party.o ECDSA/P256Element.o $(VM)
|
||||
$(CXX) -o $@ $(CFLAGS) $^ $(LDLIBS)
|
||||
|
||||
@@ -188,7 +203,7 @@ replicated-field-party.x: GC/square64.o
|
||||
brain-party.x: GC/square64.o
|
||||
malicious-rep-bin-party.x: GC/square64.o
|
||||
ps-rep-bin-party.x: GC/PostSacriBin.o
|
||||
semi-bin-party.x: $(VM) $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o
|
||||
semi-bin-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o
|
||||
tiny-party.x: $(OT)
|
||||
tinier-party.x: $(OT)
|
||||
spdz2k-party.x: $(OT) $(patsubst %.cpp,%.o,$(wildcard Machines/SPDZ2*.cpp))
|
||||
@@ -202,12 +217,12 @@ chaigear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(OT)
|
||||
lowgear-party.x: $(FHEOFFLINE) $(OT) Protocols/CowGearOptions.o Protocols/LowGearKeyGen.o
|
||||
highgear-party.x: $(FHEOFFLINE) $(OT) Protocols/CowGearOptions.o Protocols/HighGearKeyGen.o
|
||||
atlas-party.x: GC/AtlasSecret.o
|
||||
static/hemi-party.x: $(FHEOFFLINE)
|
||||
static/soho-party.x: $(FHEOFFLINE)
|
||||
static/cowgear-party.x: $(FHEOFFLINE)
|
||||
static/chaigear-party.x: $(FHEOFFLINE)
|
||||
static/lowgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o Protocols/LowGearKeyGen.o
|
||||
static/highgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o Protocols/HighGearKeyGen.o
|
||||
static/hemi-party.x: $(FHEOBJS)
|
||||
static/soho-party.x: $(FHEOBJS)
|
||||
static/cowgear-party.x: $(FHEOBJS)
|
||||
static/chaigear-party.x: $(FHEOBJS)
|
||||
static/lowgear-party.x: $(FHEOBJS) Protocols/CowGearOptions.o Protocols/LowGearKeyGen.o
|
||||
static/highgear-party.x: $(FHEOBJS) Protocols/CowGearOptions.o Protocols/HighGearKeyGen.o
|
||||
mascot-party.x: Machines/SPDZ.o $(OT)
|
||||
static/mascot-party.x: Machines/SPDZ.o
|
||||
Player-Online.x: Machines/SPDZ.o $(OT)
|
||||
@@ -226,7 +241,6 @@ real-bmr-party.x: $(OT)
|
||||
paper-example.x: $(VM) $(OT) $(FHEOFFLINE)
|
||||
mascot-offline.x: $(VM) $(OT)
|
||||
cowgear-offline.x: $(OT) $(FHEOFFLINE)
|
||||
Fake-Offline.x: $(VM)
|
||||
static/rep-bmr-party.x: $(BMR)
|
||||
static/mal-rep-bmr-party.x: $(BMR)
|
||||
static/shamir-bmr-party.x: $(BMR)
|
||||
@@ -269,7 +283,7 @@ mpir: mpir-setup
|
||||
./configure --enable-cxx --prefix=$(CURDIR)/local
|
||||
$(MAKE) -C mpir install
|
||||
-echo MY_CFLAGS += -I./local/include >> CONFIG.mine
|
||||
-echo MY_LDLIBS += -Wl,-rpath -Wl,./local/lib -L./local/lib >> CONFIG.mine
|
||||
-echo MY_LDLIBS += -Wl,-rpath -Wl,$(CURDIR)/local/lib -L$(CURDIR)/local/lib >> CONFIG.mine
|
||||
|
||||
mac-setup:
|
||||
brew install openssl boost libsodium mpir yasm ntl
|
||||
@@ -281,4 +295,4 @@ simde/simde:
|
||||
git submodule update --init simde
|
||||
|
||||
clean:
|
||||
-rm -f */*.o *.o */*.d *.d *.x core.* *.a gmon.out */*/*.o static/*.x
|
||||
-rm -f */*.o *.o */*.d *.d *.x core.* *.a gmon.out */*/*.o static/*.x *.so
|
||||
|
||||
@@ -52,21 +52,19 @@ public:
|
||||
BitVec_& operator-=(const BitVec_& other) { *this ^= other; return *this; }
|
||||
|
||||
BitVec_ extend_bit() const { return -(this->a & 1); }
|
||||
BitVec_ mask(int n) const { return n < n_bits ? *this & ((1L << n) - 1) : *this; }
|
||||
|
||||
void extend_bit(BitVec_& res, int) const { res = extend_bit(); }
|
||||
void mask(BitVec_& res, int n) const { res = mask(n); }
|
||||
|
||||
void add(octetStream& os) { *this += os.get<BitVec_>(); }
|
||||
|
||||
void mul(const BitVec_& a, const BitVec_& b) { *this = a * b; }
|
||||
|
||||
void randomize(PRNG& G, int n = n_bits) { super::randomize(G); *this = mask(n); }
|
||||
void randomize(PRNG& G, int n = n_bits) { super::randomize(G); *this = this->mask(n); }
|
||||
|
||||
void pack(octetStream& os) const { os.store_int<sizeof(T)>(this->a); }
|
||||
void unpack(octetStream& os) { this->a = os.get_int<sizeof(T)>(); }
|
||||
|
||||
void pack(octetStream& os, int n) const { os.store_int(mask(n).a, DIV_CEIL(n, 8)); }
|
||||
void pack(octetStream& os, int n) const { os.store_int(super::mask(n).get(), DIV_CEIL(n, 8)); }
|
||||
void unpack(octetStream& os, int n) { this->a = os.get_int(DIV_CEIL(n, 8)); }
|
||||
|
||||
static BitVec_ unpack_new(octetStream& os, int n = n_bits)
|
||||
|
||||
@@ -85,6 +85,9 @@ public:
|
||||
T& operator^=(const IntBase& other) { return a ^= other.a; }
|
||||
T& operator&=(const IntBase& other) { return a &= other.a; }
|
||||
|
||||
IntBase mask(int n) const { return n < N_BITS ? *this & ((1L << n) - 1) : *this; }
|
||||
void mask(IntBase& res, int n) const { res = mask(n); }
|
||||
|
||||
friend ostream& operator<<(ostream& s, const IntBase& x) { x.output(s, true); return s; }
|
||||
friend istream& operator>>(istream& s, IntBase& x) { x.input(s, true); return s; }
|
||||
|
||||
|
||||
@@ -30,7 +30,8 @@ void Z2<K>::reqbl(int n)
|
||||
}
|
||||
else if (n > 0)
|
||||
{
|
||||
throw Processor_Error("Program compiled for fields not rings");
|
||||
throw Processor_Error("Program compiled for fields not rings, "
|
||||
"run compile.py with '-R " + to_string(K) + "'");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -133,7 +133,7 @@ inline void Zp_Data::Add<1>(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y
|
||||
#else
|
||||
*ans = *x + *y;
|
||||
asm goto ("jc %l[sub]" :::: sub);
|
||||
if (*ans >= *prA)
|
||||
if (mpn_cmp(ans, prA, 1) >= 0)
|
||||
sub:
|
||||
*ans -= *prA;
|
||||
#endif
|
||||
@@ -251,13 +251,17 @@ inline void Zp_Data::Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t*
|
||||
break;
|
||||
CASE(1)
|
||||
CASE(2)
|
||||
#if MAX_MOD_SZ >= 5
|
||||
#if MAX_MOD_SZ >= 4
|
||||
CASE(3)
|
||||
CASE(4)
|
||||
#endif
|
||||
#if MAX_MOD_SZ >= 5
|
||||
CASE(5)
|
||||
#endif
|
||||
#if MAX_MOD_SZ >= 10
|
||||
#if MAX_MOD_SZ >= 6
|
||||
CASE(6)
|
||||
#endif
|
||||
#if MAX_MOD_SZ >= 10
|
||||
CASE(7)
|
||||
CASE(8)
|
||||
CASE(9)
|
||||
|
||||
@@ -166,7 +166,8 @@ void gfp_<X, L>::reqbl(int n)
|
||||
}
|
||||
else if ((int)n < 0)
|
||||
{
|
||||
throw Processor_Error("Program compiled for rings not fields");
|
||||
throw Processor_Error("Program compiled for rings not fields, "
|
||||
"run compile.py without -R");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -26,9 +26,9 @@ void ssl_error(string side, string pronoun, string other, string server)
|
||||
<< "with `Scripts/setup-ssl.sh` expire after a month." << endl;
|
||||
}
|
||||
|
||||
CryptoPlayer::CryptoPlayer(const Names& Nms, int id_base) :
|
||||
MultiPlayer<ssl_socket*>(Nms, id_base), plaintext_player(Nms, id_base),
|
||||
other_player(Nms, id_base + Nms.num_players()),
|
||||
CryptoPlayer::CryptoPlayer(const Names& Nms, const string& id_base) :
|
||||
MultiPlayer<ssl_socket*>(Nms), plaintext_player(Nms, id_base),
|
||||
other_player(Nms, id_base + "recv"),
|
||||
ctx("P" + to_string(my_num()))
|
||||
{
|
||||
sockets.resize(num_players());
|
||||
@@ -57,6 +57,11 @@ CryptoPlayer::CryptoPlayer(const Names& Nms, int id_base) :
|
||||
}
|
||||
}
|
||||
|
||||
CryptoPlayer::CryptoPlayer(const Names& Nms, int id_base) :
|
||||
CryptoPlayer(Nms, to_string(id_base))
|
||||
{
|
||||
}
|
||||
|
||||
CryptoPlayer::~CryptoPlayer()
|
||||
{
|
||||
close_client_socket(plaintext_player.socket(my_num()));
|
||||
@@ -124,13 +129,6 @@ void CryptoPlayer::pass_around_no_stats(const octetStream& to_send,
|
||||
}
|
||||
}
|
||||
|
||||
template<>
|
||||
void MultiPlayer<ssl_socket*>::setup_sockets(const vector<string>& names,
|
||||
const vector<int>& ports, int id_base, ServerSocket& server)
|
||||
{
|
||||
(void)names, (void)ports, (void)id_base, (void)server;
|
||||
}
|
||||
|
||||
void CryptoPlayer::send_receive_all_no_stats(const vector<vector<bool>>& channels,
|
||||
const vector<octetStream>& to_send,
|
||||
vector<octetStream>& to_receive) const
|
||||
|
||||
@@ -12,6 +12,12 @@
|
||||
#include <boost/asio/ssl.hpp>
|
||||
#include <boost/asio.hpp>
|
||||
|
||||
/**
|
||||
* Encrypted multi-party communication.
|
||||
* Uses OpenSSL and certificates issued to "P<player_no>".
|
||||
* Sending and receiving is done in separate threads to allow
|
||||
* for bidirectional communication.
|
||||
*/
|
||||
class CryptoPlayer : public MultiPlayer<ssl_socket*>
|
||||
{
|
||||
PlainPlayer plaintext_player, other_player;
|
||||
@@ -24,7 +30,14 @@ class CryptoPlayer : public MultiPlayer<ssl_socket*>
|
||||
vector<Receiver<ssl_socket*>*> receivers;
|
||||
|
||||
public:
|
||||
CryptoPlayer(const Names& Nms, int id_base=0);
|
||||
/**
|
||||
* Start a new set of encrypted connections.
|
||||
* @param Nms network setup
|
||||
* @param id unique identifier
|
||||
*/
|
||||
CryptoPlayer(const Names& Nms, const string& id);
|
||||
// legacy interface
|
||||
CryptoPlayer(const Names& Nms, int id_base = 0);
|
||||
~CryptoPlayer();
|
||||
|
||||
bool is_encrypted() { return true; }
|
||||
|
||||
@@ -22,24 +22,20 @@ void Names::init(int player,int pnb,int my_port,const char* servername)
|
||||
setup_server();
|
||||
}
|
||||
|
||||
void Names::init(int player,int pnb,vector<string> Nms)
|
||||
Names::Names(int player, int nplayers, const string& servername, int pnb,
|
||||
int my_port) :
|
||||
Names()
|
||||
{
|
||||
vector<octet*> names;
|
||||
for (auto& name : Nms)
|
||||
names.push_back((octet*)name.c_str());
|
||||
init(player, pnb, names);
|
||||
Server::start_networking(*this, player, nplayers, servername, pnb, my_port);
|
||||
}
|
||||
|
||||
void Names::init(int player,int pnb,vector<octet*> Nms)
|
||||
void Names::init(int player,int pnb,vector<string> Nms)
|
||||
{
|
||||
player_no=player;
|
||||
portnum_base=pnb;
|
||||
nplayers=Nms.size();
|
||||
names.resize(nplayers);
|
||||
names=Nms;
|
||||
setup_ports();
|
||||
for (int i=0; i<nplayers; i++) {
|
||||
names[i]=(char*)Nms[i];
|
||||
}
|
||||
setup_server();
|
||||
}
|
||||
|
||||
@@ -112,7 +108,7 @@ Names::Names(ez::ezOptionParser& opt, int argc, const char** argv,
|
||||
cerr << usage;
|
||||
exit(1);
|
||||
}
|
||||
global_server = network_opts.start_networking(*this, player_no);
|
||||
network_opts.start_networking(*this, player_no);
|
||||
}
|
||||
|
||||
void Names::setup_ports()
|
||||
@@ -130,7 +126,7 @@ void Names::setup_names(const char *servername, int my_port)
|
||||
int socket_num;
|
||||
int pn = portnum_base - 1;
|
||||
set_up_client_socket(socket_num, servername, pn);
|
||||
send(socket_num, (octet*)&player_no, sizeof(player_no));
|
||||
octetStream("P" + to_string(player_no)).Send(socket_num);
|
||||
#ifdef DEBUG_NETWORKING
|
||||
cerr << "Sent " << player_no << " to " << servername << ":" << pn << endl;
|
||||
#endif
|
||||
@@ -191,16 +187,12 @@ Names::Names(const Names& other)
|
||||
names = other.names;
|
||||
ports = other.ports;
|
||||
server = 0;
|
||||
global_server = 0;
|
||||
}
|
||||
|
||||
|
||||
Names::~Names()
|
||||
{
|
||||
if (server != 0)
|
||||
delete server;
|
||||
if (global_server != 0)
|
||||
delete global_server;
|
||||
}
|
||||
|
||||
|
||||
@@ -213,18 +205,27 @@ Player::Player(const Names& Nms) :
|
||||
|
||||
|
||||
template<class T>
|
||||
MultiPlayer<T>::MultiPlayer(const Names& Nms, int id) :
|
||||
MultiPlayer<T>::MultiPlayer(const Names& Nms) :
|
||||
Player(Nms), send_to_self_socket(0)
|
||||
{
|
||||
if (Nms.num_players() > 1)
|
||||
setup_sockets(Nms.names, Nms.ports, id, *Nms.server);
|
||||
else
|
||||
sockets.resize(Nms.num_players());
|
||||
sockets.resize(Nms.num_players());
|
||||
}
|
||||
|
||||
|
||||
template<>
|
||||
MultiPlayer<int>::~MultiPlayer()
|
||||
PlainPlayer::PlainPlayer(const Names& Nms, const string& id) :
|
||||
MultiPlayer<int>(Nms)
|
||||
{
|
||||
if (Nms.num_players() > 1)
|
||||
setup_sockets(Nms.names, Nms.ports, id, *Nms.server);
|
||||
}
|
||||
|
||||
|
||||
PlainPlayer::PlainPlayer(const Names& Nms, int id_base) :
|
||||
PlainPlayer(Nms, to_string(id_base))
|
||||
{
|
||||
}
|
||||
|
||||
PlainPlayer::~PlainPlayer()
|
||||
{
|
||||
if (num_players() > 1)
|
||||
{
|
||||
@@ -258,33 +259,38 @@ PlayerBase::~PlayerBase()
|
||||
// Set up nmachines client and server sockets to send data back and fro
|
||||
// A machine is a server between it and player i if i<=my_number
|
||||
// Can also communicate with myself, but only with send_to and receive_from
|
||||
template<>
|
||||
void MultiPlayer<int>::setup_sockets(const vector<string>& names,const vector<int>& ports,int id_base,ServerSocket& server)
|
||||
void PlainPlayer::setup_sockets(const vector<string>& names,
|
||||
const vector<int>& ports, const string& id_base, ServerSocket& server)
|
||||
{
|
||||
sockets.resize(nplayers);
|
||||
// Set up the client side
|
||||
for (int i=player_no; i<nplayers; i++) {
|
||||
int pn=id_base+player_no;
|
||||
auto pn=id_base+"P"+to_string(player_no);
|
||||
if (i==player_no) {
|
||||
const char* localhost = "127.0.0.1";
|
||||
#ifdef DEBUG_NETWORKING
|
||||
fprintf(stderr, "Setting up send to self socket to %s:%d with id 0x%x\n",localhost,ports[i],pn);
|
||||
fprintf(stderr,
|
||||
"Setting up send to self socket to %s:%d with id %s\n",
|
||||
localhost, ports[i], pn.c_str());
|
||||
#endif
|
||||
set_up_client_socket(sockets[i],localhost,ports[i]);
|
||||
} else {
|
||||
#ifdef DEBUG_NETWORKING
|
||||
fprintf(stderr, "Setting up client to %s:%d with id 0x%x\n",names[i].c_str(),ports[i],pn);
|
||||
fprintf(stderr, "Setting up client to %s:%d with id %s\n",
|
||||
names[i].c_str(), ports[i], pn.c_str());
|
||||
#endif
|
||||
set_up_client_socket(sockets[i],names[i].c_str(),ports[i]);
|
||||
}
|
||||
send(sockets[i], (unsigned char*)&pn, sizeof(pn));
|
||||
octetStream(pn).Send(sockets[i]);
|
||||
}
|
||||
send_to_self_socket = sockets[player_no];
|
||||
// Setting up the server side
|
||||
for (int i=0; i<=player_no; i++) {
|
||||
int id=id_base+i;
|
||||
auto id=id_base+"P"+to_string(i);
|
||||
#ifdef DEBUG_NETWORKING
|
||||
fprintf(stderr, "As a server, waiting for client with id 0x%x to connect.\n",id);
|
||||
fprintf(stderr,
|
||||
"As a server, waiting for client with id %s to connect.\n",
|
||||
id.c_str());
|
||||
#endif
|
||||
sockets[i] = server.get_connection_socket(id);
|
||||
}
|
||||
@@ -539,11 +545,10 @@ void Player::send_receive_all(const vector<vector<bool>>& channels,
|
||||
send_receive_all_no_stats(channels, to_send, to_receive);
|
||||
}
|
||||
|
||||
void Player::partial_broadcast(const vector<bool>& senders,
|
||||
vector<octetStream>& os) const
|
||||
void Player::partial_broadcast(const vector<bool>&,
|
||||
const vector<bool>&, vector<octetStream>& os) const
|
||||
{
|
||||
partial_broadcast(senders, vector<bool>(num_players(), senders[my_num()]),
|
||||
os);
|
||||
unchecked_broadcast(os);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -570,7 +575,8 @@ void MultiPlayer<T>::send_receive_all_no_stats(
|
||||
}
|
||||
|
||||
|
||||
ThreadPlayer::ThreadPlayer(const Names& Nms, int id_base) : PlainPlayer(Nms, id_base)
|
||||
ThreadPlayer::ThreadPlayer(const Names& Nms, const string& id_base) :
|
||||
PlainPlayer(Nms, id_base)
|
||||
{
|
||||
for (int i = 0; i < Nms.num_players(); i++)
|
||||
{
|
||||
@@ -625,35 +631,41 @@ void ThreadPlayer::send_all(const octetStream& o) const
|
||||
}
|
||||
|
||||
|
||||
RealTwoPartyPlayer::RealTwoPartyPlayer(const Names& Nms, int other_player, int id) :
|
||||
RealTwoPartyPlayer::RealTwoPartyPlayer(const Names& Nms, int other_player, const string& id) :
|
||||
TwoPartyPlayer(Nms.my_num()), other_player(other_player)
|
||||
{
|
||||
is_server = Nms.my_num() > other_player;
|
||||
setup_sockets(other_player, Nms, Nms.ports[other_player], id);
|
||||
}
|
||||
|
||||
RealTwoPartyPlayer::RealTwoPartyPlayer(const Names& Nms, int other_player,
|
||||
int id_base) : RealTwoPartyPlayer(Nms, other_player, to_string(id_base))
|
||||
{
|
||||
}
|
||||
|
||||
RealTwoPartyPlayer::~RealTwoPartyPlayer()
|
||||
{
|
||||
close_client_socket(socket);
|
||||
}
|
||||
|
||||
void RealTwoPartyPlayer::setup_sockets(int other_player, const Names &nms, int portNum, int id)
|
||||
void RealTwoPartyPlayer::setup_sockets(int other_player, const Names &nms, int portNum, string id)
|
||||
{
|
||||
id += 0xF << 28;
|
||||
id += "2";
|
||||
const char *hostname = nms.names[other_player].c_str();
|
||||
ServerSocket *server = nms.server;
|
||||
if (is_server) {
|
||||
#ifdef DEBUG_NETWORKING
|
||||
fprintf(stderr, "Setting up server with id %d\n",id);
|
||||
fprintf(stderr, "Setting up server with id %s\n", id.c_str());
|
||||
#endif
|
||||
socket = server->get_connection_socket(id);
|
||||
}
|
||||
else {
|
||||
#ifdef DEBUG_NETWORKING
|
||||
fprintf(stderr, "Setting up client to %s:%d with id %d\n", hostname, portNum, id);
|
||||
fprintf(stderr, "Setting up client to %s:%d with id %s\n", hostname,
|
||||
portNum, id.c_str());
|
||||
#endif
|
||||
set_up_client_socket(socket, hostname, portNum);
|
||||
::send(socket, (unsigned char*)&id, sizeof(id));
|
||||
octetStream(id).Send(socket);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -739,6 +751,10 @@ void TwoPartyPlayer::Broadcast_Receive(vector<octetStream>& o) const
|
||||
o[1 - my_num()] = os[1];
|
||||
}
|
||||
|
||||
NamedCommStats::NamedCommStats() : sent(0)
|
||||
{
|
||||
}
|
||||
|
||||
CommStats& CommStats::operator +=(const CommStats& other)
|
||||
{
|
||||
data += other.data;
|
||||
@@ -749,6 +765,7 @@ CommStats& CommStats::operator +=(const CommStats& other)
|
||||
|
||||
NamedCommStats& NamedCommStats::operator +=(const NamedCommStats& other)
|
||||
{
|
||||
sent += other.sent;
|
||||
for (auto it = other.begin(); it != other.end(); it++)
|
||||
(*this)[it->first] += it->second;
|
||||
return *this;
|
||||
@@ -772,6 +789,7 @@ CommStats& CommStats::operator -=(const CommStats& other)
|
||||
NamedCommStats NamedCommStats::operator -(const NamedCommStats& other) const
|
||||
{
|
||||
NamedCommStats res = *this;
|
||||
res.sent = sent - other.sent;
|
||||
for (auto it = other.begin(); it != other.end(); it++)
|
||||
res[it->first] -= it->second;
|
||||
return res;
|
||||
|
||||
@@ -27,15 +27,22 @@ template<class T> class MultiPlayer;
|
||||
class Server;
|
||||
class ServerSocket;
|
||||
|
||||
/* Class to get the names off the server */
|
||||
/**
|
||||
* Network setup (hostnames and port numbers)
|
||||
*/
|
||||
class Names
|
||||
{
|
||||
friend class Player;
|
||||
friend class PlainPlayer;
|
||||
friend class RealTwoPartyPlayer;
|
||||
|
||||
vector<string> names;
|
||||
vector<int> ports;
|
||||
int nplayers;
|
||||
int portnum_base;
|
||||
int player_no;
|
||||
Server* global_server;
|
||||
|
||||
ServerSocket* server;
|
||||
|
||||
int default_port(int playerno) { return portnum_base + playerno; }
|
||||
void setup_ports();
|
||||
@@ -48,28 +55,63 @@ class Names
|
||||
|
||||
static const int DEFAULT_PORT = -1;
|
||||
|
||||
mutable ServerSocket* server;
|
||||
|
||||
/**
|
||||
* Initialize with central server
|
||||
* @param player my number
|
||||
* @param pnb base port number (server listens one below)
|
||||
* @param my_port my port number (`DEFAULT_PORT` for default,
|
||||
* which is base port number plus player number)
|
||||
* @param servername location of server
|
||||
*/
|
||||
void init(int player,int pnb,int my_port,const char* servername);
|
||||
Names(int player,int pnb,int my_port,const char* servername) : Names()
|
||||
{ init(player,pnb,my_port,servername); }
|
||||
// Set up names when we KNOW who we are going to be using before hand
|
||||
void init(int player,int pnb,vector<octet*> Nms);
|
||||
Names(int player,int pnb,vector<octet*> Nms) : Names()
|
||||
{ init(player,pnb,Nms); }
|
||||
|
||||
/**
|
||||
* Initialize with central server running on player 0
|
||||
* @param player my number
|
||||
* @param nplayers number of players
|
||||
* @param servername location of player 0
|
||||
* @param pnb base port number
|
||||
* @param my_port my port number (`DEFAULT_PORT` for default,
|
||||
* which is base port number plus player number)
|
||||
*/
|
||||
Names(int player, int nplayers, const string& servername, int pnb,
|
||||
int my_port = DEFAULT_PORT);
|
||||
|
||||
/**
|
||||
* Initialize without central server
|
||||
* @param player my number
|
||||
* @param pnb base port number
|
||||
* @param Nms locations of all parties
|
||||
*/
|
||||
void init(int player,int pnb,vector<string> Nms);
|
||||
Names(int player,int pnb,vector<string> Nms) : Names()
|
||||
{ init(player,pnb,Nms); }
|
||||
// nplayers = 0 for taking it from hostsfile
|
||||
|
||||
/**
|
||||
* Initialize from file. One party per line, format ``<hostname>[:<port>]``
|
||||
* @param player my number
|
||||
* @param pnb base port number
|
||||
* @param hostsfile filename
|
||||
* @param players number of players (0 to take from file)
|
||||
*/
|
||||
void init(int player, int pnb, const string& hostsfile, int players = 0);
|
||||
Names(int player, int pnb, const string& hostsfile) : Names()
|
||||
{ init(player, pnb, hostsfile); }
|
||||
|
||||
// initialize from command-line options
|
||||
/**
|
||||
* Initialize from command-line options
|
||||
* @param opt option parser instance
|
||||
* @param argc number of command-line arguments
|
||||
* @param argv command-line arguments
|
||||
* @param default_nplayers default number of players
|
||||
* (used if not given in arguments)
|
||||
*/
|
||||
Names(ez::ezOptionParser& opt, int argc, const char** argv,
|
||||
int default_nplayers = 2);
|
||||
|
||||
Names() : nplayers(1), portnum_base(-1), player_no(0), global_server(0), server(0) { ; }
|
||||
Names() : nplayers(1), portnum_base(-1), player_no(0), server(0) { ; }
|
||||
Names(const Names& other);
|
||||
~Names();
|
||||
|
||||
@@ -77,11 +119,6 @@ class Names
|
||||
int my_num() const { return player_no; }
|
||||
const string get_name(int i) const { return names[i]; }
|
||||
int get_portnum_base() const { return portnum_base; }
|
||||
|
||||
friend class PlayerBase;
|
||||
friend class Player;
|
||||
template<class T> friend class MultiPlayer;
|
||||
friend class RealTwoPartyPlayer;
|
||||
};
|
||||
|
||||
|
||||
@@ -108,6 +145,10 @@ struct CommStats
|
||||
class NamedCommStats : public map<string, CommStats>
|
||||
{
|
||||
public:
|
||||
size_t sent;
|
||||
|
||||
NamedCommStats();
|
||||
|
||||
NamedCommStats& operator+=(const NamedCommStats& other);
|
||||
NamedCommStats operator+(const NamedCommStats& other) const;
|
||||
NamedCommStats operator-(const NamedCommStats& other) const;
|
||||
@@ -123,17 +164,20 @@ public:
|
||||
#endif
|
||||
};
|
||||
|
||||
/**
|
||||
* Abstract class for two- and multi-player communication
|
||||
*/
|
||||
class PlayerBase
|
||||
{
|
||||
protected:
|
||||
int player_no;
|
||||
|
||||
public:
|
||||
mutable size_t sent;
|
||||
size_t& sent;
|
||||
mutable Timer timer;
|
||||
mutable NamedCommStats comm_stats;
|
||||
|
||||
PlayerBase(int player_no) : player_no(player_no), sent(0) {}
|
||||
PlayerBase(int player_no) : player_no(player_no), sent(comm_stats.sent) {}
|
||||
virtual ~PlayerBase();
|
||||
|
||||
int my_real_num() const { return player_no; }
|
||||
@@ -146,6 +190,11 @@ public:
|
||||
{ Broadcast_Receive(o); }
|
||||
};
|
||||
|
||||
/**
|
||||
* Abstract class for multi-player communication.
|
||||
* ``*_no_stats`` functions are called by their equivalents
|
||||
* after accounting for communications statistics.
|
||||
*/
|
||||
class Player : public PlayerBase
|
||||
{
|
||||
protected:
|
||||
@@ -159,7 +208,13 @@ public:
|
||||
Player(const Names& Nms);
|
||||
virtual ~Player();
|
||||
|
||||
/**
|
||||
* Get number of players
|
||||
*/
|
||||
int num_players() const { return nplayers; }
|
||||
/**
|
||||
* Get my player number
|
||||
*/
|
||||
int my_num() const { return player_no; }
|
||||
|
||||
int get_offset(int other_player) const { return positive_modulo(other_player - my_num(), num_players()); }
|
||||
@@ -173,61 +228,115 @@ public:
|
||||
// The following functions generally update the statistics
|
||||
// and then call the *_no_stats equivalent specified by a subclass.
|
||||
|
||||
// send the same to all other players
|
||||
/**
|
||||
* Send the same to all other players
|
||||
*/
|
||||
virtual void send_all(const octetStream& o) const;
|
||||
// send to a specific player
|
||||
/**
|
||||
* Send to a specific player
|
||||
*/
|
||||
void send_to(int player,const octetStream& o) const;
|
||||
virtual void send_to_no_stats(int player,const octetStream& o) const = 0;
|
||||
// receive from all other players
|
||||
/**
|
||||
* Receive from all other players.
|
||||
* Information from player 0 at ``os[0]`` etc.
|
||||
*/
|
||||
void receive_all(vector<octetStream>& os) const;
|
||||
// receive from a specific player
|
||||
/**
|
||||
* Receive from a specific player
|
||||
*/
|
||||
void receive_player(int i,octetStream& o) const;
|
||||
virtual void receive_player_no_stats(int i,octetStream& o) const = 0;
|
||||
virtual void receive_player(int i,FlexBuffer& buffer) const;
|
||||
|
||||
// Communication relative to my number
|
||||
// send to all other players by offset
|
||||
/**
|
||||
* Send to all other players by offset.
|
||||
* ``o[0]`` gets sent to the next player etc.
|
||||
*/
|
||||
void send_relative(const vector<octetStream>& o) const;
|
||||
// send to other player specified by offset
|
||||
/*
|
||||
* Send to other player specified by offset.
|
||||
* 1 stands for the next player etc.
|
||||
*/
|
||||
void send_relative(int offset, const octetStream& o) const;
|
||||
// receive from all other players by offset
|
||||
/**
|
||||
* Receive from all other players by offset.
|
||||
* ``o[0]`` will contain data from the next player etc.
|
||||
*/
|
||||
void receive_relative(vector<octetStream>& o) const;
|
||||
// receive from other palyer specified by offset
|
||||
/**
|
||||
* Receive from other player specified by offset.
|
||||
* 1 stands for the next player etc.
|
||||
*/
|
||||
void receive_relative(int offset, octetStream& o) const;
|
||||
|
||||
// exchange data with minimal memory usage
|
||||
// exchange information with one other party
|
||||
/**
|
||||
* Exchange information with one other party,
|
||||
* reusing the buffer if possible.
|
||||
*/
|
||||
void exchange(int other, const octetStream& to_send, octetStream& ot_receive) const;
|
||||
virtual void exchange_no_stats(int other, const octetStream& to_send, octetStream& ot_receive) const = 0;
|
||||
/**
|
||||
* Exchange information with one other party, reusing the buffer.
|
||||
*/
|
||||
void exchange(int other, octetStream& o) const;
|
||||
// exchange with one other partiy specified by offset
|
||||
/**
|
||||
* Exchange information with one other party specified by offset,
|
||||
* reusing the buffer if possible.
|
||||
*/
|
||||
void exchange_relative(int offset, octetStream& o) const;
|
||||
// send information to party while receiving from another by offset
|
||||
/**
|
||||
* Send information to a party while receiving from another by offset,
|
||||
* The default is to send to the next party while receiving from the previous.
|
||||
* The buffer is reused.
|
||||
*/
|
||||
void pass_around(octetStream& o, int offset = 1) const { pass_around(o, o, offset); }
|
||||
/**
|
||||
* Send information to a party while receiving from another by offset.
|
||||
* The default is to send to the next party while receiving from the previous.
|
||||
*/
|
||||
void pass_around(octetStream& to_send, octetStream& to_receive, int offset) const;
|
||||
virtual void pass_around_no_stats(const octetStream& to_send,
|
||||
octetStream& to_receive, int offset) const = 0;
|
||||
|
||||
/* Broadcast and Receive data to/from all players
|
||||
* - Assumes o[player_no] contains the thing broadcast by me
|
||||
/**
|
||||
* Broadcast and receive data to/from all players.
|
||||
* Assumes o[player_no] contains the data to be broadcast by me.
|
||||
*/
|
||||
virtual void unchecked_broadcast(vector<octetStream>& o) const;
|
||||
// broadcast with eventual verification
|
||||
/**
|
||||
* Broadcast and receive data to/from all players with eventual verification.
|
||||
* Assumes o[player_no] contains the data to be broadcast by me.
|
||||
*/
|
||||
virtual void Broadcast_Receive(vector<octetStream>& o) const;
|
||||
virtual void Broadcast_Receive_no_stats(vector<octetStream>& o) const = 0;
|
||||
|
||||
/* Run Protocol To Verify Broadcast Is Correct
|
||||
* - Resets the blk_SHA_CTX at the same time
|
||||
/**
|
||||
* Run protocol to verify broadcast is correct
|
||||
*/
|
||||
virtual void Check_Broadcast() const;
|
||||
|
||||
// send something different to all
|
||||
/**
|
||||
* Send something different to each player.
|
||||
*/
|
||||
void send_receive_all(const vector<octetStream>& to_send,
|
||||
vector<octetStream>& to_receive) const;
|
||||
// specified senders only send something different to all
|
||||
/**
|
||||
* Specified senders only send something different to each player.
|
||||
* @param senders set whether a player sends or not,
|
||||
* must be equal on all players
|
||||
* @param to_send data to send by player number
|
||||
* @param to_receive received data by player number
|
||||
*/
|
||||
void send_receive_all(const vector<bool>& senders,
|
||||
const vector<octetStream>& to_send, vector<octetStream>& to_receive) const;
|
||||
// send something different only one specified channels
|
||||
/**
|
||||
* Send something different only one specified channels.
|
||||
* @param channels ``channel[i][j]`` indicates whether party ``i`` sends
|
||||
* to party ``j``
|
||||
* @param to_send data to send by player number
|
||||
* @param to_receive received data by player number
|
||||
*/
|
||||
void send_receive_all(const vector<vector<bool>>& channels,
|
||||
const vector<octetStream>& to_send,
|
||||
vector<octetStream>& to_receive) const;
|
||||
@@ -235,11 +344,15 @@ public:
|
||||
const vector<octetStream>& to_send,
|
||||
vector<octetStream>& to_receive) const = 0;
|
||||
|
||||
// specified senders broadcast information
|
||||
/**
|
||||
* Specified senders broadcast information to specified receivers.
|
||||
* @param senders specify which parties do send
|
||||
* @param receivers specify which parties do send
|
||||
* @param os data to send at ``os[my_number]``, received data elsewhere
|
||||
*/
|
||||
virtual void partial_broadcast(const vector<bool>& senders,
|
||||
const vector<bool>& receivers,
|
||||
vector<octetStream>& os) const;
|
||||
virtual void partial_broadcast(const vector<bool>&, const vector<bool>&,
|
||||
vector<octetStream>& os) const { unchecked_broadcast(os); }
|
||||
|
||||
// dummy functions for compatibility
|
||||
virtual void request_receive(int i, octetStream& o) const { (void)i; (void)o; }
|
||||
@@ -247,6 +360,11 @@ public:
|
||||
{ receive_player(i, o); }
|
||||
};
|
||||
|
||||
/**
|
||||
* Multi-player communication helper class.
|
||||
* ``T = int`` for unencrypted BSD sockets and
|
||||
* ``T = ssl_socket*`` for Boost SSL streams.
|
||||
*/
|
||||
template<class T>
|
||||
class MultiPlayer : public Player
|
||||
{
|
||||
@@ -254,16 +372,12 @@ protected:
|
||||
vector<T> sockets;
|
||||
T send_to_self_socket;
|
||||
|
||||
void setup_sockets(const vector<string>& names,const vector<int>& ports,int id_base,ServerSocket& server);
|
||||
|
||||
T socket_to_send(int player) const { return player == player_no ? send_to_self_socket : sockets[player]; }
|
||||
|
||||
friend class CryptoPlayer;
|
||||
|
||||
public:
|
||||
// The offset is used for the multi-threaded call, to ensure different
|
||||
// portnum bases in each thread
|
||||
MultiPlayer(const Names& Nms,int id_base=0);
|
||||
MultiPlayer(const Names& Nms);
|
||||
|
||||
virtual ~MultiPlayer();
|
||||
|
||||
@@ -296,16 +410,34 @@ public:
|
||||
vector<octetStream>& to_receive) const;
|
||||
};
|
||||
|
||||
typedef MultiPlayer<int> PlainPlayer;
|
||||
/**
|
||||
* Plaintext multi-player communication
|
||||
*/
|
||||
class PlainPlayer : public MultiPlayer<int>
|
||||
{
|
||||
void setup_sockets(const vector<string>& names, const vector<int>& ports,
|
||||
const string& id_base, ServerSocket& server);
|
||||
|
||||
public:
|
||||
/**
|
||||
* Start a new set of unencrypted connections.
|
||||
* @param Nms network setup
|
||||
* @param id unique identifier
|
||||
*/
|
||||
PlainPlayer(const Names& Nms, const string& id);
|
||||
// legacy interface
|
||||
PlainPlayer(const Names& Nms, int id_base = 0);
|
||||
~PlainPlayer();
|
||||
};
|
||||
|
||||
|
||||
class ThreadPlayer : public MultiPlayer<int>
|
||||
class ThreadPlayer : public PlainPlayer
|
||||
{
|
||||
public:
|
||||
mutable vector<Receiver<int>*> receivers;
|
||||
mutable vector<Sender<int>*> senders;
|
||||
|
||||
ThreadPlayer(const Names& Nms,int id_base=0);
|
||||
ThreadPlayer(const Names& Nms, const string& id_base);
|
||||
virtual ~ThreadPlayer();
|
||||
|
||||
void request_receive(int i, octetStream& o) const;
|
||||
@@ -335,14 +467,16 @@ class RealTwoPartyPlayer : public TwoPartyPlayer
|
||||
{
|
||||
private:
|
||||
// setup sockets for comm. with only one other player
|
||||
void setup_sockets(int other_player, const Names &nms, int portNum, int id);
|
||||
void setup_sockets(int other_player, const Names &nms, int portNum, string id);
|
||||
|
||||
int socket;
|
||||
bool is_server;
|
||||
int other_player;
|
||||
|
||||
public:
|
||||
RealTwoPartyPlayer(const Names& Nms, int other_player, int pn_offset=0);
|
||||
RealTwoPartyPlayer(const Names& Nms, int other_player, const string& id);
|
||||
// legacy
|
||||
RealTwoPartyPlayer(const Names& Nms, int other_player, int id_base = 0);
|
||||
~RealTwoPartyPlayer();
|
||||
|
||||
void send(octetStream& o) const;
|
||||
|
||||
@@ -116,7 +116,7 @@ void Server::start()
|
||||
#ifdef DEBUG_NETWORKING
|
||||
cerr << "Waiting for player " << i << endl;
|
||||
#endif
|
||||
socket_num[i] = server.get_connection_socket(i);
|
||||
socket_num[i] = server.get_connection_socket("P" + to_string(i));
|
||||
#ifdef DEBUG_NETWORKING
|
||||
cerr << "Connected to player " << i << endl;
|
||||
#endif
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include <Networking/sockets.h>
|
||||
#include "Tools/Exceptions.h"
|
||||
#include "Tools/time-func.h"
|
||||
#include "Tools/octetStream.h"
|
||||
|
||||
#include <netinet/ip.h>
|
||||
#include <netinet/tcp.h>
|
||||
@@ -53,7 +54,10 @@ ServerSocket::ServerSocket(int Portnum) : portnum(Portnum), thread(0)
|
||||
while (fl!=0 and timer.elapsed() < 600)
|
||||
{ fl=::bind(main_socket, (struct sockaddr *)&serv, sizeof(struct sockaddr));
|
||||
if (fl != 0)
|
||||
{ cerr << "Binding to socket on " << my_name << ":" << Portnum << " failed, trying again in a second ..." << endl;
|
||||
{
|
||||
cerr << "Binding to socket on " << my_name << ":" << Portnum
|
||||
<< " failed (" << strerror(errno)
|
||||
<< "), trying again in a second ..." << endl;
|
||||
sleep(1);
|
||||
}
|
||||
#ifdef DEBUG_NETWORKING
|
||||
@@ -109,11 +113,11 @@ ServerSocket::~ServerSocket()
|
||||
void ServerSocket::wait_for_client_id(int socket, struct sockaddr dest)
|
||||
{
|
||||
(void) dest;
|
||||
int client_id;
|
||||
try
|
||||
{
|
||||
receive(socket, (unsigned char*) &client_id, sizeof(client_id));
|
||||
process_connection(socket, client_id);
|
||||
octetStream client_id;
|
||||
client_id.Receive(socket);
|
||||
process_connection(socket, client_id.str());
|
||||
}
|
||||
catch (closed_connection&)
|
||||
{
|
||||
@@ -138,9 +142,13 @@ void ServerSocket::accept_clients()
|
||||
int consocket = accept(main_socket, (struct sockaddr *)&dest, (socklen_t*) &socksize);
|
||||
if (consocket<0) { error("set_up_socket:accept"); }
|
||||
|
||||
int client_id;
|
||||
if (receive_all_or_nothing(consocket, (unsigned char*)&client_id, sizeof(client_id)))
|
||||
process_connection(consocket, client_id);
|
||||
octetStream client_id;
|
||||
char buf[1];
|
||||
if (recv(consocket, buf, 1, MSG_PEEK | MSG_DONTWAIT) > 0)
|
||||
{
|
||||
client_id.Receive(consocket);
|
||||
process_connection(consocket, client_id.str());
|
||||
}
|
||||
else
|
||||
{
|
||||
#ifdef DEBUG_NETWORKING
|
||||
@@ -162,7 +170,7 @@ void ServerSocket::accept_clients()
|
||||
}
|
||||
}
|
||||
|
||||
void ServerSocket::process_connection(int consocket, int client_id)
|
||||
void ServerSocket::process_connection(int consocket, const string& client_id)
|
||||
{
|
||||
data_signal.lock();
|
||||
#ifdef DEBUG_NETWORKING
|
||||
@@ -175,7 +183,7 @@ void ServerSocket::process_connection(int consocket, int client_id)
|
||||
data_signal.unlock();
|
||||
}
|
||||
|
||||
int ServerSocket::get_connection_socket(int id)
|
||||
int ServerSocket::get_connection_socket(const string& id)
|
||||
{
|
||||
data_signal.lock();
|
||||
if (used.find(id) != used.end())
|
||||
@@ -208,14 +216,14 @@ void AnonymousServerSocket::init()
|
||||
pthread_create(&thread, 0, anonymous_accept_thread, this);
|
||||
}
|
||||
|
||||
void AnonymousServerSocket::process_client(int client_id)
|
||||
void AnonymousServerSocket::process_client(const string& client_id)
|
||||
{
|
||||
if (clients.find(client_id) != clients.end())
|
||||
close_client_socket(clients[client_id]);
|
||||
client_connection_queue.push(client_id);
|
||||
}
|
||||
|
||||
int AnonymousServerSocket::get_connection_socket(int& client_id)
|
||||
int AnonymousServerSocket::get_connection_socket(string& client_id)
|
||||
{
|
||||
data_signal.lock();
|
||||
|
||||
@@ -230,7 +238,7 @@ int AnonymousServerSocket::get_connection_socket(int& client_id)
|
||||
return client_socket;
|
||||
}
|
||||
|
||||
void AnonymousServerSocket::remove_client(int client_id)
|
||||
void AnonymousServerSocket::remove_client(const string& client_id)
|
||||
{
|
||||
clients.erase(client_id);
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <queue>
|
||||
#include <string>
|
||||
using namespace std;
|
||||
|
||||
#include <pthread.h>
|
||||
@@ -23,8 +24,8 @@ class ServerSocket
|
||||
{
|
||||
protected:
|
||||
int main_socket, portnum;
|
||||
map<int,int> clients;
|
||||
std::set<int> used;
|
||||
map<string,int> clients;
|
||||
std::set<string> used;
|
||||
Signal data_signal;
|
||||
pthread_t thread;
|
||||
|
||||
@@ -33,9 +34,9 @@ protected:
|
||||
// disable copying
|
||||
ServerSocket(const ServerSocket& other);
|
||||
|
||||
void process_connection(int socket, int client_id);
|
||||
void process_connection(int socket, const string& client_id);
|
||||
|
||||
virtual void process_client(int) {}
|
||||
virtual void process_client(const string&) {}
|
||||
|
||||
public:
|
||||
ServerSocket(int Portnum);
|
||||
@@ -47,9 +48,9 @@ public:
|
||||
|
||||
void wait_for_client_id(int socket, struct sockaddr dest);
|
||||
|
||||
// This depends on clients sending their id as int.
|
||||
// This depends on clients sending their id.
|
||||
// Has to be thread-safe.
|
||||
int get_connection_socket(int number);
|
||||
int get_connection_socket(const string& id);
|
||||
};
|
||||
|
||||
/*
|
||||
@@ -59,9 +60,9 @@ class AnonymousServerSocket : public ServerSocket
|
||||
{
|
||||
private:
|
||||
// No. of accepted connections in this instance
|
||||
queue<int> client_connection_queue;
|
||||
queue<string> client_connection_queue;
|
||||
|
||||
void process_client(int client_id);
|
||||
void process_client(const string& client_id);
|
||||
|
||||
public:
|
||||
AnonymousServerSocket(int Portnum) :
|
||||
@@ -69,9 +70,9 @@ public:
|
||||
void init();
|
||||
|
||||
// Get socket and id for the last client who connected
|
||||
int get_connection_socket(int& client_id);
|
||||
int get_connection_socket(string& client_id);
|
||||
|
||||
void remove_client(int client_id);
|
||||
void remove_client(const string& client_id);
|
||||
};
|
||||
|
||||
#endif /* NETWORKING_SERVERSOCKET_H_ */
|
||||
|
||||
@@ -116,7 +116,6 @@ public:
|
||||
|
||||
mac_key_type get_mac_key() const { return mac_key; }
|
||||
|
||||
size_t data_sent();
|
||||
NamedCommStats comm_stats();
|
||||
};
|
||||
|
||||
@@ -210,17 +209,6 @@ public:
|
||||
void generateTriples();
|
||||
};
|
||||
|
||||
template<class T>
|
||||
size_t OTTripleGenerator<T>::data_sent()
|
||||
{
|
||||
size_t res = 0;
|
||||
if (parentPlayer != &globalPlayer)
|
||||
res = globalPlayer.sent;
|
||||
for (auto& player : players)
|
||||
res += player->sent;
|
||||
return res;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
NamedCommStats OTTripleGenerator<T>::comm_stats()
|
||||
{
|
||||
|
||||
@@ -72,7 +72,7 @@ OTTripleGenerator<T>::OTTripleGenerator(const OTTripleSetup& setup,
|
||||
const Names& names, int thread_num, int _nTriples, int nloops,
|
||||
MascotParams& machine, mac_key_type mac_key, Player* parentPlayer) :
|
||||
globalPlayer(parentPlayer ? *parentPlayer : *new PlainPlayer(names,
|
||||
- thread_num * names.num_players() * names.num_players())),
|
||||
to_string(thread_num))),
|
||||
parentPlayer(parentPlayer),
|
||||
thread_num(thread_num),
|
||||
mac_key(mac_key),
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
*/
|
||||
|
||||
#include "BaseMachine.h"
|
||||
#include "OnlineOptions.h"
|
||||
#include "Math/Setup.h"
|
||||
|
||||
#include <iostream>
|
||||
@@ -90,10 +91,8 @@ void BaseMachine::load_schedule(const string& progname, bool load_bytecode)
|
||||
|
||||
void BaseMachine::print_compiler()
|
||||
{
|
||||
#ifdef VERBOSE
|
||||
if (compiler.size() != 0)
|
||||
if (compiler.size() != 0 and OnlineOptions::singleton.verbose)
|
||||
cerr << "Compiler: " << compiler << endl;
|
||||
#endif
|
||||
}
|
||||
|
||||
void BaseMachine::load_program(const string& threadname, const string& filename)
|
||||
|
||||
@@ -20,6 +20,8 @@ class Binary_File_IO
|
||||
{
|
||||
public:
|
||||
|
||||
static string filename(int my_number);
|
||||
|
||||
/*
|
||||
* Append the buffer values as binary to the filename.
|
||||
* Throws file_error.
|
||||
|
||||
@@ -6,6 +6,13 @@
|
||||
* Intended for application specific file IO.
|
||||
*/
|
||||
|
||||
inline string Binary_File_IO::filename(int my_number)
|
||||
{
|
||||
string dir = "Persistence";
|
||||
mkdir_p(dir.c_str());
|
||||
return dir + "/Transactions-P" + to_string(my_number) + ".data";
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Binary_File_IO::write_to_file(const string filename, const vector< T >& buffer)
|
||||
{
|
||||
|
||||
@@ -58,6 +58,7 @@ public:
|
||||
vector< array<long long, N_DATA_FIELD_TYPE> > inputs;
|
||||
array<map<DataTag, long long>, N_DATA_FIELD_TYPE> extended;
|
||||
map<pair<bool, int>, long long> edabits;
|
||||
map<array<int, 3>, long long> matmuls;
|
||||
|
||||
DataPositions(int num_players = 0);
|
||||
DataPositions(const Player& P) : DataPositions(P.num_players()) {}
|
||||
@@ -126,7 +127,7 @@ public:
|
||||
virtual void prune() {}
|
||||
virtual void purge() {}
|
||||
|
||||
virtual size_t data_sent() { return 0; }
|
||||
virtual size_t data_sent() { return comm_stats().sent; }
|
||||
virtual NamedCommStats comm_stats() { return {}; }
|
||||
|
||||
virtual void get_three_no_count(Dtype dtype, T& a, T& b, T& c) = 0;
|
||||
@@ -288,7 +289,7 @@ class Data_Files
|
||||
|
||||
void reset_usage() { usage.reset(); skipped.reset(); }
|
||||
|
||||
size_t data_sent() { return DataFp.data_sent() + DataF2.data_sent(); }
|
||||
NamedCommStats comm_stats() { return DataFp.comm_stats() + DataF2.comm_stats(); }
|
||||
};
|
||||
|
||||
template<class T> inline
|
||||
|
||||
@@ -46,7 +46,10 @@ int ExternalClients::get_client_connection(int portnum_base)
|
||||
}
|
||||
cerr << "Thread " << this_thread::get_id() << " found server." << endl;
|
||||
int client_id, socket;
|
||||
socket = client_connection_servers[portnum_base]->get_connection_socket(client_id);
|
||||
string client;
|
||||
socket = client_connection_servers[portnum_base]->get_connection_socket(
|
||||
client);
|
||||
client_id = stoi(client);
|
||||
if (ctx == 0)
|
||||
ctx = new ssl_ctx("P" + to_string(get_party_num()));
|
||||
external_client_sockets[client_id] = new ssl_socket(io_service, *ctx, socket,
|
||||
@@ -63,7 +66,8 @@ void ExternalClients::close_connection(int client_id)
|
||||
throw runtime_error("client id not active: " + to_string(client_id));
|
||||
delete external_client_sockets[client_id];
|
||||
external_client_sockets.erase(client_id);
|
||||
client_connection_servers[client_ports[client_id]]->remove_client(client_id);
|
||||
client_connection_servers[client_ports[client_id]]->remove_client(
|
||||
to_string(client_id));
|
||||
}
|
||||
|
||||
int ExternalClients::get_party_num()
|
||||
|
||||
@@ -67,6 +67,7 @@ enum
|
||||
THRESHOLD = 0xE3,
|
||||
PLAYERID = 0xE4,
|
||||
USE_EDABIT = 0xE5,
|
||||
USE_MATMUL = 0x1F,
|
||||
// Addition
|
||||
ADDC = 0x20,
|
||||
ADDS = 0x21,
|
||||
|
||||
@@ -92,9 +92,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case DIVINT:
|
||||
case CONDPRINTPLAIN:
|
||||
case INPUTMASKREG:
|
||||
r[0]=get_int(s);
|
||||
r[1]=get_int(s);
|
||||
r[2]=get_int(s);
|
||||
get_ints(r, s, 3);
|
||||
break;
|
||||
// instructions with 2 register operands
|
||||
case LDMCI:
|
||||
@@ -130,8 +128,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case DABIT:
|
||||
case SHUFFLE:
|
||||
case ACCEPTCLIENTCONNECTION:
|
||||
r[0]=get_int(s);
|
||||
r[1]=get_int(s);
|
||||
get_ints(r, s, 2);
|
||||
break;
|
||||
// instructions with 1 register operand
|
||||
case BIT:
|
||||
@@ -157,6 +154,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case PLAYERID:
|
||||
case LISTEN:
|
||||
case CLOSECLIENTCONNECTION:
|
||||
case CRASH:
|
||||
r[0]=get_int(s);
|
||||
break;
|
||||
// instructions with 2 registers + 1 integer operand
|
||||
@@ -205,8 +203,11 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case STOPPRIVATEOUTPUT:
|
||||
case GSTOPPRIVATEOUTPUT:
|
||||
case DIGESTC:
|
||||
r[0]=get_int(s);
|
||||
r[1]=get_int(s);
|
||||
get_ints(r, s, 2);
|
||||
n = get_int(s);
|
||||
break;
|
||||
case USE_MATMUL:
|
||||
get_ints(r, s, 3);
|
||||
n = get_int(s);
|
||||
break;
|
||||
// instructions with 1 register + 1 integer operand
|
||||
@@ -254,7 +255,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
break;
|
||||
// instructions with no operand
|
||||
case TIME:
|
||||
case CRASH:
|
||||
case STARTGRIND:
|
||||
case STOPGRIND:
|
||||
case CHECK:
|
||||
@@ -320,8 +320,9 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case READSOCKETC:
|
||||
case READSOCKETS:
|
||||
case READSOCKETINT:
|
||||
num_var_args = get_int(s) - 1;
|
||||
num_var_args = get_int(s) - 2;
|
||||
r[0] = get_int(s);
|
||||
n = get_int(s);
|
||||
get_vector(num_var_args, start, s);
|
||||
break;
|
||||
|
||||
@@ -329,9 +330,10 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case WRITESOCKETC:
|
||||
case WRITESOCKETSHARE:
|
||||
case WRITESOCKETINT:
|
||||
num_var_args = get_int(s) - 2;
|
||||
num_var_args = get_int(s) - 3;
|
||||
r[0] = get_int(s);
|
||||
r[1] = get_int(s);
|
||||
n = get_int(s);
|
||||
get_vector(num_var_args, start, s);
|
||||
break;
|
||||
case WRITESOCKETS:
|
||||
@@ -383,9 +385,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
// subtract extra argument
|
||||
num_var_args = get_int(s) - 1;
|
||||
s.read((char*)r, sizeof(r));
|
||||
start.resize(num_var_args);
|
||||
for (int i = 0; i < num_var_args; i++)
|
||||
{ start[i] = get_int(s); }
|
||||
get_vector(num_var_args, start, s);
|
||||
break;
|
||||
case USE_PREP:
|
||||
case GUSE_PREP:
|
||||
@@ -431,6 +431,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case CONVCBITVEC:
|
||||
case CONVCBIT2S:
|
||||
case NOTS:
|
||||
case NOTCB:
|
||||
n = get_int(s);
|
||||
get_ints(r, s, 2);
|
||||
break;
|
||||
@@ -497,6 +498,9 @@ bool Instruction::get_offline_data_usage(DataPositions& usage)
|
||||
case USE_EDABIT:
|
||||
usage.edabits[{r[0], r[1]}] = n;
|
||||
return int(n) >= 0;
|
||||
case USE_MATMUL:
|
||||
usage.matmuls[{{r[0], r[1], r[2]}}] = n;
|
||||
return int(n) >= 0;
|
||||
case USE_PREP:
|
||||
usage.extended[DATA_INT][r] = n;
|
||||
return int(n) >= 0;
|
||||
@@ -552,6 +556,7 @@ int BaseInstruction::get_reg_type() const
|
||||
case USE_PREP:
|
||||
case GUSE_PREP:
|
||||
case USE_EDABIT:
|
||||
case USE_MATMUL:
|
||||
case RUN_TAPE:
|
||||
case CISC:
|
||||
// those use r[] not for registers
|
||||
@@ -709,6 +714,7 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
|
||||
}
|
||||
case ANDM:
|
||||
case NOTS:
|
||||
case NOTCB:
|
||||
size = DIV_CEIL(n, 64);
|
||||
break;
|
||||
case CONVCBIT2S:
|
||||
@@ -735,6 +741,14 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
|
||||
offset = 2;
|
||||
skip = 4;
|
||||
break;
|
||||
case READSOCKETS:
|
||||
case READSOCKETC:
|
||||
case READSOCKETINT:
|
||||
case WRITESOCKETSHARE:
|
||||
case WRITESOCKETC:
|
||||
case WRITESOCKETINT:
|
||||
size = n;
|
||||
break;
|
||||
}
|
||||
|
||||
if (skip > 0)
|
||||
@@ -1016,11 +1030,11 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
Proc.Procp.matmuls(Proc.Procp.get_S(), *this, r[1], r[2]);
|
||||
return;
|
||||
case MATMULSM:
|
||||
Proc.Procp.matmulsm(Proc.machine.Mp.MS, *this, Proc.read_Ci(r[1]),
|
||||
Proc.read_Ci(r[2]));
|
||||
Proc.Procp.protocol.matmulsm(Proc.Procp, Proc.machine.Mp.MS, *this,
|
||||
Proc.read_Ci(r[1]), Proc.read_Ci(r[2]));
|
||||
return;
|
||||
case CONV2DS:
|
||||
Proc.Procp.conv2ds(*this);
|
||||
Proc.Procp.protocol.conv2ds(Proc.Procp, *this);
|
||||
return;
|
||||
case TRUNC_PR:
|
||||
Proc.Procp.protocol.trunc_pr(start, size, Proc.Procp);
|
||||
@@ -1096,6 +1110,7 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
case USE:
|
||||
case USE_INP:
|
||||
case USE_EDABIT:
|
||||
case USE_MATMUL:
|
||||
case USE_PREP:
|
||||
case GUSE_PREP:
|
||||
break;
|
||||
@@ -1117,7 +1132,8 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
Proc.machine.join_tape(r[0]);
|
||||
break;
|
||||
case CRASH:
|
||||
throw crash_requested();
|
||||
if (Proc.read_Ci(r[0]))
|
||||
throw crash_requested();
|
||||
break;
|
||||
case STARTGRIND:
|
||||
CALLGRIND_START_INSTRUMENTATION;
|
||||
@@ -1160,25 +1176,25 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
Proc.external_clients.close_connection(Proc.read_Ci(r[0]));
|
||||
break;
|
||||
case READSOCKETINT:
|
||||
Proc.read_socket_ints(Proc.read_Ci(r[0]), start);
|
||||
Proc.read_socket_ints(Proc.read_Ci(r[0]), start, n);
|
||||
break;
|
||||
case READSOCKETC:
|
||||
Proc.read_socket_vector(Proc.read_Ci(r[0]), start);
|
||||
Proc.read_socket_vector(Proc.read_Ci(r[0]), start, n);
|
||||
break;
|
||||
case READSOCKETS:
|
||||
// read shares and MAC shares
|
||||
Proc.read_socket_private(Proc.read_Ci(r[0]), start, true);
|
||||
Proc.read_socket_private(Proc.read_Ci(r[0]), start, n, true);
|
||||
break;
|
||||
case WRITESOCKETINT:
|
||||
Proc.write_socket(INT, Proc.read_Ci(r[0]), r[1], start);
|
||||
Proc.write_socket(INT, Proc.read_Ci(r[0]), r[1], start, n);
|
||||
break;
|
||||
case WRITESOCKETC:
|
||||
Proc.write_socket(CINT, Proc.read_Ci(r[0]), r[1], start);
|
||||
Proc.write_socket(CINT, Proc.read_Ci(r[0]), r[1], start, n);
|
||||
break;
|
||||
case WRITESOCKETSHARE:
|
||||
// Send only shares, no MACs
|
||||
// N.B. doesn't make sense to have a corresponding read instruction for this
|
||||
Proc.write_socket(SINT, Proc.read_Ci(r[0]), r[1], start);
|
||||
Proc.write_socket(SINT, Proc.read_Ci(r[0]), r[1], start, n);
|
||||
break;
|
||||
case WRITEFILESHARE:
|
||||
// Write shares to file system
|
||||
|
||||
@@ -72,7 +72,6 @@ class Machine : public BaseMachine
|
||||
|
||||
OnlineOptions opts;
|
||||
|
||||
atomic<size_t> data_sent;
|
||||
NamedCommStats comm_stats;
|
||||
ExecutionStats stats;
|
||||
|
||||
@@ -89,6 +88,11 @@ class Machine : public BaseMachine
|
||||
void fill_buffers(int thread_number, int tape_number,
|
||||
Preprocessing<sint> *prep,
|
||||
Preprocessing<typename sint::bit_type> *bit_prep);
|
||||
template<int = 0>
|
||||
void fill_matmul(int thread_numbber, int tape_number,
|
||||
Preprocessing<sint> *prep, true_type);
|
||||
template<int = 0>
|
||||
void fill_matmul(int, int, Preprocessing<sint>*, false_type) {}
|
||||
DataPositions run_tape(int thread_number, int tape_number, int arg);
|
||||
DataPositions join_tape(int thread_number);
|
||||
void run();
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "Memory.hpp"
|
||||
#include "Online-Thread.hpp"
|
||||
#include "Protocols/Hemi.hpp"
|
||||
|
||||
#include "Tools/Exceptions.h"
|
||||
|
||||
@@ -30,8 +31,7 @@ Machine<sint, sgf2n>::Machine(int my_number, Names& playerNames,
|
||||
: my_number(my_number), N(playerNames),
|
||||
direct(direct), opening_sum(opening_sum),
|
||||
receive_threads(receive_threads), max_broadcast(max_broadcast),
|
||||
use_encryption(use_encryption), live_prep(live_prep), opts(opts),
|
||||
data_sent(0)
|
||||
use_encryption(use_encryption), live_prep(live_prep), opts(opts)
|
||||
{
|
||||
if (opening_sum < 2)
|
||||
this->opening_sum = N.num_players();
|
||||
@@ -49,10 +49,11 @@ Machine<sint, sgf2n>::Machine(int my_number, Names& playerNames,
|
||||
// make directory for outputs if necessary
|
||||
mkdir_p(PREP_DIR);
|
||||
|
||||
string id = "machine";
|
||||
if (use_encryption)
|
||||
P = new CryptoPlayer(N, 0xF00);
|
||||
P = new CryptoPlayer(N, id);
|
||||
else
|
||||
P = new PlainPlayer(N, 0xF00);
|
||||
P = new PlainPlayer(N, id);
|
||||
|
||||
if (opts.live_prep)
|
||||
{
|
||||
@@ -96,6 +97,13 @@ Machine<sint, sgf2n>::Machine(int my_number, Names& playerNames,
|
||||
|
||||
load_schedule(progname_str);
|
||||
|
||||
// remove persistence if necessary
|
||||
for (auto& prog : progs)
|
||||
{
|
||||
if (prog.writes_persistance)
|
||||
ofstream(Binary_File_IO::filename(my_number), ios::out);
|
||||
}
|
||||
|
||||
#ifdef VERBOSE
|
||||
progs[0].print_offline_cost();
|
||||
#endif
|
||||
@@ -223,6 +231,49 @@ void Machine<sint, sgf2n>::fill_buffers(int thread_number, int tape_number,
|
||||
{
|
||||
#ifdef VERBOSE_CENTRAL
|
||||
cerr << "Problem with central bit triple preprocessing: " << e.what() << endl;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
if (not HemiOptions::singleton.plain_matmul)
|
||||
fill_matmul(thread_number, tape_number, prep, sint::triple_matmul);
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
template<int>
|
||||
void Machine<sint, sgf2n>::fill_matmul(int thread_number, int tape_number,
|
||||
Preprocessing<sint>* prep, true_type)
|
||||
{
|
||||
auto usage = progs[tape_number].get_offline_data_used();
|
||||
for (auto it = usage.matmuls.begin(); it != usage.matmuls.end(); it++)
|
||||
{
|
||||
try
|
||||
{
|
||||
auto& source_proc = *dynamic_cast<BufferPrep<sint>&>(*prep).proc;
|
||||
int max_inner = opts.batch_size;
|
||||
int max_cols = opts.batch_size;
|
||||
for (int j = 0; j < it->first[1]; j += max_inner)
|
||||
{
|
||||
for (int k = 0; k < it->first[2]; k += max_cols)
|
||||
{
|
||||
auto subdim = it->first;
|
||||
subdim[1] = min(subdim[1] - j, max_inner);
|
||||
subdim[2] = min(subdim[2] - k, max_cols);
|
||||
auto& source =
|
||||
dynamic_cast<Hemi<sint>&>(source_proc.protocol).get_matrix_prep(
|
||||
subdim, source_proc);
|
||||
auto& dest =
|
||||
dynamic_cast<Hemi<sint>&>(tinfo[thread_number].processor->Procp.protocol).get_matrix_prep(
|
||||
subdim, tinfo[thread_number].processor->Procp);
|
||||
for (int i = 0; i < it->second; i++)
|
||||
dest.push_triple(source.get_triple_no_count(-1));
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (bad_cast& e)
|
||||
{
|
||||
#ifdef VERBOSE_CENTRAL
|
||||
cerr << "Problem with central matmul preprocessing: " << e.what() << endl;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@@ -320,20 +371,27 @@ void Machine<sint, sgf2n>::run()
|
||||
for (unsigned int i = 0; i < join_timer.size(); i++)
|
||||
cerr << "Join timer: " << i << " " << join_timer[i].elapsed() << endl;
|
||||
cerr << "Finish timer: " << finish_timer.elapsed() << endl;
|
||||
cerr << "Process timer: " << proc_timer.elapsed() << endl;
|
||||
#endif
|
||||
|
||||
if (opts.verbose)
|
||||
{
|
||||
cerr << "Communication details "
|
||||
"(rounds in parallel threads counted double):" << endl;
|
||||
comm_stats.print();
|
||||
cerr << "CPU time = " << proc_timer.elapsed() << endl;
|
||||
}
|
||||
|
||||
print_timers();
|
||||
|
||||
size_t rounds = 0;
|
||||
for (auto& x : comm_stats)
|
||||
rounds += x.second.rounds;
|
||||
cerr << "Data sent = " << data_sent / 1e6 << " MB in ~" << rounds
|
||||
cerr << "Data sent = " << comm_stats.sent / 1e6 << " MB in ~" << rounds
|
||||
<< " rounds (party " << my_number << ")" << endl;
|
||||
|
||||
auto& P = *this->P;
|
||||
Bundle<octetStream> bundle(P);
|
||||
bundle.mine.store(data_sent.load());
|
||||
bundle.mine.store(comm_stats.sent);
|
||||
P.Broadcast_Receive_no_stats(bundle);
|
||||
size_t global = 0;
|
||||
for (auto& os : bundle)
|
||||
@@ -384,10 +442,11 @@ void Machine<sint, sgf2n>::run()
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef VERBOSE
|
||||
cerr << "Actual cost of program:" << endl;
|
||||
pos.print_cost();
|
||||
#endif
|
||||
if (opts.verbose)
|
||||
{
|
||||
cerr << "Actual cost of program:" << endl;
|
||||
pos.print_cost();
|
||||
}
|
||||
|
||||
if (pos.any_more(progs[0].get_offline_data_used())
|
||||
and not progs[0].usage_unknown())
|
||||
|
||||
@@ -15,7 +15,7 @@ OfflineMachine<W>::OfflineMachine(int argc, const char** argv,
|
||||
ez::ezOptionParser& opt, OnlineOptions& online_opts, V,
|
||||
int nplayers) :
|
||||
W(argc, argv, opt, online_opts, V(), nplayers), playerNames(
|
||||
W::playerNames), P(*this->new_player())
|
||||
W::playerNames), P(*this->new_player("machine"))
|
||||
{
|
||||
machine.load_schedule(online_opts.progname, false);
|
||||
Program program(playerNames.num_players());
|
||||
|
||||
@@ -49,26 +49,27 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
|
||||
fprintf(stderr, "\tI am in thread %d\n",num);
|
||||
#endif
|
||||
Player* player;
|
||||
string id = "thread" + to_string(num);
|
||||
if (machine.use_encryption)
|
||||
{
|
||||
#ifdef VERBOSE_OPTIONS
|
||||
cerr << "Using encrypted single-threaded communication" << endl;
|
||||
#endif
|
||||
player = new CryptoPlayer(*(tinfo->Nms), num << 16);
|
||||
player = new CryptoPlayer(*(tinfo->Nms), id);
|
||||
}
|
||||
else if (!machine.receive_threads or machine.direct)
|
||||
{
|
||||
#ifdef VERBOSE_OPTIONS
|
||||
cerr << "Using single-threaded receiving" << endl;
|
||||
#endif
|
||||
player = new PlainPlayer(*(tinfo->Nms), num << 16);
|
||||
player = new PlainPlayer(*(tinfo->Nms), id);
|
||||
}
|
||||
else
|
||||
{
|
||||
#ifdef VERBOSE_OPTIONS
|
||||
cerr << "Using player-specific threads for receiving" << endl;
|
||||
#endif
|
||||
player = new ThreadPlayer(*(tinfo->Nms), num << 16);
|
||||
player = new ThreadPlayer(*(tinfo->Nms), id);
|
||||
}
|
||||
Player& P = *player;
|
||||
#ifdef DEBUG_THREADS
|
||||
@@ -238,6 +239,16 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
|
||||
*(Zp_Data*) job.supply);
|
||||
queues->finished(job);
|
||||
}
|
||||
else if (job.type == CIPHER_PLAIN_MULT_JOB)
|
||||
{
|
||||
cipher_plain_mult(job, sint::triple_matmul);
|
||||
queues->finished(job);
|
||||
}
|
||||
else if (job.type == MATRX_RAND_MULT_JOB)
|
||||
{
|
||||
matrix_rand_mult(job, sint::triple_matmul);
|
||||
queues->finished(job);
|
||||
}
|
||||
else
|
||||
{ // RUN PROGRAM
|
||||
#ifdef DEBUG_THREADS
|
||||
@@ -303,16 +314,15 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
|
||||
#endif
|
||||
|
||||
// wind down thread by thread
|
||||
size_t prep_sent = Proc.DataF.data_sent();
|
||||
prep_sent += Proc.share_thread.DataF.data_sent();
|
||||
prep_sent += Proc.Procp.bit_prep.data_sent();
|
||||
auto prep_stats = Proc.DataF.comm_stats();
|
||||
prep_stats += Proc.share_thread.DataF.comm_stats();
|
||||
prep_stats += Proc.Procp.bit_prep.comm_stats();
|
||||
for (auto& x : Proc.Procp.personal_bit_preps)
|
||||
prep_sent += x->data_sent();
|
||||
prep_stats += x->comm_stats();
|
||||
machine.stats += Proc.stats;
|
||||
delete processor;
|
||||
|
||||
machine.data_sent += P.sent + prep_sent;
|
||||
machine.comm_stats += P.comm_stats;
|
||||
machine.comm_stats += P.comm_stats + prep_stats;
|
||||
queues->finished(actual_usage);
|
||||
|
||||
delete MC2;
|
||||
|
||||
@@ -20,7 +20,6 @@ protected:
|
||||
int lg2, opening_sum, max_broadcast;
|
||||
|
||||
Names playerNames;
|
||||
Server* server;
|
||||
|
||||
bool use_encryption, receive_threads;
|
||||
|
||||
@@ -38,7 +37,7 @@ public:
|
||||
template<class T, class U>
|
||||
int run();
|
||||
|
||||
Player* new_player(int id_base = 0);
|
||||
Player* new_player(const string& id_base);
|
||||
};
|
||||
|
||||
class DishonestMajorityMachine : public OnlineMachine
|
||||
|
||||
@@ -28,7 +28,7 @@ template<class V>
|
||||
OnlineMachine::OnlineMachine(int argc, const char** argv, ez::ezOptionParser& opt,
|
||||
OnlineOptions& online_opts, int nplayers, V) :
|
||||
argc(argc), argv(argv), online_opts(online_opts), lg2(0),
|
||||
opening_sum(0), max_broadcast(0), server(0),
|
||||
opening_sum(0), max_broadcast(0),
|
||||
use_encryption(false), receive_threads(false),
|
||||
opt(opt), nplayers(nplayers)
|
||||
{
|
||||
@@ -192,7 +192,6 @@ void OnlineMachine::start_networking()
|
||||
int mynum = online_opts.playerno;
|
||||
int playerno = online_opts.playerno;
|
||||
|
||||
server = 0;
|
||||
if (ipFileName.size() > 0) {
|
||||
if (my_port != Names::DEFAULT_PORT)
|
||||
throw runtime_error("cannot set port number when using IP file");
|
||||
@@ -204,7 +203,7 @@ void OnlineMachine::start_networking()
|
||||
{
|
||||
if (nplayers == 0)
|
||||
opt.get("-N")->getInt(nplayers);
|
||||
server = Server::start_networking(playerNames, mynum, nplayers,
|
||||
Server::start_networking(playerNames, mynum, nplayers,
|
||||
hostname, pnbase, my_port);
|
||||
}
|
||||
else
|
||||
@@ -216,7 +215,7 @@ void OnlineMachine::start_networking()
|
||||
}
|
||||
|
||||
inline
|
||||
Player* OnlineMachine::new_player(int id_base)
|
||||
Player* OnlineMachine::new_player(const string& id_base)
|
||||
{
|
||||
if (use_encryption)
|
||||
return new CryptoPlayer(playerNames, id_base);
|
||||
@@ -238,15 +237,13 @@ int OnlineMachine::run()
|
||||
use_encryption, online_opts.live_prep,
|
||||
online_opts).run();
|
||||
|
||||
if (server)
|
||||
delete server;
|
||||
|
||||
#ifdef VERBOSE
|
||||
cerr << "Command line:";
|
||||
for (int i = 0; i < argc; i++)
|
||||
cerr << " " << argv[i];
|
||||
cerr << endl;
|
||||
#endif
|
||||
if (online_opts.verbose)
|
||||
{
|
||||
cerr << "Command line:";
|
||||
for (int i = 0; i < argc; i++)
|
||||
cerr << " " << argv[i];
|
||||
cerr << endl;
|
||||
}
|
||||
}
|
||||
#ifndef INSECURE
|
||||
catch(...)
|
||||
|
||||
@@ -7,12 +7,14 @@
|
||||
#include "BaseMachine.h"
|
||||
#include "Math/gfp.h"
|
||||
#include "Math/gfpvar.h"
|
||||
#include "Protocols/HemiOptions.h"
|
||||
|
||||
#include "Math/gfp.hpp"
|
||||
|
||||
using namespace std;
|
||||
|
||||
OnlineOptions OnlineOptions::singleton;
|
||||
HemiOptions HemiOptions::singleton;
|
||||
|
||||
OnlineOptions::OnlineOptions() : playerno(-1)
|
||||
{
|
||||
@@ -26,6 +28,11 @@ OnlineOptions::OnlineOptions() : playerno(-1)
|
||||
bucket_size = 4;
|
||||
cmd_private_input_file = "Player-Data/Input";
|
||||
cmd_private_output_file = "";
|
||||
#ifdef VERBOSE
|
||||
verbose = true;
|
||||
#else
|
||||
verbose = false;
|
||||
#endif
|
||||
}
|
||||
|
||||
OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
@@ -65,7 +72,8 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
"Prefix for output file path "
|
||||
"(default: output to stdout for party 0 (silent otherwise "
|
||||
"unless interactive mode is active). "
|
||||
"Output will be written to {prefix}-P{id}-{thread_id}.", // Help description.
|
||||
"Output will be written to {prefix}-P{id}-{thread_id}. "
|
||||
"Use '.' for stdout on all parties.", // Help description.
|
||||
"-OF", // Flag token.
|
||||
"--output-file" // Flag token.
|
||||
);
|
||||
@@ -171,6 +179,15 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
"-B", // Flag token.
|
||||
"--bucket-size" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Verbose output", // Help description.
|
||||
"-v", // Flag token.
|
||||
"--verbose" // Flag token.
|
||||
);
|
||||
|
||||
opt.parse(argc, argv);
|
||||
|
||||
@@ -198,6 +215,10 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
|
||||
opt.get("--bucket-size")->getInt(bucket_size);
|
||||
|
||||
#ifndef VERBOSE
|
||||
verbose = opt.isSet("--verbose");
|
||||
#endif
|
||||
|
||||
opt.resetArgs();
|
||||
}
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ public:
|
||||
int bucket_size;
|
||||
std::string cmd_private_input_file;
|
||||
std::string cmd_private_output_file;
|
||||
bool verbose;
|
||||
|
||||
OnlineOptions();
|
||||
OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv,
|
||||
|
||||
@@ -227,13 +227,15 @@ class Processor : public ArithmeticProcessor
|
||||
void split(const Instruction& instruction);
|
||||
|
||||
// Access to external client sockets for reading clear/shared data
|
||||
void read_socket_ints(int client_id, const vector<int>& registers);
|
||||
|
||||
void write_socket(const RegType reg_type,
|
||||
int socket_id, int message_type, const vector<int>& registers);
|
||||
void read_socket_ints(int client_id, const vector<int>& registers, int size);
|
||||
|
||||
void read_socket_vector(int client_id, const vector<int>& registers);
|
||||
void read_socket_private(int client_id, const vector<int>& registers, bool send_macs);
|
||||
void write_socket(const RegType reg_type, int socket_id, int message_type,
|
||||
const vector<int>& registers, int size);
|
||||
|
||||
void read_socket_vector(int client_id, const vector<int>& registers,
|
||||
int size);
|
||||
void read_socket_private(int client_id, const vector<int>& registers,
|
||||
int size, bool send_macs);
|
||||
|
||||
// Read and write secret numeric data to file (name hardcoded at present)
|
||||
void read_shares_from_file(int start_file_pos, int end_file_pos_register, const vector<int>& data_registers);
|
||||
|
||||
@@ -95,10 +95,13 @@ Processor<sint, sgf2n>::Processor(int thread_num,Player& P,
|
||||
shared_prng.SeedGlobally(P, false);
|
||||
|
||||
// only output on party 0 if not interactive
|
||||
bool output = P.my_num() == 0 or machine.opts.interactive;
|
||||
bool always_stdout = machine.opts.cmd_private_output_file == ".";
|
||||
bool output = P.my_num() == 0 or machine.opts.interactive or always_stdout;
|
||||
out.activate(output);
|
||||
Procb.out.activate(output);
|
||||
setup_redirection(P.my_num(), thread_num, opts);
|
||||
|
||||
if (not always_stdout)
|
||||
setup_redirection(P.my_num(), thread_num, opts);
|
||||
|
||||
if (stdout_redirect_file.is_open())
|
||||
{
|
||||
@@ -236,7 +239,8 @@ void Processor<sint, sgf2n>::convcbit2s(const Instruction& instruction)
|
||||
for (int i = 0; i < DIV_CEIL(instruction.get_n(), unit); i++)
|
||||
Procb.S[instruction.get_r(0) + i] = sint::bit_type::constant(
|
||||
Procb.C[instruction.get_r(1) + i], P.my_num(),
|
||||
share_thread.MC->get_alphai());
|
||||
share_thread.MC->get_alphai(),
|
||||
min(unsigned(unit), instruction.get_n() - i * unit));
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
@@ -262,8 +266,8 @@ void Processor<sint, sgf2n>::split(const Instruction& instruction)
|
||||
// If message_type is > 0, send message_type in bytes 0 - 3, to allow an external client to
|
||||
// determine the data structure being sent in a message.
|
||||
template<class sint, class sgf2n>
|
||||
void Processor<sint, sgf2n>::write_socket(const RegType reg_type,
|
||||
int socket_id, int message_type, const vector<int>& registers)
|
||||
void Processor<sint, sgf2n>::write_socket(const RegType reg_type, int socket_id,
|
||||
int message_type, const vector<int>& registers, int size)
|
||||
{
|
||||
int m = registers.size();
|
||||
socket_stream.reset_write_head();
|
||||
@@ -273,28 +277,40 @@ void Processor<sint, sgf2n>::write_socket(const RegType reg_type,
|
||||
socket_stream.store(message_type);
|
||||
}
|
||||
|
||||
for (int i = 0; i < m; i++)
|
||||
{
|
||||
if (reg_type == SINT) {
|
||||
// Send vector of secret shares
|
||||
get_Sp_ref(registers[i]).pack(socket_stream,
|
||||
sint::get_rec_factor(P.my_num(), P.num_players()));
|
||||
for (int j = 0; j < size; j++)
|
||||
{
|
||||
for (int i = 0; i < m; i++)
|
||||
{
|
||||
if (reg_type == SINT)
|
||||
{
|
||||
// Send vector of secret shares
|
||||
get_Sp_ref(registers[i] + j).pack(socket_stream,
|
||||
sint::get_rec_factor(P.my_num(), P.num_players()));
|
||||
}
|
||||
else if (reg_type == CINT)
|
||||
{
|
||||
// Send vector of clear public field elements
|
||||
get_Cp_ref(registers[i] + j).pack(socket_stream);
|
||||
}
|
||||
else if (reg_type == INT)
|
||||
{
|
||||
// Send vector of 32-bit clear ints
|
||||
socket_stream.store((int&) get_Ci_ref(registers[i] + j));
|
||||
}
|
||||
else
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "Write socket instruction with unknown reg type "
|
||||
<< reg_type << "." << endl;
|
||||
throw Processor_Error(ss.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (reg_type == CINT) {
|
||||
// Send vector of clear public field elements
|
||||
get_Cp_ref(registers[i]).pack(socket_stream);
|
||||
}
|
||||
else if (reg_type == INT) {
|
||||
// Send vector of 32-bit clear ints
|
||||
socket_stream.store((int&)get_Ci_ref(registers[i]));
|
||||
}
|
||||
else {
|
||||
stringstream ss;
|
||||
ss << "Write socket instruction with unknown reg type " << reg_type <<
|
||||
"." << endl;
|
||||
throw Processor_Error(ss.str());
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef VERBOSE_COMM
|
||||
cerr << "send " << socket_stream.get_length() << " to client " << socket_id
|
||||
<< endl;
|
||||
#endif
|
||||
|
||||
try {
|
||||
socket_stream.Send(external_clients.get_socket(socket_id));
|
||||
@@ -308,44 +324,47 @@ void Processor<sint, sgf2n>::write_socket(const RegType reg_type,
|
||||
|
||||
// Receive vector of 32-bit clear ints
|
||||
template<class sint, class sgf2n>
|
||||
void Processor<sint, sgf2n>::read_socket_ints(int client_id, const vector<int>& registers)
|
||||
void Processor<sint, sgf2n>::read_socket_ints(int client_id,
|
||||
const vector<int>& registers, int size)
|
||||
{
|
||||
int m = registers.size();
|
||||
socket_stream.reset_write_head();
|
||||
socket_stream.Receive(external_clients.get_socket(client_id));
|
||||
for (int i = 0; i < m; i++)
|
||||
{
|
||||
int val;
|
||||
socket_stream.get(val);
|
||||
write_Ci(registers[i], (long)val);
|
||||
}
|
||||
for (int j = 0; j < size; j++)
|
||||
for (int i = 0; i < m; i++)
|
||||
{
|
||||
int val;
|
||||
socket_stream.get(val);
|
||||
write_Ci(registers[i] + j, (long) val);
|
||||
}
|
||||
}
|
||||
|
||||
// Receive vector of public field elements
|
||||
template<class sint, class sgf2n>
|
||||
void Processor<sint, sgf2n>::read_socket_vector(int client_id, const vector<int>& registers)
|
||||
void Processor<sint, sgf2n>::read_socket_vector(int client_id,
|
||||
const vector<int>& registers, int size)
|
||||
{
|
||||
int m = registers.size();
|
||||
socket_stream.reset_write_head();
|
||||
socket_stream.Receive(external_clients.get_socket(client_id));
|
||||
for (int i = 0; i < m; i++)
|
||||
{
|
||||
get_Cp_ref(registers[i]) = socket_stream.get<typename sint::open_type>();
|
||||
}
|
||||
for (int j = 0; j < size; j++)
|
||||
for (int i = 0; i < m; i++)
|
||||
get_Cp_ref(registers[i] + j) =
|
||||
socket_stream.get<typename sint::open_type>();
|
||||
}
|
||||
|
||||
// Receive vector of field element shares over private channel
|
||||
template<class sint, class sgf2n>
|
||||
void Processor<sint, sgf2n>::read_socket_private(int client_id, const vector<int>& registers, bool read_macs)
|
||||
void Processor<sint, sgf2n>::read_socket_private(int client_id,
|
||||
const vector<int>& registers, int size, bool read_macs)
|
||||
{
|
||||
int m = registers.size();
|
||||
socket_stream.reset_write_head();
|
||||
socket_stream.Receive(external_clients.get_socket(client_id));
|
||||
|
||||
for (int i = 0; i < m; i++)
|
||||
{
|
||||
get_Sp_ref(registers[i]).unpack(socket_stream, read_macs);
|
||||
}
|
||||
for (int j = 0; j < size; j++)
|
||||
for (int i = 0; i < m; i++)
|
||||
get_Sp_ref(registers[i] + j).unpack(socket_stream, read_macs);
|
||||
}
|
||||
|
||||
|
||||
@@ -382,11 +401,7 @@ void Processor<sint, sgf2n>::read_shares_from_file(int start_file_posn, int end_
|
||||
// Append share data in data_registers to end of file. Expects Persistence directory to exist.
|
||||
template<class sint, class sgf2n>
|
||||
void Processor<sint, sgf2n>::write_shares_to_file(const vector<int>& data_registers) {
|
||||
string dir = "Persistence";
|
||||
mkdir_p(dir.c_str());
|
||||
|
||||
string filename;
|
||||
filename = dir + "/Transactions-P" + to_string(P.my_num()) + ".data";
|
||||
string filename = binary_file_io.filename(P.my_num());
|
||||
|
||||
unsigned int size = data_registers.size();
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ void Program::compute_constants()
|
||||
max_mem[reg_type] = max(max_mem[reg_type],
|
||||
p[i].get_mem(RegType(reg_type)));
|
||||
}
|
||||
writes_persistance |= p[i].opcode == WRITEFILESHARE;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,6 +42,8 @@ void Program::parse(istream& s)
|
||||
s.peek();
|
||||
while (!s.eof())
|
||||
{ instr.parse(s, p.size());
|
||||
if (s.fail())
|
||||
throw runtime_error("error while parsing " + to_string(instr.opcode));
|
||||
p.push_back(instr);
|
||||
//cerr << "\t" << instr << endl;
|
||||
s.peek();
|
||||
|
||||
@@ -30,8 +30,10 @@ class Program
|
||||
|
||||
public:
|
||||
|
||||
bool writes_persistance;
|
||||
|
||||
Program(int nplayers) : offline_data_used(nplayers),
|
||||
unknown_usage(false)
|
||||
unknown_usage(false), writes_persistance(false)
|
||||
{ compute_constants(); }
|
||||
|
||||
// Read in a program
|
||||
|
||||
@@ -74,6 +74,7 @@ HonestMajorityRingMachineWithSecurity<U, V>::HonestMajorityRingMachineWithSecuri
|
||||
Y(K, 40) \
|
||||
default: \
|
||||
cerr << "not compiled for security parameter " << to_string(opts.S) << endl; \
|
||||
cerr << "add 'Y(K, " << opts.S << ")' to " __FILE__ ", line 76" << endl; \
|
||||
exit(1); \
|
||||
} \
|
||||
break;
|
||||
|
||||
@@ -23,6 +23,8 @@ enum ThreadJobType
|
||||
TRIPLE_SACRIFICE_JOB,
|
||||
CHECK_JOB,
|
||||
FFT_JOB,
|
||||
CIPHER_PLAIN_MULT_JOB,
|
||||
MATRX_RAND_MULT_JOB,
|
||||
NO_JOB
|
||||
};
|
||||
|
||||
|
||||
@@ -6,17 +6,21 @@
|
||||
#include "ThreadQueues.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <math.h>
|
||||
|
||||
int ThreadQueues::distribute(ThreadJob job, int n_items, int base,
|
||||
int granularity)
|
||||
{
|
||||
find_available();
|
||||
return distribute_no_setup(job, n_items, base, granularity);
|
||||
if (find_available() > 0)
|
||||
return distribute_no_setup(job, n_items, base, granularity);
|
||||
else
|
||||
return base;
|
||||
}
|
||||
|
||||
int ThreadQueues::find_available()
|
||||
{
|
||||
available.clear();
|
||||
if (not available.empty())
|
||||
return 0;
|
||||
for (size_t i = 1; i < size(); i++)
|
||||
if (at(i)->available())
|
||||
available.push_back(i);
|
||||
@@ -28,7 +32,7 @@ int ThreadQueues::find_available()
|
||||
|
||||
int ThreadQueues::get_n_per_thread(int n_items, int granularity)
|
||||
{
|
||||
int n_per_thread = n_items / (available.size() + 1) / granularity
|
||||
int n_per_thread = ceil(n_items / (available.size() + 1.0)) / granularity
|
||||
* granularity;
|
||||
return n_per_thread;
|
||||
}
|
||||
@@ -39,6 +43,11 @@ int ThreadQueues::distribute_no_setup(ThreadJob job, int n_items, int base,
|
||||
int n_per_thread = get_n_per_thread(n_items, granularity);
|
||||
for (size_t i = 0; i < available.size(); i++)
|
||||
{
|
||||
if (base + (i + 1) * n_per_thread > size_t(n_items))
|
||||
{
|
||||
available.resize(i);
|
||||
return base + i * n_per_thread;
|
||||
}
|
||||
if (supplies)
|
||||
job.supply = supplies->at(i);
|
||||
job.begin = base + i * n_per_thread;
|
||||
@@ -52,4 +61,5 @@ void ThreadQueues::wrap_up(ThreadJob job)
|
||||
{
|
||||
for (int i : available)
|
||||
assert(at(i)->result().output == job.output);
|
||||
available.clear();
|
||||
}
|
||||
|
||||
@@ -190,7 +190,7 @@
|
||||
*dest++ = *op1++ * *op2++) \
|
||||
X(DIVINT, auto dest = &Proc.get_Ci()[r[0]]; auto op1 = &Proc.get_Ci()[r[1]]; \
|
||||
auto op2 = &Proc.get_Ci()[r[2]], \
|
||||
*dest++ = *op1++ / *op2++) \
|
||||
if (*op2 == 0) throw division_by_zero(); *dest++ = *op1++ / *op2++) \
|
||||
X(INCINT, auto dest = &Proc.get_Ci()[r[0]]; auto base = Proc.get_Ci()[r[1]], \
|
||||
int inc = (i / start[0]) % start[1]; *dest++ = base + inc * int(n)) \
|
||||
X(EQZC, auto dest = &Ci[r[0]]; auto source = &Ci[r[1]], *dest++ = *source++ == 0) \
|
||||
|
||||
@@ -67,13 +67,7 @@ def write_winner_to_clients(sockets, number_clients, winning_client_id):
|
||||
|
||||
# Setup authenticate result using share of random.
|
||||
# client can validate ∑ winning_client_id * ∑ rnd_from_triple = ∑ auth_result
|
||||
rnd_from_triple = sint.get_random_triple()[0]
|
||||
auth_result = winning_client_id * rnd_from_triple
|
||||
|
||||
@for_range(number_clients)
|
||||
def loop_body(i):
|
||||
sint.write_shares_to_socket(sockets[i], [winning_client_id, rnd_from_triple, auth_result])
|
||||
|
||||
sint.reveal_to_clients(sockets.get_sub(number_clients), [winning_client_id])
|
||||
|
||||
def main():
|
||||
"""Listen in while loop for players to join a game.
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
prog = program
|
||||
|
||||
prog.options.binary = -1
|
||||
|
||||
from Compiler.GC.types import *
|
||||
from Compiler.GC.instructions import *
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ from Compiler import ml
|
||||
|
||||
debug = False
|
||||
|
||||
program.use_edabit(True)
|
||||
program.options_from_args()
|
||||
|
||||
sfix.set_precision(16, 31)
|
||||
@@ -12,30 +11,39 @@ dim = int(program.args[1])
|
||||
batch = int(program.args[2])
|
||||
|
||||
try:
|
||||
ml.set_n_threads(int(program.args[3]))
|
||||
n_iterations = int(program.args[3])
|
||||
except:
|
||||
n_iterations = 1
|
||||
|
||||
try:
|
||||
ml.set_n_threads(int(program.args[4]))
|
||||
except:
|
||||
ml.set_n_threads(None)
|
||||
|
||||
X_normal = sfix.Matrix(6400, dim)
|
||||
X_pos = sfix.Matrix(6400, dim)
|
||||
N = batch * n_iterations
|
||||
|
||||
dense = ml.Dense(12800, dim, 1)
|
||||
layers = [dense, ml.Output(12800, debug=debug, approx='approx' in program.args)]
|
||||
print('run 1 epoch of logistic regression with %d examples' % (N))
|
||||
|
||||
sgd = ml.SGD(layers, batch // 128 * 10 , debug=debug, report_loss=False)
|
||||
dense = ml.Dense(N, dim, 1)
|
||||
sigmoid = ml.Output(N, debug=debug, approx='approx' in program.args)
|
||||
|
||||
for x in dense.X, sigmoid.Y:
|
||||
x.assign_all(0)
|
||||
|
||||
sgd = ml.SGD([dense, sigmoid], 1, debug=debug, report_loss=False)
|
||||
sgd.reset()
|
||||
|
||||
if not ('forward' in program.args or 'backward' in program.args):
|
||||
sgd.reset([X_normal, X_pos])
|
||||
sgd.run(batch_size=batch)
|
||||
|
||||
if 'forward' in program.args:
|
||||
@for_range(1000)
|
||||
@for_range(n_iterations)
|
||||
def _(i):
|
||||
sgd.forward(N=batch)
|
||||
|
||||
if 'backward' in program.args:
|
||||
b = regint.Array(batch)
|
||||
b.assign(regint.inc(batch))
|
||||
@for_range(1000)
|
||||
@for_range(n_iterations)
|
||||
def _(i):
|
||||
sgd.backward(batch=b)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user