Optimized matrix multiplication in Hemi.

This commit is contained in:
Marcel Keller
2021-09-17 14:29:28 +10:00
parent 5c6f101c12
commit 799929b801
151 changed files with 5262 additions and 748 deletions

View File

@@ -79,7 +79,7 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
auto& MC = this->MC;
this->_id = online_opts.playerno + 1;
Server* server = Server::start_networking(N, online_opts.playerno, nparties,
Server::start_networking(N, online_opts.playerno, nparties,
network_opts.hostname, network_opts.portnum_base);
if (T::dishonest_majority)
P = new PlainPlayer(N, 0);
@@ -159,8 +159,7 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
MC->Check(*P);
data_sent = P->comm_stats.total_data() + prep->data_sent();
if (server)
delete server;
this->machine.write_memory(this->N.my_num());
}
template<class T>

View File

@@ -25,6 +25,9 @@ static void throw_bad_ip(const char* ip) {
throw std::invalid_argument( "bad ip" );
}
namespace BIU
{
Client::Client(endpoint_t* endpoints, int numservers, ClientUpdatable* updatable, unsigned int max_message_size)
:_max_msg_sz(max_message_size),
_numservers(numservers),
@@ -205,3 +208,5 @@ void Client::_send_blocking(SendBuffer& msg, int id) {
fflush(0);
#endif
}
}

View File

@@ -28,6 +28,8 @@ public:
namespace BIU
{
class Client {
public:
@@ -61,4 +63,6 @@ private:
boost::thread_group threads;
};
}
#endif /* NETWORK_INC_CLIENT_H_ */

View File

@@ -35,7 +35,7 @@ Node::Node(const char* netmap_file, int my_id, NodeUpdatable* updatable, int num
_ready_nodes = new bool[_numparties](); //initialized to false
_clients_connected = new bool[_numparties]();
_server = new BIU::Server(_port, _numparties-1, this, max_message_size);
_client = new Client(_endpoints, _numparties-1, this, max_message_size);
_client = new BIU::Client(_endpoints, _numparties-1, this, max_message_size);
}
Node::~Node() {

View File

@@ -69,7 +69,7 @@ private:
int _numparties;
endpoint_t* _endpoints;
Client* _client;
BIU::Client* _client;
BIU::Server* _server;
bool* _ready_nodes;
volatile bool _connected_to_servers;

View File

@@ -1,5 +1,17 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
## 0.2.7 (Sep 17, 2021)
- Optimized matrix multiplication in Hemi
- Improved client communication
- Private integer division as per `Veugen and Abspoel
<https://doi.org/10.2478/popets-2021-0073>`
- Compiler option to translate some Python control flow instructions
to run-time instructions
- Functionality to break out of run-time loops
- Run-time range check of data structure accesses
- Improved documentation of network infrastructure
## 0.2.6 (Aug 6, 2021)
- [ATLAS](https://eprint.iacr.org/2021/833)

View File

@@ -52,6 +52,7 @@ opcodes = dict(
CONVCBIT2S = 0x249,
XORCBI = 0x210,
BITDECC = 0x211,
NOTCB = 0x212,
CONVCINT = 0x213,
REVEAL = 0x214,
STMSDCI = 0x215,
@@ -190,6 +191,16 @@ class nots(BinaryVectorInstruction):
code = opcodes['NOTS']
arg_format = ['int','sbw','sb']
class notcb(BinaryVectorInstruction):
""" Bitwise NOT of secret register vector.
:param: number of bits
:param: result (cbit)
:param: operand (cbit)
"""
code = opcodes['NOTCB']
arg_format = ['int','cbw','cb']
class addcb(NonVectorInstruction):
""" Integer addition two single clear bit registers.
@@ -617,4 +628,4 @@ class cond_print_strb(base.IOInstruction):
arg_format = ['cb', 'int']
def __init__(self, cond, val):
super(cond_print_str, self).__init__(cond, self.str_to_int(val))
super(cond_print_strb, self).__init__(cond, self.str_to_int(val))

View File

@@ -169,6 +169,33 @@ class bits(Tape.Register, _structure, _bit):
(str(other), repr(other), type(other), type(self)))
def long_one(self):
return 2**self.n - 1 if self.n != None else None
def is_long_one(self, other):
return util.is_all_ones(other, self.n) or \
(other is None and self.n == None)
def res_type(self, other):
if self.n == None and other.n == None:
n = None
else:
n = max(self.n, other.n)
return self.get_type(n)
@read_mem_value
def __and__(self, other):
if util.is_zero(other):
return 0
elif self.is_long_one(other):
return self
else:
return self._and(other)
@read_mem_value
def __xor__(self, other):
if util.is_zero(other):
return self
elif self.is_long_one(other):
return ~self
else:
return self._xor(other)
__rand__ = __and__
__rxor__ = __xor__
def __repr__(self):
if self.n != None:
suffix = '%d' % self.n
@@ -245,19 +272,20 @@ class cbits(bits):
self.clear_op(other, inst.addcb, inst.addcbi, operator.add)
__sub__ = lambda self, other: \
self.clear_op(-other, inst.addcb, inst.addcbi, operator.add)
def __xor__(self, other):
def _xor(self, other):
if isinstance(other, (sbits, sbitvec)):
return NotImplemented
elif isinstance(other, cbits):
res = cbits.get_type(max(self.n, other.n))()
res = self.res_type(other)()
assert res.size == self.size
assert res.size == other.size
inst.xorcb(res.n, res, self, other)
return res
else:
return self.clear_op(other, None, inst.xorcbi, operator.xor)
def _and(self, other):
return NotImplemented
__radd__ = __add__
__rxor__ = __xor__
def __mul__(self, other):
if isinstance(other, cbits):
return NotImplemented
@@ -278,7 +306,9 @@ class cbits(bits):
inst.shlcbi(res, self, other)
return res
def __invert__(self):
return self ^ self.long_one()
res = type(self)()
inst.notcb(self.n, res, self)
return res
def print_reg(self, desc=''):
inst.print_regb(self, desc)
def print_reg_plain(self):
@@ -287,6 +317,8 @@ class cbits(bits):
def print_if(self, string):
inst.cond_print_strb(self, string)
def output_if(self, cond):
if Program.prog.options.binary:
raise CompilerError('conditional output not supported')
cint(self).output_if(cond)
def reveal(self):
return self
@@ -423,8 +455,7 @@ class sbits(bits):
__radd__ = __add__
__sub__ = __add__
__rsub__ = __add__
__xor__ = __add__
__rxor__ = __add__
_xor = __add__
@read_mem_value
def __mul__(self, other):
if isinstance(other, int):
@@ -440,13 +471,7 @@ class sbits(bits):
except AttributeError:
return NotImplemented
__rmul__ = __mul__
@read_mem_value
def __and__(self, other):
if util.is_zero(other):
return 0
elif util.is_all_ones(other, self.n) or \
(other is None and self.n == None):
return self
def _and(self, other):
res = self.new(n=self.n)
if not isinstance(other, sbits):
other = cbits.get_type(self.n).conv(other)
@@ -456,7 +481,6 @@ class sbits(bits):
assert(self.n == other.n)
inst.ands(self.n, res, self, other)
return res
__rand__ = __and__
def xor_int(self, other):
if other == 0:
return self
@@ -551,12 +575,6 @@ class sbits(bits):
@staticmethod
def ripple_carry_adder(*args, **kwargs):
return sbitint.ripple_carry_adder(*args, **kwargs)
def to_sint(self, n_bits):
""" Convert the :py:obj:`n_bits` least significant bits to
:py:obj:`~Compiler.types.sint`. """
bits = sbitvec.from_vec(sbitvec([self]).v[:n_bits]).elements()[0]
bits = sint(bits, size=n_bits)
return sint.bit_compose(bits)
class sbitvec(_vec):
""" Vector of registers of secret bits, effectively a matrix of secret bits.

View File

@@ -591,7 +591,7 @@ class RegintOptimizer:
elif isinstance(inst, IndirectMemoryInstruction):
if inst.args[1] in self.cache:
instructions[i] = inst.get_direct(self.cache[inst.args[1]])
elif isinstance(inst, convint_class):
elif type(inst) == convint_class:
if inst.args[1] in self.cache:
res = self.cache[inst.args[1]]
self.cache[inst.args[0]] = res

View File

@@ -2,6 +2,7 @@ from Compiler.program import Program
from .GC import types as GC_types
import sys
import re, tempfile, os
def run(args, options):
@@ -21,11 +22,62 @@ def run(args, options):
del VARS[i]
print('Compiling file', prog.infile)
f = open(prog.infile, 'rb')
changed = False
if options.flow_optimization:
output = []
if_stack = []
for line in open(prog.infile):
if if_stack and not re.match(if_stack[-1][0], line):
if_stack.pop()
m = re.match(
'(\s*)for +([a-zA-Z_]+) +in +range\(([0-9a-zA-Z_]+)\):',
line)
if m:
output.append('%s@for_range_opt(%s)\n' % (m.group(1),
m.group(3)))
output.append('%sdef _(%s):\n' % (m.group(1), m.group(2)))
changed = True
continue
m = re.match('(\s*)if(\W.*):', line)
if m:
if_stack.append((m.group(1), len(output)))
output.append('%s@if_(%s)\n' % (m.group(1), m.group(2)))
output.append('%sdef _():\n' % (m.group(1)))
changed = True
continue
m = re.match('(\s*)elif\s+', line)
if m:
raise CompilerError('elif not supported')
if if_stack:
m = re.match('%selse:' % if_stack[-1][0], line)
if m:
start = if_stack[-1][1]
ws = if_stack[-1][0]
output[start] = re.sub(r'^%s@if_\(' % ws, r'%s@if_e(' % ws,
output[start])
output.append('%s@else_\n' % ws)
output.append('%sdef _():\n' % ws)
continue
output.append(line)
if changed:
infile = tempfile.NamedTemporaryFile('w+', delete=False)
for line in output:
infile.write(line)
infile.seek(0)
else:
infile = open(prog.infile)
else:
infile = open(prog.infile)
# make compiler modules directly accessible
sys.path.insert(0, 'Compiler')
# create the tapes
exec(compile(open(prog.infile).read(), prog.infile, 'exec'), VARS)
exec(compile(infile.read(), infile.name, 'exec'), VARS)
if changed and not options.debug:
os.unlink(infile.name)
prog.finalize()

View File

@@ -562,7 +562,7 @@ def SDiv(a, b, l, kappa, round_nearest=False):
y = a * w
y = y.round(2 * l + 1, l, kappa, round_nearest, signed=False)
x2 = types.sint()
comparison.Mod2m(x2, x, 2 * l + 1, l, kappa, False)
comparison.Mod2m(x2, x, 2 * l + 1, l, kappa, True)
x1 = comparison.TruncZeros(x - x2, 2 * l + 1, l, True)
for i in range(theta-1):
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa,
@@ -642,7 +642,7 @@ def BitDecFull(a, maybe_mixed=False):
b, bbits = sint.get_edabit(logp, True, size=a.size)
if logp != bit_length:
from .GC.types import sbits
bbits += [sbits.get_type(a.size)(0)]
bbits += [0]
else:
bbits = [sint.get_random_bit(size=a.size) for i in range(logp)]
b = sint.bit_compose(bbits)

View File

@@ -409,6 +409,18 @@ class use_edabit(base.Instruction):
code = base.opcodes['USE_EDABIT']
arg_format = ['int','int','int']
class use_matmul(base.Instruction):
""" Matrix multiplication usage. Used for multithreading of
preprocessing.
:param: number of left-hand rows (int)
:param: number of left-hand columns/right-hand rows (int)
:param: number of right-hand columns (int)
:param: number (int, -1 for unknown)
"""
code = base.opcodes['USE_MATMUL']
arg_format = ['int','int','int','int']
class run_tape(base.Instruction):
""" Start tape/bytecode file in another thread.
@@ -432,7 +444,7 @@ class join_tape(base.Instruction):
class crash(base.IOInstruction):
""" Crash runtime. """
code = base.opcodes['CRASH']
arg_format = []
arg_format = ['ci']
class start_grind(base.IOInstruction):
code = base.opcodes['STARTGRIND']
@@ -1559,51 +1571,49 @@ class pubinput(base.PublicFileIOInstruction):
code = base.opcodes['PUBINPUT']
arg_format = ['cw']
@base.vectorize
class readsocketc(base.IOInstruction):
""" Read a variable number of clear values in internal representation
from socket for a specified client id and store them in clear registers.
:param: number of arguments to follow / number of inputs minus one (int)
:param: client id (regint)
:param: vector size (int)
:param: destination (cint)
:param: (repeat destination)...
"""
__slots__ = []
code = base.opcodes['READSOCKETC']
arg_format = tools.chain(['ci'], itertools.repeat('cw'))
arg_format = tools.chain(['ci','int'], itertools.repeat('cw'))
def has_var_args(self):
return True
@base.vectorize
class readsockets(base.IOInstruction):
"""Read a variable number of secret shares + MACs from socket for a client id and store in registers"""
__slots__ = []
code = base.opcodes['READSOCKETS']
arg_format = tools.chain(['ci'], itertools.repeat('sw'))
arg_format = tools.chain(['ci','int'], itertools.repeat('sw'))
def has_var_args(self):
return True
@base.vectorize
class readsocketint(base.IOInstruction):
""" Read a variable number of 32-bit integers from socket for a
specified client id and store them in clear integer registers.
:param: number of arguments to follow / number of inputs minus one (int)
:param: client id (regint)
:param: vector size (int)
:param: destination (regint)
:param: (repeat destination)...
"""
__slots__ = []
code = base.opcodes['READSOCKETINT']
arg_format = tools.chain(['ci'], itertools.repeat('ciw'))
arg_format = tools.chain(['ci','int'], itertools.repeat('ciw'))
def has_var_args(self):
return True
@base.vectorize
class writesocketc(base.IOInstruction):
"""
Write a variable number of clear GF(p) values from registers into socket
@@ -1611,29 +1621,28 @@ class writesocketc(base.IOInstruction):
"""
__slots__ = []
code = base.opcodes['WRITESOCKETC']
arg_format = tools.chain(['ci', 'int'], itertools.repeat('c'))
arg_format = tools.chain(['ci', 'int', 'int'], itertools.repeat('c'))
def has_var_args(self):
return True
@base.vectorize
class writesocketshare(base.IOInstruction):
""" Write a variable number of shares (without MACs) from secret
registers into socket for a specified client id.
:param: client id (regint)
:param: message type (must be 0)
:param: vector size (int)
:param: source (sint)
:param: (repeat source)...
"""
__slots__ = []
code = base.opcodes['WRITESOCKETSHARE']
arg_format = tools.chain(['ci', 'int'], itertools.repeat('s'))
arg_format = tools.chain(['ci', 'int', 'int'], itertools.repeat('s'))
def has_var_args(self):
return True
@base.vectorize
class writesocketint(base.IOInstruction):
"""
Write a variable number of 32-bit ints from registers into socket
@@ -1641,7 +1650,7 @@ class writesocketint(base.IOInstruction):
"""
__slots__ = []
code = base.opcodes['WRITESOCKETINT']
arg_format = tools.chain(['ci', 'int'], itertools.repeat('ci'))
arg_format = tools.chain(['ci', 'int', 'int'], itertools.repeat('ci'))
def has_var_args(self):
return True
@@ -2266,6 +2275,10 @@ class matmulsm(matmul_base):
for i in range(2):
assert args[8 + i].size == args[4 + i]
def add_usage(self, req_node):
super(matmulsm, self).add_usage(req_node)
req_node.increment(('matmul', tuple(self.args[3:6])), 1)
class conv2ds(base.DataInstruction):
""" Secret 2D convolution.
@@ -2301,6 +2314,12 @@ class conv2ds(base.DataInstruction):
return self.args[3] * self.args[4] * self.args[7] * self.args[8] * \
self.args[11] * self.args[14]
def add_usage(self, req_node):
super(conv2ds, self).add_usage(req_node)
args = self.args
req_node.increment(('matmul', (1, args[7] * args[8] * args[11],
args[14] * args[3] * args[4])), 1)
@base.vectorize
class trunc_pr(base.VarArgsInstruction):
""" Probabilistic truncation if supported by the protocol.

View File

@@ -63,6 +63,7 @@ opcodes = dict(
THRESHOLD = 0xE3,
PLAYERID = 0xE4,
USE_EDABIT = 0xE5,
USE_MATMUL = 0x1F,
# Addition
ADDC = 0x20,
ADDS = 0x21,
@@ -456,7 +457,7 @@ def cisc(function):
reset_global_vector_size()
program.curr_tape = old_tape
for x, bl in tape.req_bit_length.items():
old_tape.require_bit_length(bl, x)
old_tape.require_bit_length(bl - 1, x)
from Compiler.allocator import Merger
merger = Merger(block, program.options,
tuple(program.to_merge))
@@ -542,7 +543,13 @@ def cisc(function):
MergeCISC.__name__ = function.__name__
def wrapper(*args, **kwargs):
if program.options.cisc:
same_sizes = True
for arg in args:
try:
same_sizes &= arg.size == args[0].size
except:
pass
if program.options.cisc and same_sizes:
return MergeCISC(*args, **kwargs)
else:
return function(*args, **kwargs)

View File

@@ -143,7 +143,6 @@ def print_str_if(cond, ss, *args):
assert len(subs) == len(args) + 1
if isinstance(cond, localint):
cond = cond._v
cond = cint.conv(cond)
for i, s in enumerate(subs):
if i != 0:
val = args[i - 1]
@@ -203,6 +202,27 @@ def runtime_error(msg='', *args):
print_ln(msg, *args)
crash()
def runtime_error_if(condition, msg='', *args):
""" Conditionally print an error message and abort the runtime.
:param condition: regint/cint/int/cbit
:param msg: message
:param args: list of public values to fit ``%s`` in the message
"""
print_ln_if(condition, msg, *args)
crash(condition)
def crash(condition=None):
""" Crash virtual machine.
:param condition: crash if true (default: true)
"""
if condition == None:
condition = regint(1)
instructions.crash(regint.conv(condition))
def public_input():
""" Public input read from ``Programs/Public-Input/<progname>``. """
res = cint()
@@ -1209,6 +1229,8 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
return loop_body(base + i)
prog = get_program()
thread_args = []
if prog.curr_tape == prog.tapes[0]:
prog.n_running_threads = n_threads
if not util.is_zero(thread_rounds):
tape = prog.new_tape(f, (0,), 'multithread')
for i in range(n_threads - remainder):
@@ -1225,6 +1247,7 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
if len(mem_state):
args[i][1] = mem_state.address
thread_args.append((tape1, i))
prog.n_running_threads = None
threads = prog.run_tapes(thread_args)
for thread in threads:
prog.join_tape(thread)
@@ -1397,6 +1420,7 @@ def do_while(loop_fn, g=None):
# possibly unknown loop count
get_tape().open_scope(lambda x: x[0].set_all(float('Inf')), \
name='begin-loop')
get_tape().loop_breaks.append([])
loop_block = instructions.program.curr_block
condition = _run_and_link(loop_fn, g)
if callable(condition):
@@ -1404,8 +1428,15 @@ def do_while(loop_fn, g=None):
branch = instructions.jmpnz(regint.conv(condition), 0, add_to_prog=False)
instructions.program.curr_block.set_exit(branch, loop_block)
get_tape().close_scope(scope, parent_node, 'end-loop')
for loop_break in get_tape().loop_breaks.pop():
loop_break.set_exit(jmp(0, add_to_prog=False), get_block())
return loop_fn
def break_loop():
""" Break out of loop. """
get_tape().loop_breaks[-1].append(get_block())
break_point('break')
def if_then(condition):
class State: pass
state = State()
@@ -1483,10 +1514,18 @@ def if_(condition):
def _():
...
"""
try:
condition = bool(condition)
except:
pass
def decorator(body):
if_then(condition)
_run_and_link(body)
end_if()
if isinstance(condition, bool):
if condition:
_run_and_link(body)
else:
if_then(condition)
_run_and_link(body)
end_if()
return decorator
def if_e(condition):
@@ -1506,15 +1545,30 @@ def if_e(condition):
def _():
...
"""
try:
condition = bool(condition)
except:
pass
def decorator(body):
if_then(condition)
_run_and_link(body)
if isinstance(condition, bool):
get_tape().if_states.append(condition)
if condition:
_run_and_link(body)
else:
if_then(condition)
_run_and_link(body)
return decorator
def else_(body):
else_then()
_run_and_link(body)
end_if()
if_states = get_tape().if_states
if isinstance(if_states[-1], bool):
if not if_states[-1]:
_run_and_link(body)
if_states.pop()
else:
else_then()
_run_and_link(body)
end_if()
def and_(*terms):
res = regint(0)

View File

@@ -292,7 +292,8 @@ class Output(NoVariableLayer):
self.l.write(sum(lse) * \
self.divisor(N, 1))
def eval(self, size, base=0):
def eval(self, size, base=0, top=False):
assert not top
if self.approx:
return approx_sigmoid(self.X.get_vector(base, size), self.approx)
else:
@@ -474,8 +475,14 @@ class MultiOutput(MultiOutputBase):
self.true_X[i] = true_X
self.l.write(sum(tmp.get_vector(0, N)) / N)
def eval(self, N):
def eval(self, N, top=False):
d_out = self.X.sizes[1]
if top:
res = sint.Array(N)
@for_range_opt_multithread(self.n_threads, N)
def _(i):
res[i] = argmax(self.X[i])
return res
res = sfix.Matrix(N, d_out)
if self.approx:
@for_range_opt_multithread(self.n_threads, N)
@@ -1832,7 +1839,7 @@ class Optimizer:
@_no_mem_warnings
def forward(self, N=None, batch=None, keep_intermediate=True,
model_from=None, training=False):
model_from=None, training=False, run_last=True):
""" Compute graph.
:param N: batch size (used if batch not given)
@@ -1851,7 +1858,9 @@ class Optimizer:
break_point()
if self.time_layers:
start_timer(100 + i)
layer.forward(batch=self.batch_for(layer, batch), training=training)
if i != len(self.layers) - 1 or run_last:
layer.forward(batch=self.batch_for(layer, batch),
training=training)
if self.time_layers:
stop_timer(100 + i)
break_point()
@@ -1862,19 +1871,23 @@ class Optimizer:
theta.delete()
@_no_mem_warnings
def eval(self, data, batch_size=None):
def eval(self, data, batch_size=None, top=False):
""" Compute evaluation after training.
:param data: sample data (:py:class:`Compiler.types.Matrix` with one row per sample)
:param top: return top prediction instead of probability distribution
"""
if isinstance(self.layers[-1].Y, Array):
res = sfix.Array(len(data))
if isinstance(self.layers[-1].Y, Array) or top:
if top:
res = sint.Array(len(data))
else:
res = sfix.Array(len(data))
else:
res = sfix.Matrix(len(data), self.layers[-1].d_out)
def f(start, batch_size, batch):
batch.assign_vector(regint.inc(batch_size, start))
self.forward(batch=batch)
part = self.layers[-1].eval(batch_size)
self.forward(batch=batch, run_last=not top)
part = self.layers[-1].eval(batch_size, top=top)
res.assign_part_vector(part.get_vector(), start)
self.run_in_batches(f, data, batch_size or len(self.layers[1].X))
return res
@@ -2487,24 +2500,29 @@ class keras:
batch_size = min(batch_size, self.batch_size)
return self.opt.eval(x, batch_size=batch_size)
def solve_linear(A, b, n_iterations, debug=False):
def solve_linear(A, b, n_iterations, progress=False):
""" Iterative linear solution approximation. """
assert len(b) == A.sizes[0]
x = sfix.Array(A.sizes[1])
x.assign_vector(sfix.get_random(-1, 1, size=len(x)))
At = A.transpose()
AtA = sfix.Matrix(len(x), len(x))
AtA[:] = A.direct_trans_mul(A)
v = sfix.Array(A.sizes[1])
v.assign_all(0)
r = Array.create_from(A.transpose() * b - AtA * x)
Av = sfix.Array(len(x))
@for_range(n_iterations)
def _(i):
r = At * (b - A * x)
tmp = A * r
tmp = sfix.dot_product(tmp, tmp)
alpha = (tmp == 0).if_else(0, sfix.dot_product(r, r) / tmp)
x.assign(x + alpha * r)
if debug:
print_ln('%s r=%s tmp=%s r*r=%s tmp*tmp=%s alpha=%s x=%s alpha*r=%s', i,
list(r.reveal()), list(tmp.reveal()),
sfix.dot_product(r, r).reveal(), sfix.dot_product(tmp, tmp).reveal(),
alpha.reveal(), x.reveal_list(), list((alpha * r).reveal()))
v[:] = r - sfix.dot_product(r, Av) / sfix.dot_product(v, Av) * v
Av[:] = AtA * v
v_norm = sfix.dot_product(v, Av)
vr = sfix.dot_product(v, r)
alpha = (v_norm == 0).if_else(0, vr / v_norm)
x[:] = x + alpha * v
r[:] = r - alpha * Av
if progress:
print_ln('%s alpha=%s vr=%s v_norm=%s', i, alpha.reveal(),
vr.reveal(), v_norm.reveal())
return x
def mr(A, n_iterations):

View File

@@ -266,7 +266,7 @@ def tan(x):
@types.vectorize
@instructions_base.sfix_cisc
def exp2_fx(a, zero_output=False):
def exp2_fx(a, zero_output=False, as19=False):
"""
Power of two for fixed-point numbers.
@@ -287,7 +287,7 @@ def exp2_fx(a, zero_output=False):
g = a._new(whole_exp.TruncMul(e.v, 2 * a.k, n_shift,
nearest=a.round_nearest), a.k, a.f)
return g
if types.program.options.ring:
if types.program.options.ring and not as19:
sint = types.sint
intbitint = types.intbitint
# how many bits to use from integer part

View File

@@ -947,7 +947,7 @@ class RefBucket(object):
child.output()
def random_block(length, value_type):
return sum(value_type.bit_type.get_random_bit() << i for i in range(length))
return bit_compose(value_type.bit_type.get_random_bit() for i in range(length))
class List(EndRecursiveEviction):
""" Debugging only. List which accepts secret values as indices

View File

@@ -150,6 +150,8 @@ class Program(object):
self._linear_rounds = False
self.warn_about_mem = [True]
self.relevant_opts = set()
self.n_running_threads = None
self.input_files = {}
Program.prog = self
from . import instructions_base, instructions, types, comparison
instructions.program = self
@@ -370,8 +372,6 @@ class Program(object):
""" Allocate memory from the top """
if not isinstance(size, int):
raise CompilerError('size must be known at compile time')
if not (creator_tape or self.curr_tape).singular:
raise CompilerError('cannot allocate memory outside main thread')
if size == 0:
return
if isinstance(mem_type, type):
@@ -383,6 +383,14 @@ class Program(object):
mem_type = mem_type.reg_type
elif reg_type is not None:
self.types[mem_type] = reg_type
single_size = None
if not (creator_tape or self.curr_tape).singular:
if self.n_running_threads:
single_size = size
size *= self.n_running_threads
else:
raise CompilerError('cannot allocate memory '
'outside main thread')
blocks = self.free_mem_blocks[mem_type]
addr = blocks.pop(size)
if addr is not None:
@@ -393,7 +401,13 @@ class Program(object):
if len(str(addr)) != len(str(addr + size)) and self.verbose:
print("Memory of type '%s' now of size %d" % (mem_type, addr + size))
self.allocated_mem_blocks[addr,mem_type] = size
return addr
if single_size:
from .library import get_thread_number, runtime_error_if
tn = get_thread_number()
runtime_error_if(tn > self.n_running_threads, 'malloc')
return addr + single_size * (tn - 1)
else:
return addr
def free(self, addr, mem_type):
""" Free memory """
@@ -424,7 +438,8 @@ class Program(object):
from . import library
self.curr_tape.start_new_basicblock(None, 'memory-usage')
# reset register counter to 0
self.curr_tape.init_registers()
if not self.options.noreallocate:
self.curr_tape.init_registers()
for mem_type,size in sorted(self.allocated_mem.items()):
if size:
#print "Memory of type '%s' of size %d" % (mem_type, size)
@@ -586,6 +601,7 @@ class Tape:
self.functions = []
self.singular = True
self.free_threads = set()
self.loop_breaks = []
class BasicBlock(object):
def __init__(self, parent, name, scope, exit_condition=None):
@@ -827,8 +843,7 @@ class Tape:
block = left.popleft()
alloc(block)
for child in block.children:
if child.instructions:
left.append(child)
left.append(child)
for i,block in enumerate(reversed(self.basicblocks)):
if len(block.instructions) > 1000000:
print('Allocating %s, %d/%d' % \
@@ -878,6 +893,10 @@ class Tape:
self.basicblocks[-1].instructions.append(
Compiler.instructions.use_edabit(True, req[1], num, \
add_to_prog=False))
elif req[0] == 'matmul':
self.basicblocks[-1].instructions.append(
Compiler.instructions.use_matmul(*req[1], num, \
add_to_prog=False))
if not self.is_empty():
# bit length requirement
@@ -1012,6 +1031,9 @@ class Tape:
else:
eda = 'loose edabits'
res += ['%s %s of length %d' % (n, eda, req[1])]
elif domain == 'matmul':
res += ['%s matrix multiplications (%dx%d * %dx%d)' %
(n, req[1][0], req[1][1], req[1][1], req[1][2])]
elif req[0] != 'all':
res += ['%s %s %ss' % (n, domain, req[1])]
if self['all','round']:
@@ -1077,7 +1099,10 @@ class Tape:
def require_bit_length(self, bit_length, t='p'):
if t == 'p':
if self.program.prime:
assert bit_length < self.program.prime.bit_length() - 1
if (bit_length >= self.program.prime.bit_length() - 1):
raise CompilerError(
'required bit length %d too much for %d' % \
(bit_length, self.program.prime))
self.req_bit_length[t] = max(bit_length + 1, \
self.req_bit_length[t])
else:
@@ -1150,7 +1175,9 @@ class Tape:
def _new_by_number(self, i, size=1):
return Tape.Register(self.reg_type, self.program, size=size, i=i)
def get_vector(self, base, size):
def get_vector(self, base=0, size=None):
if size == None:
size = self.size
if base == 0 and size == self.size:
return self
if size == 1:
@@ -1206,7 +1233,8 @@ class Tape:
self.reg_type == RegType.ClearInt
def __bool__(self):
raise CompilerError('cannot derive truth value from register')
raise CompilerError('Cannot derive truth value from register, '
"consider using 'compile.py -l'")
def __str__(self):
return self.reg_type + str(self.i)

View File

@@ -508,7 +508,7 @@ class _structure(object):
a = sfix.Tensor([10, 10])
"""
if len(shape) == 1:
return Array(size[0], cls)
return Array(shape[0], cls)
else:
return MultiArray(shape, cls)
@@ -522,6 +522,75 @@ class _structure(object):
def mem_size():
return 1
class _secret_structure(_structure):
@classmethod
def input_tensor_from(cls, player, shape):
""" Input tensor secretly from player.
:param player: int/regint/cint
:param shape: tensor shape
"""
res = cls.Tensor(shape)
res.input_from(player)
return res
@classmethod
def input_tensor_from_client(cls, client_id, shape):
""" Input tensor secretly from client.
:param client_id: client identifier (public)
:param shape: tensor shape
"""
res = cls.Tensor(shape)
res.assign_vector(cls.receive_from_client(1, client_id,
size=res.total_size())[0])
return res
@classmethod
def input_tensor_via(cls, player, content):
"""
Input tensor-like data via a player. This overwrites the input
file for the relevant player. The following returns an
:py:class:`sint` matrix of dimension 2 by 2::
M = [[1, 2], [3, 4]]
sint.input_tensor_via(0, M)
Make sure to copy ``Player-Data/Input-P<player>-0`` if running
on another host.
"""
if program.curr_tape != program.tapes[0]:
raise CompilerError('only available in main thread')
shape = []
tmp = content
while True:
try:
shape.append(len(tmp))
tmp = tmp[0]
except:
break
if not program.input_files.get(player, None):
program.input_files[player] = open(
'Player-Data/Input-P%d-0' % player, 'w')
f = program.input_files[player]
def traverse(content, level):
assert len(content) == shape[level]
if level == len(shape) - 1:
for x in content:
f.write(' ')
f.write(str(x))
else:
for x in content:
traverse(x, level + 1)
traverse(content, 0)
f.write('\n')
res = cls.Tensor(shape)
res.input_from(player)
return res
class _vec(object):
def link(self, other):
assert len(self.v) == len(other.v)
@@ -835,23 +904,26 @@ class cint(_clear, _int):
:param client_id: Client id (regint)
:param n: number of values (default 1)
:param size: vector size (default 1)
:returns: cint (if n=1) or list of cint
"""
res = [cls() for i in range(n)]
readsocketc(client_id, *res)
readsocketc(client_id, get_global_vector_size(), *res)
if n == 1:
return res[0]
else:
return res
@vectorized_classmethod
@classmethod
def write_to_socket(self, client_id, values, message_type=ClientMessageType.NoType):
""" Send a list of clear values to a client.
:param client_id: Client id (regint)
:param values: list of cint
"""
writesocketc(client_id, message_type, *values)
for value in values:
assert(value.size == values[0].size)
writesocketc(client_id, message_type, values[0].size, *values)
@vectorized_classmethod
def load_mem(cls, address, mem_type=None):
@@ -1089,7 +1161,7 @@ class cint(_clear, _int):
cond_print_str(self, string)
def output_if(self, cond):
cond_print_plain(cond, self, cint(0))
cond_print_plain(self.conv(cond), self, cint(0))
class cgf2n(_clear, _gf2n):
@@ -1298,23 +1370,26 @@ class regint(_register, _int):
:param client_id: Client id (regint)
:param n: number of values (default 1)
:param size: vector size (default 1)
:returns: regint (if n=1) or list of regint
"""
res = [cls() for i in range(n)]
readsocketint(client_id, *res)
readsocketint(client_id, get_global_vector_size(), *res)
if n == 1:
return res[0]
else:
return res
@vectorized_classmethod
@classmethod
def write_to_socket(self, client_id, values, message_type=ClientMessageType.NoType):
""" Send a list of clear integers to a client.
:param client_id: Client id (regint)
:param values: list of regint
"""
writesocketint(client_id, message_type, *values)
for value in values:
assert(value.size == values[0].size)
writesocketint(client_id, message_type, values[0].size, *values)
@vectorize_init
def __init__(self, val=None, size=None):
@@ -1388,6 +1463,8 @@ class regint(_register, _int):
""" Clear integer division (rounding to floor).
:param other: regint/cint/int """
if util.is_constant(other) and other >= 2 ** 64:
return 0
return self.int_op(other, divint)
def __rfloordiv__(self, other):
@@ -1546,10 +1623,17 @@ class regint(_register, _int):
""" Output string if value is non-zero.
:param string: Python string """
cint(self).print_if(string)
self._condition().print_if(string)
def output_if(self, cond):
cint(self).output_if(cond)
self._condition().output_if(cond)
def _condition(self):
if program.options.binary:
from GC.types import cbits
return cbits.get_type(64)(self)
else:
return cint(self)
def binary_output(self, player=None):
""" Write 64-bit signed integer to
@@ -1597,6 +1681,9 @@ class personal(object):
def binary_output(self):
self._v.binary_output(self.player)
def bit_decompose(self, length):
return [personal(self.player, x) for x in self._v.bit_decompose(length)]
def _san(self, other):
if isinstance(other, personal):
assert self.player == other.player
@@ -1689,7 +1776,7 @@ class longint:
res += x.bit_decompose(64)
return res[:bit_length]
class _secret(_register):
class _secret(_register, _secret_structure):
__slots__ = []
mov = staticmethod(set_instruction_type(movs))
@@ -2022,7 +2109,7 @@ class sint(_secret, _int):
the bit length.
:param val: initialization (sint/cint/regint/int/cgf2n or list
thereof or sbits/sbitvec)
thereof or sbits/sbitvec/sfix)
:param size: vector size (int), defaults to 1 or size of list
"""
@@ -2130,33 +2217,62 @@ class sint(_secret, _int):
rawinput(player, res)
return res
@classmethod
@vectorized_classmethod
def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType):
""" Securely obtain shares of values input by a client.
:param n: number of inputs (int)
:param client_id: regint
:param size: vector size (default 1)
"""
# send shares of a triple to client
triples = list(itertools.chain(*(sint.get_random_triple() for i in range(n))))
sint.write_shares_to_socket(client_id, triples, message_type)
received = cint.read_from_socket(client_id, n)
received = util.tuplify(cint.read_from_socket(client_id, n))
y = [0] * n
for i in range(n):
y[i] = received[i] - triples[i * 3]
return y
@classmethod
def reveal_to_clients(cls, clients, values):
""" Reveal securely to clients.
:param clients: client ids (list or array)
:param values: list of sint to reveal
"""
set_global_vector_size(values[0].size)
to_send = []
for value in values:
assert(value.size == values[0].size)
r = sint.get_random()
to_send += [value, r, value * r]
if isinstance(clients, Array):
n_clients = clients.length
else:
n_clients = len(clients)
@library.for_range(n_clients)
def loop_body(i):
sint.write_shares_to_socket(clients[i], to_send)
reset_global_vector_size()
@vectorized_classmethod
def read_from_socket(cls, client_id, n=1):
""" Receive secret-shared value(s) from client.
:param client_id: Client id (regint)
:param n: number of values (default 1)
:param size: vector size of values (default 1)
:returns: sint (if n=1) or list of sint
"""
res = [cls() for i in range(n)]
readsockets(client_id, *res)
readsockets(client_id, get_global_vector_size(), *res)
if n == 1:
return res[0]
else:
@@ -2165,9 +2281,9 @@ class sint(_secret, _int):
@vectorize
def write_share_to_socket(self, client_id, message_type=ClientMessageType.NoType):
""" Send only share to socket """
writesocketshare(client_id, message_type, self)
writesocketshare(client_id, message_type, self.size, self)
@vectorized_classmethod
@classmethod
def write_shares_to_socket(cls, client_id, values,
message_type=ClientMessageType.NoType):
""" Send shares of a list of values to a specified client socket.
@@ -2175,7 +2291,7 @@ class sint(_secret, _int):
:param client_id: regint
:param values: list of sint
"""
writesocketshare(client_id, message_type, *values)
writesocketshare(client_id, message_type, values[0].size, *values)
@classmethod
def read_from_file(cls, start, n_items):
@@ -2227,6 +2343,9 @@ class sint(_secret, _int):
size = val._v.size
super(sint, self).__init__('s', size=size)
inputpersonal(size, val.player, self, self.clear_type.conv(val._v))
elif isinstance(val, _fix):
super(sint, self).__init__('s', size=val.v.size)
self.load_other(val.v.round(val.k, val.f))
else:
super(sint, self).__init__('s', val=val, size=size)
@@ -2498,8 +2617,18 @@ class sint(_secret, _int):
def private_division(self, divisor, active=True, dividend_length=None,
divisor_length=None):
assert active == False
""" Private integer division as per `Veugen and Abspoel
<https://doi.org/10.2478/popets-2021-0073>`_
:param divisor: public (cint/regint) or personal value thereof
:param active: whether to check on the party knowing the
divisor (active security)
:param dividend_length: bit length of the dividend (default:
global bit length)
:param dividend_length: bit length of the divisor (default:
global bit length)
"""
d = divisor
l = divisor_length or program.bit_length
m = dividend_length or program.bit_length
@@ -2515,11 +2644,28 @@ class sint(_secret, _int):
r_prime = sint.get_random_int(m + sigma)
r_pprime = sint.get_random_int(l + sigma)
h = (r + (r_prime << (l + sigma))) * sint(d)
z = ((self << (l + sigma)) + h + r_pprime).reveal_to(0)
d_shared = sint(d)
h = (r + (r_prime << (l + sigma))) * d_shared
z_shared = ((self << (l + sigma)) + h + r_pprime)
z = z_shared.reveal_to(0)
y = sint(z // (d << (l + sigma)))
y_prime = sint((z // d) % (2 ** (l + sigma)))
if active:
z_prime = [sint(x) for x in (z // d).bit_decompose(min_length)]
check = [(x * (1 - x)).reveal() == 0 for x in z_prime]
z_pp = [sint(x) for x in (z % d).bit_decompose(l)]
check += [(x * (1 - x)).reveal() == 0 for x in z_pp]
library.runtime_error_if(sum(check) != len(check),
'private division')
z_pp = sint.bit_compose(z_pp)
beta1 = z_pp.less_than(d_shared, l)
beta2 = z_shared - sint.bit_compose(z_prime) * d_shared - z_pp
library.runtime_error_if(beta1.reveal() != 1, 'private div')
library.runtime_error_if(beta2.reveal() != 0, 'private div')
y_prime = sint.bit_compose(z_prime[:l + sigma])
y = sint.bit_compose(z_prime[l + sigma:])
else:
y = sint(z // (d << (l + sigma)))
y_prime = sint((z // d) % (2 ** (l + sigma)))
b = r.greater_than(y_prime, l + sigma)
w = y - b - r_prime
@@ -3320,15 +3466,16 @@ class cfix(_number, _structure):
:param client_id: Client id (regint)
:param n: number of values (default 1)
:param: vector size (int)
:returns: cfix (if n=1) or list of cfix
"""
cint_input = cint.read_from_socket(client_id, n)
cint_inputs = cint.read_from_socket(client_id, n)
if n == 1:
return cfix._new(cint_inputs)
else:
return list(map(cfix, cint_inputs))
@vectorized_classmethod
return list(map(cfix._new, cint_inputs))
@classmethod
def write_to_socket(self, client_id, values, message_type=ClientMessageType.NoType):
""" Send a list of clear fixed-point values to a client
(represented as clear integers).
@@ -3336,10 +3483,12 @@ class cfix(_number, _structure):
:param client_id: Client id (regint)
:param values: list of cint
"""
for value in values:
assert(value.size == values[0].size)
def cfix_to_cint(fix_val):
return cint(fix_val.v)
cint_values = list(map(cfix_to_cint, values))
writesocketc(client_id, message_type, *cint_values)
writesocketc(client_id, message_type, values[0].size, *cint_values)
@staticmethod
def malloc(size, creator_tape=None):
@@ -3402,6 +3551,9 @@ class cfix(_number, _structure):
def __len__(self):
return len(self.v)
def __getitem__(self, index):
return self._new(self.v[index], k=self.k, f=self.f)
@vectorize
def load_int(self, v):
self.v = cint(v) * (2 ** self.f)
@@ -3601,7 +3753,7 @@ class cfix(_number, _structure):
cint(0), cint(0), cint(0))
def output_if(self, cond):
cond_print_plain(cond, self.v, cint(-self.f))
cond_print_plain(cint.conv(cond), self.v, cint(-self.f))
@vectorize
def binary_output(self, player=None):
@@ -3614,7 +3766,7 @@ class cfix(_number, _structure):
player = -1
floatoutput(player, self.v, cint(-self.f), cint(0), cint(0))
class _single(_number, _structure):
class _single(_number, _secret_structure):
""" Representation as single integer preserving the order """
""" E.g. fixed-point numbers """
__slots__ = ['v']
@@ -3623,7 +3775,7 @@ class _single(_number, _structure):
""" Whether to round deterministically to nearest instead of
probabilistically, e.g. after fixed-point multiplication. """
@classmethod
@vectorized_classmethod
def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType):
"""
Securely obtain shares of values input by a client. Assumes client
@@ -3631,12 +3783,22 @@ class _single(_number, _structure):
:param n: number of inputs (int)
:param client_id: regint
:param size: vector size (default 1)
"""
sint_inputs = cls.int_type.receive_from_client(n, client_id,
message_type)
return list(map(cls._new, sint_inputs))
@classmethod
def reveal_to_clients(cls, clients, values):
""" Reveal securely to clients.
:param clients: client ids (list or array)
:param values: list of values of this class
"""
cls.int_type.reveal_to_clients(clients, [x.v for x in values])
@vectorized_classmethod
def write_shares_to_socket(cls, client_id, values,
message_type=ClientMessageType.NoType):
@@ -3926,6 +4088,9 @@ class _fix(_single):
self.v = (1-2*_v.s)*a
elif isinstance(_v, type(self)):
self.v = _v.v
elif isinstance(_v, cfix):
assert _v.f <= self.f
self.v = self.int_type(_v.v << (self.f - _v.f))
elif isinstance(_v, (MemValue, MemFix)):
#this is a memvalue object
self.v = type(self)(_v.read()).v
@@ -4049,9 +4214,9 @@ class _fix(_single):
return revealed_fix._new(val)
class sfix(_fix):
""" Secret fixed-point number represented as secret integer.
This uses integer operations internally, see :py:class:`sint` for security
considerations.
""" Secret fixed-point number represented as secret integer, by
multiplying with ``2^f`` and then rounding. See :py:class:`sint`
for security considerations of the underlying integer operations.
It supports basic arithmetic (``+, -, *, /``), returning
:py:class:`sfix`, and comparisons (``==, !=, <, <=, >, >=``),
@@ -4391,7 +4556,7 @@ class squant_params(object):
shifted = under.if_else(0, shifted)
return squant._new(shifted, params=self)
class sfloat(_number, _structure):
class sfloat(_number, _secret_structure):
"""
Secret floating-point number.
Represents :math:`(1 - 2s) \cdot (1 - z)\cdot v \cdot 2^p`.
@@ -4705,6 +4870,7 @@ class sfloat(_number, _structure):
return -self + other
__rsub__.__doc__ = __sub__.__doc__
@vectorize
def __truediv__(self, other):
""" Secret floating-point division.
@@ -4857,21 +5023,36 @@ def _get_type(t):
else:
return t
class Array(object):
class _vectorizable:
def reveal_to_clients(self, clients):
""" Reveal contents to list of clients.
:param clients: list or array of client identifiers
"""
self.value_type.reveal_to_clients(clients, [self.get_vector()])
class Array(_vectorizable):
"""
Array accessible by public index. That is, ``a[i]`` works for an
array ``a`` and ``i`` being a :py:class:`regint`,
:py:class:`cint`, or a Python integer. ``a[start:stop:step]``
works as well, and so does iteration over an array.
Arrays support a number of element-wise operations if the
underlying basic type does so. These are ``+, -, *, **, /``. The
return type of these is a vector, which can be assigned to an
array of a compatible type using :py:func:`assign`.
:py:class:`cint`, or a Python integer.
:param length: compile-time integer (int) or :py:obj:`None` for unknown length
:param value_type: basic type
:param address: if given (regint/int), the array will not be allocated
You can convert between arrays and register vectors by using slice
indexing. This allows for element-wise operations as long as
supported by the basic type. The following adds 10 secret integers
from the first two parties::
a = sint.Array(10)
a.input_from(0)
b = sint.Array(10)
b.input_from(1)
a[:] += b[:]
"""
@classmethod
def create_from(cls, l):
@@ -4900,6 +5081,7 @@ class Array(object):
self.debug = debug
self.creator_tape = program.curr_tape
self.sink = None
self.check_indices = True
if alloc:
self.alloc()
@@ -4914,11 +5096,17 @@ class Array(object):
def get_address(self, index):
key = str(index)
if isinstance(index, int) and self.length is not None:
index += self.length * (index < 0)
if index >= self.length or index < 0:
raise IndexError('index %s, length %s' % \
(str(index), str(self.length)))
if self.length is not None:
from .GC.types import cbits
if isinstance(index, int):
index += self.length * (index < 0)
if index >= self.length or index < 0:
raise IndexError('index %s, length %s' % \
(str(index), str(self.length)))
elif self.check_indices and not isinstance(index, cbits):
library.runtime_error_if(regint.conv(index) >= self.length,
'overflow: %s/%s',
index, self.length)
if (program.curr_block, key) not in self.address_cache:
n = self.value_type.n_elements()
length = self.length
@@ -4937,21 +5125,24 @@ class Array(object):
def get_slice(self, index):
if index.stop is None and self.length is None:
raise CompilerError('Cannot slice array of unknown length')
return index.start or 0, index.stop or self.length, index.step or 1
if index.step == 0:
raise CompilerError('slice step cannot be zero')
return index.start or 0, \
min(index.stop or self.length, self.length), index.step or 1
def __getitem__(self, index):
""" Reading from array.
:param index: public (regint/cint/int/slice)
:return: array if slice is given, basic type otherwise"""
:return: vector if slice is given, basic type otherwise"""
if isinstance(index, slice):
start, stop, step = self.get_slice(index)
res_length = (stop - start - 1) // step + 1
res = Array(res_length, self.value_type)
@library.for_range(res_length)
def f(i):
res[i] = self[start+i*step]
return res
if step == 1:
return self.get_vector(start, stop - start)
else:
res_length = (stop - start - 1) // step + 1
addresses = regint.inc(res_length, start, step)
return self.get_vector(addresses, res_length)
return self._load(self.get_address(index))
def __setitem__(self, index, value):
@@ -4961,15 +5152,21 @@ class Array(object):
:param value: convertible for relevant basic type """
if isinstance(index, slice):
start, stop, step = self.get_slice(index)
value = Array.create_from(value)
source_index = MemValue(0)
@library.for_range(start, stop, step)
def f(i):
self[i] = value[source_index]
source_index.iadd(1)
return
if step == 1:
return self.assign(value, start)
else:
res_length = (stop - start - 1) // step + 1
addresses = regint.inc(res_length, start, step)
return self.assign(value, addresses)
self._store(value, self.get_address(index))
def get_sub(self, start, stop=None):
if stop is None:
stop = start
start = 0
return Array(stop - start, self.value_type,
address=self.address + start)
def maybe_get(self, condition, index):
""" Return entry if condition is true.
@@ -4989,7 +5186,7 @@ class Array(object):
self.sink = self.value_type.Array(
1, address=self.value_type.malloc(1, creator_tape=program.tapes[0]))
addresses = (condition.if_else(x, y) for x, y in
zip(util.tuplify(self.get_address(index)),
zip(util.tuplify(self.get_address(condition * index)),
util.tuplify(self.sink.get_address(0))))
self._store(value, util.untuplify(tuple(addresses)))
@@ -5034,10 +5231,10 @@ class Array(object):
except:
pass
try:
other.store_in_mem(self.get_address(base))
self.value_type.conv(other).store_in_mem(self.get_address(base))
if len(self) != None and util.is_constant(base):
assert len(self) >= other.size + base
except AttributeError:
except (AttributeError, CompilerError):
if isinstance(other, Array):
@library.for_range_opt(len(other))
def _(i):
@@ -5071,7 +5268,7 @@ class Array(object):
:param base: starting point (regint/cint/int)
:param size: length (compile-time int) """
size = size or self.length
size = size or self.length - base
return self.value_type.load_mem(self.get_address(base), size=size)
get_part_vector = get_vector
@@ -5249,7 +5446,7 @@ sint.dynamic_array = Array
sgf2n.dynamic_array = Array
class SubMultiArray(object):
class SubMultiArray(_vectorizable):
""" Multidimensional array functionality. Don't construct this
directly, use :py:class:`MultiArray` instead. """
def __init__(self, sizes, value_type, address, index, debug=None):
@@ -5261,6 +5458,7 @@ class SubMultiArray(object):
self.address = None
self.sub_cache = {}
self.debug = debug
self.check_indices = True
if debug:
library.print_ln_if(self.address + reduce(operator.mul, self.sizes) * self.value_type.n_elements() > program.allocated_mem[self.value_type.reg_type], 'AOF%d:' % len(self.sizes) + self.debug)
@@ -5271,11 +5469,17 @@ class SubMultiArray(object):
:return: :py:class:`Array` if one-dimensional, :py:class:`SubMultiArray` otherwise"""
if util.is_constant(index) and index >= self.sizes[0]:
raise StopIteration
if isinstance(index, slice) and index == slice(None):
return self.get_vector()
key = program.curr_block, str(index)
if key not in self.sub_cache:
if self.debug:
library.print_ln_if(index >= self.sizes[0], \
'OF%d:' % len(self.sizes) + self.debug)
if util.is_constant(index) and \
(index >= self.sizes[0] or index < 0):
raise CompilerError('index out of range')
elif self.check_indices:
library.runtime_error_if(index >= self.sizes[0],
'overflow: %s/%s',
index, self.sizes)
if len(self.sizes) == 2:
self.sub_cache[key] = \
Array(self.sizes[1], self.value_type, \
@@ -5293,6 +5497,8 @@ class SubMultiArray(object):
:param index: public (regint/cint/int)
:param other: container of matching size and type """
if isinstance(index, slice) and index == slice(None):
return self.assign(other)
self[index].assign(other)
def __len__(self):
@@ -5526,13 +5732,18 @@ class SubMultiArray(object):
self.assign_vector(self.get_vector() + other.get_vector())
def __mul__(self, other):
# legacy function
return self.mul(other)
def mul(self, other, res_params=None):
# legacy function
return self.dot(other, res_params)
def dot(self, other, res_params=None):
""" Matrix-matrix and matrix-vector multiplication.
:param self: two-dimensional
:param other: Matrix or Array of matching size and type """
return self.mul(other)
def mul(self, other, res_params=None):
assert len(self.sizes) == 2
if isinstance(other, Array):
assert len(other) == self.sizes[1]
@@ -5762,11 +5973,16 @@ class SubMultiArray(object):
assert len(self.sizes) == 2
res = Matrix(self.sizes[1], self.sizes[0], self.value_type)
library.break_point()
@library.for_range_opt(self.sizes[1])
def _(i):
if self.value_type.n_elements() == 1:
@library.for_range_opt(self.sizes[0])
def _(j):
res[i][j] = self[j][i]
res.set_column(j, self[j][:])
else:
@library.for_range_opt(self.sizes[1])
def _(i):
@library.for_range_opt(self.sizes[0])
def _(j):
res[i][j] = self[j][i]
library.break_point()
return res
@@ -5809,14 +6025,22 @@ class MultiArray(SubMultiArray):
"""
Multidimensional array. The access operator (``a[i]``) allows to a
multi-dimensional array of dimension one less or a simple array
for a two-dimensional array. Element-wise addition and subtraction
is supported, returning a vector, which can be assigned using
:py:func:`assign`. Matrix-vector and matrix-matrix multiplication
is supported as well.
for a two-dimensional array.
:param sizes: shape (compile-time list of integers)
:param value_type: basic type of entries
You can convert between arrays and register vectors by using slice
indexing. This allows for element-wise operations as long as
supported by the basic type. The following has the first two parties
input a 10x10 secret integer matrix followed by storing the
element-wise multiplications in the same data structure::
a = sint.Tensor([3, 10, 10])
a[0].input_from(0)
a[1].input_from(1)
a[2][:] = a[0][:] * a[1][:]
"""
def __init__(self, sizes, value_type, debug=None, address=None, alloc=True):
if isinstance(address, Array):

View File

@@ -35,7 +35,7 @@ int main(int argc, const char** argv)
int n_tuples = 1000;
if (not opt.lastArgs.empty())
n_tuples = atoi(opt.lastArgs[0]->c_str());
PlainPlayer P(N);
PlainPlayer P(N, "ecdsa");
P256Element::init();
P256Element::Scalar keyp;

View File

@@ -43,7 +43,7 @@ void run(int argc, const char** argv)
int n_tuples = 1000;
if (not opt.lastArgs.empty())
n_tuples = atoi(opt.lastArgs[0]->c_str());
CryptoPlayer P(N);
CryptoPlayer P(N, "ecdsa");
P256Element::init();
typedef T<P256Element::Scalar> pShare;
OnlineOptions::singleton.batch_size = 1;

View File

@@ -88,7 +88,7 @@ void run(int argc, const char** argv)
int n_tuples = 1000;
if (not opt.lastArgs.empty())
n_tuples = atoi(opt.lastArgs[0]->c_str());
PlainPlayer P(N);
PlainPlayer P(N, "ecdsa");
P256Element::init();
P256Element::Scalar::next::init_field(P256Element::Scalar::pr(), false);

31
ExternalIO/Client.h Normal file
View File

@@ -0,0 +1,31 @@
/*
* Client.h
*
*/
#ifndef EXTERNALIO_CLIENT_H_
#define EXTERNALIO_CLIENT_H_
#include "Networking/ssl_sockets.h"
class Client
{
vector<int> plain_sockets;
ssl_ctx ctx;
ssl_service io_service;
public:
vector<ssl_socket*> sockets;
octetStream specification;
Client(const vector<string>& hostnames, int port_base, int my_client_id);
~Client();
template<class T>
void send_private_inputs(const vector<T>& values);
template<class T>
vector<T> receive_outputs(int n);
};
#endif /* EXTERNALIO_CLIENT_H_ */

126
ExternalIO/Client.hpp Normal file
View File

@@ -0,0 +1,126 @@
/*
* Client.cpp
*
*/
#include "Client.h"
inline
Client::Client(const vector<string>& hostnames, int port_base,
int my_client_id) :
ctx("C" + to_string(my_client_id))
{
bigint::init_thread();
// Setup connections from this client to each party socket
int nparties = hostnames.size();
plain_sockets.resize(nparties);
sockets.resize(nparties);
for (int i = 0; i < nparties; i++)
{
set_up_client_socket(plain_sockets[i], hostnames[i].c_str(), port_base + i);
octetStream(to_string(my_client_id)).Send(plain_sockets[i]);
sockets[i] = new ssl_socket(io_service, ctx, plain_sockets[i],
"P" + to_string(i), "C" + to_string(my_client_id), true);
if (i == 0)
specification.Receive(sockets[0]);
}
}
inline
Client::~Client()
{
for (auto& socket : sockets)
{
delete socket;
}
}
// Send the private inputs masked with a random value.
// Receive shares of a preprocessed triple from each SPDZ engine, combine and check the triples are valid.
// Add the private input value to triple[0] and send to each spdz engine.
template<class T>
void Client::send_private_inputs(const vector<T>& values)
{
int num_inputs = values.size();
octetStream os;
vector< vector<T> > triples(num_inputs, vector<T>(3));
vector<T> triple_shares(3);
// Receive num_inputs triples from SPDZ
for (size_t j = 0; j < sockets.size(); j++)
{
os.reset_write_head();
os.Receive(sockets[j]);
#ifdef VERBOSE_COMM
cerr << "received " << os.get_length() << " from " << j << endl;
#endif
for (int j = 0; j < num_inputs; j++)
{
for (int k = 0; k < 3; k++)
{
triple_shares[k].unpack(os);
triples[j][k] += triple_shares[k];
}
}
}
// Check triple relations (is a party cheating?)
for (int i = 0; i < num_inputs; i++)
{
if (T(triples[i][0] * triples[i][1]) != triples[i][2])
{
cerr << triples[i][2] << " != " << triples[i][0] << " * " << triples[i][1] << endl;
cerr << "Incorrect triple at " << i << ", aborting\n";
throw mac_fail();
}
}
// Send inputs + triple[0], so SPDZ can compute shares of each value
os.reset_write_head();
for (int i = 0; i < num_inputs; i++)
{
T y = values[i] + triples[i][0];
y.pack(os);
}
for (auto& socket : sockets)
os.Send(socket);
}
// Receive shares of the result and sum together.
// Also receive authenticating values.
template<class T>
vector<T> Client::receive_outputs(int n)
{
vector<T> triples(3 * n);
octetStream os;
for (auto& socket : sockets)
{
os.reset_write_head();
os.Receive(socket);
#ifdef VERBOSE_COMM
cout << "received " << os.get_length() << endl;
#endif
for (int j = 0; j < 3 * n; j++)
{
T value;
value.unpack(os);
triples[j] += value;
}
}
vector<T> output_values;
for (int i = 0; i < 3 * n; i += 3)
{
if (T(triples[i] * triples[i + 1]) != triples[i + 2])
{
cerr << "Unable to authenticate output value as correct, aborting." << endl;
throw mac_fail();
}
output_values.push_back(triples[i]);
}
return output_values;
}

View File

@@ -39,111 +39,33 @@
#include "Protocols/fake-stuff.h"
#include "Math/gfp.hpp"
#include "Client.hpp"
#include <sodium.h>
#include <iostream>
#include <sstream>
#include <fstream>
// Send the private inputs masked with a random value.
// Receive shares of a preprocessed triple from each SPDZ engine, combine and check the triples are valid.
// Add the private input value to triple[0] and send to each spdz engine.
template<class T>
void send_private_inputs(const vector<T>& values, vector<ssl_socket*>& sockets, int nparties)
{
int num_inputs = values.size();
octetStream os;
vector< vector<T> > triples(num_inputs, vector<T>(3));
vector<T> triple_shares(3);
// Receive num_inputs triples from SPDZ
for (int j = 0; j < nparties; j++)
{
os.reset_write_head();
os.Receive(sockets[j]);
#ifdef VERBOSE_COMM
cerr << "received " << os.get_length() << " from " << j << endl;
#endif
for (int j = 0; j < num_inputs; j++)
{
for (int k = 0; k < 3; k++)
{
triple_shares[k].unpack(os);
triples[j][k] += triple_shares[k];
}
}
}
// Check triple relations (is a party cheating?)
for (int i = 0; i < num_inputs; i++)
{
if (T(triples[i][0] * triples[i][1]) != triples[i][2])
{
cerr << triples[i][2] << " != " << triples[i][0] << " * " << triples[i][1] << endl;
cerr << "Incorrect triple at " << i << ", aborting\n";
throw mac_fail();
}
}
// Send inputs + triple[0], so SPDZ can compute shares of each value
os.reset_write_head();
for (int i = 0; i < num_inputs; i++)
{
T y = values[i] + triples[i][0];
y.pack(os);
}
for (int j = 0; j < nparties; j++)
os.Send(sockets[j]);
}
// Receive shares of the result and sum together.
// Also receive authenticating values.
template<class T>
T receive_result(vector<ssl_socket*>& sockets, int nparties)
{
vector<T> output_values(3);
octetStream os;
for (int i = 0; i < nparties; i++)
{
os.reset_write_head();
os.Receive(sockets[i]);
for (unsigned int j = 0; j < 3; j++)
{
T value;
value.unpack(os);
output_values[j] += value;
}
}
if (T(output_values[0] * output_values[1]) != output_values[2])
{
cerr << "Unable to authenticate output value as correct, aborting." << endl;
throw mac_fail();
}
return output_values[0];
}
template<class T>
void one_run(T salary_value, vector<ssl_socket*>& sockets, int nparties)
void one_run(T salary_value, Client& client)
{
// Run the computation
send_private_inputs<T>({salary_value}, sockets, nparties);
client.send_private_inputs<T>({salary_value});
cout << "Sent private inputs to each SPDZ engine, waiting for result..." << endl;
// Get the result back (client_id of winning client)
T result = receive_result<T>(sockets, nparties);
T result = client.receive_outputs<T>(1)[0];
cout << "Winning client id is : " << result << endl;
}
template<class T>
void run(double salary_value, vector<ssl_socket*>& sockets, int nparties)
void run(double salary_value, Client& client)
{
// sint
one_run<T>(long(round(salary_value)), sockets, nparties);
one_run<T>(long(round(salary_value)), client);
// sfix with f = 16
one_run<T>(long(round(salary_value * exp2(16))), sockets, nparties);
one_run<T>(long(round(salary_value * exp2(16))), client);
}
int main(int argc, char** argv)
@@ -165,7 +87,7 @@ int main(int argc, char** argv)
nparties = atoi(argv[2]);
salary_value = atof(argv[3]);
finish = atoi(argv[4]);
vector<const char*> hostnames(nparties, "localhost");
vector<string> hostnames(nparties, "localhost");
if (argc > 5)
{
@@ -185,19 +107,11 @@ int main(int argc, char** argv)
bigint::init_thread();
// Setup connections from this client to each party socket
vector<int> plain_sockets(nparties);
vector<ssl_socket*> sockets(nparties);
ssl_ctx ctx("C" + to_string(my_client_id));
ssl_service io_service;
octetStream specification;
Client client(hostnames, port_base, my_client_id);
auto& specification = client.specification;
auto& sockets = client.sockets;
for (int i = 0; i < nparties; i++)
{
set_up_client_socket(plain_sockets[i], hostnames[i], port_base + i);
send(plain_sockets[i], (octet*) &my_client_id, sizeof(int));
sockets[i] = new ssl_socket(io_service, ctx, plain_sockets[i],
"P" + to_string(i), "C" + to_string(my_client_id), true);
if (i == 0)
specification.Receive(sockets[0]);
octetStream os;
os.store(finish);
os.Send(sockets[i]);
@@ -211,7 +125,7 @@ int main(int argc, char** argv)
{
gfp::init_field(specification.get<bigint>());
cerr << "using prime " << gfp::pr() << endl;
run<gfp>(salary_value, sockets, nparties);
run<gfp>(salary_value, client);
break;
}
case 'R':
@@ -220,13 +134,13 @@ int main(int argc, char** argv)
switch (R)
{
case 64:
run<Z2<64>>(salary_value, sockets, nparties);
run<Z2<64>>(salary_value, client);
break;
case 104:
run<Z2<104>>(salary_value, sockets, nparties);
run<Z2<104>>(salary_value, client);
break;
case 128:
run<Z2<128>>(salary_value, sockets, nparties);
run<Z2<128>>(salary_value, client);
break;
default:
cerr << R << "-bit ring not implemented";
@@ -239,8 +153,5 @@ int main(int argc, char** argv)
exit(1);
}
for (int i = 0; i < nparties; i++)
delete sockets[i];
return 0;
}

View File

@@ -63,6 +63,16 @@ public:
return *this;
}
AddableVector<T> operator-(const AddableVector<T>& y) const
{
if (this->size() != y.size())
throw out_of_range("vector length mismatch");
AddableVector<T> res(y.size());
for (unsigned int i = 0; i < this->size(); i++)
res[i] = (*this)[i] - y[i];
return res;
}
void mul(const AddableVector<T>& x, const AddableVector<T>& y)
{
if (x.size() != y.size())

71
FHE/Diagonalizer.cpp Normal file
View File

@@ -0,0 +1,71 @@
/*
* Diagonalizer.cpp
*
*/
#include "Diagonalizer.h"
Diagonalizer::Diagonalizer(const MatrixVector& matrices,
const FFT_Data& FTD, const FHE_PK& pk) :
FTD(FTD)
{
assert(not matrices.empty());
for (auto& matrix : matrices)
{
assert(matrix.n_cols == matrices[0].n_cols);
assert(matrix.n_rows == matrices[0].n_rows);
}
n_rows = matrices[0].n_rows;
n_cols = matrices[0].n_cols;
assert(n_rows * matrices.size() <= size_t(FTD.num_slots()));
for (size_t i = 0; i < n_cols; i++)
{
Plaintext_<FFT_Data> plaintext(FTD, Evaluation);
for (size_t k = 0; k < matrices.size(); k++)
{
for (size_t j = 0; j < n_rows; j++)
{
auto entry = matrices.at(k)[{j, (j + i) % n_cols}];
plaintext.set_element(k * n_rows + j, entry);
}
}
ciphertexts.push_back(pk.encrypt(plaintext));
}
}
Plaintext_<FFT_Data> Diagonalizer::get_plaintext(
const MatrixVector& matrices, int left_col,
int right_col)
{
Plaintext_<FFT_Data> plaintext(FTD, Evaluation);
for (size_t k = 0; k < matrices.size(); k++)
for (size_t j = 0; j < n_rows; j++)
plaintext.set_element(k * n_rows + j,
matrices.at(k)[{(left_col + j) % n_cols, right_col}]);
return plaintext;
}
Diagonalizer::MatrixVector Diagonalizer::decrypt(
const vector<Ciphertext>& products, int n_matrices, FHE_SK& sk)
{
vector<Plaintext_<FFT_Data>> plaintexts;
for (auto& x : products)
plaintexts.push_back(sk.decrypt(x, FTD));
return dediag(plaintexts, n_matrices);
}
Diagonalizer::MatrixVector Diagonalizer::dediag(
const vector<Plaintext_<FFT_Data>>& products, int n_matrices)
{
int n_cols_out = products.size();
MatrixVector res(n_matrices, {int(n_rows), n_cols_out});
for (int i = 0; i < n_cols_out; i++)
{
auto& c = products.at(i);
for (int j = 0; j < n_matrices; j++)
for (size_t k = 0; k < n_rows; k++)
res.at(j)[{k, i}] = c.element(j * n_rows + k);
}
return res;
}

36
FHE/Diagonalizer.h Normal file
View File

@@ -0,0 +1,36 @@
/*
* Diagonalizer.h
*
*/
#ifndef FHE_DIAGONALIZER_H_
#define FHE_DIAGONALIZER_H_
#include "Math/gfpvar.h"
#include "Ciphertext.h"
#include "Protocols/ShareMatrix.h"
class Diagonalizer
{
const FFT_Data& FTD;
size_t n_rows, n_cols;
public:
typedef AddableVector<ValueMatrix<gfpvar>> MatrixVector;
vector<Ciphertext> ciphertexts;
Diagonalizer(const MatrixVector& matrices,
const FFT_Data& FTD, const FHE_PK& pk);
Plaintext_<FFT_Data> get_plaintext(const MatrixVector& matrices,
int left_col, int right_col);
MatrixVector decrypt(const vector<Ciphertext>&, int n_matrices, FHE_SK& sk);
MatrixVector dediag(const vector<Plaintext_<FFT_Data>>& plaintexts,
int n_matrices);
};
#endif /* FHE_DIAGONALIZER_H_ */

View File

@@ -185,7 +185,8 @@ void FFT_Iter(vector<modp>& ioput, int n, const vector<modp>& roots,
int start = queues.distribute(job, n / 2);
for (int i = start; i < n / 2; i++)
FFT_Iter2_body(ioput, alpha2, i, m, PrD);
queues.wrap_up(job);
if (start > 0)
queues.wrap_up(job);
}
else
for (int i = 0; i < n / 2; i++)

View File

@@ -359,11 +359,19 @@ void FHE_PK::check(const FHE_Params& params, const bigint& pr) const
template void FHE_PK::encrypt(Ciphertext&, const Plaintext_<FFT_Data>& mess,
const Random_Coins& rc) const;
template void FHE_PK::encrypt(Ciphertext&, const Plaintext_<P2Data>& mess,
const Random_Coins& rc) const;
template Ciphertext FHE_PK::encrypt(const Plaintext_<FFT_Data>& mess,
const Random_Coins& rc) const;
template Ciphertext FHE_PK::encrypt(const Plaintext_<FFT_Data>& mess) const;
template Ciphertext FHE_PK::encrypt(const Plaintext_<P2Data>& mess) const;
template void FHE_SK::decrypt(Plaintext_<FFT_Data>&, const Ciphertext& c) const;
template void FHE_SK::decrypt(Plaintext_<P2Data>&, const Ciphertext& c) const;
template Plaintext_<FFT_Data> FHE_SK::decrypt(const Ciphertext& c,
const FFT_Data& FieldD);
template Plaintext_<P2Data> FHE_SK::decrypt(const Ciphertext& c,

View File

@@ -44,6 +44,20 @@ void Multiplier<FD>::multiply_and_add(Plaintext_<FD>& res,
template <class FD>
void Multiplier<FD>::multiply_and_add(Plaintext_<FD>& res,
const Ciphertext& enc_a, const Rq_Element& b, OT_ROLE role)
{
if (role & SENDER)
{
timers["Ciphertext multiplication"].start();
C.mul(enc_a, b);
timers["Ciphertext multiplication"].stop();
}
add(res, C, role);
}
template <class FD>
void Multiplier<FD>::add(Plaintext_<FD>& res, const Ciphertext& c,
OT_ROLE role, int n_summands)
{
o.reset_write_head();
@@ -51,9 +65,6 @@ void Multiplier<FD>::multiply_and_add(Plaintext_<FD>& res,
{
PRNG G;
G.ReSeed();
timers["Ciphertext multiplication"].start();
C.mul(enc_a, b);
timers["Ciphertext multiplication"].stop();
timers["Mask randomization"].start();
product_share.randomize(G);
bigint B = 6 * machine.setup<FD>().params.get_R();
@@ -63,12 +74,13 @@ void Multiplier<FD>::multiply_and_add(Plaintext_<FD>& res,
B *= NonInteractiveProof::slack(machine.sec,
machine.setup<FD>().params.phi_m());
B <<= machine.extra_slack;
B *= n_summands;
rc.generateUniform(G, 0, B, B);
timers["Mask randomization"].stop();
timers["Encryption"].start();
other_pk.encrypt(mask, product_share, rc);
timers["Encryption"].stop();
mask += C;
mask += c;
mask.pack(o);
res -= product_share;
}

View File

@@ -47,6 +47,8 @@ public:
const Plaintext_<FD>& b);
void multiply_and_add(Plaintext_<FD>& res, const Ciphertext& C,
const Rq_Element& b, OT_ROLE role = BOTH);
void add(Plaintext_<FD>& res, const Ciphertext& C, OT_ROLE role = BOTH,
int n_summands = 1);
void multiply_alpha_and_add(Plaintext_<FD>& res, const Rq_Element& b,
OT_ROLE role = BOTH);
int get_offset() { return P.get_offset(); }

View File

@@ -17,7 +17,7 @@ PairwiseMachine::PairwiseMachine(Player& P) :
}
PairwiseMachine::PairwiseMachine(int argc, const char** argv) :
MachineBase(argc, argv), P(*new PlainPlayer(N, 0xffff << 16)),
MachineBase(argc, argv), P(*new PlainPlayer(N, "pairwise")),
other_pks(N.num_players(), {setup_p.params, 0}),
pk(other_pks[N.my_num()]), sk(pk)
{

View File

@@ -31,7 +31,7 @@ public:
map<string, Timer> timers;
GeneratorBase(int thread_num, const Names& N, Player* player = 0) :
player(player ? 0 : new PlainPlayer(N, thread_num << 16)),
player(player ? 0 : new PlainPlayer(N, to_string(thread_num))),
thread_num(thread_num),
P(player ? *player : *this->player), thread(0), total(0)
{

View File

@@ -216,7 +216,7 @@ void MultiplicativeMachine::generate_setup(int slack)
template <class FD>
void MultiplicativeMachine::fake_keys(int slack)
{
PlainPlayer P(N, -1 * N.num_players() * N.num_players());
PlainPlayer P(N, "fake");
octetStream os;
PartSetup<FD>& part_setup = setup.part<FD>();
if (P.my_num() == 0)

View File

@@ -6,6 +6,7 @@
#include "AtlasSecret.h"
#include "TinyMC.h"
#include "Protocols/Shamir.hpp"
#include "Protocols/ShamirMC.hpp"
#include "Protocols/MAC_Check_Base.hpp"
#include "Secret.hpp"

View File

@@ -92,9 +92,9 @@ public:
}
}
size_t data_sent()
NamedCommStats comm_stats()
{
return part_prep.data_sent();
return part_prep.comm_stats();
}
};

View File

@@ -108,7 +108,8 @@ public:
static FakeSecret input(GC::Processor<FakeSecret>& processor, const InputArgs& args);
static FakeSecret input(int from, word input, int n_bits);
static FakeSecret constant(clear value, int = 0, mac_key_type = {}) { return value; }
static FakeSecret constant(clear value, int = 0, mac_key_type = {}, int = -1)
{ return value; }
FakeSecret() {}
template <class T>

View File

@@ -68,6 +68,7 @@ enum
CLEAR_WRITE = 0x210,
XORCBI = 0x210,
BITDECC = 0x211,
NOTCB = 0x212,
CONVCINT = 0x213,
REVEAL = 0x214,
STMSDCI = 0x215,

View File

@@ -84,12 +84,6 @@ public:
return new HashMaliciousRepMC<U>;
}
static U constant(const BitVec& other, int my_num, const BitVec& alphai)
{
(void) my_num, (void) alphai;
return other;
}
MalRepSecretBase() {}
template<class T>
MalRepSecretBase(const T& other) : super(other) {}

View File

@@ -143,7 +143,7 @@ public:
static void trans(Processor<NoShare>&, Integer, const vector<int>&) { fail(); }
static NoShare constant(const GC::Clear&, int, mac_key_type) { fail(); return {}; }
static NoShare constant(const GC::Clear&, int, mac_key_type, int = -1) { fail(); return {}; }
NoShare() {}

View File

@@ -86,6 +86,7 @@ public:
void xors(const vector<int>& args, size_t start, size_t end);
void xorc(const ::BaseInstruction& instruction);
void nots(const ::BaseInstruction& instruction);
void notcb(const ::BaseInstruction& instruction);
void andm(const ::BaseInstruction& instruction);
void and_(const vector<int>& args, bool repeat);
void andrs(const vector<int>& args) { and_(args, true); }

View File

@@ -257,6 +257,19 @@ void Processor<T>::nots(const ::BaseInstruction& instruction)
}
}
template<class T>
void Processor<T>::notcb(const ::BaseInstruction& instruction)
{
int total = instruction.get_n();
int unit = Clear::N_BITS;
for (int i = 0; i < DIV_CEIL(total, unit); i++)
{
int n = min(unit, total - i * unit);
C[instruction.get_r(0) + i] =
Clear(~C[instruction.get_r(1) + i].get()).mask(n);
}
}
template<class T>
void Processor<T>::andm(const ::BaseInstruction& instruction)
{

View File

@@ -30,7 +30,7 @@ public:
static MC* new_mc(typename super::mac_key_type) { return new MC; }
static This constant(const typename super::clear& constant, int my_num,
typename super::mac_key_type = {})
typename super::mac_key_type = {}, int = -1)
{
return Rep4Share<typename super::clear>::constant(constant, my_num);
}

View File

@@ -10,6 +10,7 @@
#include "Protocols/ReplicatedPrep.hpp"
#include "Protocols/MAC_Check_Base.hpp"
#include "Protocols/Replicated.hpp"
#include "OT/NPartyTripleGenerator.hpp"
namespace GC
@@ -65,12 +66,12 @@ void SemiPrep::buffer_bits()
}
}
size_t SemiPrep::data_sent()
NamedCommStats SemiPrep::comm_stats()
{
if (triple_generator)
return triple_generator->data_sent();
return triple_generator->comm_stats();
else
return 0;
return {};
}
} /* namespace GC */

View File

@@ -53,7 +53,7 @@ public:
throw not_implemented();
}
size_t data_sent();
NamedCommStats comm_stats();
};
} /* namespace GC */

View File

@@ -102,16 +102,16 @@ ShareParty<T>::ShareParty(int argc, const char** argv, ez::ezOptionParser& opt,
if (not this->machine.use_encryption and not T::dishonest_majority)
insecure("unencrypted communication");
Server* server = network_opts.start_networking(this->N, my_num);
network_opts.start_networking(this->N, my_num);
if (online_opts.live_prep)
if (T::needs_ot)
{
Player* P;
if (this->machine.use_encryption)
P = new CryptoPlayer(this->N, 0xFFFF);
P = new CryptoPlayer(this->N, "shareparty");
else
P = new PlainPlayer(this->N, 0xFFFF);
P = new PlainPlayer(this->N, "shareparty");
for (int i = 0; i < this->machine.nthreads; i++)
this->machine.ot_setups.push_back({*P, true});
delete P;
@@ -133,9 +133,6 @@ ShareParty<T>::ShareParty(int argc, const char** argv, ez::ezOptionParser& opt,
this->run();
this->machine.write_memory(this->N.my_num());
if (server)
delete server;
}
template<class T>

View File

@@ -171,8 +171,8 @@ class ReplicatedSecret : public RepSecretBase<U, 2>
public:
typedef ReplicatedBase Protocol;
static ReplicatedSecret constant(const typename super::clear& value, int my_num,
typename super::mac_key_type)
static ReplicatedSecret constant(const typename super::clear& value,
int my_num, typename super::mac_key_type, int = -1)
{
ReplicatedSecret res;
if (my_num < 2)

View File

@@ -58,8 +58,8 @@ public:
void pre_run();
void post_run() { ShareThread<T>::post_run(); }
size_t data_sent()
{ return Thread<T>::data_sent() + this->DataF.data_sent(); }
NamedCommStats comm_stats()
{ return Thread<T>::comm_stats() + this->DataF.comm_stats(); }
};
template<class T>

View File

@@ -56,7 +56,7 @@ public:
void join_tape();
void finish();
virtual size_t data_sent();
virtual NamedCommStats comm_stats();
};
template<class T>

View File

@@ -51,10 +51,11 @@ void Thread<T>::run()
singleton = this;
BaseMachine::s().thread_num = thread_num;
secure_prng.ReSeed();
string id = "T" + to_string(thread_num);
if (machine.use_encryption)
P = new CryptoPlayer(N, thread_num << 16);
P = new CryptoPlayer(N, id);
else
P = new PlainPlayer(N, thread_num << 16);
P = new PlainPlayer(N, id);
processor.open_input_file(N.my_num(), thread_num,
master.opts.cmd_private_input_file);
processor.out.activate(N.my_num() == 0 or master.opts.interactive);
@@ -98,10 +99,10 @@ void Thread<T>::finish()
}
template<class T>
size_t GC::Thread<T>::data_sent()
NamedCommStats Thread<T>::comm_stats()
{
assert(P);
return P->comm_stats.total_data();
return P->comm_stats;
}
} /* namespace GC */

View File

@@ -58,7 +58,7 @@ Thread<T>* ThreadMaster<T>::new_thread(int i)
template<class T>
void ThreadMaster<T>::run()
{
P = new PlainPlayer(N, 0xff << 24);
P = new PlainPlayer(N, "main");
machine.load_schedule(progname);
@@ -87,12 +87,10 @@ void ThreadMaster<T>::run()
NamedCommStats stats = P->comm_stats;
ExecutionStats exe_stats;
size_t data_sent = P->comm_stats.total_data();
for (auto thread : threads)
{
stats += thread->P->comm_stats;
exe_stats += thread->processor.stats;
data_sent += thread->data_sent();
delete thread;
}
@@ -102,7 +100,7 @@ void ThreadMaster<T>::run()
stats.print();
cerr << "Time = " << timer.elapsed() << " seconds" << endl;
cerr << "Data sent = " << data_sent * 1e-6 << " MB" << endl;
cerr << "Data sent = " << stats.sent * 1e-6 << " MB" << endl;
}
} /* namespace GC */

View File

@@ -48,7 +48,7 @@ public:
void set_protocol(typename T::Protocol& protocol);
size_t data_sent();
NamedCommStats comm_stats();
};
}

View File

@@ -92,13 +92,13 @@ void GC::TinierSharePrep<T>::buffer_bits()
}
template<class T>
size_t TinierSharePrep<T>::data_sent()
NamedCommStats TinierSharePrep<T>::comm_stats()
{
size_t res = 0;
NamedCommStats res;
if (triple_generator)
res += triple_generator->data_sent();
res += triple_generator->comm_stats();
if (real_triple_generator)
res += real_triple_generator->data_sent();
res += real_triple_generator->comm_stats();
return res;
}

View File

@@ -70,11 +70,14 @@ public:
T::reveal_inst(processor, args);
}
static This constant(BitVec other, int my_num, mac_key_type alphai)
static This constant(BitVec other, int my_num, mac_key_type alphai,
int n_bits = -1)
{
if (n_bits < 0)
n_bits = other.length();
This res;
res.resize_regs(other.length());
for (int i = 0; i < other.length(); i++)
res.resize_regs(n_bits);
for (int i = 0; i < n_bits; i++)
res.get_reg(i) = part_type::constant(other.get_bit(i), my_num, alphai);
return res;
}

View File

@@ -43,6 +43,7 @@
X(XORCB, processor.xorc(instruction)) \
X(XORCBI, C0.xor_(PC1, IMM)) \
X(NOTS, processor.nots(INST)) \
X(NOTCB, processor.notcb(INST)) \
X(ANDRS, T::andrs(PROC, EXTRA)) \
X(ANDS, T::ands(PROC, EXTRA)) \
X(ADDCB, C0 = PC1 + PC2) \
@@ -140,6 +141,7 @@
X(NPLAYERS, I0 = Thread<T>::s().P->num_players()) \
X(THRESHOLD, I0 = T::threshold(Thread<T>::s().P->num_players())) \
X(PLAYERID, I0 = Thread<T>::s().P->my_num()) \
X(CRASH, if (I0.get()) throw crash_requested()) \
#define INSTRUCTIONS BIT_INSTRUCTIONS GC_INSTRUCTIONS

View File

@@ -226,7 +226,7 @@ OTMachine::OTMachine(int argc, const char** argv)
N.push_back(new Names(my_num, portnum_base + 1000 * N.size(), names));
}
P = new RealTwoPartyPlayer(*N[0], 1 - my_num, 500);
P = new RealTwoPartyPlayer(*N[0], 1 - my_num, "machine");
timeval baseOTstart, baseOTend;
gettimeofday(&baseOTstart, NULL);
@@ -319,7 +319,8 @@ void OTMachine::run()
}
// now setup resources for each thread
// round robin with the names
players[i] = new RealTwoPartyPlayer(*N[i%N.size()], 1 - my_num, (i+1) * 1000);
players[i] = new RealTwoPartyPlayer(*N[i % N.size()], 1 - my_num,
"thread" + to_string(i));
tinfos[i].thread_num = i+1;
tinfos[i].other_player_num = 1 - my_num;
tinfos[i].nOTs = nOTs;

View File

@@ -138,7 +138,10 @@ TripleMachine::TripleMachine(int argc, const char** argv) :
opt.get("-S")->getInt(z2s);
// doesn't work with Montgomery multiplication
gfpvar1::init_field(prime, false);
if (prime)
gfpvar1::init_field(prime, false);
else
gfpvar1::init_default(128, false);
gf2n_long::init_field(128);
gf2n_short::init_field(40);
@@ -175,7 +178,7 @@ void TripleMachine::run()
nConnections = 2;
}
// do the base OTs
PlainPlayer P(N[0], 0xF000);
PlainPlayer P(N[0], "base");
OTTripleSetup setup(P, true);
vector<GeneratorThread*> generators(nthreads);

View File

@@ -4,6 +4,7 @@
*/
#include "Protocols/HemiShare.h"
#include "Protocols/HemiOptions.h"
#include "Math/gfp.h"
#include "Math/gf2n.h"
#include "FHE/P2Data.h"
@@ -22,6 +23,7 @@
#include "Protocols/MAC_Check.hpp"
#include "Protocols/SemiMC.hpp"
#include "Protocols/Beaver.hpp"
#include "Protocols/Hemi.hpp"
#include "GC/ShareSecret.hpp"
#include "GC/SemiHonestRepPrep.h"
#include "Math/gfp.hpp"
@@ -29,6 +31,7 @@
int main(int argc, const char** argv)
{
ez::ezOptionParser opt;
HemiOptions::singleton = {opt, argc, argv};
DishonestMajorityFieldMachine<HemiShare, HemiShare, gf2n_short>(argc, argv,
opt);
}

View File

@@ -9,7 +9,7 @@ NETWORK = $(patsubst %.cpp,%.o,$(wildcard Networking/*.cpp))
PROCESSOR = $(patsubst %.cpp,%.o,$(wildcard Processor/*.cpp))
FHEOFFLINE = $(patsubst %.cpp,%.o,$(wildcard FHEOffline/*.cpp FHE/*.cpp)) Protocols/CowGearOptions.o
FHEOBJS = $(patsubst %.cpp,%.o,$(wildcard FHEOffline/*.cpp FHE/*.cpp)) Protocols/CowGearOptions.o
GC = $(patsubst %.cpp,%.o,$(wildcard GC/*.cpp)) $(PROCESSOR)
GC_SEMI = GC/SemiSecret.o GC/SemiPrep.o GC/square64.o
@@ -17,32 +17,38 @@ GC_SEMI = GC/SemiSecret.o GC/SemiPrep.o GC/square64.o
OT = $(patsubst %.cpp,%.o,$(wildcard OT/*.cpp))
OT_EXE = ot.x ot-offline.x
COMMON = $(MATH) $(TOOLS) $(NETWORK) GC/square64.o Processor/OnlineOptions.o Processor/BaseMachine.o Processor/DataPositions.o Processor/ThreadQueues.o Processor/ThreadQueue.o
COMMONOBJS = $(MATH) $(TOOLS) $(NETWORK) GC/square64.o Processor/OnlineOptions.o Processor/BaseMachine.o Processor/DataPositions.o Processor/ThreadQueues.o Processor/ThreadQueue.o
COMPLETE = $(COMMON) $(PROCESSOR) $(FHEOFFLINE) $(TINYOTOFFLINE) $(GC) $(OT)
YAO = $(patsubst %.cpp,%.o,$(wildcard Yao/*.cpp)) $(OT) BMR/Key.o
BMR = $(patsubst %.cpp,%.o,$(wildcard BMR/*.cpp BMR/network/*.cpp))
VM = $(PROCESSOR) $(COMMON) GC/square64.o GC/Instruction.o OT/OTTripleSetup.o OT/BaseOT.o $(LIBSIMPLEOT)
MINI_OT = OT/OTTripleSetup.o OT/BaseOT.o $(LIBSIMPLEOT)
VMOBJS = $(PROCESSOR) $(COMMONOBJS) GC/square64.o GC/Instruction.o OT/OTTripleSetup.o OT/BaseOT.o $(LIBSIMPLEOT)
VM = $(MINI_OT) $(SHAREDLIB)
COMMON = $(SHAREDLIB)
LIB = libSPDZ.a
SHAREDLIB = libSPDZ.so
FHEOFFLINE = libFHE.so
LIBRELEASE = librelease.a
ifeq ($(AVX_OT), 0)
VM += ECDSA/P256Element.o
OT += ECDSA/P256Element.o
MINI_OT += ECDSA/P256Element.o
else
LIBSIMPLEOT = SimpleOT/libsimpleot.a
endif
# used for dependency generation
OBJS = $(BMR) $(FHEOFFLINE) $(TINYOTOFFLINE) $(YAO) $(COMPLETE) $(patsubst %.cpp,%.o,$(wildcard Machines/*.cpp Utils/*.cpp))
OBJS = $(BMR) $(FHEOBJS) $(TINYOTOFFLINE) $(YAO) $(COMPLETE) $(patsubst %.cpp,%.o,$(wildcard Machines/*.cpp Utils/*.cpp))
DEPS := $(wildcard */*.d */*/*.d)
# never delete
.SECONDARY: $(OBJS)
.SECONDARY: $(OBJS) $(patsubst %.cpp,%.o,$(wildcard */*.cpp))
all: arithmetic binary gen_input online offline externalIO bmr ecdsa doc
all: arithmetic binary gen_input online offline externalIO bmr ecdsa
vm: arithmetic binary
.PHONY: doc
@@ -112,13 +118,22 @@ sy: sy-rep-field-party.x sy-rep-ring-party.x sy-shamir-party.x
ecdsa: $(patsubst ECDSA/%.cpp,%.x,$(wildcard ECDSA/*-ecdsa-party.cpp)) Fake-ECDSA.x
ecdsa-static: static-dir $(patsubst ECDSA/%.cpp,static/%.x,$(wildcard ECDSA/*-ecdsa-party.cpp))
$(LIBRELEASE): Protocols/MalRepRingOptions.o $(PROCESSOR) $(COMMON) $(OT) $(GC)
$(LIBRELEASE): Protocols/MalRepRingOptions.o $(PROCESSOR) $(COMMONOBJS) $(OT) $(GC)
$(AR) -csr $@ $^
CFLAGS += -fPIC
LDLIBS += -Wl,-rpath -Wl,$(CURDIR)
$(SHAREDLIB): $(PROCESSOR) $(COMMONOBJS) GC/square64.o GC/Instruction.o
$(CXX) $(CFLAGS) -shared -o $@ $^ $(LDLIBS)
$(FHEOFFLINE): $(FHEOBJS) $(SHAREDLIB)
$(CXX) $(CFLAGS) -shared -o $@ $^ $(LDLIBS)
static/%.x: Machines/%.o $(LIBRELEASE) $(LIBSIMPLEOT)
$(CXX) $(CFLAGS) -o $@ $^ -Wl,-Map=$<.map -Wl,-Bstatic -static-libgcc -static-libstdc++ $(LIBRELEASE) $(LIBSIMPLEOT) $(BOOST) $(LDLIBS) -Wl,-Bdynamic -ldl
static/%.x: ECDSA/%.o ECDSA/P256Element.o $(VM) $(OT) $(LIBSIMPLEOT)
static/%.x: ECDSA/%.o ECDSA/P256Element.o $(VMOBJS) $(OT) $(LIBSIMPLEOT)
$(CXX) $(CFLAGS) -o $@ $^ -Wl,-Map=$<.map -Wl,-Bstatic -static-libgcc -static-libstdc++ $(BOOST) $(LDLIBS) -Wl,-Bdynamic -ldl
static-dir:
@@ -141,7 +156,7 @@ gc-emulate.x: $(VM) GC/FakeSecret.o GC/square64.o
bmr-%.x: $(BMR) $(VM) Machines/bmr-%.cpp $(LIBSIMPLEOT)
$(CXX) -o $@ $(CFLAGS) $^ $(BOOST) $(LDLIBS)
%-bmr-party.x: Machines/%-bmr-party.o $(BMR) $(VM) $(LIBSIMPLEOT)
%-bmr-party.x: Machines/%-bmr-party.o $(BMR) $(SHAREDLIB) $(MINI_OT)
$(CXX) -o $@ $(CFLAGS) $^ $(BOOST) $(LDLIBS)
bmr-clean:
@@ -170,15 +185,15 @@ default-prime-length.x: Utils/default-prime-length.o
secure.x: Utils/secure.o
$(CXX) -o $@ $(CFLAGS) $^
Fake-Offline.x: Utils/Fake-Offline.o $(VM)
$(CXX) -o $@ $(CFLAGS) $^ $(LDLIBS)
%.x: Utils/%.o $(COMMON)
$(CXX) -o $@ $(CFLAGS) $^ $(LDLIBS)
%.x: Machines/%.o $(VM) OT/OTTripleSetup.o OT/BaseOT.o $(LIBSIMPLEOT)
%.x: Machines/%.o $(MINI_OT) $(SHAREDLIB)
$(CXX) -o $@ $(CFLAGS) $^ $(LDLIBS)
%gear-party.x: Machines/%gear-party.o $(VM) OT/OTTripleSetup.o OT/BaseOT.o $(LIBSIMPLEOT)
$(CXX) -o $@ $(CFLAGS) $^ $(LDLIBS) -lntl
%-ecdsa-party.x: ECDSA/%-ecdsa-party.o ECDSA/P256Element.o $(VM)
$(CXX) -o $@ $(CFLAGS) $^ $(LDLIBS)
@@ -188,7 +203,7 @@ replicated-field-party.x: GC/square64.o
brain-party.x: GC/square64.o
malicious-rep-bin-party.x: GC/square64.o
ps-rep-bin-party.x: GC/PostSacriBin.o
semi-bin-party.x: $(VM) $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o
semi-bin-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o
tiny-party.x: $(OT)
tinier-party.x: $(OT)
spdz2k-party.x: $(OT) $(patsubst %.cpp,%.o,$(wildcard Machines/SPDZ2*.cpp))
@@ -202,12 +217,12 @@ chaigear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(OT)
lowgear-party.x: $(FHEOFFLINE) $(OT) Protocols/CowGearOptions.o Protocols/LowGearKeyGen.o
highgear-party.x: $(FHEOFFLINE) $(OT) Protocols/CowGearOptions.o Protocols/HighGearKeyGen.o
atlas-party.x: GC/AtlasSecret.o
static/hemi-party.x: $(FHEOFFLINE)
static/soho-party.x: $(FHEOFFLINE)
static/cowgear-party.x: $(FHEOFFLINE)
static/chaigear-party.x: $(FHEOFFLINE)
static/lowgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o Protocols/LowGearKeyGen.o
static/highgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o Protocols/HighGearKeyGen.o
static/hemi-party.x: $(FHEOBJS)
static/soho-party.x: $(FHEOBJS)
static/cowgear-party.x: $(FHEOBJS)
static/chaigear-party.x: $(FHEOBJS)
static/lowgear-party.x: $(FHEOBJS) Protocols/CowGearOptions.o Protocols/LowGearKeyGen.o
static/highgear-party.x: $(FHEOBJS) Protocols/CowGearOptions.o Protocols/HighGearKeyGen.o
mascot-party.x: Machines/SPDZ.o $(OT)
static/mascot-party.x: Machines/SPDZ.o
Player-Online.x: Machines/SPDZ.o $(OT)
@@ -226,7 +241,6 @@ real-bmr-party.x: $(OT)
paper-example.x: $(VM) $(OT) $(FHEOFFLINE)
mascot-offline.x: $(VM) $(OT)
cowgear-offline.x: $(OT) $(FHEOFFLINE)
Fake-Offline.x: $(VM)
static/rep-bmr-party.x: $(BMR)
static/mal-rep-bmr-party.x: $(BMR)
static/shamir-bmr-party.x: $(BMR)
@@ -269,7 +283,7 @@ mpir: mpir-setup
./configure --enable-cxx --prefix=$(CURDIR)/local
$(MAKE) -C mpir install
-echo MY_CFLAGS += -I./local/include >> CONFIG.mine
-echo MY_LDLIBS += -Wl,-rpath -Wl,./local/lib -L./local/lib >> CONFIG.mine
-echo MY_LDLIBS += -Wl,-rpath -Wl,$(CURDIR)/local/lib -L$(CURDIR)/local/lib >> CONFIG.mine
mac-setup:
brew install openssl boost libsodium mpir yasm ntl
@@ -281,4 +295,4 @@ simde/simde:
git submodule update --init simde
clean:
-rm -f */*.o *.o */*.d *.d *.x core.* *.a gmon.out */*/*.o static/*.x
-rm -f */*.o *.o */*.d *.d *.x core.* *.a gmon.out */*/*.o static/*.x *.so

View File

@@ -52,21 +52,19 @@ public:
BitVec_& operator-=(const BitVec_& other) { *this ^= other; return *this; }
BitVec_ extend_bit() const { return -(this->a & 1); }
BitVec_ mask(int n) const { return n < n_bits ? *this & ((1L << n) - 1) : *this; }
void extend_bit(BitVec_& res, int) const { res = extend_bit(); }
void mask(BitVec_& res, int n) const { res = mask(n); }
void add(octetStream& os) { *this += os.get<BitVec_>(); }
void mul(const BitVec_& a, const BitVec_& b) { *this = a * b; }
void randomize(PRNG& G, int n = n_bits) { super::randomize(G); *this = mask(n); }
void randomize(PRNG& G, int n = n_bits) { super::randomize(G); *this = this->mask(n); }
void pack(octetStream& os) const { os.store_int<sizeof(T)>(this->a); }
void unpack(octetStream& os) { this->a = os.get_int<sizeof(T)>(); }
void pack(octetStream& os, int n) const { os.store_int(mask(n).a, DIV_CEIL(n, 8)); }
void pack(octetStream& os, int n) const { os.store_int(super::mask(n).get(), DIV_CEIL(n, 8)); }
void unpack(octetStream& os, int n) { this->a = os.get_int(DIV_CEIL(n, 8)); }
static BitVec_ unpack_new(octetStream& os, int n = n_bits)

View File

@@ -85,6 +85,9 @@ public:
T& operator^=(const IntBase& other) { return a ^= other.a; }
T& operator&=(const IntBase& other) { return a &= other.a; }
IntBase mask(int n) const { return n < N_BITS ? *this & ((1L << n) - 1) : *this; }
void mask(IntBase& res, int n) const { res = mask(n); }
friend ostream& operator<<(ostream& s, const IntBase& x) { x.output(s, true); return s; }
friend istream& operator>>(istream& s, IntBase& x) { x.input(s, true); return s; }

View File

@@ -30,7 +30,8 @@ void Z2<K>::reqbl(int n)
}
else if (n > 0)
{
throw Processor_Error("Program compiled for fields not rings");
throw Processor_Error("Program compiled for fields not rings, "
"run compile.py with '-R " + to_string(K) + "'");
}
}

View File

@@ -133,7 +133,7 @@ inline void Zp_Data::Add<1>(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y
#else
*ans = *x + *y;
asm goto ("jc %l[sub]" :::: sub);
if (*ans >= *prA)
if (mpn_cmp(ans, prA, 1) >= 0)
sub:
*ans -= *prA;
#endif
@@ -251,13 +251,17 @@ inline void Zp_Data::Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t*
break;
CASE(1)
CASE(2)
#if MAX_MOD_SZ >= 5
#if MAX_MOD_SZ >= 4
CASE(3)
CASE(4)
#endif
#if MAX_MOD_SZ >= 5
CASE(5)
#endif
#if MAX_MOD_SZ >= 10
#if MAX_MOD_SZ >= 6
CASE(6)
#endif
#if MAX_MOD_SZ >= 10
CASE(7)
CASE(8)
CASE(9)

View File

@@ -166,7 +166,8 @@ void gfp_<X, L>::reqbl(int n)
}
else if ((int)n < 0)
{
throw Processor_Error("Program compiled for rings not fields");
throw Processor_Error("Program compiled for rings not fields, "
"run compile.py without -R");
}
}

View File

@@ -26,9 +26,9 @@ void ssl_error(string side, string pronoun, string other, string server)
<< "with `Scripts/setup-ssl.sh` expire after a month." << endl;
}
CryptoPlayer::CryptoPlayer(const Names& Nms, int id_base) :
MultiPlayer<ssl_socket*>(Nms, id_base), plaintext_player(Nms, id_base),
other_player(Nms, id_base + Nms.num_players()),
CryptoPlayer::CryptoPlayer(const Names& Nms, const string& id_base) :
MultiPlayer<ssl_socket*>(Nms), plaintext_player(Nms, id_base),
other_player(Nms, id_base + "recv"),
ctx("P" + to_string(my_num()))
{
sockets.resize(num_players());
@@ -57,6 +57,11 @@ CryptoPlayer::CryptoPlayer(const Names& Nms, int id_base) :
}
}
CryptoPlayer::CryptoPlayer(const Names& Nms, int id_base) :
CryptoPlayer(Nms, to_string(id_base))
{
}
CryptoPlayer::~CryptoPlayer()
{
close_client_socket(plaintext_player.socket(my_num()));
@@ -124,13 +129,6 @@ void CryptoPlayer::pass_around_no_stats(const octetStream& to_send,
}
}
template<>
void MultiPlayer<ssl_socket*>::setup_sockets(const vector<string>& names,
const vector<int>& ports, int id_base, ServerSocket& server)
{
(void)names, (void)ports, (void)id_base, (void)server;
}
void CryptoPlayer::send_receive_all_no_stats(const vector<vector<bool>>& channels,
const vector<octetStream>& to_send,
vector<octetStream>& to_receive) const

View File

@@ -12,6 +12,12 @@
#include <boost/asio/ssl.hpp>
#include <boost/asio.hpp>
/**
* Encrypted multi-party communication.
* Uses OpenSSL and certificates issued to "P<player_no>".
* Sending and receiving is done in separate threads to allow
* for bidirectional communication.
*/
class CryptoPlayer : public MultiPlayer<ssl_socket*>
{
PlainPlayer plaintext_player, other_player;
@@ -24,7 +30,14 @@ class CryptoPlayer : public MultiPlayer<ssl_socket*>
vector<Receiver<ssl_socket*>*> receivers;
public:
CryptoPlayer(const Names& Nms, int id_base=0);
/**
* Start a new set of encrypted connections.
* @param Nms network setup
* @param id unique identifier
*/
CryptoPlayer(const Names& Nms, const string& id);
// legacy interface
CryptoPlayer(const Names& Nms, int id_base = 0);
~CryptoPlayer();
bool is_encrypted() { return true; }

View File

@@ -22,24 +22,20 @@ void Names::init(int player,int pnb,int my_port,const char* servername)
setup_server();
}
void Names::init(int player,int pnb,vector<string> Nms)
Names::Names(int player, int nplayers, const string& servername, int pnb,
int my_port) :
Names()
{
vector<octet*> names;
for (auto& name : Nms)
names.push_back((octet*)name.c_str());
init(player, pnb, names);
Server::start_networking(*this, player, nplayers, servername, pnb, my_port);
}
void Names::init(int player,int pnb,vector<octet*> Nms)
void Names::init(int player,int pnb,vector<string> Nms)
{
player_no=player;
portnum_base=pnb;
nplayers=Nms.size();
names.resize(nplayers);
names=Nms;
setup_ports();
for (int i=0; i<nplayers; i++) {
names[i]=(char*)Nms[i];
}
setup_server();
}
@@ -112,7 +108,7 @@ Names::Names(ez::ezOptionParser& opt, int argc, const char** argv,
cerr << usage;
exit(1);
}
global_server = network_opts.start_networking(*this, player_no);
network_opts.start_networking(*this, player_no);
}
void Names::setup_ports()
@@ -130,7 +126,7 @@ void Names::setup_names(const char *servername, int my_port)
int socket_num;
int pn = portnum_base - 1;
set_up_client_socket(socket_num, servername, pn);
send(socket_num, (octet*)&player_no, sizeof(player_no));
octetStream("P" + to_string(player_no)).Send(socket_num);
#ifdef DEBUG_NETWORKING
cerr << "Sent " << player_no << " to " << servername << ":" << pn << endl;
#endif
@@ -191,16 +187,12 @@ Names::Names(const Names& other)
names = other.names;
ports = other.ports;
server = 0;
global_server = 0;
}
Names::~Names()
{
if (server != 0)
delete server;
if (global_server != 0)
delete global_server;
}
@@ -213,18 +205,27 @@ Player::Player(const Names& Nms) :
template<class T>
MultiPlayer<T>::MultiPlayer(const Names& Nms, int id) :
MultiPlayer<T>::MultiPlayer(const Names& Nms) :
Player(Nms), send_to_self_socket(0)
{
if (Nms.num_players() > 1)
setup_sockets(Nms.names, Nms.ports, id, *Nms.server);
else
sockets.resize(Nms.num_players());
sockets.resize(Nms.num_players());
}
template<>
MultiPlayer<int>::~MultiPlayer()
PlainPlayer::PlainPlayer(const Names& Nms, const string& id) :
MultiPlayer<int>(Nms)
{
if (Nms.num_players() > 1)
setup_sockets(Nms.names, Nms.ports, id, *Nms.server);
}
PlainPlayer::PlainPlayer(const Names& Nms, int id_base) :
PlainPlayer(Nms, to_string(id_base))
{
}
PlainPlayer::~PlainPlayer()
{
if (num_players() > 1)
{
@@ -258,33 +259,38 @@ PlayerBase::~PlayerBase()
// Set up nmachines client and server sockets to send data back and fro
// A machine is a server between it and player i if i<=my_number
// Can also communicate with myself, but only with send_to and receive_from
template<>
void MultiPlayer<int>::setup_sockets(const vector<string>& names,const vector<int>& ports,int id_base,ServerSocket& server)
void PlainPlayer::setup_sockets(const vector<string>& names,
const vector<int>& ports, const string& id_base, ServerSocket& server)
{
sockets.resize(nplayers);
// Set up the client side
for (int i=player_no; i<nplayers; i++) {
int pn=id_base+player_no;
auto pn=id_base+"P"+to_string(player_no);
if (i==player_no) {
const char* localhost = "127.0.0.1";
#ifdef DEBUG_NETWORKING
fprintf(stderr, "Setting up send to self socket to %s:%d with id 0x%x\n",localhost,ports[i],pn);
fprintf(stderr,
"Setting up send to self socket to %s:%d with id %s\n",
localhost, ports[i], pn.c_str());
#endif
set_up_client_socket(sockets[i],localhost,ports[i]);
} else {
#ifdef DEBUG_NETWORKING
fprintf(stderr, "Setting up client to %s:%d with id 0x%x\n",names[i].c_str(),ports[i],pn);
fprintf(stderr, "Setting up client to %s:%d with id %s\n",
names[i].c_str(), ports[i], pn.c_str());
#endif
set_up_client_socket(sockets[i],names[i].c_str(),ports[i]);
}
send(sockets[i], (unsigned char*)&pn, sizeof(pn));
octetStream(pn).Send(sockets[i]);
}
send_to_self_socket = sockets[player_no];
// Setting up the server side
for (int i=0; i<=player_no; i++) {
int id=id_base+i;
auto id=id_base+"P"+to_string(i);
#ifdef DEBUG_NETWORKING
fprintf(stderr, "As a server, waiting for client with id 0x%x to connect.\n",id);
fprintf(stderr,
"As a server, waiting for client with id %s to connect.\n",
id.c_str());
#endif
sockets[i] = server.get_connection_socket(id);
}
@@ -539,11 +545,10 @@ void Player::send_receive_all(const vector<vector<bool>>& channels,
send_receive_all_no_stats(channels, to_send, to_receive);
}
void Player::partial_broadcast(const vector<bool>& senders,
vector<octetStream>& os) const
void Player::partial_broadcast(const vector<bool>&,
const vector<bool>&, vector<octetStream>& os) const
{
partial_broadcast(senders, vector<bool>(num_players(), senders[my_num()]),
os);
unchecked_broadcast(os);
}
template<class T>
@@ -570,7 +575,8 @@ void MultiPlayer<T>::send_receive_all_no_stats(
}
ThreadPlayer::ThreadPlayer(const Names& Nms, int id_base) : PlainPlayer(Nms, id_base)
ThreadPlayer::ThreadPlayer(const Names& Nms, const string& id_base) :
PlainPlayer(Nms, id_base)
{
for (int i = 0; i < Nms.num_players(); i++)
{
@@ -625,35 +631,41 @@ void ThreadPlayer::send_all(const octetStream& o) const
}
RealTwoPartyPlayer::RealTwoPartyPlayer(const Names& Nms, int other_player, int id) :
RealTwoPartyPlayer::RealTwoPartyPlayer(const Names& Nms, int other_player, const string& id) :
TwoPartyPlayer(Nms.my_num()), other_player(other_player)
{
is_server = Nms.my_num() > other_player;
setup_sockets(other_player, Nms, Nms.ports[other_player], id);
}
RealTwoPartyPlayer::RealTwoPartyPlayer(const Names& Nms, int other_player,
int id_base) : RealTwoPartyPlayer(Nms, other_player, to_string(id_base))
{
}
RealTwoPartyPlayer::~RealTwoPartyPlayer()
{
close_client_socket(socket);
}
void RealTwoPartyPlayer::setup_sockets(int other_player, const Names &nms, int portNum, int id)
void RealTwoPartyPlayer::setup_sockets(int other_player, const Names &nms, int portNum, string id)
{
id += 0xF << 28;
id += "2";
const char *hostname = nms.names[other_player].c_str();
ServerSocket *server = nms.server;
if (is_server) {
#ifdef DEBUG_NETWORKING
fprintf(stderr, "Setting up server with id %d\n",id);
fprintf(stderr, "Setting up server with id %s\n", id.c_str());
#endif
socket = server->get_connection_socket(id);
}
else {
#ifdef DEBUG_NETWORKING
fprintf(stderr, "Setting up client to %s:%d with id %d\n", hostname, portNum, id);
fprintf(stderr, "Setting up client to %s:%d with id %s\n", hostname,
portNum, id.c_str());
#endif
set_up_client_socket(socket, hostname, portNum);
::send(socket, (unsigned char*)&id, sizeof(id));
octetStream(id).Send(socket);
}
}
@@ -739,6 +751,10 @@ void TwoPartyPlayer::Broadcast_Receive(vector<octetStream>& o) const
o[1 - my_num()] = os[1];
}
NamedCommStats::NamedCommStats() : sent(0)
{
}
CommStats& CommStats::operator +=(const CommStats& other)
{
data += other.data;
@@ -749,6 +765,7 @@ CommStats& CommStats::operator +=(const CommStats& other)
NamedCommStats& NamedCommStats::operator +=(const NamedCommStats& other)
{
sent += other.sent;
for (auto it = other.begin(); it != other.end(); it++)
(*this)[it->first] += it->second;
return *this;
@@ -772,6 +789,7 @@ CommStats& CommStats::operator -=(const CommStats& other)
NamedCommStats NamedCommStats::operator -(const NamedCommStats& other) const
{
NamedCommStats res = *this;
res.sent = sent - other.sent;
for (auto it = other.begin(); it != other.end(); it++)
res[it->first] -= it->second;
return res;

View File

@@ -27,15 +27,22 @@ template<class T> class MultiPlayer;
class Server;
class ServerSocket;
/* Class to get the names off the server */
/**
* Network setup (hostnames and port numbers)
*/
class Names
{
friend class Player;
friend class PlainPlayer;
friend class RealTwoPartyPlayer;
vector<string> names;
vector<int> ports;
int nplayers;
int portnum_base;
int player_no;
Server* global_server;
ServerSocket* server;
int default_port(int playerno) { return portnum_base + playerno; }
void setup_ports();
@@ -48,28 +55,63 @@ class Names
static const int DEFAULT_PORT = -1;
mutable ServerSocket* server;
/**
* Initialize with central server
* @param player my number
* @param pnb base port number (server listens one below)
* @param my_port my port number (`DEFAULT_PORT` for default,
* which is base port number plus player number)
* @param servername location of server
*/
void init(int player,int pnb,int my_port,const char* servername);
Names(int player,int pnb,int my_port,const char* servername) : Names()
{ init(player,pnb,my_port,servername); }
// Set up names when we KNOW who we are going to be using before hand
void init(int player,int pnb,vector<octet*> Nms);
Names(int player,int pnb,vector<octet*> Nms) : Names()
{ init(player,pnb,Nms); }
/**
* Initialize with central server running on player 0
* @param player my number
* @param nplayers number of players
* @param servername location of player 0
* @param pnb base port number
* @param my_port my port number (`DEFAULT_PORT` for default,
* which is base port number plus player number)
*/
Names(int player, int nplayers, const string& servername, int pnb,
int my_port = DEFAULT_PORT);
/**
* Initialize without central server
* @param player my number
* @param pnb base port number
* @param Nms locations of all parties
*/
void init(int player,int pnb,vector<string> Nms);
Names(int player,int pnb,vector<string> Nms) : Names()
{ init(player,pnb,Nms); }
// nplayers = 0 for taking it from hostsfile
/**
* Initialize from file. One party per line, format ``<hostname>[:<port>]``
* @param player my number
* @param pnb base port number
* @param hostsfile filename
* @param players number of players (0 to take from file)
*/
void init(int player, int pnb, const string& hostsfile, int players = 0);
Names(int player, int pnb, const string& hostsfile) : Names()
{ init(player, pnb, hostsfile); }
// initialize from command-line options
/**
* Initialize from command-line options
* @param opt option parser instance
* @param argc number of command-line arguments
* @param argv command-line arguments
* @param default_nplayers default number of players
* (used if not given in arguments)
*/
Names(ez::ezOptionParser& opt, int argc, const char** argv,
int default_nplayers = 2);
Names() : nplayers(1), portnum_base(-1), player_no(0), global_server(0), server(0) { ; }
Names() : nplayers(1), portnum_base(-1), player_no(0), server(0) { ; }
Names(const Names& other);
~Names();
@@ -77,11 +119,6 @@ class Names
int my_num() const { return player_no; }
const string get_name(int i) const { return names[i]; }
int get_portnum_base() const { return portnum_base; }
friend class PlayerBase;
friend class Player;
template<class T> friend class MultiPlayer;
friend class RealTwoPartyPlayer;
};
@@ -108,6 +145,10 @@ struct CommStats
class NamedCommStats : public map<string, CommStats>
{
public:
size_t sent;
NamedCommStats();
NamedCommStats& operator+=(const NamedCommStats& other);
NamedCommStats operator+(const NamedCommStats& other) const;
NamedCommStats operator-(const NamedCommStats& other) const;
@@ -123,17 +164,20 @@ public:
#endif
};
/**
* Abstract class for two- and multi-player communication
*/
class PlayerBase
{
protected:
int player_no;
public:
mutable size_t sent;
size_t& sent;
mutable Timer timer;
mutable NamedCommStats comm_stats;
PlayerBase(int player_no) : player_no(player_no), sent(0) {}
PlayerBase(int player_no) : player_no(player_no), sent(comm_stats.sent) {}
virtual ~PlayerBase();
int my_real_num() const { return player_no; }
@@ -146,6 +190,11 @@ public:
{ Broadcast_Receive(o); }
};
/**
* Abstract class for multi-player communication.
* ``*_no_stats`` functions are called by their equivalents
* after accounting for communications statistics.
*/
class Player : public PlayerBase
{
protected:
@@ -159,7 +208,13 @@ public:
Player(const Names& Nms);
virtual ~Player();
/**
* Get number of players
*/
int num_players() const { return nplayers; }
/**
* Get my player number
*/
int my_num() const { return player_no; }
int get_offset(int other_player) const { return positive_modulo(other_player - my_num(), num_players()); }
@@ -173,61 +228,115 @@ public:
// The following functions generally update the statistics
// and then call the *_no_stats equivalent specified by a subclass.
// send the same to all other players
/**
* Send the same to all other players
*/
virtual void send_all(const octetStream& o) const;
// send to a specific player
/**
* Send to a specific player
*/
void send_to(int player,const octetStream& o) const;
virtual void send_to_no_stats(int player,const octetStream& o) const = 0;
// receive from all other players
/**
* Receive from all other players.
* Information from player 0 at ``os[0]`` etc.
*/
void receive_all(vector<octetStream>& os) const;
// receive from a specific player
/**
* Receive from a specific player
*/
void receive_player(int i,octetStream& o) const;
virtual void receive_player_no_stats(int i,octetStream& o) const = 0;
virtual void receive_player(int i,FlexBuffer& buffer) const;
// Communication relative to my number
// send to all other players by offset
/**
* Send to all other players by offset.
* ``o[0]`` gets sent to the next player etc.
*/
void send_relative(const vector<octetStream>& o) const;
// send to other player specified by offset
/*
* Send to other player specified by offset.
* 1 stands for the next player etc.
*/
void send_relative(int offset, const octetStream& o) const;
// receive from all other players by offset
/**
* Receive from all other players by offset.
* ``o[0]`` will contain data from the next player etc.
*/
void receive_relative(vector<octetStream>& o) const;
// receive from other palyer specified by offset
/**
* Receive from other player specified by offset.
* 1 stands for the next player etc.
*/
void receive_relative(int offset, octetStream& o) const;
// exchange data with minimal memory usage
// exchange information with one other party
/**
* Exchange information with one other party,
* reusing the buffer if possible.
*/
void exchange(int other, const octetStream& to_send, octetStream& ot_receive) const;
virtual void exchange_no_stats(int other, const octetStream& to_send, octetStream& ot_receive) const = 0;
/**
* Exchange information with one other party, reusing the buffer.
*/
void exchange(int other, octetStream& o) const;
// exchange with one other partiy specified by offset
/**
* Exchange information with one other party specified by offset,
* reusing the buffer if possible.
*/
void exchange_relative(int offset, octetStream& o) const;
// send information to party while receiving from another by offset
/**
* Send information to a party while receiving from another by offset,
* The default is to send to the next party while receiving from the previous.
* The buffer is reused.
*/
void pass_around(octetStream& o, int offset = 1) const { pass_around(o, o, offset); }
/**
* Send information to a party while receiving from another by offset.
* The default is to send to the next party while receiving from the previous.
*/
void pass_around(octetStream& to_send, octetStream& to_receive, int offset) const;
virtual void pass_around_no_stats(const octetStream& to_send,
octetStream& to_receive, int offset) const = 0;
/* Broadcast and Receive data to/from all players
* - Assumes o[player_no] contains the thing broadcast by me
/**
* Broadcast and receive data to/from all players.
* Assumes o[player_no] contains the data to be broadcast by me.
*/
virtual void unchecked_broadcast(vector<octetStream>& o) const;
// broadcast with eventual verification
/**
* Broadcast and receive data to/from all players with eventual verification.
* Assumes o[player_no] contains the data to be broadcast by me.
*/
virtual void Broadcast_Receive(vector<octetStream>& o) const;
virtual void Broadcast_Receive_no_stats(vector<octetStream>& o) const = 0;
/* Run Protocol To Verify Broadcast Is Correct
* - Resets the blk_SHA_CTX at the same time
/**
* Run protocol to verify broadcast is correct
*/
virtual void Check_Broadcast() const;
// send something different to all
/**
* Send something different to each player.
*/
void send_receive_all(const vector<octetStream>& to_send,
vector<octetStream>& to_receive) const;
// specified senders only send something different to all
/**
* Specified senders only send something different to each player.
* @param senders set whether a player sends or not,
* must be equal on all players
* @param to_send data to send by player number
* @param to_receive received data by player number
*/
void send_receive_all(const vector<bool>& senders,
const vector<octetStream>& to_send, vector<octetStream>& to_receive) const;
// send something different only one specified channels
/**
* Send something different only one specified channels.
* @param channels ``channel[i][j]`` indicates whether party ``i`` sends
* to party ``j``
* @param to_send data to send by player number
* @param to_receive received data by player number
*/
void send_receive_all(const vector<vector<bool>>& channels,
const vector<octetStream>& to_send,
vector<octetStream>& to_receive) const;
@@ -235,11 +344,15 @@ public:
const vector<octetStream>& to_send,
vector<octetStream>& to_receive) const = 0;
// specified senders broadcast information
/**
* Specified senders broadcast information to specified receivers.
* @param senders specify which parties do send
* @param receivers specify which parties do send
* @param os data to send at ``os[my_number]``, received data elsewhere
*/
virtual void partial_broadcast(const vector<bool>& senders,
const vector<bool>& receivers,
vector<octetStream>& os) const;
virtual void partial_broadcast(const vector<bool>&, const vector<bool>&,
vector<octetStream>& os) const { unchecked_broadcast(os); }
// dummy functions for compatibility
virtual void request_receive(int i, octetStream& o) const { (void)i; (void)o; }
@@ -247,6 +360,11 @@ public:
{ receive_player(i, o); }
};
/**
* Multi-player communication helper class.
* ``T = int`` for unencrypted BSD sockets and
* ``T = ssl_socket*`` for Boost SSL streams.
*/
template<class T>
class MultiPlayer : public Player
{
@@ -254,16 +372,12 @@ protected:
vector<T> sockets;
T send_to_self_socket;
void setup_sockets(const vector<string>& names,const vector<int>& ports,int id_base,ServerSocket& server);
T socket_to_send(int player) const { return player == player_no ? send_to_self_socket : sockets[player]; }
friend class CryptoPlayer;
public:
// The offset is used for the multi-threaded call, to ensure different
// portnum bases in each thread
MultiPlayer(const Names& Nms,int id_base=0);
MultiPlayer(const Names& Nms);
virtual ~MultiPlayer();
@@ -296,16 +410,34 @@ public:
vector<octetStream>& to_receive) const;
};
typedef MultiPlayer<int> PlainPlayer;
/**
* Plaintext multi-player communication
*/
class PlainPlayer : public MultiPlayer<int>
{
void setup_sockets(const vector<string>& names, const vector<int>& ports,
const string& id_base, ServerSocket& server);
public:
/**
* Start a new set of unencrypted connections.
* @param Nms network setup
* @param id unique identifier
*/
PlainPlayer(const Names& Nms, const string& id);
// legacy interface
PlainPlayer(const Names& Nms, int id_base = 0);
~PlainPlayer();
};
class ThreadPlayer : public MultiPlayer<int>
class ThreadPlayer : public PlainPlayer
{
public:
mutable vector<Receiver<int>*> receivers;
mutable vector<Sender<int>*> senders;
ThreadPlayer(const Names& Nms,int id_base=0);
ThreadPlayer(const Names& Nms, const string& id_base);
virtual ~ThreadPlayer();
void request_receive(int i, octetStream& o) const;
@@ -335,14 +467,16 @@ class RealTwoPartyPlayer : public TwoPartyPlayer
{
private:
// setup sockets for comm. with only one other player
void setup_sockets(int other_player, const Names &nms, int portNum, int id);
void setup_sockets(int other_player, const Names &nms, int portNum, string id);
int socket;
bool is_server;
int other_player;
public:
RealTwoPartyPlayer(const Names& Nms, int other_player, int pn_offset=0);
RealTwoPartyPlayer(const Names& Nms, int other_player, const string& id);
// legacy
RealTwoPartyPlayer(const Names& Nms, int other_player, int id_base = 0);
~RealTwoPartyPlayer();
void send(octetStream& o) const;

View File

@@ -116,7 +116,7 @@ void Server::start()
#ifdef DEBUG_NETWORKING
cerr << "Waiting for player " << i << endl;
#endif
socket_num[i] = server.get_connection_socket(i);
socket_num[i] = server.get_connection_socket("P" + to_string(i));
#ifdef DEBUG_NETWORKING
cerr << "Connected to player " << i << endl;
#endif

View File

@@ -7,6 +7,7 @@
#include <Networking/sockets.h>
#include "Tools/Exceptions.h"
#include "Tools/time-func.h"
#include "Tools/octetStream.h"
#include <netinet/ip.h>
#include <netinet/tcp.h>
@@ -53,7 +54,10 @@ ServerSocket::ServerSocket(int Portnum) : portnum(Portnum), thread(0)
while (fl!=0 and timer.elapsed() < 600)
{ fl=::bind(main_socket, (struct sockaddr *)&serv, sizeof(struct sockaddr));
if (fl != 0)
{ cerr << "Binding to socket on " << my_name << ":" << Portnum << " failed, trying again in a second ..." << endl;
{
cerr << "Binding to socket on " << my_name << ":" << Portnum
<< " failed (" << strerror(errno)
<< "), trying again in a second ..." << endl;
sleep(1);
}
#ifdef DEBUG_NETWORKING
@@ -109,11 +113,11 @@ ServerSocket::~ServerSocket()
void ServerSocket::wait_for_client_id(int socket, struct sockaddr dest)
{
(void) dest;
int client_id;
try
{
receive(socket, (unsigned char*) &client_id, sizeof(client_id));
process_connection(socket, client_id);
octetStream client_id;
client_id.Receive(socket);
process_connection(socket, client_id.str());
}
catch (closed_connection&)
{
@@ -138,9 +142,13 @@ void ServerSocket::accept_clients()
int consocket = accept(main_socket, (struct sockaddr *)&dest, (socklen_t*) &socksize);
if (consocket<0) { error("set_up_socket:accept"); }
int client_id;
if (receive_all_or_nothing(consocket, (unsigned char*)&client_id, sizeof(client_id)))
process_connection(consocket, client_id);
octetStream client_id;
char buf[1];
if (recv(consocket, buf, 1, MSG_PEEK | MSG_DONTWAIT) > 0)
{
client_id.Receive(consocket);
process_connection(consocket, client_id.str());
}
else
{
#ifdef DEBUG_NETWORKING
@@ -162,7 +170,7 @@ void ServerSocket::accept_clients()
}
}
void ServerSocket::process_connection(int consocket, int client_id)
void ServerSocket::process_connection(int consocket, const string& client_id)
{
data_signal.lock();
#ifdef DEBUG_NETWORKING
@@ -175,7 +183,7 @@ void ServerSocket::process_connection(int consocket, int client_id)
data_signal.unlock();
}
int ServerSocket::get_connection_socket(int id)
int ServerSocket::get_connection_socket(const string& id)
{
data_signal.lock();
if (used.find(id) != used.end())
@@ -208,14 +216,14 @@ void AnonymousServerSocket::init()
pthread_create(&thread, 0, anonymous_accept_thread, this);
}
void AnonymousServerSocket::process_client(int client_id)
void AnonymousServerSocket::process_client(const string& client_id)
{
if (clients.find(client_id) != clients.end())
close_client_socket(clients[client_id]);
client_connection_queue.push(client_id);
}
int AnonymousServerSocket::get_connection_socket(int& client_id)
int AnonymousServerSocket::get_connection_socket(string& client_id)
{
data_signal.lock();
@@ -230,7 +238,7 @@ int AnonymousServerSocket::get_connection_socket(int& client_id)
return client_socket;
}
void AnonymousServerSocket::remove_client(int client_id)
void AnonymousServerSocket::remove_client(const string& client_id)
{
clients.erase(client_id);
}

View File

@@ -9,6 +9,7 @@
#include <map>
#include <set>
#include <queue>
#include <string>
using namespace std;
#include <pthread.h>
@@ -23,8 +24,8 @@ class ServerSocket
{
protected:
int main_socket, portnum;
map<int,int> clients;
std::set<int> used;
map<string,int> clients;
std::set<string> used;
Signal data_signal;
pthread_t thread;
@@ -33,9 +34,9 @@ protected:
// disable copying
ServerSocket(const ServerSocket& other);
void process_connection(int socket, int client_id);
void process_connection(int socket, const string& client_id);
virtual void process_client(int) {}
virtual void process_client(const string&) {}
public:
ServerSocket(int Portnum);
@@ -47,9 +48,9 @@ public:
void wait_for_client_id(int socket, struct sockaddr dest);
// This depends on clients sending their id as int.
// This depends on clients sending their id.
// Has to be thread-safe.
int get_connection_socket(int number);
int get_connection_socket(const string& id);
};
/*
@@ -59,9 +60,9 @@ class AnonymousServerSocket : public ServerSocket
{
private:
// No. of accepted connections in this instance
queue<int> client_connection_queue;
queue<string> client_connection_queue;
void process_client(int client_id);
void process_client(const string& client_id);
public:
AnonymousServerSocket(int Portnum) :
@@ -69,9 +70,9 @@ public:
void init();
// Get socket and id for the last client who connected
int get_connection_socket(int& client_id);
int get_connection_socket(string& client_id);
void remove_client(int client_id);
void remove_client(const string& client_id);
};
#endif /* NETWORKING_SERVERSOCKET_H_ */

View File

@@ -116,7 +116,6 @@ public:
mac_key_type get_mac_key() const { return mac_key; }
size_t data_sent();
NamedCommStats comm_stats();
};
@@ -210,17 +209,6 @@ public:
void generateTriples();
};
template<class T>
size_t OTTripleGenerator<T>::data_sent()
{
size_t res = 0;
if (parentPlayer != &globalPlayer)
res = globalPlayer.sent;
for (auto& player : players)
res += player->sent;
return res;
}
template<class T>
NamedCommStats OTTripleGenerator<T>::comm_stats()
{

View File

@@ -72,7 +72,7 @@ OTTripleGenerator<T>::OTTripleGenerator(const OTTripleSetup& setup,
const Names& names, int thread_num, int _nTriples, int nloops,
MascotParams& machine, mac_key_type mac_key, Player* parentPlayer) :
globalPlayer(parentPlayer ? *parentPlayer : *new PlainPlayer(names,
- thread_num * names.num_players() * names.num_players())),
to_string(thread_num))),
parentPlayer(parentPlayer),
thread_num(thread_num),
mac_key(mac_key),

View File

@@ -4,6 +4,7 @@
*/
#include "BaseMachine.h"
#include "OnlineOptions.h"
#include "Math/Setup.h"
#include <iostream>
@@ -90,10 +91,8 @@ void BaseMachine::load_schedule(const string& progname, bool load_bytecode)
void BaseMachine::print_compiler()
{
#ifdef VERBOSE
if (compiler.size() != 0)
if (compiler.size() != 0 and OnlineOptions::singleton.verbose)
cerr << "Compiler: " << compiler << endl;
#endif
}
void BaseMachine::load_program(const string& threadname, const string& filename)

View File

@@ -20,6 +20,8 @@ class Binary_File_IO
{
public:
static string filename(int my_number);
/*
* Append the buffer values as binary to the filename.
* Throws file_error.

View File

@@ -6,6 +6,13 @@
* Intended for application specific file IO.
*/
inline string Binary_File_IO::filename(int my_number)
{
string dir = "Persistence";
mkdir_p(dir.c_str());
return dir + "/Transactions-P" + to_string(my_number) + ".data";
}
template<class T>
void Binary_File_IO::write_to_file(const string filename, const vector< T >& buffer)
{

View File

@@ -58,6 +58,7 @@ public:
vector< array<long long, N_DATA_FIELD_TYPE> > inputs;
array<map<DataTag, long long>, N_DATA_FIELD_TYPE> extended;
map<pair<bool, int>, long long> edabits;
map<array<int, 3>, long long> matmuls;
DataPositions(int num_players = 0);
DataPositions(const Player& P) : DataPositions(P.num_players()) {}
@@ -126,7 +127,7 @@ public:
virtual void prune() {}
virtual void purge() {}
virtual size_t data_sent() { return 0; }
virtual size_t data_sent() { return comm_stats().sent; }
virtual NamedCommStats comm_stats() { return {}; }
virtual void get_three_no_count(Dtype dtype, T& a, T& b, T& c) = 0;
@@ -288,7 +289,7 @@ class Data_Files
void reset_usage() { usage.reset(); skipped.reset(); }
size_t data_sent() { return DataFp.data_sent() + DataF2.data_sent(); }
NamedCommStats comm_stats() { return DataFp.comm_stats() + DataF2.comm_stats(); }
};
template<class T> inline

View File

@@ -46,7 +46,10 @@ int ExternalClients::get_client_connection(int portnum_base)
}
cerr << "Thread " << this_thread::get_id() << " found server." << endl;
int client_id, socket;
socket = client_connection_servers[portnum_base]->get_connection_socket(client_id);
string client;
socket = client_connection_servers[portnum_base]->get_connection_socket(
client);
client_id = stoi(client);
if (ctx == 0)
ctx = new ssl_ctx("P" + to_string(get_party_num()));
external_client_sockets[client_id] = new ssl_socket(io_service, *ctx, socket,
@@ -63,7 +66,8 @@ void ExternalClients::close_connection(int client_id)
throw runtime_error("client id not active: " + to_string(client_id));
delete external_client_sockets[client_id];
external_client_sockets.erase(client_id);
client_connection_servers[client_ports[client_id]]->remove_client(client_id);
client_connection_servers[client_ports[client_id]]->remove_client(
to_string(client_id));
}
int ExternalClients::get_party_num()

View File

@@ -67,6 +67,7 @@ enum
THRESHOLD = 0xE3,
PLAYERID = 0xE4,
USE_EDABIT = 0xE5,
USE_MATMUL = 0x1F,
// Addition
ADDC = 0x20,
ADDS = 0x21,

View File

@@ -92,9 +92,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case DIVINT:
case CONDPRINTPLAIN:
case INPUTMASKREG:
r[0]=get_int(s);
r[1]=get_int(s);
r[2]=get_int(s);
get_ints(r, s, 3);
break;
// instructions with 2 register operands
case LDMCI:
@@ -130,8 +128,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case DABIT:
case SHUFFLE:
case ACCEPTCLIENTCONNECTION:
r[0]=get_int(s);
r[1]=get_int(s);
get_ints(r, s, 2);
break;
// instructions with 1 register operand
case BIT:
@@ -157,6 +154,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case PLAYERID:
case LISTEN:
case CLOSECLIENTCONNECTION:
case CRASH:
r[0]=get_int(s);
break;
// instructions with 2 registers + 1 integer operand
@@ -205,8 +203,11 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case STOPPRIVATEOUTPUT:
case GSTOPPRIVATEOUTPUT:
case DIGESTC:
r[0]=get_int(s);
r[1]=get_int(s);
get_ints(r, s, 2);
n = get_int(s);
break;
case USE_MATMUL:
get_ints(r, s, 3);
n = get_int(s);
break;
// instructions with 1 register + 1 integer operand
@@ -254,7 +255,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
break;
// instructions with no operand
case TIME:
case CRASH:
case STARTGRIND:
case STOPGRIND:
case CHECK:
@@ -320,8 +320,9 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case READSOCKETC:
case READSOCKETS:
case READSOCKETINT:
num_var_args = get_int(s) - 1;
num_var_args = get_int(s) - 2;
r[0] = get_int(s);
n = get_int(s);
get_vector(num_var_args, start, s);
break;
@@ -329,9 +330,10 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case WRITESOCKETC:
case WRITESOCKETSHARE:
case WRITESOCKETINT:
num_var_args = get_int(s) - 2;
num_var_args = get_int(s) - 3;
r[0] = get_int(s);
r[1] = get_int(s);
n = get_int(s);
get_vector(num_var_args, start, s);
break;
case WRITESOCKETS:
@@ -383,9 +385,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
// subtract extra argument
num_var_args = get_int(s) - 1;
s.read((char*)r, sizeof(r));
start.resize(num_var_args);
for (int i = 0; i < num_var_args; i++)
{ start[i] = get_int(s); }
get_vector(num_var_args, start, s);
break;
case USE_PREP:
case GUSE_PREP:
@@ -431,6 +431,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case CONVCBITVEC:
case CONVCBIT2S:
case NOTS:
case NOTCB:
n = get_int(s);
get_ints(r, s, 2);
break;
@@ -497,6 +498,9 @@ bool Instruction::get_offline_data_usage(DataPositions& usage)
case USE_EDABIT:
usage.edabits[{r[0], r[1]}] = n;
return int(n) >= 0;
case USE_MATMUL:
usage.matmuls[{{r[0], r[1], r[2]}}] = n;
return int(n) >= 0;
case USE_PREP:
usage.extended[DATA_INT][r] = n;
return int(n) >= 0;
@@ -552,6 +556,7 @@ int BaseInstruction::get_reg_type() const
case USE_PREP:
case GUSE_PREP:
case USE_EDABIT:
case USE_MATMUL:
case RUN_TAPE:
case CISC:
// those use r[] not for registers
@@ -709,6 +714,7 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
}
case ANDM:
case NOTS:
case NOTCB:
size = DIV_CEIL(n, 64);
break;
case CONVCBIT2S:
@@ -735,6 +741,14 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
offset = 2;
skip = 4;
break;
case READSOCKETS:
case READSOCKETC:
case READSOCKETINT:
case WRITESOCKETSHARE:
case WRITESOCKETC:
case WRITESOCKETINT:
size = n;
break;
}
if (skip > 0)
@@ -1016,11 +1030,11 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
Proc.Procp.matmuls(Proc.Procp.get_S(), *this, r[1], r[2]);
return;
case MATMULSM:
Proc.Procp.matmulsm(Proc.machine.Mp.MS, *this, Proc.read_Ci(r[1]),
Proc.read_Ci(r[2]));
Proc.Procp.protocol.matmulsm(Proc.Procp, Proc.machine.Mp.MS, *this,
Proc.read_Ci(r[1]), Proc.read_Ci(r[2]));
return;
case CONV2DS:
Proc.Procp.conv2ds(*this);
Proc.Procp.protocol.conv2ds(Proc.Procp, *this);
return;
case TRUNC_PR:
Proc.Procp.protocol.trunc_pr(start, size, Proc.Procp);
@@ -1096,6 +1110,7 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
case USE:
case USE_INP:
case USE_EDABIT:
case USE_MATMUL:
case USE_PREP:
case GUSE_PREP:
break;
@@ -1117,7 +1132,8 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
Proc.machine.join_tape(r[0]);
break;
case CRASH:
throw crash_requested();
if (Proc.read_Ci(r[0]))
throw crash_requested();
break;
case STARTGRIND:
CALLGRIND_START_INSTRUMENTATION;
@@ -1160,25 +1176,25 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
Proc.external_clients.close_connection(Proc.read_Ci(r[0]));
break;
case READSOCKETINT:
Proc.read_socket_ints(Proc.read_Ci(r[0]), start);
Proc.read_socket_ints(Proc.read_Ci(r[0]), start, n);
break;
case READSOCKETC:
Proc.read_socket_vector(Proc.read_Ci(r[0]), start);
Proc.read_socket_vector(Proc.read_Ci(r[0]), start, n);
break;
case READSOCKETS:
// read shares and MAC shares
Proc.read_socket_private(Proc.read_Ci(r[0]), start, true);
Proc.read_socket_private(Proc.read_Ci(r[0]), start, n, true);
break;
case WRITESOCKETINT:
Proc.write_socket(INT, Proc.read_Ci(r[0]), r[1], start);
Proc.write_socket(INT, Proc.read_Ci(r[0]), r[1], start, n);
break;
case WRITESOCKETC:
Proc.write_socket(CINT, Proc.read_Ci(r[0]), r[1], start);
Proc.write_socket(CINT, Proc.read_Ci(r[0]), r[1], start, n);
break;
case WRITESOCKETSHARE:
// Send only shares, no MACs
// N.B. doesn't make sense to have a corresponding read instruction for this
Proc.write_socket(SINT, Proc.read_Ci(r[0]), r[1], start);
Proc.write_socket(SINT, Proc.read_Ci(r[0]), r[1], start, n);
break;
case WRITEFILESHARE:
// Write shares to file system

View File

@@ -72,7 +72,6 @@ class Machine : public BaseMachine
OnlineOptions opts;
atomic<size_t> data_sent;
NamedCommStats comm_stats;
ExecutionStats stats;
@@ -89,6 +88,11 @@ class Machine : public BaseMachine
void fill_buffers(int thread_number, int tape_number,
Preprocessing<sint> *prep,
Preprocessing<typename sint::bit_type> *bit_prep);
template<int = 0>
void fill_matmul(int thread_numbber, int tape_number,
Preprocessing<sint> *prep, true_type);
template<int = 0>
void fill_matmul(int, int, Preprocessing<sint>*, false_type) {}
DataPositions run_tape(int thread_number, int tape_number, int arg);
DataPositions join_tape(int thread_number);
void run();

View File

@@ -5,6 +5,7 @@
#include "Memory.hpp"
#include "Online-Thread.hpp"
#include "Protocols/Hemi.hpp"
#include "Tools/Exceptions.h"
@@ -30,8 +31,7 @@ Machine<sint, sgf2n>::Machine(int my_number, Names& playerNames,
: my_number(my_number), N(playerNames),
direct(direct), opening_sum(opening_sum),
receive_threads(receive_threads), max_broadcast(max_broadcast),
use_encryption(use_encryption), live_prep(live_prep), opts(opts),
data_sent(0)
use_encryption(use_encryption), live_prep(live_prep), opts(opts)
{
if (opening_sum < 2)
this->opening_sum = N.num_players();
@@ -49,10 +49,11 @@ Machine<sint, sgf2n>::Machine(int my_number, Names& playerNames,
// make directory for outputs if necessary
mkdir_p(PREP_DIR);
string id = "machine";
if (use_encryption)
P = new CryptoPlayer(N, 0xF00);
P = new CryptoPlayer(N, id);
else
P = new PlainPlayer(N, 0xF00);
P = new PlainPlayer(N, id);
if (opts.live_prep)
{
@@ -96,6 +97,13 @@ Machine<sint, sgf2n>::Machine(int my_number, Names& playerNames,
load_schedule(progname_str);
// remove persistence if necessary
for (auto& prog : progs)
{
if (prog.writes_persistance)
ofstream(Binary_File_IO::filename(my_number), ios::out);
}
#ifdef VERBOSE
progs[0].print_offline_cost();
#endif
@@ -223,6 +231,49 @@ void Machine<sint, sgf2n>::fill_buffers(int thread_number, int tape_number,
{
#ifdef VERBOSE_CENTRAL
cerr << "Problem with central bit triple preprocessing: " << e.what() << endl;
#endif
}
}
if (not HemiOptions::singleton.plain_matmul)
fill_matmul(thread_number, tape_number, prep, sint::triple_matmul);
}
template<class sint, class sgf2n>
template<int>
void Machine<sint, sgf2n>::fill_matmul(int thread_number, int tape_number,
Preprocessing<sint>* prep, true_type)
{
auto usage = progs[tape_number].get_offline_data_used();
for (auto it = usage.matmuls.begin(); it != usage.matmuls.end(); it++)
{
try
{
auto& source_proc = *dynamic_cast<BufferPrep<sint>&>(*prep).proc;
int max_inner = opts.batch_size;
int max_cols = opts.batch_size;
for (int j = 0; j < it->first[1]; j += max_inner)
{
for (int k = 0; k < it->first[2]; k += max_cols)
{
auto subdim = it->first;
subdim[1] = min(subdim[1] - j, max_inner);
subdim[2] = min(subdim[2] - k, max_cols);
auto& source =
dynamic_cast<Hemi<sint>&>(source_proc.protocol).get_matrix_prep(
subdim, source_proc);
auto& dest =
dynamic_cast<Hemi<sint>&>(tinfo[thread_number].processor->Procp.protocol).get_matrix_prep(
subdim, tinfo[thread_number].processor->Procp);
for (int i = 0; i < it->second; i++)
dest.push_triple(source.get_triple_no_count(-1));
}
}
}
catch (bad_cast& e)
{
#ifdef VERBOSE_CENTRAL
cerr << "Problem with central matmul preprocessing: " << e.what() << endl;
#endif
}
}
@@ -320,20 +371,27 @@ void Machine<sint, sgf2n>::run()
for (unsigned int i = 0; i < join_timer.size(); i++)
cerr << "Join timer: " << i << " " << join_timer[i].elapsed() << endl;
cerr << "Finish timer: " << finish_timer.elapsed() << endl;
cerr << "Process timer: " << proc_timer.elapsed() << endl;
#endif
if (opts.verbose)
{
cerr << "Communication details "
"(rounds in parallel threads counted double):" << endl;
comm_stats.print();
cerr << "CPU time = " << proc_timer.elapsed() << endl;
}
print_timers();
size_t rounds = 0;
for (auto& x : comm_stats)
rounds += x.second.rounds;
cerr << "Data sent = " << data_sent / 1e6 << " MB in ~" << rounds
cerr << "Data sent = " << comm_stats.sent / 1e6 << " MB in ~" << rounds
<< " rounds (party " << my_number << ")" << endl;
auto& P = *this->P;
Bundle<octetStream> bundle(P);
bundle.mine.store(data_sent.load());
bundle.mine.store(comm_stats.sent);
P.Broadcast_Receive_no_stats(bundle);
size_t global = 0;
for (auto& os : bundle)
@@ -384,10 +442,11 @@ void Machine<sint, sgf2n>::run()
}
#endif
#ifdef VERBOSE
cerr << "Actual cost of program:" << endl;
pos.print_cost();
#endif
if (opts.verbose)
{
cerr << "Actual cost of program:" << endl;
pos.print_cost();
}
if (pos.any_more(progs[0].get_offline_data_used())
and not progs[0].usage_unknown())

View File

@@ -15,7 +15,7 @@ OfflineMachine<W>::OfflineMachine(int argc, const char** argv,
ez::ezOptionParser& opt, OnlineOptions& online_opts, V,
int nplayers) :
W(argc, argv, opt, online_opts, V(), nplayers), playerNames(
W::playerNames), P(*this->new_player())
W::playerNames), P(*this->new_player("machine"))
{
machine.load_schedule(online_opts.progname, false);
Program program(playerNames.num_players());

View File

@@ -49,26 +49,27 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
fprintf(stderr, "\tI am in thread %d\n",num);
#endif
Player* player;
string id = "thread" + to_string(num);
if (machine.use_encryption)
{
#ifdef VERBOSE_OPTIONS
cerr << "Using encrypted single-threaded communication" << endl;
#endif
player = new CryptoPlayer(*(tinfo->Nms), num << 16);
player = new CryptoPlayer(*(tinfo->Nms), id);
}
else if (!machine.receive_threads or machine.direct)
{
#ifdef VERBOSE_OPTIONS
cerr << "Using single-threaded receiving" << endl;
#endif
player = new PlainPlayer(*(tinfo->Nms), num << 16);
player = new PlainPlayer(*(tinfo->Nms), id);
}
else
{
#ifdef VERBOSE_OPTIONS
cerr << "Using player-specific threads for receiving" << endl;
#endif
player = new ThreadPlayer(*(tinfo->Nms), num << 16);
player = new ThreadPlayer(*(tinfo->Nms), id);
}
Player& P = *player;
#ifdef DEBUG_THREADS
@@ -238,6 +239,16 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
*(Zp_Data*) job.supply);
queues->finished(job);
}
else if (job.type == CIPHER_PLAIN_MULT_JOB)
{
cipher_plain_mult(job, sint::triple_matmul);
queues->finished(job);
}
else if (job.type == MATRX_RAND_MULT_JOB)
{
matrix_rand_mult(job, sint::triple_matmul);
queues->finished(job);
}
else
{ // RUN PROGRAM
#ifdef DEBUG_THREADS
@@ -303,16 +314,15 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
#endif
// wind down thread by thread
size_t prep_sent = Proc.DataF.data_sent();
prep_sent += Proc.share_thread.DataF.data_sent();
prep_sent += Proc.Procp.bit_prep.data_sent();
auto prep_stats = Proc.DataF.comm_stats();
prep_stats += Proc.share_thread.DataF.comm_stats();
prep_stats += Proc.Procp.bit_prep.comm_stats();
for (auto& x : Proc.Procp.personal_bit_preps)
prep_sent += x->data_sent();
prep_stats += x->comm_stats();
machine.stats += Proc.stats;
delete processor;
machine.data_sent += P.sent + prep_sent;
machine.comm_stats += P.comm_stats;
machine.comm_stats += P.comm_stats + prep_stats;
queues->finished(actual_usage);
delete MC2;

View File

@@ -20,7 +20,6 @@ protected:
int lg2, opening_sum, max_broadcast;
Names playerNames;
Server* server;
bool use_encryption, receive_threads;
@@ -38,7 +37,7 @@ public:
template<class T, class U>
int run();
Player* new_player(int id_base = 0);
Player* new_player(const string& id_base);
};
class DishonestMajorityMachine : public OnlineMachine

View File

@@ -28,7 +28,7 @@ template<class V>
OnlineMachine::OnlineMachine(int argc, const char** argv, ez::ezOptionParser& opt,
OnlineOptions& online_opts, int nplayers, V) :
argc(argc), argv(argv), online_opts(online_opts), lg2(0),
opening_sum(0), max_broadcast(0), server(0),
opening_sum(0), max_broadcast(0),
use_encryption(false), receive_threads(false),
opt(opt), nplayers(nplayers)
{
@@ -192,7 +192,6 @@ void OnlineMachine::start_networking()
int mynum = online_opts.playerno;
int playerno = online_opts.playerno;
server = 0;
if (ipFileName.size() > 0) {
if (my_port != Names::DEFAULT_PORT)
throw runtime_error("cannot set port number when using IP file");
@@ -204,7 +203,7 @@ void OnlineMachine::start_networking()
{
if (nplayers == 0)
opt.get("-N")->getInt(nplayers);
server = Server::start_networking(playerNames, mynum, nplayers,
Server::start_networking(playerNames, mynum, nplayers,
hostname, pnbase, my_port);
}
else
@@ -216,7 +215,7 @@ void OnlineMachine::start_networking()
}
inline
Player* OnlineMachine::new_player(int id_base)
Player* OnlineMachine::new_player(const string& id_base)
{
if (use_encryption)
return new CryptoPlayer(playerNames, id_base);
@@ -238,15 +237,13 @@ int OnlineMachine::run()
use_encryption, online_opts.live_prep,
online_opts).run();
if (server)
delete server;
#ifdef VERBOSE
cerr << "Command line:";
for (int i = 0; i < argc; i++)
cerr << " " << argv[i];
cerr << endl;
#endif
if (online_opts.verbose)
{
cerr << "Command line:";
for (int i = 0; i < argc; i++)
cerr << " " << argv[i];
cerr << endl;
}
}
#ifndef INSECURE
catch(...)

View File

@@ -7,12 +7,14 @@
#include "BaseMachine.h"
#include "Math/gfp.h"
#include "Math/gfpvar.h"
#include "Protocols/HemiOptions.h"
#include "Math/gfp.hpp"
using namespace std;
OnlineOptions OnlineOptions::singleton;
HemiOptions HemiOptions::singleton;
OnlineOptions::OnlineOptions() : playerno(-1)
{
@@ -26,6 +28,11 @@ OnlineOptions::OnlineOptions() : playerno(-1)
bucket_size = 4;
cmd_private_input_file = "Player-Data/Input";
cmd_private_output_file = "";
#ifdef VERBOSE
verbose = true;
#else
verbose = false;
#endif
}
OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
@@ -65,7 +72,8 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
"Prefix for output file path "
"(default: output to stdout for party 0 (silent otherwise "
"unless interactive mode is active). "
"Output will be written to {prefix}-P{id}-{thread_id}.", // Help description.
"Output will be written to {prefix}-P{id}-{thread_id}. "
"Use '.' for stdout on all parties.", // Help description.
"-OF", // Flag token.
"--output-file" // Flag token.
);
@@ -171,6 +179,15 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
"-B", // Flag token.
"--bucket-size" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Verbose output", // Help description.
"-v", // Flag token.
"--verbose" // Flag token.
);
opt.parse(argc, argv);
@@ -198,6 +215,10 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
opt.get("--bucket-size")->getInt(bucket_size);
#ifndef VERBOSE
verbose = opt.isSet("--verbose");
#endif
opt.resetArgs();
}

View File

@@ -28,6 +28,7 @@ public:
int bucket_size;
std::string cmd_private_input_file;
std::string cmd_private_output_file;
bool verbose;
OnlineOptions();
OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv,

View File

@@ -227,13 +227,15 @@ class Processor : public ArithmeticProcessor
void split(const Instruction& instruction);
// Access to external client sockets for reading clear/shared data
void read_socket_ints(int client_id, const vector<int>& registers);
void write_socket(const RegType reg_type,
int socket_id, int message_type, const vector<int>& registers);
void read_socket_ints(int client_id, const vector<int>& registers, int size);
void read_socket_vector(int client_id, const vector<int>& registers);
void read_socket_private(int client_id, const vector<int>& registers, bool send_macs);
void write_socket(const RegType reg_type, int socket_id, int message_type,
const vector<int>& registers, int size);
void read_socket_vector(int client_id, const vector<int>& registers,
int size);
void read_socket_private(int client_id, const vector<int>& registers,
int size, bool send_macs);
// Read and write secret numeric data to file (name hardcoded at present)
void read_shares_from_file(int start_file_pos, int end_file_pos_register, const vector<int>& data_registers);

View File

@@ -95,10 +95,13 @@ Processor<sint, sgf2n>::Processor(int thread_num,Player& P,
shared_prng.SeedGlobally(P, false);
// only output on party 0 if not interactive
bool output = P.my_num() == 0 or machine.opts.interactive;
bool always_stdout = machine.opts.cmd_private_output_file == ".";
bool output = P.my_num() == 0 or machine.opts.interactive or always_stdout;
out.activate(output);
Procb.out.activate(output);
setup_redirection(P.my_num(), thread_num, opts);
if (not always_stdout)
setup_redirection(P.my_num(), thread_num, opts);
if (stdout_redirect_file.is_open())
{
@@ -236,7 +239,8 @@ void Processor<sint, sgf2n>::convcbit2s(const Instruction& instruction)
for (int i = 0; i < DIV_CEIL(instruction.get_n(), unit); i++)
Procb.S[instruction.get_r(0) + i] = sint::bit_type::constant(
Procb.C[instruction.get_r(1) + i], P.my_num(),
share_thread.MC->get_alphai());
share_thread.MC->get_alphai(),
min(unsigned(unit), instruction.get_n() - i * unit));
}
template<class sint, class sgf2n>
@@ -262,8 +266,8 @@ void Processor<sint, sgf2n>::split(const Instruction& instruction)
// If message_type is > 0, send message_type in bytes 0 - 3, to allow an external client to
// determine the data structure being sent in a message.
template<class sint, class sgf2n>
void Processor<sint, sgf2n>::write_socket(const RegType reg_type,
int socket_id, int message_type, const vector<int>& registers)
void Processor<sint, sgf2n>::write_socket(const RegType reg_type, int socket_id,
int message_type, const vector<int>& registers, int size)
{
int m = registers.size();
socket_stream.reset_write_head();
@@ -273,28 +277,40 @@ void Processor<sint, sgf2n>::write_socket(const RegType reg_type,
socket_stream.store(message_type);
}
for (int i = 0; i < m; i++)
{
if (reg_type == SINT) {
// Send vector of secret shares
get_Sp_ref(registers[i]).pack(socket_stream,
sint::get_rec_factor(P.my_num(), P.num_players()));
for (int j = 0; j < size; j++)
{
for (int i = 0; i < m; i++)
{
if (reg_type == SINT)
{
// Send vector of secret shares
get_Sp_ref(registers[i] + j).pack(socket_stream,
sint::get_rec_factor(P.my_num(), P.num_players()));
}
else if (reg_type == CINT)
{
// Send vector of clear public field elements
get_Cp_ref(registers[i] + j).pack(socket_stream);
}
else if (reg_type == INT)
{
// Send vector of 32-bit clear ints
socket_stream.store((int&) get_Ci_ref(registers[i] + j));
}
else
{
stringstream ss;
ss << "Write socket instruction with unknown reg type "
<< reg_type << "." << endl;
throw Processor_Error(ss.str());
}
}
}
else if (reg_type == CINT) {
// Send vector of clear public field elements
get_Cp_ref(registers[i]).pack(socket_stream);
}
else if (reg_type == INT) {
// Send vector of 32-bit clear ints
socket_stream.store((int&)get_Ci_ref(registers[i]));
}
else {
stringstream ss;
ss << "Write socket instruction with unknown reg type " << reg_type <<
"." << endl;
throw Processor_Error(ss.str());
}
}
#ifdef VERBOSE_COMM
cerr << "send " << socket_stream.get_length() << " to client " << socket_id
<< endl;
#endif
try {
socket_stream.Send(external_clients.get_socket(socket_id));
@@ -308,44 +324,47 @@ void Processor<sint, sgf2n>::write_socket(const RegType reg_type,
// Receive vector of 32-bit clear ints
template<class sint, class sgf2n>
void Processor<sint, sgf2n>::read_socket_ints(int client_id, const vector<int>& registers)
void Processor<sint, sgf2n>::read_socket_ints(int client_id,
const vector<int>& registers, int size)
{
int m = registers.size();
socket_stream.reset_write_head();
socket_stream.Receive(external_clients.get_socket(client_id));
for (int i = 0; i < m; i++)
{
int val;
socket_stream.get(val);
write_Ci(registers[i], (long)val);
}
for (int j = 0; j < size; j++)
for (int i = 0; i < m; i++)
{
int val;
socket_stream.get(val);
write_Ci(registers[i] + j, (long) val);
}
}
// Receive vector of public field elements
template<class sint, class sgf2n>
void Processor<sint, sgf2n>::read_socket_vector(int client_id, const vector<int>& registers)
void Processor<sint, sgf2n>::read_socket_vector(int client_id,
const vector<int>& registers, int size)
{
int m = registers.size();
socket_stream.reset_write_head();
socket_stream.Receive(external_clients.get_socket(client_id));
for (int i = 0; i < m; i++)
{
get_Cp_ref(registers[i]) = socket_stream.get<typename sint::open_type>();
}
for (int j = 0; j < size; j++)
for (int i = 0; i < m; i++)
get_Cp_ref(registers[i] + j) =
socket_stream.get<typename sint::open_type>();
}
// Receive vector of field element shares over private channel
template<class sint, class sgf2n>
void Processor<sint, sgf2n>::read_socket_private(int client_id, const vector<int>& registers, bool read_macs)
void Processor<sint, sgf2n>::read_socket_private(int client_id,
const vector<int>& registers, int size, bool read_macs)
{
int m = registers.size();
socket_stream.reset_write_head();
socket_stream.Receive(external_clients.get_socket(client_id));
for (int i = 0; i < m; i++)
{
get_Sp_ref(registers[i]).unpack(socket_stream, read_macs);
}
for (int j = 0; j < size; j++)
for (int i = 0; i < m; i++)
get_Sp_ref(registers[i] + j).unpack(socket_stream, read_macs);
}
@@ -382,11 +401,7 @@ void Processor<sint, sgf2n>::read_shares_from_file(int start_file_posn, int end_
// Append share data in data_registers to end of file. Expects Persistence directory to exist.
template<class sint, class sgf2n>
void Processor<sint, sgf2n>::write_shares_to_file(const vector<int>& data_registers) {
string dir = "Persistence";
mkdir_p(dir.c_str());
string filename;
filename = dir + "/Transactions-P" + to_string(P.my_num()) + ".data";
string filename = binary_file_io.filename(P.my_num());
unsigned int size = data_registers.size();

View File

@@ -23,6 +23,7 @@ void Program::compute_constants()
max_mem[reg_type] = max(max_mem[reg_type],
p[i].get_mem(RegType(reg_type)));
}
writes_persistance |= p[i].opcode == WRITEFILESHARE;
}
}
@@ -41,6 +42,8 @@ void Program::parse(istream& s)
s.peek();
while (!s.eof())
{ instr.parse(s, p.size());
if (s.fail())
throw runtime_error("error while parsing " + to_string(instr.opcode));
p.push_back(instr);
//cerr << "\t" << instr << endl;
s.peek();

View File

@@ -30,8 +30,10 @@ class Program
public:
bool writes_persistance;
Program(int nplayers) : offline_data_used(nplayers),
unknown_usage(false)
unknown_usage(false), writes_persistance(false)
{ compute_constants(); }
// Read in a program

View File

@@ -74,6 +74,7 @@ HonestMajorityRingMachineWithSecurity<U, V>::HonestMajorityRingMachineWithSecuri
Y(K, 40) \
default: \
cerr << "not compiled for security parameter " << to_string(opts.S) << endl; \
cerr << "add 'Y(K, " << opts.S << ")' to " __FILE__ ", line 76" << endl; \
exit(1); \
} \
break;

View File

@@ -23,6 +23,8 @@ enum ThreadJobType
TRIPLE_SACRIFICE_JOB,
CHECK_JOB,
FFT_JOB,
CIPHER_PLAIN_MULT_JOB,
MATRX_RAND_MULT_JOB,
NO_JOB
};

View File

@@ -6,17 +6,21 @@
#include "ThreadQueues.h"
#include <assert.h>
#include <math.h>
int ThreadQueues::distribute(ThreadJob job, int n_items, int base,
int granularity)
{
find_available();
return distribute_no_setup(job, n_items, base, granularity);
if (find_available() > 0)
return distribute_no_setup(job, n_items, base, granularity);
else
return base;
}
int ThreadQueues::find_available()
{
available.clear();
if (not available.empty())
return 0;
for (size_t i = 1; i < size(); i++)
if (at(i)->available())
available.push_back(i);
@@ -28,7 +32,7 @@ int ThreadQueues::find_available()
int ThreadQueues::get_n_per_thread(int n_items, int granularity)
{
int n_per_thread = n_items / (available.size() + 1) / granularity
int n_per_thread = ceil(n_items / (available.size() + 1.0)) / granularity
* granularity;
return n_per_thread;
}
@@ -39,6 +43,11 @@ int ThreadQueues::distribute_no_setup(ThreadJob job, int n_items, int base,
int n_per_thread = get_n_per_thread(n_items, granularity);
for (size_t i = 0; i < available.size(); i++)
{
if (base + (i + 1) * n_per_thread > size_t(n_items))
{
available.resize(i);
return base + i * n_per_thread;
}
if (supplies)
job.supply = supplies->at(i);
job.begin = base + i * n_per_thread;
@@ -52,4 +61,5 @@ void ThreadQueues::wrap_up(ThreadJob job)
{
for (int i : available)
assert(at(i)->result().output == job.output);
available.clear();
}

View File

@@ -190,7 +190,7 @@
*dest++ = *op1++ * *op2++) \
X(DIVINT, auto dest = &Proc.get_Ci()[r[0]]; auto op1 = &Proc.get_Ci()[r[1]]; \
auto op2 = &Proc.get_Ci()[r[2]], \
*dest++ = *op1++ / *op2++) \
if (*op2 == 0) throw division_by_zero(); *dest++ = *op1++ / *op2++) \
X(INCINT, auto dest = &Proc.get_Ci()[r[0]]; auto base = Proc.get_Ci()[r[1]], \
int inc = (i / start[0]) % start[1]; *dest++ = base + inc * int(n)) \
X(EQZC, auto dest = &Ci[r[0]]; auto source = &Ci[r[1]], *dest++ = *source++ == 0) \

View File

@@ -67,13 +67,7 @@ def write_winner_to_clients(sockets, number_clients, winning_client_id):
# Setup authenticate result using share of random.
# client can validate ∑ winning_client_id * ∑ rnd_from_triple = ∑ auth_result
rnd_from_triple = sint.get_random_triple()[0]
auth_result = winning_client_id * rnd_from_triple
@for_range(number_clients)
def loop_body(i):
sint.write_shares_to_socket(sockets[i], [winning_client_id, rnd_from_triple, auth_result])
sint.reveal_to_clients(sockets.get_sub(number_clients), [winning_client_id])
def main():
"""Listen in while loop for players to join a game.

View File

@@ -1,5 +1,7 @@
prog = program
prog.options.binary = -1
from Compiler.GC.types import *
from Compiler.GC.instructions import *

View File

@@ -2,7 +2,6 @@ from Compiler import ml
debug = False
program.use_edabit(True)
program.options_from_args()
sfix.set_precision(16, 31)
@@ -12,30 +11,39 @@ dim = int(program.args[1])
batch = int(program.args[2])
try:
ml.set_n_threads(int(program.args[3]))
n_iterations = int(program.args[3])
except:
n_iterations = 1
try:
ml.set_n_threads(int(program.args[4]))
except:
ml.set_n_threads(None)
X_normal = sfix.Matrix(6400, dim)
X_pos = sfix.Matrix(6400, dim)
N = batch * n_iterations
dense = ml.Dense(12800, dim, 1)
layers = [dense, ml.Output(12800, debug=debug, approx='approx' in program.args)]
print('run 1 epoch of logistic regression with %d examples' % (N))
sgd = ml.SGD(layers, batch // 128 * 10 , debug=debug, report_loss=False)
dense = ml.Dense(N, dim, 1)
sigmoid = ml.Output(N, debug=debug, approx='approx' in program.args)
for x in dense.X, sigmoid.Y:
x.assign_all(0)
sgd = ml.SGD([dense, sigmoid], 1, debug=debug, report_loss=False)
sgd.reset()
if not ('forward' in program.args or 'backward' in program.args):
sgd.reset([X_normal, X_pos])
sgd.run(batch_size=batch)
if 'forward' in program.args:
@for_range(1000)
@for_range(n_iterations)
def _(i):
sgd.forward(N=batch)
if 'backward' in program.args:
b = regint.Array(batch)
b.assign(regint.inc(batch))
@for_range(1000)
@for_range(n_iterations)
def _(i):
sgd.backward(batch=b)

Some files were not shown because too many files have changed in this diff Show More