mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 13:37:58 -05:00
Maintenance.
This commit is contained in:
@@ -112,8 +112,6 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
|
||||
garble_processor.reset(program);
|
||||
this->processor.open_input_file(N.my_num(), 0);
|
||||
|
||||
T::bit_type::mac_key_type::init_field();
|
||||
GC::ShareThread<typename T::bit_type> share_thread(N, online_opts, *P, 0, usage);
|
||||
shared_proc = new SubProcessor<T>(dummy_proc, *MC, *prep, *P);
|
||||
|
||||
auto& inputter = shared_proc->input;
|
||||
|
||||
@@ -243,6 +243,9 @@ public:
|
||||
template <class T>
|
||||
static T get_input(int from, GC::Processor<T>& processor, int n_bits)
|
||||
{ return T::input(from, processor.get_input(n_bits), n_bits); }
|
||||
template<class U>
|
||||
static void reveal_inst(GC::Processor<U>& processor, const vector<int>& args)
|
||||
{ processor.reveal(args); }
|
||||
|
||||
template<class T>
|
||||
static void convcbit(Integer& dest, const GC::Clear& source, T&)
|
||||
|
||||
@@ -1,5 +1,14 @@
|
||||
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
|
||||
|
||||
## 0.2.8 (Nov 4, 2021)
|
||||
|
||||
- Tested on Apple laptop with ARM chip
|
||||
- Restore trusted client interface
|
||||
- Directly accessible softmax function
|
||||
- Signature in preprocessing files to reduce confusing errors
|
||||
- Improved error messages for connection issues
|
||||
- Documentation of low-level share types and protocol pairs
|
||||
|
||||
## 0.2.7 (Sep 17, 2021)
|
||||
|
||||
- Optimized matrix multiplication in Hemi
|
||||
|
||||
4
CONFIG
4
CONFIG
@@ -88,7 +88,9 @@ CPPFLAGS = $(CFLAGS)
|
||||
LD = $(CXX)
|
||||
|
||||
ifeq ($(OS), Darwin)
|
||||
# for boost with OpenSSL 3
|
||||
CFLAGS += -Wno-error=deprecated-declarations
|
||||
ifeq ($(USE_NTL),1)
|
||||
CFLAGS += -Wno-error=unused-parameter
|
||||
CFLAGS += -Wno-error=unused-parameter -Wno-error=deprecated-copy
|
||||
endif
|
||||
endif
|
||||
|
||||
@@ -309,6 +309,8 @@ class cbits(bits):
|
||||
res = type(self)()
|
||||
inst.notcb(self.n, res, self)
|
||||
return res
|
||||
def __eq__(self, other):
|
||||
raise CompilerError('equality not implemented')
|
||||
def print_reg(self, desc=''):
|
||||
inst.print_regb(self, desc)
|
||||
def print_reg_plain(self):
|
||||
|
||||
@@ -19,6 +19,8 @@ class BlockAllocator:
|
||||
self.by_address = {}
|
||||
|
||||
def by_size(self, size):
|
||||
if size >= 2 ** 32:
|
||||
raise CompilerError('size exceeds addressing capability')
|
||||
return self.by_logsize[int(math.log(size, 2))][size]
|
||||
|
||||
def push(self, address, size):
|
||||
|
||||
@@ -290,6 +290,7 @@ def BitDecRing(a, k, m):
|
||||
return [types.sint.conv(bit) for bit in reversed(bits)][::-1]
|
||||
|
||||
def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None):
|
||||
instructions_base.set_global_vector_size(a.size)
|
||||
r_dprime = types.sint()
|
||||
r_prime = types.sint()
|
||||
c = types.cint()
|
||||
@@ -298,6 +299,7 @@ def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None):
|
||||
pow2 = two_power(k + kappa)
|
||||
asm_open(c, pow2 + two_power(k) + a - two_power(m)*r_dprime - r_prime)
|
||||
res = r[0].bit_adder(r, list(r[0].bit_decompose_clear(c,m)))
|
||||
instructions_base.reset_global_vector_size()
|
||||
return res
|
||||
|
||||
def BitDecField(a, k, m, kappa, bits_to_compute=None):
|
||||
|
||||
@@ -1496,6 +1496,14 @@ class cond_print_plain(base.IOInstruction):
|
||||
code = base.opcodes['CONDPRINTPLAIN']
|
||||
arg_format = ['c', 'c', 'c']
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
base.Instruction.__init__(self, *args, **kwargs)
|
||||
self.size = args[1].size
|
||||
args[2].set_size(self.size)
|
||||
|
||||
def get_code(self):
|
||||
return base.Instruction.get_code(self, self.size)
|
||||
|
||||
class print_int(base.IOInstruction):
|
||||
""" Output clear integer register.
|
||||
|
||||
@@ -1591,7 +1599,16 @@ class readsocketc(base.IOInstruction):
|
||||
return True
|
||||
|
||||
class readsockets(base.IOInstruction):
|
||||
"""Read a variable number of secret shares + MACs from socket for a client id and store in registers"""
|
||||
""" Read a variable number of secret shares (potentially with MAC)
|
||||
from a socket for a client id and store them in registers. If the
|
||||
protocol uses MACs, the client should be different for every party.
|
||||
|
||||
:param: client id (regint)
|
||||
:param: vector size (int)
|
||||
:param: source (sint)
|
||||
:param: (repeat source)...
|
||||
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['READSOCKETS']
|
||||
arg_format = tools.chain(['ci','int'], itertools.repeat('sw'))
|
||||
@@ -1628,6 +1645,25 @@ class writesocketc(base.IOInstruction):
|
||||
def has_var_args(self):
|
||||
return True
|
||||
|
||||
class writesockets(base.IOInstruction):
|
||||
""" Write a variable number of secret shares (potentially with MAC)
|
||||
from registers into a socket for a specified client id. If the
|
||||
protocol uses MACs, the client should be different for every party.
|
||||
|
||||
:param: client id (regint)
|
||||
:param: message type (must be 0)
|
||||
:param: vector size (int)
|
||||
:param: source (sint)
|
||||
:param: (repeat source)...
|
||||
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['WRITESOCKETS']
|
||||
arg_format = tools.chain(['ci', 'int', 'int'], itertools.repeat('s'))
|
||||
|
||||
def has_var_args(self):
|
||||
return True
|
||||
|
||||
class writesocketshare(base.IOInstruction):
|
||||
""" Write a variable number of shares (without MACs) from secret
|
||||
registers into socket for a specified client id.
|
||||
@@ -2384,5 +2420,76 @@ class lts(base.CISC):
|
||||
subs(a, self.args[1], self.args[2])
|
||||
comparison.LTZ(self.args[0], a, self.args[3], self.args[4])
|
||||
|
||||
# placeholder for documentation
|
||||
class cisc:
|
||||
""" Meta instruction for emulation. This instruction is only generated
|
||||
when using ``-K`` with ``compile.py``. The header looks as follows:
|
||||
|
||||
:param: number of arguments after name plus one
|
||||
:param: name (16 bytes, zero-padded)
|
||||
|
||||
Currently, the following names are supported:
|
||||
|
||||
LTZ
|
||||
Less than zero.
|
||||
|
||||
:param: number of arguments in this unit (must be 6)
|
||||
:param: vector size
|
||||
:param: result (sint)
|
||||
:param: input (sint)
|
||||
:param: bit length
|
||||
:param: (ignored)
|
||||
:param: (repeat)...
|
||||
|
||||
Trunc
|
||||
Truncation.
|
||||
|
||||
:param: number of arguments in this unit (must be 8)
|
||||
:param: vector size
|
||||
:param: result (sint)
|
||||
:param: input (sint)
|
||||
:param: bit length
|
||||
:param: number of bits to truncate
|
||||
:param: (ignored)
|
||||
:param: 0 for unsigned or 1 for signed
|
||||
:param: (repeat)...
|
||||
|
||||
FPDiv
|
||||
Fixed-point division. Division by zero results in zero without error.
|
||||
|
||||
:param: number of arguments in this unit (must be at least 7)
|
||||
:param: vector size
|
||||
:param: result (sint)
|
||||
:param: dividend (sint)
|
||||
:param: divisor (sint)
|
||||
:param: (ignored)
|
||||
:param: fixed-point precision
|
||||
:param: (repeat)...
|
||||
|
||||
exp2_fx
|
||||
Fixed-point power of two.
|
||||
|
||||
:param: number of arguments in this unit (must be at least 6)
|
||||
:param: vector size
|
||||
:param: result (sint)
|
||||
:param: exponent (sint)
|
||||
:param: (ignored)
|
||||
:param: fixed-point precision
|
||||
:param: (repeat)...
|
||||
|
||||
log2_fx
|
||||
Fixed-point logarithm with base 2.
|
||||
|
||||
:param: number of arguments in this unit (must be at least 6)
|
||||
:param: vector size
|
||||
:param: result (sint)
|
||||
:param: input (sint)
|
||||
:param: (ignored)
|
||||
:param: fixed-point precision
|
||||
:param: (repeat)...
|
||||
|
||||
"""
|
||||
code = base.opcodes['CISC']
|
||||
|
||||
# hack for circular dependency
|
||||
from Compiler import comparison
|
||||
|
||||
@@ -481,7 +481,16 @@ def cisc(function):
|
||||
|
||||
def expand_merged(self, skip):
|
||||
if function.__name__ in skip:
|
||||
return [self], 0
|
||||
good = True
|
||||
for call in self.calls:
|
||||
if not good:
|
||||
break
|
||||
for arg in call[0]:
|
||||
if isinstance(arg, program.curr_tape.Register) and \
|
||||
not issubclass(type(self.calls[0][0][0]), type(arg)):
|
||||
good = False
|
||||
if good:
|
||||
return [self], 0
|
||||
tape = program.curr_tape
|
||||
block = tape.BasicBlock(tape, None, None)
|
||||
tape.active_basicblock = block
|
||||
@@ -520,6 +529,7 @@ def cisc(function):
|
||||
String.check(name)
|
||||
res += String.encode(name)
|
||||
for call in self.calls:
|
||||
call[1].pop('nearest', None)
|
||||
assert not call[1]
|
||||
res += int_to_bytes(len(call[0]) + 2)
|
||||
res += int_to_bytes(call[0][0].size)
|
||||
|
||||
136
Compiler/ml.py
136
Compiler/ml.py
@@ -155,6 +155,26 @@ def argmax(x):
|
||||
return comp.if_else(a[0], b[0]), comp.if_else(a[1], b[1])
|
||||
return tree_reduce(op, enumerate(x))[0]
|
||||
|
||||
def softmax(x):
|
||||
""" Softmax.
|
||||
|
||||
:param x: vector or list of sfix
|
||||
:returns: sfix vector
|
||||
"""
|
||||
return softmax_from_exp(exp_for_softmax(x)[0])
|
||||
|
||||
def exp_for_softmax(x):
|
||||
m = util.max(x)
|
||||
mv = m.expand_to_vector(len(x))
|
||||
try:
|
||||
x = x.get_vector()
|
||||
except AttributeError:
|
||||
x = sfix(x)
|
||||
return (x - mv > -get_limit(x)).if_else(exp(x - mv), 0), m
|
||||
|
||||
def softmax_from_exp(x):
|
||||
return x / sum(x)
|
||||
|
||||
report_progress = False
|
||||
|
||||
def progress(x):
|
||||
@@ -464,10 +484,7 @@ class MultiOutput(MultiOutputBase):
|
||||
self.losses[i] = -sfix.dot_product(
|
||||
self.Y[batch[i]].get_vector(), log_e(div))
|
||||
else:
|
||||
m = util.max(self.X[i])
|
||||
mv = m.expand_to_vector(d_out)
|
||||
x = self.X[i].get_vector()
|
||||
e = (x - mv > -get_limit(x)).if_else(exp(x - mv), 0)
|
||||
e, m = exp_for_softmax(self.X[i])
|
||||
self.exp[i].assign_vector(e)
|
||||
if self.compute_loss:
|
||||
true_X = sfix.dot_product(self.Y[batch[i]], self.X[i])
|
||||
@@ -532,11 +549,8 @@ class MultiOutput(MultiOutputBase):
|
||||
return
|
||||
@for_range_opt_multithread(self.n_threads, len(batch))
|
||||
def _(i):
|
||||
for j in range(d_out):
|
||||
dividend = self.exp[i][j]
|
||||
divisor = sum(self.exp[i])
|
||||
div = (divisor > 0.1).if_else(dividend / divisor, 0)
|
||||
self.nabla_X[i][j] = (-self.Y[batch[i]][j] + div)
|
||||
div = softmax_from_exp(self.exp[i])
|
||||
self.nabla_X[i][:] = -self.Y[batch[i]][:] + div
|
||||
self.maybe_debug_backward(batch)
|
||||
|
||||
def maybe_debug_backward(self, batch):
|
||||
@@ -588,7 +602,7 @@ class DenseBase(Layer):
|
||||
nablas = lambda self: (self.nabla_W, self.nabla_b)
|
||||
|
||||
def output_weights(self):
|
||||
print_ln('%s', self.W.reveal_nested())
|
||||
self.W.print_reveal_nested()
|
||||
print_ln('%s', self.b.reveal_nested())
|
||||
|
||||
def backward_params(self, f_schur_Y, batch):
|
||||
@@ -1316,7 +1330,7 @@ class ConvBase(BaseLayer):
|
||||
self.bias.input_from(player, raw=raw)
|
||||
|
||||
def output_weights(self):
|
||||
print_ln('%s', self.weights.reveal_nested())
|
||||
self.weights.print_reveal_nested()
|
||||
print_ln('%s', self.bias.reveal_nested())
|
||||
|
||||
def dot_product(self, iv, wv, out_y, out_x, out_c):
|
||||
@@ -1942,7 +1956,12 @@ class Optimizer:
|
||||
for label, X in enumerate(self.X_by_label):
|
||||
indices = regint.Array(n * n_per_epoch)
|
||||
indices_by_label.append(indices)
|
||||
indices.assign(regint.inc(len(indices), 0, 1, 1, len(X)))
|
||||
indices.assign(regint.inc(len(X)))
|
||||
missing = len(indices) - len(X)
|
||||
if missing:
|
||||
indices.assign_vector(
|
||||
regint.get_random(int(math.log2(len(X))), size=missing),
|
||||
base=len(X))
|
||||
if self.always_shuffle or n_per_epoch > 1:
|
||||
indices.shuffle()
|
||||
loss_sum = MemValue(sfix(0))
|
||||
@@ -2050,6 +2069,8 @@ class Optimizer:
|
||||
self.time_layers = 'time_layers' in program.args
|
||||
self.revealing_correctness = not 'no_acc' in program.args
|
||||
self.layers[-1].compute_loss = not 'no_loss' in program.args
|
||||
if 'full_cisc' in program.args:
|
||||
program.options.keep_cisc = 'FPDiv,exp2_fx,log2_fx'
|
||||
model_input = 'model_input' in program.args
|
||||
acc_first = model_input and not 'train_first' in program.args
|
||||
if model_input:
|
||||
@@ -2058,12 +2079,14 @@ class Optimizer:
|
||||
else:
|
||||
self.reset()
|
||||
if 'one_iter' in program.args:
|
||||
print_float_prec(16)
|
||||
self.output_weights()
|
||||
print_ln('loss')
|
||||
print_ln('%s', self.eval(
|
||||
self.layers[0].X.get_part(0, batch_size)).reveal_nested())
|
||||
self.eval(
|
||||
self.layers[0].X.get_part(0, batch_size),
|
||||
batch_size=batch_size).print_reveal_nested()
|
||||
for layer in self.layers:
|
||||
print_ln('%s', layer.X.get_part(0, batch_size).reveal_nested())
|
||||
layer.X.get_part(0, batch_size).print_reveal_nested()
|
||||
print_ln('%s', self.layers[-1].Y.get_part(0, batch_size).reveal_nested())
|
||||
batch = Array.create_from(regint.inc(batch_size))
|
||||
self.forward(batch=batch, training=True)
|
||||
@@ -2083,9 +2106,10 @@ class Optimizer:
|
||||
return
|
||||
N = self.layers[0].X.sizes[0]
|
||||
n_trained = (N + batch_size - 1) // batch_size * batch_size
|
||||
print_ln('train_acc: %s (%s/%s)',
|
||||
cfix(self.n_correct, k=63, f=31) / n_trained,
|
||||
self.n_correct, n_trained)
|
||||
if not acc_first:
|
||||
print_ln('train_acc: %s (%s/%s)',
|
||||
cfix(self.n_correct, k=63, f=31) / n_trained,
|
||||
self.n_correct, n_trained)
|
||||
if test_X and test_Y:
|
||||
n_test = len(test_Y)
|
||||
n_correct, loss = self.reveal_correctness(test_X, test_Y,
|
||||
@@ -2500,16 +2524,30 @@ class keras:
|
||||
batch_size = min(batch_size, self.batch_size)
|
||||
return self.opt.eval(x, batch_size=batch_size)
|
||||
|
||||
def solve_linear(A, b, n_iterations, progress=False):
|
||||
""" Iterative linear solution approximation. """
|
||||
def solve_linear(A, b, n_iterations, progress=False, n_threads=None,
|
||||
stop=False, already_symmetric=False, precond=False):
|
||||
""" Iterative linear solution approximation for :math:`Ax=b`.
|
||||
|
||||
:param progress: print some information on the progress (implies revealing)
|
||||
:param n_threads: number of threads to use
|
||||
:param stop: whether to stop when converged (implies revealing)
|
||||
|
||||
"""
|
||||
assert len(b) == A.sizes[0]
|
||||
x = sfix.Array(A.sizes[1])
|
||||
x.assign_vector(sfix.get_random(-1, 1, size=len(x)))
|
||||
AtA = sfix.Matrix(len(x), len(x))
|
||||
AtA[:] = A.direct_trans_mul(A)
|
||||
if already_symmetric:
|
||||
AtA = A
|
||||
r = Array.create_from(b - AtA * x)
|
||||
else:
|
||||
AtA = sfix.Matrix(len(x), len(x))
|
||||
A.trans_mul_to(A, AtA, n_threads=n_threads)
|
||||
r = Array.create_from(A.transpose() * b - AtA * x)
|
||||
if precond:
|
||||
return solve_linear_diag_precond(AtA, b, x, r, n_iterations,
|
||||
progress, stop)
|
||||
v = sfix.Array(A.sizes[1])
|
||||
v.assign_all(0)
|
||||
r = Array.create_from(A.transpose() * b - AtA * x)
|
||||
Av = sfix.Array(len(x))
|
||||
@for_range(n_iterations)
|
||||
def _(i):
|
||||
@@ -2523,10 +2561,43 @@ def solve_linear(A, b, n_iterations, progress=False):
|
||||
if progress:
|
||||
print_ln('%s alpha=%s vr=%s v_norm=%s', i, alpha.reveal(),
|
||||
vr.reveal(), v_norm.reveal())
|
||||
if stop:
|
||||
return (alpha > 0).reveal()
|
||||
return x
|
||||
|
||||
def mr(A, n_iterations):
|
||||
""" Iterative matrix inverse approximation. """
|
||||
def solve_linear_diag_precond(A, b, x, r, n_iterations, progress=False,
|
||||
stop=False):
|
||||
m = 1 / A.diag()
|
||||
mr = Array.create_from(m * r[:])
|
||||
d = Array.create_from(mr)
|
||||
@for_range(n_iterations)
|
||||
def _(i):
|
||||
Ad = A * d
|
||||
d_norm = sfix.dot_product(d, Ad)
|
||||
alpha = (d_norm == 0).if_else(0, sfix.dot_product(r, mr) / d_norm)
|
||||
x[:] = x[:] + alpha * d[:]
|
||||
r_norm = sfix.dot_product(r, mr)
|
||||
r[:] = r[:] - alpha * Ad
|
||||
tmp = m * r[:]
|
||||
beta = (r_norm == 0).if_else(0, sfix.dot_product(r, tmp) / r_norm)
|
||||
mr[:] = tmp
|
||||
d[:] = tmp + beta * d
|
||||
if progress:
|
||||
print_ln('%s alpha=%s beta=%s r_norm=%s d_norm=%s', i,
|
||||
alpha.reveal(), beta.reveal(), r_norm.reveal(),
|
||||
d_norm.reveal())
|
||||
if stop:
|
||||
return (alpha > 0).reveal()
|
||||
return x
|
||||
|
||||
def mr(A, n_iterations, stop=False):
|
||||
""" Iterative matrix inverse approximation.
|
||||
|
||||
:param A: matrix to invert
|
||||
:param n_iterations: maximum number of iterations
|
||||
:param stop: whether to stop when converged (implies revealing)
|
||||
|
||||
"""
|
||||
assert len(A.sizes) == 2
|
||||
assert A.sizes[0] == A.sizes[1]
|
||||
M = A.same_shape()
|
||||
@@ -2536,5 +2607,18 @@ def mr(A, n_iterations):
|
||||
e = sfix.Array(n)
|
||||
e.assign_all(0)
|
||||
e[i] = 1
|
||||
M[i] = solve_linear(A, e, n_iterations)
|
||||
M[i] = solve_linear(A, e, n_iterations, stop=stop)
|
||||
return M.transpose()
|
||||
|
||||
def var(x):
|
||||
""" Variance. """
|
||||
mean = MemValue(type(x[0])(0))
|
||||
@for_range_opt(len(x))
|
||||
def _(i):
|
||||
mean.iadd(x[i])
|
||||
mean /= len(x)
|
||||
res = MemValue(type(x[0])(0))
|
||||
@for_range_opt(len(x))
|
||||
def _(i):
|
||||
res.iadd((x[i] - mean.read()) ** 2)
|
||||
return res.read()
|
||||
|
||||
@@ -690,8 +690,9 @@ class Tape:
|
||||
|
||||
def expand_cisc(self):
|
||||
new_instructions = []
|
||||
if self.parent.program.options.keep_cisc:
|
||||
if self.parent.program.options.keep_cisc != None:
|
||||
skip = ['LTZ', 'Trunc']
|
||||
skip += self.parent.program.options.keep_cisc.split(',')
|
||||
else:
|
||||
skip = []
|
||||
for inst in self.instructions:
|
||||
|
||||
@@ -1161,7 +1161,7 @@ class cint(_clear, _int):
|
||||
cond_print_str(self, string)
|
||||
|
||||
def output_if(self, cond):
|
||||
cond_print_plain(self.conv(cond), self, cint(0))
|
||||
cond_print_plain(self.conv(cond), self, cint(0, size=self.size))
|
||||
|
||||
|
||||
class cgf2n(_clear, _gf2n):
|
||||
@@ -1922,14 +1922,7 @@ class _secret(_register, _secret_structure):
|
||||
r = self.get_dabit()
|
||||
movs(self, r[0].bit_xor((r[1] ^ val).reveal().to_regint_by_bit()))
|
||||
elif isinstance(val, sbitvec):
|
||||
assert(sum(x.n for x in val.v) == self.size)
|
||||
for val_part, base in zip(val, range(0, self.size, 64)):
|
||||
left = min(64, self.size - base)
|
||||
r = self.get_dabit(size=left)
|
||||
v = regint(size=left)
|
||||
bitdecint_class(regint((r[1] ^ val_part).reveal()), *v)
|
||||
part = r[0].bit_xor(v)
|
||||
vmovs(left, self.get_vector(base, left), part)
|
||||
movs(self, sint.bit_compose(val))
|
||||
else:
|
||||
self.load_clear(self.clear_type(val))
|
||||
|
||||
@@ -1939,6 +1932,8 @@ class _secret(_register, _secret_structure):
|
||||
|
||||
:param bits: iterable of any type convertible to sint """
|
||||
from Compiler.GC.types import sbits, sbitintvec
|
||||
if isinstance(bits, sbits):
|
||||
bits = bits.bit_decompose()
|
||||
bits = list(bits)
|
||||
if (program.use_edabit() or program.use_split()) and isinstance(bits[0], sbits):
|
||||
if program.use_edabit():
|
||||
@@ -2112,6 +2107,11 @@ class sint(_secret, _int):
|
||||
thereof or sbits/sbitvec/sfix)
|
||||
:param size: vector size (int), defaults to 1 or size of list
|
||||
|
||||
When converting :py:class:`~Compiler.GC.types.sbits`, the result is a
|
||||
vector of bits, and when converting
|
||||
:py:class:`~Compiler.GC.types.sbitvec`, the result is a vector of values
|
||||
with bit length equal the length of the input.
|
||||
|
||||
"""
|
||||
__slots__ = []
|
||||
instruction_type = 'modp'
|
||||
@@ -2278,6 +2278,17 @@ class sint(_secret, _int):
|
||||
else:
|
||||
return res
|
||||
|
||||
@vectorized_classmethod
|
||||
def write_to_socket(cls, client_id, values,
|
||||
message_type=ClientMessageType.NoType):
|
||||
""" Send a list of shares and MAC shares to a client socket.
|
||||
|
||||
:param client_id: regint
|
||||
:param values: list of sint
|
||||
|
||||
"""
|
||||
writesockets(client_id, message_type, values[0].size, *values)
|
||||
|
||||
@vectorize
|
||||
def write_share_to_socket(self, client_id, message_type=ClientMessageType.NoType):
|
||||
""" Send only share to socket """
|
||||
@@ -2339,6 +2350,7 @@ class sint(_secret, _int):
|
||||
|
||||
@vectorize_init
|
||||
def __init__(self, val=None, size=None):
|
||||
from .GC.types import sbitvec
|
||||
if isinstance(val, personal):
|
||||
size = val._v.size
|
||||
super(sint, self).__init__('s', size=size)
|
||||
@@ -2346,6 +2358,8 @@ class sint(_secret, _int):
|
||||
elif isinstance(val, _fix):
|
||||
super(sint, self).__init__('s', size=val.v.size)
|
||||
self.load_other(val.v.round(val.k, val.f))
|
||||
elif isinstance(val, sbitvec):
|
||||
super(sint, self).__init__('s', val=val, size=val[0].n)
|
||||
else:
|
||||
super(sint, self).__init__('s', val=val, size=size)
|
||||
|
||||
@@ -3747,13 +3761,14 @@ class cfix(_number, _structure):
|
||||
other = parse_type(other, self.k, self.f)
|
||||
return other / self
|
||||
|
||||
@vectorize
|
||||
def print_plain(self):
|
||||
""" Clear fixed-point output. """
|
||||
print_float_plain(cint.conv(self.v), cint(-self.f), \
|
||||
cint(0), cint(0), cint(0))
|
||||
|
||||
def output_if(self, cond):
|
||||
cond_print_plain(cint.conv(cond), self.v, cint(-self.f))
|
||||
cond_print_plain(cint.conv(cond), self.v, cint(-self.f, size=self.size))
|
||||
|
||||
@vectorize
|
||||
def binary_output(self, player=None):
|
||||
@@ -4028,7 +4043,8 @@ class _fix(_single):
|
||||
print('Nearest rounding instead of proabilistic '
|
||||
'for fixed-point computation')
|
||||
cls.round_nearest = True
|
||||
if adapt_ring and program.options.ring:
|
||||
if adapt_ring and program.options.ring \
|
||||
and 'fix_ring' not in program.args:
|
||||
need = 2 ** int(math.ceil(math.log(2 * cls.k, 2)))
|
||||
if need != int(program.options.ring):
|
||||
print('Changing computation modulus to 2^%d' % need)
|
||||
@@ -4094,6 +4110,8 @@ class _fix(_single):
|
||||
elif isinstance(_v, (MemValue, MemFix)):
|
||||
#this is a memvalue object
|
||||
self.v = type(self)(_v.read()).v
|
||||
elif isinstance(_v, (list, tuple)):
|
||||
self.v = self.int_type(list(self.conv(x).v for x in _v))
|
||||
else:
|
||||
raise CompilerError('cannot convert %s to sfix' % _v)
|
||||
if not isinstance(self.v, self.int_type):
|
||||
@@ -4250,7 +4268,7 @@ class sfix(_fix):
|
||||
return cls._new(cls.int_type.get_raw_input_from(player))
|
||||
|
||||
@vectorized_classmethod
|
||||
def get_random(cls, lower, upper):
|
||||
def get_random(cls, lower, upper, symmetric=True):
|
||||
""" Uniform secret random number around centre of bounds.
|
||||
Actual range can be smaller but never larger.
|
||||
|
||||
@@ -4261,8 +4279,19 @@ class sfix(_fix):
|
||||
log_range = int(math.log(upper - lower, 2))
|
||||
n_bits = log_range + cls.f
|
||||
average = lower + 0.5 * (upper - lower)
|
||||
lower = average - 0.5 * 2 ** log_range
|
||||
return cls._new(cls.int_type.get_random_int(n_bits)) + lower
|
||||
real_range = (2 ** (n_bits) - 1) / 2 ** cls.f
|
||||
lower = average - 0.5 * real_range
|
||||
real_lower = round(lower * 2 ** cls.f) / 2 ** cls.f
|
||||
r = cls._new(cls.int_type.get_random_int(n_bits)) + lower
|
||||
if symmetric:
|
||||
lowest = math.floor(lower * 2 ** cls.f) / 2 ** cls.f
|
||||
print('randomness range [%f,%f], fringes half the probability' % \
|
||||
(lowest, lowest + 2 ** log_range))
|
||||
return cls.int_type.get_random_bit().if_else(r, -r + 2 * average)
|
||||
else:
|
||||
print('randomness range [%f,%f], %d bits' % \
|
||||
(real_lower, real_lower + real_range, n_bits))
|
||||
return r
|
||||
|
||||
@classmethod
|
||||
def direct_matrix_mul(cls, A, B, n, m, l, reduce=True, indices=None):
|
||||
@@ -4985,6 +5014,7 @@ class cfloat(object):
|
||||
""" Helper class for printing revealed sfloats. """
|
||||
__slots__ = ['v', 'p', 'z', 's', 'nan']
|
||||
|
||||
@vectorize_init
|
||||
def __init__(self, v, p=None, z=None, s=None, nan=0):
|
||||
""" Parameters as with :py:class:`sfloat` but public. """
|
||||
if s is None:
|
||||
@@ -4993,6 +5023,11 @@ class cfloat(object):
|
||||
parts = [cint.conv(x) for x in (v, p, z, s, nan)]
|
||||
self.v, self.p, self.z, self.s, self.nan = parts
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
return self.v.size
|
||||
|
||||
@vectorize
|
||||
def print_float_plain(self):
|
||||
""" Output. """
|
||||
print_float_plain(self.v, self.p, self.z, self.s, self.nan)
|
||||
@@ -5389,15 +5424,7 @@ class Array(_vectorizable):
|
||||
|
||||
def shuffle(self):
|
||||
""" Insecure shuffle in place. """
|
||||
if self.value_type == regint:
|
||||
self.assign(self.get_vector().shuffle())
|
||||
else:
|
||||
@library.for_range(len(self))
|
||||
def _(i):
|
||||
j = regint.get_random(64) % (len(self) - i)
|
||||
tmp = self[i]
|
||||
self[i] = self[i + j]
|
||||
self[i + j] = tmp
|
||||
self.assign_vector(self.get(regint.inc(len(self)).shuffle()))
|
||||
|
||||
def reveal(self):
|
||||
""" Reveal the whole array.
|
||||
@@ -5411,6 +5438,13 @@ class Array(_vectorizable):
|
||||
|
||||
reveal_nested = reveal_list
|
||||
|
||||
def print_reveal_nested(self, end='\n'):
|
||||
""" Reveal and print as list.
|
||||
|
||||
:param end: string to print after (default: line break)
|
||||
"""
|
||||
library.print_str('%s' + end, self.get_vector().reveal())
|
||||
|
||||
def reveal_to_binary_output(self, player=None):
|
||||
""" Reveal to binary output if supported by type.
|
||||
|
||||
@@ -5449,6 +5483,8 @@ sgf2n.dynamic_array = Array
|
||||
class SubMultiArray(_vectorizable):
|
||||
""" Multidimensional array functionality. Don't construct this
|
||||
directly, use :py:class:`MultiArray` instead. """
|
||||
check_indices = True
|
||||
|
||||
def __init__(self, sizes, value_type, address, index, debug=None):
|
||||
self.sizes = tuple(sizes)
|
||||
self.value_type = _get_type(value_type)
|
||||
@@ -5458,7 +5494,6 @@ class SubMultiArray(_vectorizable):
|
||||
self.address = None
|
||||
self.sub_cache = {}
|
||||
self.debug = debug
|
||||
self.check_indices = True
|
||||
if debug:
|
||||
library.print_ln_if(self.address + reduce(operator.mul, self.sizes) * self.value_type.n_elements() > program.allocated_mem[self.value_type.reg_type], 'AOF%d:' % len(self.sizes) + self.debug)
|
||||
|
||||
@@ -5467,8 +5502,6 @@ class SubMultiArray(_vectorizable):
|
||||
|
||||
:param index: public (regint/cint/int)
|
||||
:return: :py:class:`Array` if one-dimensional, :py:class:`SubMultiArray` otherwise"""
|
||||
if util.is_constant(index) and index >= self.sizes[0]:
|
||||
raise StopIteration
|
||||
if isinstance(index, slice) and index == slice(None):
|
||||
return self.get_vector()
|
||||
key = program.curr_block, str(index)
|
||||
@@ -5490,7 +5523,9 @@ class SubMultiArray(_vectorizable):
|
||||
self.sub_cache[key] = \
|
||||
SubMultiArray(self.sizes[1:], self.value_type, \
|
||||
self.address, index, debug=self.debug)
|
||||
return self.sub_cache[key]
|
||||
res = self.sub_cache[key]
|
||||
res.check_indices = self.check_indices
|
||||
return res
|
||||
|
||||
def __setitem__(self, index, other):
|
||||
""" Part assignment.
|
||||
@@ -5505,6 +5540,9 @@ class SubMultiArray(_vectorizable):
|
||||
""" Size of top dimension. """
|
||||
return self.sizes[0]
|
||||
|
||||
def __iter__(self):
|
||||
return (self[i] for i in range(len(self)))
|
||||
|
||||
def assign_all(self, value):
|
||||
""" Assign the same value to all entries.
|
||||
|
||||
@@ -5877,6 +5915,38 @@ class SubMultiArray(_vectorizable):
|
||||
self.address, other.address, None, 1, other.sizes[1],
|
||||
reduce=reduce, indices=indices)
|
||||
|
||||
def trans_mul_to(self, other, res, n_threads=None):
|
||||
"""
|
||||
Matrix multiplication with the transpose of :py:obj:`self`
|
||||
in the virtual machine.
|
||||
|
||||
:param self: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray`
|
||||
:param other: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray`
|
||||
:param res: matrix of matching dimension to store result
|
||||
:param n_threads: number of threads (default: single thread)
|
||||
"""
|
||||
@library.for_range_multithread(n_threads, 1, self.sizes[1])
|
||||
def _(i):
|
||||
indices = [regint(i), regint.inc(self.sizes[0])]
|
||||
indices += [regint.inc(i) for i in other.sizes]
|
||||
res[i] = self.direct_trans_mul(other, indices=indices)
|
||||
|
||||
def mul_trans_to(self, other, res, n_threads=None):
|
||||
"""
|
||||
Matrix multiplication with the transpose of :py:obj:`other`
|
||||
in the virtual machine.
|
||||
|
||||
:param self: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray`
|
||||
:param other: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray`
|
||||
:param res: matrix of matching dimension to store result
|
||||
:param n_threads: number of threads (default: single thread)
|
||||
"""
|
||||
@library.for_range_multithread(n_threads, 1, self.sizes[0])
|
||||
def _(i):
|
||||
indices = [regint(i), regint.inc(self.sizes[1])]
|
||||
indices += [regint.inc(i) for i in reversed(other.sizes)]
|
||||
res[i] = self.direct_mul_trans(other, indices=indices)
|
||||
|
||||
def direct_mul_to_matrix(self, other):
|
||||
""" Matrix multiplication in the virtual machine.
|
||||
|
||||
@@ -5992,6 +6062,13 @@ class SubMultiArray(_vectorizable):
|
||||
assert self.sizes[0] == self.sizes[1]
|
||||
return sum(self[i][i] for i in range(self.sizes[0]))
|
||||
|
||||
def diag(self):
|
||||
""" Matrix diagonal. """
|
||||
assert len(self.sizes) == 2
|
||||
assert self.sizes[0] == self.sizes[1]
|
||||
n = self.sizes[0]
|
||||
return self.array.get(regint.inc(n, 0, n + 1))
|
||||
|
||||
def reveal_list(self):
|
||||
""" Reveal as list. """
|
||||
return list(self.get_vector().reveal())
|
||||
@@ -6007,6 +6084,21 @@ class SubMultiArray(_vectorizable):
|
||||
return [f(sizes[1:]) for i in range(sizes[0])]
|
||||
return f(self.sizes)
|
||||
|
||||
def print_reveal_nested(self, end='\n'):
|
||||
""" Reveal and print as nested list.
|
||||
|
||||
:param end: string to print after (default: line break)
|
||||
"""
|
||||
if self.total_size() < program.options.budget:
|
||||
library.print_str('%s' + end, self.reveal_nested())
|
||||
else:
|
||||
library.print_str('[')
|
||||
@library.for_range(len(self) - 1)
|
||||
def _(i):
|
||||
self[i].print_reveal_nested(end=', ')
|
||||
self[len(self) - 1].print_reveal_nested(end='')
|
||||
library.print_str(']' + end)
|
||||
|
||||
def reveal_to_binary_output(self, player=None):
|
||||
""" Reveal to binary output if supported by type.
|
||||
|
||||
@@ -6042,6 +6134,10 @@ class MultiArray(SubMultiArray):
|
||||
a[2][:] = a[0][:] * a[1][:]
|
||||
|
||||
"""
|
||||
@staticmethod
|
||||
def disable_index_checks():
|
||||
SubMultiArray.check_indices = False
|
||||
|
||||
def __init__(self, sizes, value_type, debug=None, address=None, alloc=True):
|
||||
if isinstance(address, Array):
|
||||
self.array = address
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
*
|
||||
*/
|
||||
|
||||
#define NO_MIXED_CIRCUITS
|
||||
|
||||
#include "Networking/Server.h"
|
||||
#include "Networking/CryptoPlayer.h"
|
||||
#include "Math/gfp.h"
|
||||
@@ -49,8 +51,6 @@ int main(int argc, const char** argv)
|
||||
ArithmeticProcessor _({}, 0);
|
||||
BaseMachine machine;
|
||||
machine.ot_setups.push_back({P, false});
|
||||
GC::ShareThread<typename pShare::bit_type> thread(N,
|
||||
OnlineOptions::singleton, P, {}, usage);
|
||||
SubProcessor<pShare> proc(_, MCp, prep, P);
|
||||
|
||||
pShare sk, __;
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
*
|
||||
*/
|
||||
|
||||
#define NO_MIXED_CIRCUITS
|
||||
|
||||
#include "GC/TinierSecret.h"
|
||||
#include "GC/TinyMC.h"
|
||||
#include "GC/VectorInput.h"
|
||||
|
||||
@@ -106,8 +106,6 @@ void run(int argc, const char** argv)
|
||||
typename pShare::Direct_MC MCp(keyp);
|
||||
ArithmeticProcessor _({}, 0);
|
||||
typename pShare::TriplePrep sk_prep(0, usage);
|
||||
GC::ShareThread<typename pShare::bit_type> thread(N,
|
||||
OnlineOptions::singleton, P, {}, usage);
|
||||
SubProcessor<pShare> sk_proc(_, MCp, sk_prep, P);
|
||||
pShare sk, __;
|
||||
// synchronize
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "FHEOffline/PairwiseMachine.h"
|
||||
#include "Tools/benchmarking.h"
|
||||
#include "Protocols/fake-stuff.h"
|
||||
#include "Tools/Bundle.h"
|
||||
|
||||
#include "Protocols/fake-stuff.hpp"
|
||||
|
||||
|
||||
@@ -167,6 +167,8 @@ string open_prep_file(ofstream& outf, string data_type, int my_num, int thread_n
|
||||
throw runtime_error("cannot create directory " + dir);
|
||||
string file = prep_filename<T>(data_type, my_num, thread_num, initial, dir);
|
||||
outf.open(file.c_str(),ios::out | ios::binary | (clear ? ios::trunc : ios::app));
|
||||
if (clear)
|
||||
file_signature<Share<T>>().output(outf);
|
||||
if (outf.fail()) { throw file_error(file); }
|
||||
return file;
|
||||
}
|
||||
@@ -516,6 +518,7 @@ InputProducer<FD>::InputProducer(const Player& P, int thread_num,
|
||||
if (thread_num)
|
||||
file << "-" << thread_num;
|
||||
outf[j].open(file.str().c_str(), ios::out | ios::binary);
|
||||
file_signature<Share<T>>().output(outf[j]);
|
||||
if (outf[j].fail())
|
||||
{
|
||||
throw file_error(file.str());
|
||||
|
||||
@@ -24,11 +24,4 @@ AtlasShare::AtlasShare(const AtlasSecret& other) :
|
||||
{
|
||||
}
|
||||
|
||||
void AtlasShare::random()
|
||||
{
|
||||
AtlasSecret tmp;
|
||||
this->get_party().DataF.get_one(DATA_BIT, tmp);
|
||||
*this = tmp.get_reg(0);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -63,8 +63,6 @@ public:
|
||||
{
|
||||
*this = input;
|
||||
}
|
||||
|
||||
void random();
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
10
GC/CcdPrep.h
10
GC/CcdPrep.h
@@ -20,17 +20,17 @@ class CcdPrep : public BufferPrep<T>
|
||||
{
|
||||
typename T::part_type::LivePrep part_prep;
|
||||
SubProcessor<typename T::part_type>* part_proc;
|
||||
ShareThread<T>& thread;
|
||||
|
||||
public:
|
||||
CcdPrep(DataPositions& usage, ShareThread<T>& thread) :
|
||||
BufferPrep<T>(usage), part_prep(usage, thread), part_proc(0),
|
||||
thread(thread)
|
||||
static const bool use_part = true;
|
||||
|
||||
CcdPrep(DataPositions& usage) :
|
||||
BufferPrep<T>(usage), part_prep(usage), part_proc(0)
|
||||
{
|
||||
}
|
||||
|
||||
CcdPrep(SubProcessor<T>*, DataPositions& usage) :
|
||||
CcdPrep(usage, ShareThread<T>::s())
|
||||
CcdPrep(usage)
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ CcdPrep<T>::~CcdPrep()
|
||||
template<class T>
|
||||
void CcdPrep<T>::set_protocol(typename T::Protocol& protocol)
|
||||
{
|
||||
auto& thread = ShareThread<T>::s();
|
||||
assert(thread.MC);
|
||||
part_proc = new SubProcessor<typename T::part_type>(
|
||||
thread.MC->get_part_MC(), part_prep, protocol.get_part().P);
|
||||
|
||||
@@ -74,13 +74,6 @@ public:
|
||||
*this = input;
|
||||
}
|
||||
|
||||
void random()
|
||||
{
|
||||
CcdSecret<T> tmp;
|
||||
ShareThread<CcdSecret<T>>::s().DataF.get_one(DATA_BIT, tmp);
|
||||
*this = tmp.get_reg(0);
|
||||
}
|
||||
|
||||
This& operator^=(const This& other)
|
||||
{
|
||||
*this += other;
|
||||
|
||||
@@ -136,6 +136,8 @@ public:
|
||||
void andrs(int n, const FakeSecret& x, const FakeSecret& y)
|
||||
{ *this = BitVec(x.a * (y.a & 1)).mask(n); }
|
||||
|
||||
void xor_bit(int i, FakeSecret bit) { *this ^= bit << i; }
|
||||
|
||||
void invert(int n, const FakeSecret& x) { *this = BitVec(~x.a).mask(n); }
|
||||
|
||||
void random_bit() { a = random() % 2; }
|
||||
|
||||
@@ -79,13 +79,6 @@ public:
|
||||
*this = input;
|
||||
}
|
||||
|
||||
void random()
|
||||
{
|
||||
MaliciousCcdSecret<T> tmp;
|
||||
ShareThread<MaliciousCcdSecret<T>>::s().DataF.get_one(DATA_BIT, tmp);
|
||||
*this = tmp.get_reg(0);
|
||||
}
|
||||
|
||||
This& operator^=(const This& other)
|
||||
{
|
||||
*this += other;
|
||||
|
||||
13
GC/NoShare.h
13
GC/NoShare.h
@@ -7,12 +7,12 @@
|
||||
#define GC_NOSHARE_H_
|
||||
|
||||
#include "Processor/DummyProtocol.h"
|
||||
#include "BMR/Register.h"
|
||||
#include "Tools/SwitchableOutput.h"
|
||||
#include "Protocols/ShareInterface.h"
|
||||
|
||||
class InputArgs;
|
||||
class ArithmeticProcessor;
|
||||
class BlackHole;
|
||||
class SwitchableOutput;
|
||||
|
||||
namespace GC
|
||||
{
|
||||
@@ -110,7 +110,7 @@ public:
|
||||
|
||||
typedef NoShare small_type;
|
||||
|
||||
typedef BlackHole out_type;
|
||||
typedef SwitchableOutput out_type;
|
||||
|
||||
static const bool is_real = false;
|
||||
|
||||
@@ -124,6 +124,11 @@ public:
|
||||
return "no";
|
||||
}
|
||||
|
||||
static void specification(octetStream&)
|
||||
{
|
||||
fail();
|
||||
}
|
||||
|
||||
static int size()
|
||||
{
|
||||
return 0;
|
||||
@@ -172,6 +177,8 @@ public:
|
||||
|
||||
NoShare get_bit(int) const { fail(); return {}; }
|
||||
|
||||
void xor_bit(int, NoShare) const { fail(); }
|
||||
|
||||
void invert(int, NoShare) { fail(); }
|
||||
|
||||
NoShare mask(int) const { fail(); return {}; }
|
||||
|
||||
@@ -27,7 +27,7 @@ public:
|
||||
static int check_args(const vector<int>& args, int n);
|
||||
|
||||
template<class U>
|
||||
static void check_input(const U& in, int n_bits);
|
||||
static void check_input(const U& in, const int* params);
|
||||
|
||||
Machine<T>* machine;
|
||||
Memories<T>& memories;
|
||||
|
||||
@@ -89,15 +89,15 @@ U GC::Processor<T>::get_long_input(const int* params,
|
||||
else
|
||||
res = input_proc.get_input<FixInput_<U>>(interactive,
|
||||
¶ms[1]).items[0];
|
||||
int n_bits = *params;
|
||||
check_input(res, n_bits);
|
||||
check_input(res, params);
|
||||
return res;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
template<class U>
|
||||
void GC::Processor<T>::check_input(const U& in, int n_bits)
|
||||
void GC::Processor<T>::check_input(const U& in, const int* params)
|
||||
{
|
||||
int n_bits = *params;
|
||||
auto test = in >> (n_bits - 1);
|
||||
if (n_bits == 1)
|
||||
{
|
||||
@@ -106,9 +106,17 @@ void GC::Processor<T>::check_input(const U& in, int n_bits)
|
||||
}
|
||||
else if (not (test == 0 or test == -1))
|
||||
{
|
||||
throw runtime_error(
|
||||
"input too large for a " + std::to_string(n_bits)
|
||||
+ "-bit signed integer: " + to_string(in));
|
||||
if (params[1] == 0)
|
||||
throw runtime_error(
|
||||
"input out of range for a " + std::to_string(n_bits)
|
||||
+ "-bit signed integer: " + to_string(in));
|
||||
else
|
||||
throw runtime_error(
|
||||
"input out of range for a " + to_string(n_bits)
|
||||
+ "-bit fixed-point number with "
|
||||
+ to_string(params[1]) + "-bit precision: "
|
||||
+ to_string(
|
||||
mpf_class(bigint(in)) * exp2(-params[1])));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -22,7 +22,6 @@ class RepPrep : public PersonalPrep<T>, ShiftableTripleBuffer<T>
|
||||
ReplicatedBase* protocol;
|
||||
|
||||
public:
|
||||
RepPrep(DataPositions& usage, ShareThread<T>& thread);
|
||||
RepPrep(DataPositions& usage, int input_player = PersonalPrep<T>::SECURE);
|
||||
~RepPrep();
|
||||
|
||||
|
||||
@@ -16,13 +16,6 @@
|
||||
namespace GC
|
||||
{
|
||||
|
||||
template<class T>
|
||||
RepPrep<T>::RepPrep(DataPositions& usage, ShareThread<T>& thread) :
|
||||
RepPrep<T>(usage)
|
||||
{
|
||||
(void) thread;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
RepPrep<T>::RepPrep(DataPositions& usage, int input_player) :
|
||||
PersonalPrep<T>(usage, input_player), protocol(0)
|
||||
|
||||
@@ -31,6 +31,11 @@ public:
|
||||
{
|
||||
tainted = true;
|
||||
}
|
||||
|
||||
bool is_tainted()
|
||||
{
|
||||
return tainted;
|
||||
}
|
||||
};
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
@@ -84,7 +84,6 @@ public:
|
||||
template <class U>
|
||||
static void store_clear_in_dynamic(U& mem, const vector<ClearWriteAccess>& accesses)
|
||||
{ T::store_clear_in_dynamic(mem, accesses); }
|
||||
static void output(T& reg);
|
||||
|
||||
template<class U, class V>
|
||||
static void load(vector< ReadAccess<V> >& accesses, const U& mem);
|
||||
@@ -113,7 +112,7 @@ public:
|
||||
{ T::inputbvec(processor, input_proc, args); }
|
||||
template<class U>
|
||||
static void reveal_inst(Processor<U>& processor, const vector<int>& args)
|
||||
{ processor.reveal(args); }
|
||||
{ T::reveal_inst(processor, args); }
|
||||
|
||||
template<class U>
|
||||
static void trans(Processor<U>& processor, int n_inputs, const vector<int>& args);
|
||||
@@ -148,7 +147,6 @@ public:
|
||||
}
|
||||
void invert(int n, const Secret<T>& x);
|
||||
void and_(int n, const Secret<T>& x, const Secret<T>& y, bool repeat);
|
||||
void andrs(int n, const Secret<T>& x, const Secret<T>& y) { and_(n, x, y, true); }
|
||||
|
||||
template <class U>
|
||||
void reveal(size_t n_bits, U& x);
|
||||
|
||||
@@ -119,12 +119,6 @@ void Secret<T>::store(U& mem,
|
||||
T::store(mem, accesses);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void Secret<T>::output(T& reg)
|
||||
{
|
||||
reg.output();
|
||||
}
|
||||
|
||||
template <class T>
|
||||
Secret<T>::Secret()
|
||||
{
|
||||
|
||||
@@ -15,10 +15,6 @@ namespace GC
|
||||
class SemiHonestRepPrep : public RepPrep<SemiHonestRepSecret>
|
||||
{
|
||||
public:
|
||||
SemiHonestRepPrep(DataPositions& usage, ShareThread<SemiHonestRepSecret>&) :
|
||||
RepPrep<SemiHonestRepSecret>(usage)
|
||||
{
|
||||
}
|
||||
SemiHonestRepPrep(DataPositions& usage, bool = false) :
|
||||
RepPrep<SemiHonestRepSecret>(usage)
|
||||
{
|
||||
|
||||
@@ -16,11 +16,6 @@
|
||||
namespace GC
|
||||
{
|
||||
|
||||
SemiPrep::SemiPrep(DataPositions& usage, ShareThread<SemiSecret>&) :
|
||||
SemiPrep(usage)
|
||||
{
|
||||
}
|
||||
|
||||
SemiPrep::SemiPrep(DataPositions& usage, bool) :
|
||||
BufferPrep<SemiSecret>(usage), triple_generator(0)
|
||||
{
|
||||
|
||||
@@ -26,7 +26,6 @@ class SemiPrep : public BufferPrep<SemiSecret>, ShiftableTripleBuffer<SemiSecret
|
||||
SeededPRNG secure_prng;
|
||||
|
||||
public:
|
||||
SemiPrep(DataPositions& usage, ShareThread<SemiSecret>& thread);
|
||||
SemiPrep(DataPositions& usage, bool = true);
|
||||
~SemiPrep();
|
||||
|
||||
|
||||
@@ -73,6 +73,9 @@ public:
|
||||
void xor_(int n, const SemiSecret& x, const SemiSecret& y)
|
||||
{ *this = BitVec(x ^ y).mask(n); }
|
||||
|
||||
void xor_bit(int i, const SemiSecret& bit)
|
||||
{ *this ^= bit << i; }
|
||||
|
||||
void reveal(size_t n_bits, Clear& x);
|
||||
|
||||
SemiSecret lsb()
|
||||
|
||||
@@ -161,6 +161,9 @@ public:
|
||||
|
||||
This get_bit(int i)
|
||||
{ return (*this >> i) & 1; }
|
||||
|
||||
void xor_bit(int i, const This& bit)
|
||||
{ *this ^= bit << i; }
|
||||
};
|
||||
|
||||
template<class U>
|
||||
|
||||
@@ -32,9 +32,9 @@ public:
|
||||
|
||||
Preprocessing<T>& DataF;
|
||||
|
||||
ShareThread(const Names& N, OnlineOptions& opts, DataPositions& usage);
|
||||
ShareThread(const Names& N, OnlineOptions& opts, Player& P,
|
||||
typename T::mac_key_type mac_key, DataPositions& usage);
|
||||
ShareThread(Preprocessing<T>& prep);
|
||||
ShareThread(Preprocessing<T>& prep, Player& P,
|
||||
typename T::mac_key_type mac_key);
|
||||
virtual ~ShareThread();
|
||||
|
||||
virtual typename T::MC* new_mc(typename T::mac_key_type mac_key)
|
||||
@@ -54,6 +54,7 @@ public:
|
||||
DataPositions usage;
|
||||
|
||||
StandaloneShareThread(int i, ThreadMaster<T>& master);
|
||||
~StandaloneShareThread();
|
||||
|
||||
void pre_run();
|
||||
void post_run() { ShareThread<T>::post_run(); }
|
||||
|
||||
@@ -18,26 +18,28 @@ namespace GC
|
||||
|
||||
template<class T>
|
||||
StandaloneShareThread<T>::StandaloneShareThread(int i, ThreadMaster<T>& master) :
|
||||
ShareThread<T>(master.N, master.opts, usage), Thread<T>(i, master)
|
||||
ShareThread<T>(*Preprocessing<T>::get_new(master.opts.live_prep,
|
||||
master.N, usage)),
|
||||
Thread<T>(i, master)
|
||||
{
|
||||
}
|
||||
|
||||
template<class T>
|
||||
ShareThread<T>::ShareThread(const Names& N, OnlineOptions& opts, DataPositions& usage) :
|
||||
P(0), MC(0), protocol(0), DataF(
|
||||
opts.live_prep ?
|
||||
*static_cast<Preprocessing<T>*>(new typename T::LivePrep(
|
||||
usage, *this)) :
|
||||
*static_cast<Preprocessing<T>*>(new BitPrepFiles<T>(N,
|
||||
get_prep_sub_dir<T>(PREP_DIR, N.num_players()),
|
||||
usage, BaseMachine::thread_num)))
|
||||
StandaloneShareThread<T>::~StandaloneShareThread()
|
||||
{
|
||||
delete &this->DataF;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
ShareThread<T>::ShareThread(Preprocessing<T>& prep) :
|
||||
P(0), MC(0), protocol(0), DataF(prep)
|
||||
{
|
||||
}
|
||||
|
||||
template<class T>
|
||||
ShareThread<T>::ShareThread(const Names& N, OnlineOptions& opts, Player& P,
|
||||
typename T::mac_key_type mac_key, DataPositions& usage) :
|
||||
ShareThread(N, opts, usage)
|
||||
ShareThread<T>::ShareThread(Preprocessing<T>& prep, Player& P,
|
||||
typename T::mac_key_type mac_key) :
|
||||
ShareThread(prep)
|
||||
{
|
||||
pre_run(P, mac_key);
|
||||
}
|
||||
@@ -45,7 +47,6 @@ ShareThread<T>::ShareThread(const Names& N, OnlineOptions& opts, Player& P,
|
||||
template<class T>
|
||||
ShareThread<T>::~ShareThread()
|
||||
{
|
||||
delete &DataF;
|
||||
if (MC)
|
||||
delete MC;
|
||||
if (protocol)
|
||||
@@ -76,12 +77,6 @@ void ShareThread<T>::post_run()
|
||||
{
|
||||
protocol->check();
|
||||
MC->Check(*this->P);
|
||||
#ifndef INSECURE
|
||||
#ifdef VERBOSE
|
||||
cerr << "Removing used pre-processed data" << endl;
|
||||
#endif
|
||||
DataF.prune();
|
||||
#endif
|
||||
}
|
||||
|
||||
template<class T>
|
||||
|
||||
@@ -58,10 +58,8 @@ void Thread<T>::run()
|
||||
P = new PlainPlayer(N, id);
|
||||
processor.open_input_file(N.my_num(), thread_num,
|
||||
master.opts.cmd_private_input_file);
|
||||
processor.out.activate(N.my_num() == 0 or master.opts.interactive);
|
||||
processor.setup_redirection(P->my_num(), thread_num, master.opts);
|
||||
if (processor.stdout_redirect_file.is_open())
|
||||
processor.out.redirect_to_file(processor.stdout_redirect_file);
|
||||
processor.setup_redirection(P->my_num(), thread_num, master.opts,
|
||||
processor.out);
|
||||
|
||||
done.push(0);
|
||||
pre_run();
|
||||
|
||||
@@ -58,6 +58,16 @@ Thread<T>* ThreadMaster<T>::new_thread(int i)
|
||||
template<class T>
|
||||
void ThreadMaster<T>::run()
|
||||
{
|
||||
#ifndef INSECURE
|
||||
if (not opts.live_prep)
|
||||
{
|
||||
cerr
|
||||
<< "Preprocessing from file not supported by binary virtual machines"
|
||||
<< endl;
|
||||
exit(1);
|
||||
}
|
||||
#endif
|
||||
|
||||
P = new PlainPlayer(N, "main");
|
||||
|
||||
machine.load_schedule(progname);
|
||||
|
||||
@@ -106,11 +106,6 @@ public:
|
||||
party.MC->get_alphai());
|
||||
}
|
||||
|
||||
void random()
|
||||
{
|
||||
*this = get_party().DataF.get_part().get_bit();
|
||||
}
|
||||
|
||||
This lsb() const
|
||||
{
|
||||
return *this;
|
||||
|
||||
@@ -25,7 +25,6 @@ class TinierSharePrep : public PersonalPrep<T>
|
||||
MascotParams params;
|
||||
|
||||
typedef typename T::whole_type secret_type;
|
||||
ShareThread<secret_type>& thread;
|
||||
|
||||
void buffer_triples();
|
||||
void buffer_squares() { throw not_implemented(); }
|
||||
@@ -39,8 +38,6 @@ class TinierSharePrep : public PersonalPrep<T>
|
||||
void init_real(Player& P);
|
||||
|
||||
public:
|
||||
TinierSharePrep(DataPositions& usage, ShareThread<secret_type>& thread,
|
||||
int input_player = PersonalPrep<T>::SECURE);
|
||||
TinierSharePrep(DataPositions& usage, int input_player =
|
||||
PersonalPrep<T>::SECURE);
|
||||
TinierSharePrep(SubProcessor<T>*, DataPositions& usage);
|
||||
|
||||
@@ -15,16 +15,8 @@ namespace GC
|
||||
|
||||
template<class T>
|
||||
TinierSharePrep<T>::TinierSharePrep(DataPositions& usage, int input_player) :
|
||||
TinierSharePrep<T>(usage, ShareThread<secret_type>::s(), input_player)
|
||||
{
|
||||
}
|
||||
|
||||
template<class T>
|
||||
TinierSharePrep<T>::TinierSharePrep(DataPositions& usage,
|
||||
ShareThread<secret_type>& thread, int input_player) :
|
||||
PersonalPrep<T>(usage, input_player), triple_generator(0),
|
||||
real_triple_generator(0),
|
||||
thread(thread)
|
||||
real_triple_generator(0)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -87,6 +79,7 @@ void TinierSharePrep<T>::buffer_inputs(int player)
|
||||
template<class T>
|
||||
void GC::TinierSharePrep<T>::buffer_bits()
|
||||
{
|
||||
auto& thread = ShareThread<secret_type>::s();
|
||||
this->bits.push_back(
|
||||
BufferPrep<T>::get_random_from_inputs(thread.P->num_players()));
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ template<class T>
|
||||
void TinierSharePrep<T>::init_real(Player& P)
|
||||
{
|
||||
assert(real_triple_generator == 0);
|
||||
auto& thread = ShareThread<secret_type>::s();
|
||||
real_triple_generator = new typename T::whole_type::TripleGenerator(
|
||||
BaseMachine::s().fresh_ot_setup(), P.N, -1,
|
||||
OnlineOptions::singleton.batch_size, 1, params,
|
||||
@@ -24,6 +25,7 @@ void TinierSharePrep<T>::init_real(Player& P)
|
||||
template<class T>
|
||||
void TinierSharePrep<T>::buffer_secret_triples()
|
||||
{
|
||||
auto& thread = ShareThread<secret_type>::s();
|
||||
auto& triple_generator = real_triple_generator;
|
||||
assert(triple_generator != 0);
|
||||
params.generateBits = false;
|
||||
|
||||
@@ -58,6 +58,11 @@ public:
|
||||
return part_type::size() * default_length;
|
||||
}
|
||||
|
||||
static void specification(octetStream& os)
|
||||
{
|
||||
T::specification(os);
|
||||
}
|
||||
|
||||
static void read_or_generate_mac_key(string directory, const Player& P,
|
||||
mac_key_type& key)
|
||||
{
|
||||
@@ -150,6 +155,17 @@ public:
|
||||
return this->get_reg(i);
|
||||
}
|
||||
|
||||
void xor_bit(size_t i, const T& bit)
|
||||
{
|
||||
if (i < this->get_regs().size())
|
||||
XOR(this->get_reg(i), this->get_reg(i), bit);
|
||||
else
|
||||
{
|
||||
this->resize_regs(i + 1);
|
||||
this->get_reg(i) = bit;
|
||||
}
|
||||
}
|
||||
|
||||
void output(ostream& s, bool human = true) const
|
||||
{
|
||||
assert(this->get_regs().size() == default_length);
|
||||
@@ -179,6 +195,12 @@ public:
|
||||
{
|
||||
inputter.finalize(from, n_bits).mask(*this, n_bits);
|
||||
}
|
||||
|
||||
void random_bit()
|
||||
{
|
||||
auto& thread = GC::ShareThread<typename T::whole_type>::s();
|
||||
*this = thread.DataF.get_part().get_bit();
|
||||
}
|
||||
};
|
||||
|
||||
template<int S>
|
||||
|
||||
@@ -74,13 +74,6 @@ public:
|
||||
*this = super::constant(input, party.P->my_num(),
|
||||
party.MC->get_alphai());
|
||||
}
|
||||
|
||||
void random()
|
||||
{
|
||||
TinySecret<S> tmp;
|
||||
this->get_party().DataF.get_one(DATA_BIT, tmp);
|
||||
*this = tmp.get_reg(0);
|
||||
}
|
||||
};
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
@@ -158,8 +158,8 @@ GeneratorThread* TripleMachine::new_generator(OTTripleSetup& setup, int i,
|
||||
{
|
||||
if (output and i == 0)
|
||||
{
|
||||
T::clear::template generate_setup<T>(PREP_DIR, nplayers, 128);
|
||||
prep_data_dir = get_prep_sub_dir<T>(PREP_DIR, nplayers);
|
||||
T::clear::write_setup(prep_data_dir);
|
||||
write_mac_key(prep_data_dir, my_num, nplayers, mac_key);
|
||||
}
|
||||
|
||||
|
||||
@@ -59,6 +59,9 @@ int main(int argc, const char** argv)
|
||||
online_opts.live_prep, online_opts).run(); \
|
||||
break;
|
||||
X(64) X(128) X(256) X(192) X(384) X(512)
|
||||
#ifdef RING_SIZE
|
||||
X(RING_SIZE)
|
||||
#endif
|
||||
#undef X
|
||||
default:
|
||||
cerr << "Not compiled for " << R << "-bit rings" << endl;
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
*
|
||||
*/
|
||||
|
||||
#define NO_MIXED_CIRCUITS
|
||||
|
||||
#include "BMR/RealProgramParty.hpp"
|
||||
#include "Machines/SPDZ.hpp"
|
||||
|
||||
|
||||
29
Makefile
29
Makefile
@@ -97,16 +97,15 @@ replicated: rep-field rep-ring rep-bin
|
||||
spdz2k: spdz2k-party.x ot-offline.x Check-Offline-Z2k.x galois-degree.x Fake-Offline.x
|
||||
mascot: mascot-party.x spdz2k mama-party.x
|
||||
|
||||
tldr:
|
||||
-echo ARCH = -march=native >> CONFIG.mine
|
||||
$(MAKE) mascot-party.x
|
||||
|
||||
ifeq ($(OS), Darwin)
|
||||
tldr: mac-setup
|
||||
else
|
||||
tldr: mpir
|
||||
tldr: mpir linux-machine-setup
|
||||
endif
|
||||
|
||||
tldr:
|
||||
$(MAKE) mascot-party.x
|
||||
|
||||
ifeq ($(MACHINE), aarch64)
|
||||
tldr: simde/simde
|
||||
endif
|
||||
@@ -144,8 +143,6 @@ static-release: static-dir $(patsubst Machines/%.cpp, static/%.x, $(wildcard Mac
|
||||
Fake-ECDSA.x: ECDSA/Fake-ECDSA.cpp ECDSA/P256Element.o $(COMMON) Processor/PrepBase.o
|
||||
$(CXX) -o $@ $^ $(CFLAGS) $(LDLIBS)
|
||||
|
||||
Check-Offline.x: $(PROCESSOR)
|
||||
|
||||
ot.x: $(OT) $(COMMON) Machines/OText_main.o Machines/OTMachine.o $(LIBSIMPLEOT)
|
||||
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
|
||||
|
||||
@@ -285,11 +282,21 @@ mpir: mpir-setup
|
||||
-echo MY_CFLAGS += -I./local/include >> CONFIG.mine
|
||||
-echo MY_LDLIBS += -Wl,-rpath -Wl,$(CURDIR)/local/lib -L$(CURDIR)/local/lib >> CONFIG.mine
|
||||
|
||||
mac-setup:
|
||||
mac-setup: mac-machine-setup
|
||||
brew install openssl boost libsodium mpir yasm ntl
|
||||
-echo MY_CFLAGS += -I/usr/local/opt/openssl/include >> CONFIG.mine
|
||||
-echo MY_LDLIBS += -L/usr/local/opt/openssl/lib >> CONFIG.mine
|
||||
-echo USE_NTL = 1 >> CONFIG.mine
|
||||
-echo MY_CFLAGS += -I/usr/local/opt/openssl/include -I/opt/homebrew/opt/openssl/include -I/opt/homebrew/include >> CONFIG.mine
|
||||
-echo MY_LDLIBS += -L/usr/local/opt/openssl/lib -L/opt/homebrew/lib -L/opt/homebrew/opt/openssl/lib >> CONFIG.mine
|
||||
# -echo USE_NTL = 1 >> CONFIG.mine
|
||||
|
||||
ifeq ($(MACHINE), aarch64)
|
||||
mac-machine-setup:
|
||||
-echo ARCH = >> CONFIG.mine
|
||||
linux-machine-setup:
|
||||
-echo ARCH = -march=armv8.2-a+crypto >> CONFIG.mine
|
||||
else
|
||||
mac-machine-setup:
|
||||
linux-machine-setup:
|
||||
endif
|
||||
|
||||
simde/simde:
|
||||
git submodule update --init simde
|
||||
|
||||
@@ -49,6 +49,11 @@ public:
|
||||
return string(1, T::type_char());
|
||||
}
|
||||
|
||||
static void specification(octetStream& os)
|
||||
{
|
||||
T::specification(os);
|
||||
}
|
||||
|
||||
template<class U, class V>
|
||||
static FixedVec Mul(const FixedVec<U, L>& a, const V& b)
|
||||
{
|
||||
|
||||
@@ -35,6 +35,8 @@ public:
|
||||
static int length() { return N_BITS; }
|
||||
static string type_string() { return "integer"; }
|
||||
|
||||
static void specification(octetStream& os);
|
||||
|
||||
static void init_default(int lgp) { (void)lgp; }
|
||||
|
||||
static bool allows(Dtype type) { return type <= DATA_BIT; }
|
||||
@@ -126,6 +128,8 @@ class Integer : public IntBase<long>
|
||||
Integer(const bigint& x) { *this = (x > 0) ? x.get_ui() : -x.get_ui(); }
|
||||
template<int K>
|
||||
Integer(const Z2<K>& x) : Integer(x.get_limb(0)) {}
|
||||
template<int K>
|
||||
Integer(const SignedZ2<K>& x);
|
||||
template<int X, int L>
|
||||
Integer(const gfp_<X, L>& x);
|
||||
Integer(int128 x) : Integer(x.get_lower()) {}
|
||||
|
||||
@@ -8,6 +8,12 @@
|
||||
template<class T>
|
||||
const int IntBase<T>::N_BITS;
|
||||
|
||||
template<class T>
|
||||
inline void IntBase<T>::specification(octetStream& os)
|
||||
{
|
||||
os.store(sizeof(T));
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void IntBase<T>::output(ostream& s,bool human) const
|
||||
{
|
||||
@@ -42,6 +48,15 @@ void Integer::reqbl(int n)
|
||||
}
|
||||
}
|
||||
|
||||
template<int K>
|
||||
Integer::Integer(const SignedZ2<K>& x)
|
||||
{
|
||||
if (K < N_BITS and x.negative())
|
||||
a = -(-x).get_limb(0);
|
||||
else
|
||||
a = x.get_limb(0);
|
||||
}
|
||||
|
||||
inline
|
||||
Integer::Integer(const Integer& x, int n_bits)
|
||||
{
|
||||
|
||||
@@ -139,6 +139,14 @@ void gf2n_<U>::init_multiplication()
|
||||
}
|
||||
|
||||
|
||||
template<class U>
|
||||
void gf2n_<U>::specification(octetStream& os)
|
||||
{
|
||||
os.store(sizeof(U));
|
||||
os.store(degree());
|
||||
}
|
||||
|
||||
|
||||
/* Takes 8bit x and y and returns the 16 bit product in c1 and c0
|
||||
ans = (c1<<8)^c0
|
||||
where c1 and c0 are 8 bit
|
||||
|
||||
@@ -76,6 +76,8 @@ protected:
|
||||
static string type_short() { return "2"; }
|
||||
static string type_string() { return "gf2n_"; }
|
||||
|
||||
static void specification(octetStream& os);
|
||||
|
||||
static int size() { return sizeof(a); }
|
||||
static int size_in_bits() { return sizeof(a) * 8; }
|
||||
|
||||
|
||||
@@ -171,6 +171,29 @@ class gf2n_long : public gf2n_<int128>
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(__aarch64__) && defined(__clang__)
|
||||
inline __m128i my_slli(int128 x, int i)
|
||||
{
|
||||
if (i < 64)
|
||||
return int128(x.get_upper() << i, x.get_lower() << i).a;
|
||||
else
|
||||
return int128().a;
|
||||
}
|
||||
|
||||
inline __m128i my_srli(int128 x, int i)
|
||||
{
|
||||
if (i < 64)
|
||||
return int128(x.get_upper() >> i, x.get_lower() >> i).a;
|
||||
else
|
||||
return int128().a;
|
||||
}
|
||||
|
||||
#undef _mm_slli_epi64
|
||||
#undef _mm_srli_epi64
|
||||
#define _mm_slli_epi64 my_slli
|
||||
#define _mm_srli_epi64 my_srli
|
||||
#endif
|
||||
|
||||
inline int128 int128::operator<<(const int& other) const
|
||||
{
|
||||
int128 res(_mm_slli_epi64(a, other));
|
||||
|
||||
@@ -15,7 +15,7 @@ Zp_Data gfpvar_<X, L>::ZpD;
|
||||
template<int X, int L>
|
||||
string gfpvar_<X, L>::type_string()
|
||||
{
|
||||
return "gfpvar";
|
||||
return "gfp";
|
||||
}
|
||||
|
||||
template<int X, int L>
|
||||
@@ -30,6 +30,12 @@ char gfpvar_<X, L>::type_char()
|
||||
return 'p';
|
||||
}
|
||||
|
||||
template<int X, int L>
|
||||
void gfpvar_<X, L>::specification(octetStream& os)
|
||||
{
|
||||
os.store(pr());
|
||||
}
|
||||
|
||||
template<int X, int L>
|
||||
int gfpvar_<X, L>::length()
|
||||
{
|
||||
|
||||
@@ -44,6 +44,8 @@ public:
|
||||
static string type_short();
|
||||
static char type_char();
|
||||
|
||||
static void specification(octetStream& os);
|
||||
|
||||
static int length();
|
||||
static int size();
|
||||
static int size_in_bits();
|
||||
|
||||
@@ -14,21 +14,43 @@ void check_ssl_file(string filename)
|
||||
"You can use `Scripts/setup-ssl.sh <nparties>`.");
|
||||
}
|
||||
|
||||
void ssl_error(string side, string pronoun, string other, string server)
|
||||
void ssl_error(string side, string other, string me)
|
||||
{
|
||||
cerr << side << "-side handshake with " << other
|
||||
<< " failed. Make sure " << pronoun
|
||||
<< " have the necessary certificate (" << PREP_DIR << server
|
||||
<< ".pem in the default configuration),"
|
||||
<< " failed. Make sure both sides "
|
||||
<< " have the necessary certificate (" << PREP_DIR << me
|
||||
<< ".pem in the default configuration on their side and "
|
||||
<< PREP_DIR << other << ".pem on ours),"
|
||||
<< " and run `c_rehash <directory>` on its location." << endl
|
||||
<< "The certificates should be the same on every host. "
|
||||
<< "Also make sure that it's still valid. Certificates generated "
|
||||
<< "with `Scripts/setup-ssl.sh` expire after a month." << endl;
|
||||
cerr << "See also "
|
||||
"https://mp-spdz.readthedocs.io/en/latest/troubleshooting.html"
|
||||
"#handshake-failures" << endl;
|
||||
|
||||
string ids[2];
|
||||
ids[side == "Client"] = other;
|
||||
ids[side != "Client"] = me;
|
||||
cerr << "Signature (should match the other side): ";
|
||||
for (int i = 0; i < 2; i++)
|
||||
{
|
||||
auto filename = PREP_DIR + ids[i] + ".pem";
|
||||
ifstream cert(filename);
|
||||
stringstream buffer;
|
||||
buffer << cert.rdbuf();
|
||||
if (buffer.str().empty())
|
||||
cerr << "<'" << filename << "' not found>";
|
||||
else
|
||||
cerr << octetStream(buffer.str()).hash();
|
||||
if (i == 0)
|
||||
cerr << "/";
|
||||
}
|
||||
cerr << endl;
|
||||
}
|
||||
|
||||
CryptoPlayer::CryptoPlayer(const Names& Nms, const string& id_base) :
|
||||
MultiPlayer<ssl_socket*>(Nms), plaintext_player(Nms, id_base),
|
||||
other_player(Nms, id_base + "recv"),
|
||||
MultiPlayer<ssl_socket*>(Nms),
|
||||
ctx("P" + to_string(my_num()))
|
||||
{
|
||||
sockets.resize(num_players());
|
||||
@@ -36,6 +58,16 @@ CryptoPlayer::CryptoPlayer(const Names& Nms, const string& id_base) :
|
||||
senders.resize(num_players());
|
||||
receivers.resize(num_players());
|
||||
|
||||
vector<int> plaintext_sockets[2];
|
||||
|
||||
for (int i = 0; i < 2; i++)
|
||||
{
|
||||
PlainPlayer player(Nms, id_base + (i ? "recv" : ""));
|
||||
plaintext_sockets[i] = player.sockets;
|
||||
close_client_socket(player.socket(my_num()));
|
||||
player.sockets.clear();
|
||||
}
|
||||
|
||||
for (int i = 0; i < (int)sockets.size(); i++)
|
||||
{
|
||||
if (i == my_num())
|
||||
@@ -47,9 +79,9 @@ CryptoPlayer::CryptoPlayer(const Names& Nms, const string& id_base) :
|
||||
continue;
|
||||
}
|
||||
|
||||
sockets[i] = new ssl_socket(io_service, ctx, plaintext_player.socket(i),
|
||||
sockets[i] = new ssl_socket(io_service, ctx, plaintext_sockets[0][i],
|
||||
"P" + to_string(i), "P" + to_string(my_num()), i < my_num());
|
||||
other_sockets[i] = new ssl_socket(io_service, ctx, other_player.socket(i),
|
||||
other_sockets[i] = new ssl_socket(io_service, ctx, plaintext_sockets[1][i],
|
||||
"P" + to_string(i), "P" + to_string(my_num()), i < my_num());
|
||||
|
||||
senders[i] = new Sender<ssl_socket*>(i < my_num() ? sockets[i] : other_sockets[i]);
|
||||
@@ -64,10 +96,6 @@ CryptoPlayer::CryptoPlayer(const Names& Nms, int id_base) :
|
||||
|
||||
CryptoPlayer::~CryptoPlayer()
|
||||
{
|
||||
close_client_socket(plaintext_player.socket(my_num()));
|
||||
close_client_socket(other_player.socket(my_num()));
|
||||
plaintext_player.sockets.clear();
|
||||
other_player.sockets.clear();
|
||||
for (int i = 0; i < num_players(); i++)
|
||||
{
|
||||
delete sockets[i];
|
||||
|
||||
@@ -20,7 +20,6 @@
|
||||
*/
|
||||
class CryptoPlayer : public MultiPlayer<ssl_socket*>
|
||||
{
|
||||
PlainPlayer plaintext_player, other_player;
|
||||
ssl_ctx ctx;
|
||||
boost::asio::io_service io_service;
|
||||
|
||||
|
||||
@@ -373,6 +373,7 @@ protected:
|
||||
T send_to_self_socket;
|
||||
|
||||
T socket_to_send(int player) const { return player == player_no ? send_to_self_socket : sockets[player]; }
|
||||
T socket(int i) const { return sockets[i]; }
|
||||
|
||||
friend class CryptoPlayer;
|
||||
|
||||
@@ -381,8 +382,6 @@ public:
|
||||
|
||||
virtual ~MultiPlayer();
|
||||
|
||||
T socket(int i) const { return sockets[i]; }
|
||||
|
||||
// Send/Receive data to/from player i
|
||||
void send_long(int i, long a) const;
|
||||
long receive_long(int i) const;
|
||||
|
||||
@@ -109,7 +109,9 @@ void set_up_client_socket(int& mysocket,const char* hostname,int Portnum)
|
||||
throw runtime_error(
|
||||
string() + "cannot connect from " + my_name + " to " + hostname + ":"
|
||||
+ to_string(Portnum) + " after " + to_string(attempts)
|
||||
+ " attempts in one minute because " + strerror(connect_errno));
|
||||
+ " attempts in one minute because " + strerror(connect_errno) + ". "
|
||||
"https://mp-spdz.readthedocs.io/en/latest/troubleshooting.html#"
|
||||
"connection-failures has more information on port requirements.");
|
||||
}
|
||||
|
||||
freeaddrinfo(ai);
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
typedef boost::asio::io_service ssl_service;
|
||||
|
||||
void check_ssl_file(string filename);
|
||||
void ssl_error(string side, string pronoun, string other, string server);
|
||||
void ssl_error(string side, string other, string server);
|
||||
|
||||
class ssl_ctx : public boost::asio::ssl::context
|
||||
{
|
||||
@@ -55,7 +55,7 @@ public:
|
||||
handshake(ssl_socket::client);
|
||||
} catch (...)
|
||||
{
|
||||
ssl_error("Client", "we", other, other);
|
||||
ssl_error("Client", other, me);
|
||||
throw;
|
||||
}
|
||||
else
|
||||
@@ -65,7 +65,7 @@ public:
|
||||
handshake(ssl_socket::server);
|
||||
} catch (...)
|
||||
{
|
||||
ssl_error("Server", "they", other, me);
|
||||
ssl_error("Server", other, me);
|
||||
throw;
|
||||
}
|
||||
|
||||
|
||||
@@ -187,7 +187,13 @@ void NPartyTripleGenerator<T>::generate()
|
||||
if (thread_num != 0)
|
||||
ss << "-" << thread_num;
|
||||
if (machine.output)
|
||||
{
|
||||
outputFile.open(ss.str().c_str());
|
||||
if (machine.generateMACs or not T::clear::invertible)
|
||||
file_signature<T>().output(outputFile);
|
||||
else
|
||||
file_signature<typename T::clear>().output(outputFile);
|
||||
}
|
||||
|
||||
if (machine.generateBits)
|
||||
generateBits();
|
||||
|
||||
@@ -89,6 +89,13 @@ DataPositions DataPositions::operator-(const DataPositions& other) const
|
||||
return res;
|
||||
}
|
||||
|
||||
DataPositions DataPositions::operator+(const DataPositions& other) const
|
||||
{
|
||||
DataPositions res = *this;
|
||||
res.increase(other);
|
||||
return res;
|
||||
}
|
||||
|
||||
void DataPositions::print_cost() const
|
||||
{
|
||||
ifstream file("cost");
|
||||
|
||||
@@ -19,6 +19,11 @@ using namespace std;
|
||||
|
||||
template<class T> class dabit;
|
||||
|
||||
namespace GC
|
||||
{
|
||||
template<class T> class ShareThread;
|
||||
}
|
||||
|
||||
class DataTag
|
||||
{
|
||||
int t[4];
|
||||
@@ -74,6 +79,7 @@ public:
|
||||
void increase(const DataPositions& delta);
|
||||
DataPositions& operator-=(const DataPositions& delta);
|
||||
DataPositions operator-(const DataPositions& delta) const;
|
||||
DataPositions operator+(const DataPositions& delta) const;
|
||||
void print_cost() const;
|
||||
bool empty() const;
|
||||
bool any_more(const DataPositions& other) const;
|
||||
@@ -84,10 +90,15 @@ template<class sint, class sgf2n> class Data_Files;
|
||||
template<class sint, class sgf2n> class Machine;
|
||||
template<class T> class SubProcessor;
|
||||
|
||||
/**
|
||||
* Abstract base class for preprocessing
|
||||
*/
|
||||
template<class T>
|
||||
class Preprocessing : public PrepBase
|
||||
{
|
||||
protected:
|
||||
static const bool use_part = false;
|
||||
|
||||
DataPositions& usage;
|
||||
|
||||
map<pair<bool, int>, vector<edabitvec<T>>> edabits;
|
||||
@@ -114,6 +125,8 @@ public:
|
||||
template<class U, class V>
|
||||
static Preprocessing<T>* get_new(Machine<U, V>& machine, DataPositions& usage,
|
||||
SubProcessor<T>* proc);
|
||||
static Preprocessing<T>* get_new(bool live_prep, const Names& N,
|
||||
DataPositions& usage);
|
||||
static Preprocessing<T>* get_live_prep(SubProcessor<T>* proc,
|
||||
DataPositions& usage);
|
||||
|
||||
@@ -144,11 +157,15 @@ public:
|
||||
void get_input(T& a, typename T::open_type& x, int i);
|
||||
void get(vector<T>& S, DataTag tag, const vector<int>& regs, int vector_size);
|
||||
|
||||
/// Get fresh random multiplication triple
|
||||
virtual array<T, 3> get_triple(int n_bits);
|
||||
virtual array<T, 3> get_triple_no_count(int n_bits);
|
||||
/// Get fresh random bit
|
||||
virtual T get_bit();
|
||||
/// Get fresh random value in domain
|
||||
virtual T get_random();
|
||||
virtual void get_dabit(T&, typename T::bit_type&);
|
||||
/// Store fresh daBit in ``a`` (arithmetic part) and ``b`` (binary part)
|
||||
virtual void get_dabit(T& a, typename T::bit_type& b);
|
||||
virtual void get_dabit_no_count(T&, typename T::bit_type&) { throw runtime_error("no daBit"); }
|
||||
virtual void get_edabits(bool strict, size_t size, T* a,
|
||||
vector<typename T::bit_type>& Sb, const vector<int>& regs)
|
||||
@@ -156,6 +173,7 @@ public:
|
||||
template<int>
|
||||
void get_edabit_no_count(bool, int n_bits, edabit<T>& eb);
|
||||
template<int>
|
||||
/// Get fresh edaBit chunk
|
||||
edabitvec<T> get_edabitvec(bool strict, int n_bits);
|
||||
virtual void buffer_edabits_with_queues(bool, int) { throw runtime_error("no edaBits"); }
|
||||
|
||||
@@ -270,13 +288,14 @@ class Data_Files
|
||||
|
||||
Preprocessing<sint>& DataFp;
|
||||
Preprocessing<sgf2n>& DataF2;
|
||||
Preprocessing<typename sint::bit_type>& DataFb;
|
||||
|
||||
Data_Files(Machine<sint, sgf2n>& machine, SubProcessor<sint>* procp = 0,
|
||||
SubProcessor<sgf2n>* proc2 = 0);
|
||||
Data_Files(const Names& N);
|
||||
~Data_Files();
|
||||
|
||||
DataPositions tellg();
|
||||
DataPositions tellg() { return usage; }
|
||||
void seekg(DataPositions& pos);
|
||||
void skip(const DataPositions& pos);
|
||||
void prune();
|
||||
@@ -289,7 +308,7 @@ class Data_Files
|
||||
|
||||
void reset_usage() { usage.reset(); skipped.reset(); }
|
||||
|
||||
NamedCommStats comm_stats() { return DataFp.comm_stats() + DataF2.comm_stats(); }
|
||||
NamedCommStats comm_stats();
|
||||
};
|
||||
|
||||
template<class T> inline
|
||||
@@ -407,6 +426,13 @@ inline void Data_Files<sint, sgf2n>::purge()
|
||||
{
|
||||
DataFp.purge();
|
||||
DataF2.purge();
|
||||
DataFb.purge();
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
NamedCommStats Data_Files<sint, sgf2n>::comm_stats()
|
||||
{
|
||||
return DataFp.comm_stats() + DataF2.comm_stats() + DataFb.comm_stats();
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "Processor/Processor.h"
|
||||
#include "Protocols/dabit.h"
|
||||
#include "Math/Setup.h"
|
||||
#include "GC/BitPrepFiles.h"
|
||||
|
||||
#include "Protocols/MascotPrep.hpp"
|
||||
|
||||
@@ -28,6 +29,19 @@ Preprocessing<T>* Preprocessing<T>::get_new(
|
||||
machine.template prep_dir_prefix<T>(), usage);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
Preprocessing<T>* Preprocessing<T>::get_new(
|
||||
bool live_prep, const Names& N,
|
||||
DataPositions& usage)
|
||||
{
|
||||
if (live_prep)
|
||||
return new typename T::LivePrep(usage);
|
||||
else
|
||||
return new GC::BitPrepFiles<T>(N,
|
||||
get_prep_sub_dir<T>(PREP_DIR, N.num_players()), usage,
|
||||
BaseMachine::thread_num);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
Sub_Data_Files<T>::Sub_Data_Files(const Names& N, DataPositions& usage,
|
||||
int thread_num) :
|
||||
@@ -96,7 +110,7 @@ Sub_Data_Files<T>::Sub_Data_Files(int my_num, int num_players,
|
||||
|
||||
dabit_buffer.setup(
|
||||
PrepBase::get_filename(prep_data_dir, DATA_DABIT,
|
||||
type_short, my_num, thread_num), 1, type_string,
|
||||
type_short, my_num, thread_num), dabit<T>::size(), type_string,
|
||||
DataPositions::dtype_names[DATA_DABIT]);
|
||||
|
||||
input_buffers.resize(num_players);
|
||||
@@ -106,7 +120,7 @@ Sub_Data_Files<T>::Sub_Data_Files(int my_num, int num_players,
|
||||
type_short, i, my_num, thread_num);
|
||||
if (i == my_num)
|
||||
my_input_buffers.setup(filename,
|
||||
T::size() * 3 / 2, type_string);
|
||||
T::size() + T::clear::size(), type_string);
|
||||
else
|
||||
input_buffers[i].setup(filename,
|
||||
T::size(), type_string);
|
||||
@@ -122,7 +136,10 @@ Data_Files<sint, sgf2n>::Data_Files(Machine<sint, sgf2n>& machine, SubProcessor<
|
||||
SubProcessor<sgf2n>* proc2) :
|
||||
usage(machine.get_N().num_players()),
|
||||
DataFp(*Preprocessing<sint>::get_new(machine, usage, procp)),
|
||||
DataF2(*Preprocessing<sgf2n>::get_new(machine, usage, proc2))
|
||||
DataF2(*Preprocessing<sgf2n>::get_new(machine, usage, proc2)),
|
||||
DataFb(
|
||||
*Preprocessing<typename sint::bit_type>::get_new(machine.live_prep,
|
||||
machine.get_N(), usage))
|
||||
{
|
||||
}
|
||||
|
||||
@@ -130,7 +147,8 @@ template<class sint, class sgf2n>
|
||||
Data_Files<sint, sgf2n>::Data_Files(const Names& N) :
|
||||
usage(N.num_players()),
|
||||
DataFp(*new Sub_Data_Files<sint>(N, usage)),
|
||||
DataF2(*new Sub_Data_Files<sgf2n>(N, usage))
|
||||
DataF2(*new Sub_Data_Files<sgf2n>(N, usage)),
|
||||
DataFb(*new Sub_Data_Files<typename sint::bit_type>(N, usage))
|
||||
{
|
||||
}
|
||||
|
||||
@@ -150,6 +168,7 @@ Data_Files<sint, sgf2n>::~Data_Files()
|
||||
DataF2.data_sent() * 1e-6 << " MB" << endl;
|
||||
#endif
|
||||
delete &DataF2;
|
||||
delete &DataFb;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -166,6 +185,12 @@ Sub_Data_Files<T>::~Sub_Data_Files()
|
||||
template<class T>
|
||||
void Sub_Data_Files<T>::seekg(DataPositions& pos)
|
||||
{
|
||||
if (T::LivePrep::use_part)
|
||||
{
|
||||
get_part().seekg(pos);
|
||||
return;
|
||||
}
|
||||
|
||||
DataFieldType field_type = T::clear::field_type();
|
||||
for (int dtype = 0; dtype < N_DTYPE; dtype++)
|
||||
if (T::clear::allows(Dtype(dtype)))
|
||||
@@ -181,6 +206,7 @@ void Sub_Data_Files<T>::seekg(DataPositions& pos)
|
||||
setup_extended(it->first);
|
||||
extended[it->first].seekg(it->second);
|
||||
}
|
||||
dabit_buffer.seekg(pos.files[field_type][DATA_DABIT]);
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
@@ -188,6 +214,7 @@ void Data_Files<sint, sgf2n>::seekg(DataPositions& pos)
|
||||
{
|
||||
DataFp.seekg(pos);
|
||||
DataF2.seekg(pos);
|
||||
DataFb.seekg(pos);
|
||||
usage = pos;
|
||||
}
|
||||
|
||||
@@ -210,6 +237,9 @@ void Sub_Data_Files<T>::prune()
|
||||
input_buffers[j].prune();
|
||||
for (auto it : extended)
|
||||
it.second.prune();
|
||||
dabit_buffer.prune();
|
||||
if (part != 0)
|
||||
part->prune();
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
@@ -217,6 +247,7 @@ void Data_Files<sint, sgf2n>::prune()
|
||||
{
|
||||
DataFp.prune();
|
||||
DataF2.prune();
|
||||
DataFb.prune();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -229,6 +260,7 @@ void Sub_Data_Files<T>::purge()
|
||||
input_buffers[j].purge();
|
||||
for (auto it : extended)
|
||||
it.second.purge();
|
||||
dabit_buffer.purge();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -280,11 +312,12 @@ void Sub_Data_Files<T>::buffer_edabits_with_queues(bool strict, int n_bits,
|
||||
ifstream* f = new ifstream(filename);
|
||||
if (f->fail())
|
||||
throw runtime_error("cannot open " + filename);
|
||||
check_file_signature<T>(*f, filename);
|
||||
edabit_buffers[n_bits] = f;
|
||||
}
|
||||
auto& buffer = *edabit_buffers[n_bits];
|
||||
if (buffer.peek() == EOF)
|
||||
buffer.seekg(0);
|
||||
buffer.seekg(file_signature<T>().get_length());
|
||||
edabitvec<T> eb;
|
||||
eb.input(n_bits, buffer);
|
||||
this->edabits[{strict, n_bits}].push_back(eb);
|
||||
|
||||
@@ -15,6 +15,9 @@ using namespace std;
|
||||
|
||||
class ArithmeticProcessor;
|
||||
|
||||
/**
|
||||
* Abstract base for input protocols
|
||||
*/
|
||||
template<class T>
|
||||
class InputBase
|
||||
{
|
||||
@@ -45,18 +48,28 @@ public:
|
||||
InputBase(SubProcessor<T>* proc);
|
||||
virtual ~InputBase();
|
||||
|
||||
/// Initialize input round for ``player``
|
||||
virtual void reset(int player) = 0;
|
||||
/// Initialize input round for all players
|
||||
void reset_all(Player& P);
|
||||
|
||||
/// Schedule input from me
|
||||
virtual void add_mine(const typename T::open_type& input, int n_bits = -1) = 0;
|
||||
/// Schedule input from other player
|
||||
virtual void add_other(int player, int n_bits = -1) = 0;
|
||||
/// Schedule input from all players
|
||||
void add_from_all(const clear& input);
|
||||
|
||||
/// Send my inputs
|
||||
virtual void send_mine() = 0;
|
||||
/// Run input protocol for all players
|
||||
virtual void exchange();
|
||||
|
||||
/// Get share for next input of mine
|
||||
virtual T finalize_mine() = 0;
|
||||
/// Store share for next input from ``player`` from buffer ``o`` in ``target``
|
||||
virtual void finalize_other(int player, T& target, octetStream& o, int n_bits = -1) = 0;
|
||||
/// Get share for next input from ``player`
|
||||
virtual T finalize(int player, int n_bits = -1);
|
||||
|
||||
void raw_input(SubProcessor<T>& proc, const vector<int>& args, int size);
|
||||
|
||||
@@ -19,6 +19,11 @@ struct InputTuple
|
||||
static string type_string()
|
||||
{ return T::type_string(); }
|
||||
|
||||
static void specification(octetStream& os)
|
||||
{
|
||||
T::specification(os);
|
||||
}
|
||||
|
||||
InputTuple() {}
|
||||
|
||||
InputTuple(const T& share, const typename T::open_type& value) : share(share), value(value) {}
|
||||
|
||||
@@ -14,6 +14,7 @@ using namespace std;
|
||||
template<class sint, class sgf2n> class Machine;
|
||||
template<class sint, class sgf2n> class Processor;
|
||||
class ArithmeticProcessor;
|
||||
class SwitchableOutput;
|
||||
|
||||
/*
|
||||
* Opcode constants
|
||||
@@ -306,12 +307,6 @@ enum RegType {
|
||||
MAX_REG_TYPE,
|
||||
};
|
||||
|
||||
enum SecrecyType {
|
||||
SECRET,
|
||||
CLEAR,
|
||||
MAX_SECRECY_TYPE
|
||||
};
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
struct TempVars {
|
||||
typename sgf2n::clear ans2;
|
||||
@@ -387,6 +382,10 @@ public:
|
||||
|
||||
void shuffle(ArithmeticProcessor& Proc) const;
|
||||
void bitdecint(ArithmeticProcessor& Proc) const;
|
||||
|
||||
template<class T>
|
||||
void print(SwitchableOutput& out, T* v, T* p = 0, T* s = 0, T* z = 0,
|
||||
T* nan = 0) const;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
@@ -328,6 +328,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
|
||||
// write to external client, input is : opcode num_args, client_id, message_type, var1, var2 ...
|
||||
case WRITESOCKETC:
|
||||
case WRITESOCKETS:
|
||||
case WRITESOCKETSHARE:
|
||||
case WRITESOCKETINT:
|
||||
num_var_args = get_int(s) - 3;
|
||||
@@ -336,8 +337,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
n = get_int(s);
|
||||
get_vector(num_var_args, start, s);
|
||||
break;
|
||||
case WRITESOCKETS:
|
||||
throw runtime_error("sending MACs to client not supported any more");
|
||||
case READCLIENTPUBLICKEY:
|
||||
case INITSECURESOCKET:
|
||||
case RESPSECURESOCKET:
|
||||
@@ -1070,31 +1069,19 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
}
|
||||
break;
|
||||
case PRINTREGPLAIN:
|
||||
{
|
||||
Proc.out << Proc.read_Cp(r[0]) << flush;
|
||||
}
|
||||
break;
|
||||
print(Proc.out, &Proc.read_Cp(r[0]));
|
||||
return;
|
||||
case CONDPRINTPLAIN:
|
||||
if (not Proc.read_Cp(r[0]).is_zero())
|
||||
{
|
||||
auto v = Proc.read_Cp(r[1]);
|
||||
auto p = Proc.read_Cp(r[2]);
|
||||
if (p.is_zero())
|
||||
Proc.out << v << flush;
|
||||
else
|
||||
Proc.out << bigint::get_float(v, p, {}, {}) << flush;
|
||||
print(Proc.out, &Proc.read_Cp(r[1]), &Proc.read_Cp(r[2]));
|
||||
}
|
||||
break;
|
||||
return;
|
||||
case PRINTFLOATPLAIN:
|
||||
{
|
||||
auto nan = Proc.read_Cp(start[4]);
|
||||
typename sint::clear v = Proc.read_Cp(start[0]);
|
||||
typename sint::clear p = Proc.read_Cp(start[1]);
|
||||
typename sint::clear z = Proc.read_Cp(start[2]);
|
||||
typename sint::clear s = Proc.read_Cp(start[3]);
|
||||
bigint::output_float(Proc.out, bigint::get_float(v, p, z, s), nan);
|
||||
}
|
||||
break;
|
||||
print(Proc.out, &Proc.read_Cp(start[0]), &Proc.read_Cp(start[1]),
|
||||
&Proc.read_Cp(start[2]), &Proc.read_Cp(start[3]),
|
||||
&Proc.read_Cp(start[4]));
|
||||
return;
|
||||
case CONDPRINTSTR:
|
||||
if (not Proc.read_Cp(r[0]).is_zero())
|
||||
{
|
||||
@@ -1124,9 +1111,7 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
Proc.machine.stop(n);
|
||||
break;
|
||||
case RUN_TAPE:
|
||||
Proc.DataF.skip(
|
||||
Proc.machine.run_tapes(start, &Proc.DataF.DataFp,
|
||||
&Proc.share_thread.DataF));
|
||||
Proc.machine.run_tapes(start, Proc.DataF);
|
||||
break;
|
||||
case JOIN_TAPE:
|
||||
Proc.machine.join_tape(r[0]);
|
||||
@@ -1186,15 +1171,19 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
Proc.read_socket_private(Proc.read_Ci(r[0]), start, n, true);
|
||||
break;
|
||||
case WRITESOCKETINT:
|
||||
Proc.write_socket(INT, Proc.read_Ci(r[0]), r[1], start, n);
|
||||
Proc.write_socket(INT, false, Proc.read_Ci(r[0]), r[1], start, n);
|
||||
break;
|
||||
case WRITESOCKETC:
|
||||
Proc.write_socket(CINT, Proc.read_Ci(r[0]), r[1], start, n);
|
||||
Proc.write_socket(CINT, false, Proc.read_Ci(r[0]), r[1], start, n);
|
||||
break;
|
||||
case WRITESOCKETS:
|
||||
// Send shares + MACs
|
||||
Proc.write_socket(SINT, true, Proc.read_Ci(r[0]), r[1], start, n);
|
||||
break;
|
||||
case WRITESOCKETSHARE:
|
||||
// Send only shares, no MACs
|
||||
// N.B. doesn't make sense to have a corresponding read instruction for this
|
||||
Proc.write_socket(SINT, Proc.read_Ci(r[0]), r[1], start, n);
|
||||
Proc.write_socket(SINT, false, Proc.read_Ci(r[0]), r[1], start, n);
|
||||
break;
|
||||
case WRITEFILESHARE:
|
||||
// Write shares to file system
|
||||
@@ -1323,4 +1312,29 @@ void Program::execute(Processor<sint, sgf2n>& Proc) const
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Instruction::print(SwitchableOutput& out, T* v, T* p, T* s, T* z, T* nan) const
|
||||
{
|
||||
if (size > 1)
|
||||
out << "[";
|
||||
for (int i = 0; i < size; i++)
|
||||
{
|
||||
if (p == 0)
|
||||
out << v[i];
|
||||
else if (s == 0)
|
||||
out << bigint::get_float(v[i], p[i], {}, {});
|
||||
else
|
||||
{
|
||||
assert(z != 0);
|
||||
assert(nan != 0);
|
||||
bigint::output_float(out, bigint::get_float(v[i], p[i], s[i], z[i]),
|
||||
nan[i]);
|
||||
}
|
||||
if (i < size - 1)
|
||||
out << ", ";
|
||||
}
|
||||
if (size > 1)
|
||||
out << "]";
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -42,9 +42,6 @@ class Machine : public BaseMachine
|
||||
typename sgf2n::mac_key_type alpha2i;
|
||||
typename sint::bit_type::mac_key_type alphabi;
|
||||
|
||||
// Keep record of used offline data
|
||||
DataPositions pos;
|
||||
|
||||
Player* P;
|
||||
|
||||
void load_program(const string& threadname, const string& filename);
|
||||
@@ -83,8 +80,8 @@ class Machine : public BaseMachine
|
||||
|
||||
const Names& get_N() { return N; }
|
||||
|
||||
DataPositions run_tapes(const vector<int> &args, Preprocessing<sint> *prep,
|
||||
Preprocessing<typename sint::bit_type> *bit_prep);
|
||||
DataPositions run_tapes(const vector<int> &args,
|
||||
Data_Files<sint, sgf2n>& DataF);
|
||||
void fill_buffers(int thread_number, int tape_number,
|
||||
Preprocessing<sint> *prep,
|
||||
Preprocessing<typename sint::bit_type> *bit_prep);
|
||||
@@ -93,7 +90,8 @@ class Machine : public BaseMachine
|
||||
Preprocessing<sint> *prep, true_type);
|
||||
template<int = 0>
|
||||
void fill_matmul(int, int, Preprocessing<sint>*, false_type) {}
|
||||
DataPositions run_tape(int thread_number, int tape_number, int arg);
|
||||
DataPositions run_tape(int thread_number, int tape_number, int arg,
|
||||
const DataPositions& pos);
|
||||
DataPositions join_tape(int thread_number);
|
||||
void run();
|
||||
|
||||
|
||||
@@ -92,9 +92,6 @@ Machine<sint, sgf2n>::Machine(int my_number, Names& playerNames,
|
||||
exit(1);
|
||||
}
|
||||
|
||||
// Keep record of used offline data
|
||||
pos.set_num_players(N.num_players());
|
||||
|
||||
load_schedule(progname_str);
|
||||
|
||||
// remove persistence if necessary
|
||||
@@ -161,14 +158,16 @@ void Machine<sint, sgf2n>::load_program(const string& threadname,
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
DataPositions Machine<sint, sgf2n>::run_tapes(const vector<int>& args,
|
||||
Preprocessing<sint>* prep, Preprocessing<typename sint::bit_type>* bit_prep)
|
||||
Data_Files<sint, sgf2n>& DataF)
|
||||
{
|
||||
assert(args.size() % 3 == 0);
|
||||
for (unsigned i = 0; i < args.size(); i += 3)
|
||||
fill_buffers(args[i], args[i + 1], prep, bit_prep);
|
||||
fill_buffers(args[i], args[i + 1], &DataF.DataFp, &DataF.DataFb);
|
||||
DataPositions res(N.num_players());
|
||||
for (unsigned i = 0; i < args.size(); i += 3)
|
||||
res.increase(run_tape(args[i], args[i + 1], args[i + 2]));
|
||||
res.increase(
|
||||
run_tape(args[i], args[i + 1], args[i + 2], DataF.tellg() + res));
|
||||
DataF.skip(res);
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -281,7 +280,7 @@ void Machine<sint, sgf2n>::fill_matmul(int thread_number, int tape_number,
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
DataPositions Machine<sint, sgf2n>::run_tape(int thread_number, int tape_number,
|
||||
int arg)
|
||||
int arg, const DataPositions& pos)
|
||||
{
|
||||
if (size_t(thread_number) >= tinfo.size())
|
||||
throw overflow("invalid thread number", thread_number, tinfo.size());
|
||||
@@ -294,7 +293,7 @@ DataPositions Machine<sint, sgf2n>::run_tape(int thread_number, int tape_number,
|
||||
if (progs[tape_number].usage_unknown())
|
||||
{
|
||||
#ifndef INSECURE
|
||||
if (not opts.live_prep)
|
||||
if (not opts.live_prep and thread_number != 0)
|
||||
{
|
||||
cerr << "Internally called tape " << tape_number <<
|
||||
" has unknown offline data usage" << endl;
|
||||
@@ -328,7 +327,7 @@ void Machine<sint, sgf2n>::run()
|
||||
timer[0].start();
|
||||
|
||||
// run main tape
|
||||
pos.increase(run_tape(0, 0, 0));
|
||||
run_tape(0, 0, 0, N.num_players());
|
||||
join_tape(0);
|
||||
|
||||
print_compiler();
|
||||
@@ -341,8 +340,8 @@ void Machine<sint, sgf2n>::run()
|
||||
queues[i]->schedule(-1);
|
||||
}
|
||||
|
||||
// reset to sum actual usage
|
||||
pos.reset();
|
||||
// sum actual usage
|
||||
DataPositions pos(N.num_players());
|
||||
|
||||
#ifdef DEBUG_THREADS
|
||||
cerr << "Waiting for all clients to finish" << endl;
|
||||
|
||||
@@ -15,22 +15,29 @@ template<class T> istream& operator>>(istream& s,Memory<T>& M);
|
||||
#include "Processor/Program.h"
|
||||
#include "Tools/CheckVector.h"
|
||||
|
||||
template<class T>
|
||||
class MemoryPart : public CheckVector<T>
|
||||
{
|
||||
public:
|
||||
void minimum_size(size_t size);
|
||||
};
|
||||
|
||||
template<class T>
|
||||
class Memory
|
||||
{
|
||||
public:
|
||||
|
||||
CheckVector<T> MS;
|
||||
CheckVector<typename T::clear> MC;
|
||||
MemoryPart<T> MS;
|
||||
MemoryPart<typename T::clear> MC;
|
||||
|
||||
void resize_s(int sz)
|
||||
void resize_s(size_t sz)
|
||||
{ MS.resize(sz); }
|
||||
void resize_c(int sz)
|
||||
void resize_c(size_t sz)
|
||||
{ MC.resize(sz); }
|
||||
|
||||
unsigned size_s()
|
||||
size_t size_s()
|
||||
{ return MS.size(); }
|
||||
unsigned size_c()
|
||||
size_t size_c()
|
||||
{ return MC.size(); }
|
||||
|
||||
template<class U>
|
||||
@@ -40,23 +47,23 @@ class Memory
|
||||
throw overflow("memory", i, M.size());
|
||||
}
|
||||
|
||||
const typename T::clear& read_C(int i) const
|
||||
const typename T::clear& read_C(size_t i) const
|
||||
{
|
||||
check_index(MC, i);
|
||||
return MC[i];
|
||||
}
|
||||
const T& read_S(int i) const
|
||||
const T& read_S(size_t i) const
|
||||
{
|
||||
check_index(MS, i);
|
||||
return MS[i];
|
||||
}
|
||||
|
||||
void write_C(unsigned int i,const typename T::clear& x)
|
||||
void write_C(size_t i,const typename T::clear& x)
|
||||
{
|
||||
check_index(MC, i);
|
||||
MC[i]=x;
|
||||
}
|
||||
void write_S(unsigned int i,const T& x)
|
||||
void write_S(size_t i,const T& x)
|
||||
{
|
||||
check_index(MS, i);
|
||||
MS[i]=x;
|
||||
|
||||
@@ -8,27 +8,23 @@ void Memory<T>::minimum_size(RegType secret_type, RegType clear_type,
|
||||
const Program &program, const string& threadname)
|
||||
{
|
||||
(void) threadname;
|
||||
unsigned sizes[MAX_SECRECY_TYPE];
|
||||
sizes[SECRET]= program.direct_mem(secret_type);
|
||||
sizes[CLEAR] = program.direct_mem(clear_type);
|
||||
if (sizes[SECRET] > size_s())
|
||||
{
|
||||
#ifdef DEBUG_MEMORY
|
||||
cerr << threadname << " needs more secret " << T::type_string() << " memory, resizing to "
|
||||
<< sizes[SECRET] << endl;
|
||||
#endif
|
||||
resize_s(sizes[SECRET]);
|
||||
}
|
||||
if (sizes[CLEAR] > size_c())
|
||||
{
|
||||
#ifdef DEBUG_MEMORY
|
||||
cerr << threadname << " needs more clear " << T::type_string() << " memory, resizing to "
|
||||
<< sizes[CLEAR] << endl;
|
||||
#endif
|
||||
resize_c(sizes[CLEAR]);
|
||||
}
|
||||
MS.minimum_size(program.direct_mem(secret_type));
|
||||
MC.minimum_size(program.direct_mem(clear_type));
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void MemoryPart<T>::minimum_size(size_t size)
|
||||
{
|
||||
try
|
||||
{
|
||||
if (size > this->size())
|
||||
this->resize(size);
|
||||
}
|
||||
catch (bad_alloc&)
|
||||
{
|
||||
throw insufficient_memory(size, T::type_string());
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
ostream& operator<<(ostream& s,const Memory<T>& M)
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
#include "OfflineMachine.h"
|
||||
#include "Protocols/mac_key.hpp"
|
||||
#include "Tools/Buffer.h"
|
||||
|
||||
template<class W>
|
||||
template<class V>
|
||||
@@ -39,8 +40,8 @@ int OfflineMachine<W>::run()
|
||||
T::bit_type::mac_key_type::init_field();
|
||||
auto binary_mac_key = read_generate_write_mac_key<
|
||||
typename T::bit_type::part_type>(P);
|
||||
GC::ShareThread<typename T::bit_type> thread(playerNames,
|
||||
OnlineOptions::singleton, P, binary_mac_key, usage);
|
||||
typename T::bit_type::LivePrep bit_prep(usage);
|
||||
GC::ShareThread<typename T::bit_type> thread(bit_prep, P, binary_mac_key);
|
||||
|
||||
generate<T>();
|
||||
generate<typename T::bit_type::part_type>();
|
||||
@@ -74,6 +75,7 @@ void OfflineMachine<W>::generate()
|
||||
if (my_usage > 0)
|
||||
{
|
||||
ofstream out(filename, iostream::out | iostream::binary);
|
||||
file_signature<T>().output(out);
|
||||
if (i == DATA_DABIT)
|
||||
{
|
||||
for (long long j = 0;
|
||||
@@ -108,6 +110,7 @@ void OfflineMachine<W>::generate()
|
||||
if (n_inputs > 0)
|
||||
{
|
||||
ofstream out(filename, iostream::out | iostream::binary);
|
||||
file_signature<T>().output(out);
|
||||
InputTuple<T> tuple;
|
||||
for (long long j = 0;
|
||||
j < DIV_CEIL(n_inputs, BUFFER_SIZE) * BUFFER_SIZE; j++)
|
||||
@@ -138,6 +141,7 @@ void OfflineMachine<W>::generate()
|
||||
if (total > 0)
|
||||
{
|
||||
ofstream out(filename, ios::binary);
|
||||
file_signature<T>().output(out);
|
||||
for (int i = 0; i < DIV_CEIL(total, batch) * batch; i++)
|
||||
preprocessing.template get_edabitvec<0>(true, n_bits).output(n_bits,
|
||||
out);
|
||||
|
||||
@@ -26,7 +26,7 @@ class thread_info
|
||||
|
||||
static void* Main_Func(void *ptr);
|
||||
|
||||
static void purge_preprocessing(const Names& N);
|
||||
static void purge_preprocessing(const Names& N, int thread_num);
|
||||
|
||||
template<class T>
|
||||
static void print_usage(ostream& o, const vector<T>& regs,
|
||||
|
||||
@@ -352,7 +352,7 @@ void* thread_info<sint, sgf2n>::Main_Func(void* ptr)
|
||||
catch (...)
|
||||
{
|
||||
thread_info<sint, sgf2n>* ti = (thread_info<sint, sgf2n>*)ptr;
|
||||
ti->purge_preprocessing(ti->machine->get_N());
|
||||
ti->purge_preprocessing(ti->machine->get_N(), ti->thread_num);
|
||||
throw;
|
||||
}
|
||||
#endif
|
||||
@@ -361,13 +361,17 @@ void* thread_info<sint, sgf2n>::Main_Func(void* ptr)
|
||||
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
void thread_info<sint, sgf2n>::purge_preprocessing(const Names& N)
|
||||
void thread_info<sint, sgf2n>::purge_preprocessing(const Names& N, int thread_num)
|
||||
{
|
||||
cerr << "Purging preprocessed data because something is wrong" << endl;
|
||||
try
|
||||
{
|
||||
Data_Files<sint, sgf2n> df(N);
|
||||
df.purge();
|
||||
DataPositions pos;
|
||||
Sub_Data_Files<typename sint::bit_type> bit_df(N, pos, thread_num);
|
||||
bit_df.get_part();
|
||||
bit_df.purge();
|
||||
}
|
||||
catch(...)
|
||||
{
|
||||
|
||||
@@ -249,7 +249,7 @@ int OnlineMachine::run()
|
||||
catch(...)
|
||||
{
|
||||
if (not online_opts.live_prep)
|
||||
thread_info<T, U>::purge_preprocessing(playerNames);
|
||||
thread_info<T, U>::purge_preprocessing(playerNames, 0);
|
||||
throw;
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -36,13 +36,9 @@ OnlineOptions::OnlineOptions() : playerno(-1)
|
||||
}
|
||||
|
||||
OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
const char** argv, int default_batch_size, bool default_live_prep,
|
||||
bool variable_prime_length) :
|
||||
const char** argv, false_type) :
|
||||
OnlineOptions()
|
||||
{
|
||||
if (default_batch_size <= 0)
|
||||
default_batch_size = batch_size;
|
||||
|
||||
opt.syntax = std::string(argv[0]) + " [OPTIONS] [<playerno>] <progname>";
|
||||
|
||||
opt.add(
|
||||
@@ -78,6 +74,58 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
"--output-file" // Flag token.
|
||||
);
|
||||
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"This player's number (required if not given before program name)", // Help description.
|
||||
"-p", // Flag token.
|
||||
"--player" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Verbose output", // Help description.
|
||||
"-v", // Flag token.
|
||||
"--verbose" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"4", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Batch size for sacrifice (3-5, default: 4)", // Help description.
|
||||
"-B", // Flag token.
|
||||
"--bucket-size" // Flag token.
|
||||
);
|
||||
|
||||
opt.parse(argc, argv);
|
||||
|
||||
interactive = opt.isSet("-I");
|
||||
|
||||
opt.get("-IF")->getString(cmd_private_input_file);
|
||||
opt.get("-OF")->getString(cmd_private_output_file);
|
||||
|
||||
opt.get("--bucket-size")->getInt(bucket_size);
|
||||
|
||||
#ifndef VERBOSE
|
||||
verbose = opt.isSet("--verbose");
|
||||
#endif
|
||||
|
||||
opt.resetArgs();
|
||||
}
|
||||
|
||||
OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
const char** argv, int default_batch_size, bool default_live_prep,
|
||||
bool variable_prime_length) :
|
||||
OnlineOptions(opt, argc, argv, false_type())
|
||||
{
|
||||
if (default_batch_size <= 0)
|
||||
default_batch_size = batch_size;
|
||||
|
||||
string default_lgp = to_string(lgp);
|
||||
if (variable_prime_length)
|
||||
{
|
||||
@@ -121,15 +169,6 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
"-L", // Flag token.
|
||||
"--live-preprocessing" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"This player's number (required if not given before program name)", // Help description.
|
||||
"-p", // Flag token.
|
||||
"--player" // Flag token.
|
||||
);
|
||||
|
||||
opt.add(
|
||||
to_string(default_batch_size).c_str(), // Default.
|
||||
@@ -170,28 +209,9 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
"-d", // Flag token.
|
||||
"--direct" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"4", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Batch size for sacrifice (3-5, default: 4)", // Help description.
|
||||
"-B", // Flag token.
|
||||
"--bucket-size" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Verbose output", // Help description.
|
||||
"-v", // Flag token.
|
||||
"--verbose" // Flag token.
|
||||
);
|
||||
|
||||
opt.parse(argc, argv);
|
||||
|
||||
interactive = opt.isSet("-I");
|
||||
if (variable_prime_length)
|
||||
{
|
||||
opt.get("--lgp")->getInt(lgp);
|
||||
@@ -208,17 +228,8 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
|
||||
opt.get("--memory")->getString(memtype);
|
||||
bits_from_squares = opt.isSet("-Q");
|
||||
|
||||
opt.get("-IF")->getString(cmd_private_input_file);
|
||||
opt.get("-OF")->getString(cmd_private_output_file);
|
||||
|
||||
direct = opt.isSet("--direct");
|
||||
|
||||
opt.get("--bucket-size")->getInt(bucket_size);
|
||||
|
||||
#ifndef VERBOSE
|
||||
verbose = opt.isSet("--verbose");
|
||||
#endif
|
||||
|
||||
opt.resetArgs();
|
||||
}
|
||||
|
||||
|
||||
@@ -31,6 +31,8 @@ public:
|
||||
bool verbose;
|
||||
|
||||
OnlineOptions();
|
||||
OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv,
|
||||
false_type);
|
||||
OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv,
|
||||
int default_batch_size = 0, bool default_live_prep = true,
|
||||
bool variable_prime_length = false);
|
||||
|
||||
@@ -9,15 +9,8 @@
|
||||
|
||||
string PrepBase::get_suffix(int thread_num)
|
||||
{
|
||||
#ifdef INSECURE
|
||||
(void) thread_num;
|
||||
return "";
|
||||
#else
|
||||
if (thread_num >= 0)
|
||||
return "-T" + to_string(thread_num);
|
||||
else
|
||||
return "";
|
||||
#endif
|
||||
}
|
||||
|
||||
string PrepBase::get_filename(const string& prep_data_dir,
|
||||
|
||||
@@ -31,7 +31,7 @@ class SubProcessor
|
||||
|
||||
DataPositions bit_usage;
|
||||
|
||||
void resize(int size) { C.resize(size); S.resize(size); }
|
||||
void resize(size_t size) { C.resize(size); S.resize(size); }
|
||||
|
||||
template<class sint, class sgf2n> friend class Processor;
|
||||
template<class U> friend class SPDZ;
|
||||
@@ -64,10 +64,10 @@ public:
|
||||
void muls(const vector<int>& reg, int size);
|
||||
void mulrs(const vector<int>& reg);
|
||||
void dotprods(const vector<int>& reg, int size);
|
||||
void matmuls(const vector<T>& source, const Instruction& instruction, int a,
|
||||
int b);
|
||||
void matmulsm(const CheckVector<T>& source, const Instruction& instruction, int a,
|
||||
int b);
|
||||
void matmuls(const vector<T>& source, const Instruction& instruction, size_t a,
|
||||
size_t b);
|
||||
void matmulsm(const CheckVector<T>& source, const Instruction& instruction, size_t a,
|
||||
size_t b);
|
||||
void conv2ds(const Instruction& instruction);
|
||||
|
||||
void input_personal(const vector<int>& args);
|
||||
@@ -82,12 +82,12 @@ public:
|
||||
return C;
|
||||
}
|
||||
|
||||
T& get_S_ref(int i)
|
||||
T& get_S_ref(size_t i)
|
||||
{
|
||||
return S[i];
|
||||
}
|
||||
|
||||
typename T::clear& get_C_ref(int i)
|
||||
typename T::clear& get_C_ref(size_t i)
|
||||
{
|
||||
return C[i];
|
||||
}
|
||||
@@ -136,11 +136,11 @@ public:
|
||||
return thread_num;
|
||||
}
|
||||
|
||||
const long& read_Ci(int i) const
|
||||
const long& read_Ci(size_t i) const
|
||||
{ return Ci[i]; }
|
||||
long& get_Ci_ref(int i)
|
||||
long& get_Ci_ref(size_t i)
|
||||
{ return Ci[i]; }
|
||||
void write_Ci(int i,const long& x)
|
||||
void write_Ci(size_t i, const long& x)
|
||||
{ Ci[i]=x; }
|
||||
CheckVector<long>& get_Ci()
|
||||
{ return Ci; }
|
||||
@@ -190,30 +190,30 @@ class Processor : public ArithmeticProcessor
|
||||
const Program& program);
|
||||
~Processor();
|
||||
|
||||
const typename sgf2n::clear& read_C2(int i) const
|
||||
const typename sgf2n::clear& read_C2(size_t i) const
|
||||
{ return Proc2.C[i]; }
|
||||
const sgf2n& read_S2(int i) const
|
||||
const sgf2n& read_S2(size_t i) const
|
||||
{ return Proc2.S[i]; }
|
||||
typename sgf2n::clear& get_C2_ref(int i)
|
||||
typename sgf2n::clear& get_C2_ref(size_t i)
|
||||
{ return Proc2.C[i]; }
|
||||
sgf2n& get_S2_ref(int i)
|
||||
sgf2n& get_S2_ref(size_t i)
|
||||
{ return Proc2.S[i]; }
|
||||
void write_C2(int i,const typename sgf2n::clear& x)
|
||||
void write_C2(size_t i,const typename sgf2n::clear& x)
|
||||
{ Proc2.C[i]=x; }
|
||||
void write_S2(int i,const sgf2n& x)
|
||||
void write_S2(size_t i,const sgf2n& x)
|
||||
{ Proc2.S[i]=x; }
|
||||
|
||||
const typename sint::clear& read_Cp(int i) const
|
||||
const typename sint::clear& read_Cp(size_t i) const
|
||||
{ return Procp.C[i]; }
|
||||
const sint & read_Sp(int i) const
|
||||
const sint & read_Sp(size_t i) const
|
||||
{ return Procp.S[i]; }
|
||||
typename sint::clear& get_Cp_ref(int i)
|
||||
typename sint::clear& get_Cp_ref(size_t i)
|
||||
{ return Procp.C[i]; }
|
||||
sint & get_Sp_ref(int i)
|
||||
sint & get_Sp_ref(size_t i)
|
||||
{ return Procp.S[i]; }
|
||||
void write_Cp(int i,const typename sint::clear& x)
|
||||
void write_Cp(size_t i,const typename sint::clear& x)
|
||||
{ Procp.C[i]=x; }
|
||||
void write_Sp(int i,const sint & x)
|
||||
void write_Sp(size_t i,const sint & x)
|
||||
{ Procp.S[i]=x; }
|
||||
|
||||
void check();
|
||||
@@ -229,8 +229,8 @@ class Processor : public ArithmeticProcessor
|
||||
// Access to external client sockets for reading clear/shared data
|
||||
void read_socket_ints(int client_id, const vector<int>& registers, int size);
|
||||
|
||||
void write_socket(const RegType reg_type, int socket_id, int message_type,
|
||||
const vector<int>& registers, int size);
|
||||
void write_socket(const RegType reg_type, bool send_macs, int socket_id,
|
||||
int message_type, const vector<int>& registers, int size);
|
||||
|
||||
void read_socket_vector(int client_id, const vector<int>& registers,
|
||||
int size);
|
||||
|
||||
@@ -70,7 +70,7 @@ Processor<sint, sgf2n>::Processor(int thread_num,Player& P,
|
||||
const Program& program)
|
||||
: ArithmeticProcessor(machine.opts, thread_num),DataF(machine, &Procp, &Proc2),P(P),
|
||||
MC2(MC2),MCp(MCp),machine(machine),
|
||||
share_thread(machine.get_N(), machine.opts, P, machine.get_bit_mac_key(), DataF.usage),
|
||||
share_thread(DataF.DataFb, P, machine.get_bit_mac_key()),
|
||||
Procb(machine.bit_memories),
|
||||
Proc2(*this,MC2,DataF.DataF2,P),Procp(*this,MCp,DataF.DataFp,P),
|
||||
privateOutput2(Proc2),privateOutputp(Procp),
|
||||
@@ -94,21 +94,8 @@ Processor<sint, sgf2n>::Processor(int thread_num,Player& P,
|
||||
secure_prng.ReSeed();
|
||||
shared_prng.SeedGlobally(P, false);
|
||||
|
||||
// only output on party 0 if not interactive
|
||||
bool always_stdout = machine.opts.cmd_private_output_file == ".";
|
||||
bool output = P.my_num() == 0 or machine.opts.interactive or always_stdout;
|
||||
out.activate(output);
|
||||
Procb.out.activate(output);
|
||||
|
||||
if (not always_stdout)
|
||||
setup_redirection(P.my_num(), thread_num, opts);
|
||||
|
||||
if (stdout_redirect_file.is_open())
|
||||
{
|
||||
out.redirect_to_file(stdout_redirect_file);
|
||||
Procb.out.redirect_to_file(stdout_redirect_file);
|
||||
}
|
||||
|
||||
setup_redirection(P.my_num(), thread_num, opts, out);
|
||||
Procb.out = out;
|
||||
}
|
||||
|
||||
|
||||
@@ -266,8 +253,9 @@ void Processor<sint, sgf2n>::split(const Instruction& instruction)
|
||||
// If message_type is > 0, send message_type in bytes 0 - 3, to allow an external client to
|
||||
// determine the data structure being sent in a message.
|
||||
template<class sint, class sgf2n>
|
||||
void Processor<sint, sgf2n>::write_socket(const RegType reg_type, int socket_id,
|
||||
int message_type, const vector<int>& registers, int size)
|
||||
void Processor<sint, sgf2n>::write_socket(const RegType reg_type,
|
||||
bool send_macs, int socket_id, int message_type,
|
||||
const vector<int>& registers, int size)
|
||||
{
|
||||
int m = registers.size();
|
||||
socket_stream.reset_write_head();
|
||||
@@ -283,9 +271,12 @@ void Processor<sint, sgf2n>::write_socket(const RegType reg_type, int socket_id,
|
||||
{
|
||||
if (reg_type == SINT)
|
||||
{
|
||||
// Send vector of secret shares
|
||||
get_Sp_ref(registers[i] + j).pack(socket_stream,
|
||||
sint::get_rec_factor(P.my_num(), P.num_players()));
|
||||
// Send vector of secret shares and optionally macs
|
||||
if (send_macs)
|
||||
get_Sp_ref(registers[i] + j).pack(socket_stream);
|
||||
else
|
||||
get_Sp_ref(registers[i] + j).pack(socket_stream,
|
||||
sint::get_rec_factor(P.my_num(), P.num_players()));
|
||||
}
|
||||
else if (reg_type == CINT)
|
||||
{
|
||||
@@ -522,7 +513,7 @@ void SubProcessor<T>::dotprods(const vector<int>& reg, int size)
|
||||
|
||||
template<class T>
|
||||
void SubProcessor<T>::matmuls(const vector<T>& source,
|
||||
const Instruction& instruction, int a, int b)
|
||||
const Instruction& instruction, size_t a, size_t b)
|
||||
{
|
||||
auto& dim = instruction.get_start();
|
||||
auto A = source.begin() + a;
|
||||
@@ -549,7 +540,7 @@ void SubProcessor<T>::matmuls(const vector<T>& source,
|
||||
|
||||
template<class T>
|
||||
void SubProcessor<T>::matmulsm(const CheckVector<T>& source,
|
||||
const Instruction& instruction, int a, int b)
|
||||
const Instruction& instruction, size_t a, size_t b)
|
||||
{
|
||||
auto& dim = instruction.get_start();
|
||||
auto C = S.begin() + (instruction.get_r(0));
|
||||
|
||||
@@ -5,6 +5,11 @@
|
||||
|
||||
#include "ProcessorBase.hpp"
|
||||
|
||||
ProcessorBase::ProcessorBase() :
|
||||
input_counter(0), arg(0)
|
||||
{
|
||||
}
|
||||
|
||||
string ProcessorBase::get_parameterized_filename(int my_num, int thread_num, const string& prefix)
|
||||
{
|
||||
string filename = prefix + "-P" + to_string(my_num) + "-" + to_string(thread_num);
|
||||
@@ -22,12 +27,18 @@ void ProcessorBase::open_input_file(int my_num, int thread_num,
|
||||
}
|
||||
|
||||
void ProcessorBase::setup_redirection(int my_num, int thread_num,
|
||||
OnlineOptions& opts)
|
||||
OnlineOptions& opts, SwitchableOutput& out)
|
||||
{
|
||||
if (not opts.cmd_private_output_file.empty())
|
||||
// only output on party 0 if not interactive
|
||||
bool always_stdout = opts.cmd_private_output_file == ".";
|
||||
bool output = my_num == 0 or opts.interactive or always_stdout;
|
||||
out.activate(output);
|
||||
|
||||
if (not (opts.cmd_private_output_file.empty() or always_stdout))
|
||||
{
|
||||
const string stdout_filename = get_parameterized_filename(my_num,
|
||||
thread_num, opts.cmd_private_output_file);
|
||||
stdout_redirect_file.open(stdout_filename.c_str(), ios_base::out);
|
||||
out.redirect_to_file(stdout_redirect_file);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
using namespace std;
|
||||
|
||||
#include "Tools/ExecutionStats.h"
|
||||
#include "Tools/SwitchableOutput.h"
|
||||
#include "OnlineOptions.h"
|
||||
|
||||
class ProcessorBase
|
||||
@@ -21,6 +22,7 @@ class ProcessorBase
|
||||
|
||||
ifstream input_file;
|
||||
string input_filename;
|
||||
size_t input_counter;
|
||||
|
||||
protected:
|
||||
// Optional argument to tape
|
||||
@@ -34,6 +36,8 @@ public:
|
||||
|
||||
ofstream stdout_redirect_file;
|
||||
|
||||
ProcessorBase();
|
||||
|
||||
void pushi(long x) { stacki.push(x); }
|
||||
void popi(long& x) { x = stacki.top(); stacki.pop(); }
|
||||
|
||||
@@ -55,7 +59,8 @@ public:
|
||||
template<class T>
|
||||
T get_input(istream& is, const string& input_filename, const int* params);
|
||||
|
||||
void setup_redirection(int my_nu, int thread_num, OnlineOptions& opts);
|
||||
void setup_redirection(int my_nu, int thread_num, OnlineOptions& opts,
|
||||
SwitchableOutput& out);
|
||||
};
|
||||
|
||||
#endif /* PROCESSOR_PROCESSORBASE_H_ */
|
||||
|
||||
@@ -42,8 +42,9 @@ T ProcessorBase::get_input(istream& input_file, const string& input_filename, co
|
||||
res.read(input_file, params);
|
||||
if (input_file.fail())
|
||||
{
|
||||
throw input_error(T::NAME, input_filename, input_file);
|
||||
throw input_error(T::NAME, input_filename, input_file, input_counter);
|
||||
}
|
||||
input_counter++;
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ RingMachine<U, V, W>::RingMachine(int argc, const char** argv,
|
||||
case L: \
|
||||
machine.template run<U<L>, V<gf2n>>(); \
|
||||
break;
|
||||
X(64) X(72) X(128)
|
||||
X(64) X(72) X(128) X(192)
|
||||
#ifdef RING_SIZE
|
||||
X(RING_SIZE)
|
||||
#endif
|
||||
|
||||
@@ -8,6 +8,7 @@ import util
|
||||
|
||||
program.options_from_args()
|
||||
sfix.set_precision_from_args(program)
|
||||
MultiArray.disable_index_checks()
|
||||
|
||||
n_examples = 11791
|
||||
n_test = 1991
|
||||
|
||||
@@ -10,6 +10,7 @@ import util
|
||||
|
||||
program.options_from_args()
|
||||
sfix.set_precision_from_args(program, adapt_ring=True)
|
||||
MultiArray.disable_index_checks()
|
||||
|
||||
if 'profile' in program.args:
|
||||
print('Compiling for profiling')
|
||||
|
||||
@@ -8,6 +8,7 @@ import util
|
||||
|
||||
program.options_from_args()
|
||||
sfix.set_precision_from_args(program, adapt_ring=True)
|
||||
MultiArray.disable_index_checks()
|
||||
|
||||
if 'profile' in program.args:
|
||||
print('Compiling for profiling')
|
||||
|
||||
@@ -8,6 +8,7 @@ import util
|
||||
|
||||
program.options_from_args()
|
||||
sfix.set_precision_from_args(program, adapt_ring=True)
|
||||
MultiArray.disable_index_checks()
|
||||
|
||||
if 'profile' in program.args:
|
||||
print('Compiling for profiling')
|
||||
|
||||
@@ -8,6 +8,7 @@ import util
|
||||
|
||||
program.options_from_args()
|
||||
sfix.set_precision_from_args(program, True)
|
||||
MultiArray.disable_index_checks()
|
||||
|
||||
if 'profile' in program.args:
|
||||
print('Compiling for profiling')
|
||||
|
||||
@@ -30,6 +30,8 @@ layers[0].X.input_from(0)
|
||||
for layer in layers:
|
||||
layer.input_from(0, raw='raw' in program.args)
|
||||
|
||||
sint(0).reveal().store_in_mem(0)
|
||||
|
||||
start_timer(1)
|
||||
opt.forward(1, keep_intermediate=False)
|
||||
stop_timer(1)
|
||||
|
||||
@@ -8,6 +8,10 @@
|
||||
|
||||
#include "Replicated.h"
|
||||
|
||||
/**
|
||||
* ATLAS protocol (simple version).
|
||||
* Uses double sharings to reduce degree of Shamir secret sharing.
|
||||
*/
|
||||
template<class T>
|
||||
class Atlas : public ProtocolBase<T>
|
||||
{
|
||||
|
||||
@@ -8,6 +8,9 @@
|
||||
|
||||
#include "ReplicatedPrep.h"
|
||||
|
||||
/**
|
||||
* ATLAS preprocessing.
|
||||
*/
|
||||
template<class T>
|
||||
class AtlasPrep : public ReplicatedPrep<T>
|
||||
{
|
||||
@@ -21,6 +24,7 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
/// Input tuples from random sharings
|
||||
void buffer_inputs(int player)
|
||||
{
|
||||
assert(this->protocol and this->proc);
|
||||
|
||||
@@ -17,6 +17,9 @@ template<class T> class SubProcessor;
|
||||
template<class T> class MAC_Check_Base;
|
||||
class Player;
|
||||
|
||||
/**
|
||||
* Beaver multiplication
|
||||
*/
|
||||
template<class T>
|
||||
class Beaver : public ProtocolBase<T>
|
||||
{
|
||||
|
||||
@@ -8,6 +8,9 @@
|
||||
|
||||
#include "FHEOffline/SimpleGenerator.h"
|
||||
|
||||
/**
|
||||
* HighGear/ChaiGear preprocessing
|
||||
*/
|
||||
template<class T>
|
||||
class ChaiGearPrep : public MaliciousRingPrep<T>
|
||||
{
|
||||
|
||||
@@ -11,6 +11,9 @@
|
||||
class PairwiseMachine;
|
||||
template<class FD> class PairwiseGenerator;
|
||||
|
||||
/**
|
||||
* LowGear/CowGear preprocessing
|
||||
*/
|
||||
template<class T>
|
||||
class CowGearPrep : public MaliciousRingPrep<T>
|
||||
{
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user