Maintenance.

This commit is contained in:
Marcel Keller
2021-11-04 16:22:45 +11:00
parent 7b52ef9035
commit 32950fe8d4
185 changed files with 1818 additions and 654 deletions

View File

@@ -112,8 +112,6 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
garble_processor.reset(program);
this->processor.open_input_file(N.my_num(), 0);
T::bit_type::mac_key_type::init_field();
GC::ShareThread<typename T::bit_type> share_thread(N, online_opts, *P, 0, usage);
shared_proc = new SubProcessor<T>(dummy_proc, *MC, *prep, *P);
auto& inputter = shared_proc->input;

View File

@@ -243,6 +243,9 @@ public:
template <class T>
static T get_input(int from, GC::Processor<T>& processor, int n_bits)
{ return T::input(from, processor.get_input(n_bits), n_bits); }
template<class U>
static void reveal_inst(GC::Processor<U>& processor, const vector<int>& args)
{ processor.reveal(args); }
template<class T>
static void convcbit(Integer& dest, const GC::Clear& source, T&)

View File

@@ -1,5 +1,14 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
## 0.2.8 (Nov 4, 2021)
- Tested on Apple laptop with ARM chip
- Restore trusted client interface
- Directly accessible softmax function
- Signature in preprocessing files to reduce confusing errors
- Improved error messages for connection issues
- Documentation of low-level share types and protocol pairs
## 0.2.7 (Sep 17, 2021)
- Optimized matrix multiplication in Hemi

4
CONFIG
View File

@@ -88,7 +88,9 @@ CPPFLAGS = $(CFLAGS)
LD = $(CXX)
ifeq ($(OS), Darwin)
# for boost with OpenSSL 3
CFLAGS += -Wno-error=deprecated-declarations
ifeq ($(USE_NTL),1)
CFLAGS += -Wno-error=unused-parameter
CFLAGS += -Wno-error=unused-parameter -Wno-error=deprecated-copy
endif
endif

View File

@@ -309,6 +309,8 @@ class cbits(bits):
res = type(self)()
inst.notcb(self.n, res, self)
return res
def __eq__(self, other):
raise CompilerError('equality not implemented')
def print_reg(self, desc=''):
inst.print_regb(self, desc)
def print_reg_plain(self):

View File

@@ -19,6 +19,8 @@ class BlockAllocator:
self.by_address = {}
def by_size(self, size):
if size >= 2 ** 32:
raise CompilerError('size exceeds addressing capability')
return self.by_logsize[int(math.log(size, 2))][size]
def push(self, address, size):

View File

@@ -290,6 +290,7 @@ def BitDecRing(a, k, m):
return [types.sint.conv(bit) for bit in reversed(bits)][::-1]
def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None):
instructions_base.set_global_vector_size(a.size)
r_dprime = types.sint()
r_prime = types.sint()
c = types.cint()
@@ -298,6 +299,7 @@ def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None):
pow2 = two_power(k + kappa)
asm_open(c, pow2 + two_power(k) + a - two_power(m)*r_dprime - r_prime)
res = r[0].bit_adder(r, list(r[0].bit_decompose_clear(c,m)))
instructions_base.reset_global_vector_size()
return res
def BitDecField(a, k, m, kappa, bits_to_compute=None):

View File

@@ -1496,6 +1496,14 @@ class cond_print_plain(base.IOInstruction):
code = base.opcodes['CONDPRINTPLAIN']
arg_format = ['c', 'c', 'c']
def __init__(self, *args, **kwargs):
base.Instruction.__init__(self, *args, **kwargs)
self.size = args[1].size
args[2].set_size(self.size)
def get_code(self):
return base.Instruction.get_code(self, self.size)
class print_int(base.IOInstruction):
""" Output clear integer register.
@@ -1591,7 +1599,16 @@ class readsocketc(base.IOInstruction):
return True
class readsockets(base.IOInstruction):
"""Read a variable number of secret shares + MACs from socket for a client id and store in registers"""
""" Read a variable number of secret shares (potentially with MAC)
from a socket for a client id and store them in registers. If the
protocol uses MACs, the client should be different for every party.
:param: client id (regint)
:param: vector size (int)
:param: source (sint)
:param: (repeat source)...
"""
__slots__ = []
code = base.opcodes['READSOCKETS']
arg_format = tools.chain(['ci','int'], itertools.repeat('sw'))
@@ -1628,6 +1645,25 @@ class writesocketc(base.IOInstruction):
def has_var_args(self):
return True
class writesockets(base.IOInstruction):
""" Write a variable number of secret shares (potentially with MAC)
from registers into a socket for a specified client id. If the
protocol uses MACs, the client should be different for every party.
:param: client id (regint)
:param: message type (must be 0)
:param: vector size (int)
:param: source (sint)
:param: (repeat source)...
"""
__slots__ = []
code = base.opcodes['WRITESOCKETS']
arg_format = tools.chain(['ci', 'int', 'int'], itertools.repeat('s'))
def has_var_args(self):
return True
class writesocketshare(base.IOInstruction):
""" Write a variable number of shares (without MACs) from secret
registers into socket for a specified client id.
@@ -2384,5 +2420,76 @@ class lts(base.CISC):
subs(a, self.args[1], self.args[2])
comparison.LTZ(self.args[0], a, self.args[3], self.args[4])
# placeholder for documentation
class cisc:
""" Meta instruction for emulation. This instruction is only generated
when using ``-K`` with ``compile.py``. The header looks as follows:
:param: number of arguments after name plus one
:param: name (16 bytes, zero-padded)
Currently, the following names are supported:
LTZ
Less than zero.
:param: number of arguments in this unit (must be 6)
:param: vector size
:param: result (sint)
:param: input (sint)
:param: bit length
:param: (ignored)
:param: (repeat)...
Trunc
Truncation.
:param: number of arguments in this unit (must be 8)
:param: vector size
:param: result (sint)
:param: input (sint)
:param: bit length
:param: number of bits to truncate
:param: (ignored)
:param: 0 for unsigned or 1 for signed
:param: (repeat)...
FPDiv
Fixed-point division. Division by zero results in zero without error.
:param: number of arguments in this unit (must be at least 7)
:param: vector size
:param: result (sint)
:param: dividend (sint)
:param: divisor (sint)
:param: (ignored)
:param: fixed-point precision
:param: (repeat)...
exp2_fx
Fixed-point power of two.
:param: number of arguments in this unit (must be at least 6)
:param: vector size
:param: result (sint)
:param: exponent (sint)
:param: (ignored)
:param: fixed-point precision
:param: (repeat)...
log2_fx
Fixed-point logarithm with base 2.
:param: number of arguments in this unit (must be at least 6)
:param: vector size
:param: result (sint)
:param: input (sint)
:param: (ignored)
:param: fixed-point precision
:param: (repeat)...
"""
code = base.opcodes['CISC']
# hack for circular dependency
from Compiler import comparison

View File

@@ -481,7 +481,16 @@ def cisc(function):
def expand_merged(self, skip):
if function.__name__ in skip:
return [self], 0
good = True
for call in self.calls:
if not good:
break
for arg in call[0]:
if isinstance(arg, program.curr_tape.Register) and \
not issubclass(type(self.calls[0][0][0]), type(arg)):
good = False
if good:
return [self], 0
tape = program.curr_tape
block = tape.BasicBlock(tape, None, None)
tape.active_basicblock = block
@@ -520,6 +529,7 @@ def cisc(function):
String.check(name)
res += String.encode(name)
for call in self.calls:
call[1].pop('nearest', None)
assert not call[1]
res += int_to_bytes(len(call[0]) + 2)
res += int_to_bytes(call[0][0].size)

View File

@@ -155,6 +155,26 @@ def argmax(x):
return comp.if_else(a[0], b[0]), comp.if_else(a[1], b[1])
return tree_reduce(op, enumerate(x))[0]
def softmax(x):
""" Softmax.
:param x: vector or list of sfix
:returns: sfix vector
"""
return softmax_from_exp(exp_for_softmax(x)[0])
def exp_for_softmax(x):
m = util.max(x)
mv = m.expand_to_vector(len(x))
try:
x = x.get_vector()
except AttributeError:
x = sfix(x)
return (x - mv > -get_limit(x)).if_else(exp(x - mv), 0), m
def softmax_from_exp(x):
return x / sum(x)
report_progress = False
def progress(x):
@@ -464,10 +484,7 @@ class MultiOutput(MultiOutputBase):
self.losses[i] = -sfix.dot_product(
self.Y[batch[i]].get_vector(), log_e(div))
else:
m = util.max(self.X[i])
mv = m.expand_to_vector(d_out)
x = self.X[i].get_vector()
e = (x - mv > -get_limit(x)).if_else(exp(x - mv), 0)
e, m = exp_for_softmax(self.X[i])
self.exp[i].assign_vector(e)
if self.compute_loss:
true_X = sfix.dot_product(self.Y[batch[i]], self.X[i])
@@ -532,11 +549,8 @@ class MultiOutput(MultiOutputBase):
return
@for_range_opt_multithread(self.n_threads, len(batch))
def _(i):
for j in range(d_out):
dividend = self.exp[i][j]
divisor = sum(self.exp[i])
div = (divisor > 0.1).if_else(dividend / divisor, 0)
self.nabla_X[i][j] = (-self.Y[batch[i]][j] + div)
div = softmax_from_exp(self.exp[i])
self.nabla_X[i][:] = -self.Y[batch[i]][:] + div
self.maybe_debug_backward(batch)
def maybe_debug_backward(self, batch):
@@ -588,7 +602,7 @@ class DenseBase(Layer):
nablas = lambda self: (self.nabla_W, self.nabla_b)
def output_weights(self):
print_ln('%s', self.W.reveal_nested())
self.W.print_reveal_nested()
print_ln('%s', self.b.reveal_nested())
def backward_params(self, f_schur_Y, batch):
@@ -1316,7 +1330,7 @@ class ConvBase(BaseLayer):
self.bias.input_from(player, raw=raw)
def output_weights(self):
print_ln('%s', self.weights.reveal_nested())
self.weights.print_reveal_nested()
print_ln('%s', self.bias.reveal_nested())
def dot_product(self, iv, wv, out_y, out_x, out_c):
@@ -1942,7 +1956,12 @@ class Optimizer:
for label, X in enumerate(self.X_by_label):
indices = regint.Array(n * n_per_epoch)
indices_by_label.append(indices)
indices.assign(regint.inc(len(indices), 0, 1, 1, len(X)))
indices.assign(regint.inc(len(X)))
missing = len(indices) - len(X)
if missing:
indices.assign_vector(
regint.get_random(int(math.log2(len(X))), size=missing),
base=len(X))
if self.always_shuffle or n_per_epoch > 1:
indices.shuffle()
loss_sum = MemValue(sfix(0))
@@ -2050,6 +2069,8 @@ class Optimizer:
self.time_layers = 'time_layers' in program.args
self.revealing_correctness = not 'no_acc' in program.args
self.layers[-1].compute_loss = not 'no_loss' in program.args
if 'full_cisc' in program.args:
program.options.keep_cisc = 'FPDiv,exp2_fx,log2_fx'
model_input = 'model_input' in program.args
acc_first = model_input and not 'train_first' in program.args
if model_input:
@@ -2058,12 +2079,14 @@ class Optimizer:
else:
self.reset()
if 'one_iter' in program.args:
print_float_prec(16)
self.output_weights()
print_ln('loss')
print_ln('%s', self.eval(
self.layers[0].X.get_part(0, batch_size)).reveal_nested())
self.eval(
self.layers[0].X.get_part(0, batch_size),
batch_size=batch_size).print_reveal_nested()
for layer in self.layers:
print_ln('%s', layer.X.get_part(0, batch_size).reveal_nested())
layer.X.get_part(0, batch_size).print_reveal_nested()
print_ln('%s', self.layers[-1].Y.get_part(0, batch_size).reveal_nested())
batch = Array.create_from(regint.inc(batch_size))
self.forward(batch=batch, training=True)
@@ -2083,9 +2106,10 @@ class Optimizer:
return
N = self.layers[0].X.sizes[0]
n_trained = (N + batch_size - 1) // batch_size * batch_size
print_ln('train_acc: %s (%s/%s)',
cfix(self.n_correct, k=63, f=31) / n_trained,
self.n_correct, n_trained)
if not acc_first:
print_ln('train_acc: %s (%s/%s)',
cfix(self.n_correct, k=63, f=31) / n_trained,
self.n_correct, n_trained)
if test_X and test_Y:
n_test = len(test_Y)
n_correct, loss = self.reveal_correctness(test_X, test_Y,
@@ -2500,16 +2524,30 @@ class keras:
batch_size = min(batch_size, self.batch_size)
return self.opt.eval(x, batch_size=batch_size)
def solve_linear(A, b, n_iterations, progress=False):
""" Iterative linear solution approximation. """
def solve_linear(A, b, n_iterations, progress=False, n_threads=None,
stop=False, already_symmetric=False, precond=False):
""" Iterative linear solution approximation for :math:`Ax=b`.
:param progress: print some information on the progress (implies revealing)
:param n_threads: number of threads to use
:param stop: whether to stop when converged (implies revealing)
"""
assert len(b) == A.sizes[0]
x = sfix.Array(A.sizes[1])
x.assign_vector(sfix.get_random(-1, 1, size=len(x)))
AtA = sfix.Matrix(len(x), len(x))
AtA[:] = A.direct_trans_mul(A)
if already_symmetric:
AtA = A
r = Array.create_from(b - AtA * x)
else:
AtA = sfix.Matrix(len(x), len(x))
A.trans_mul_to(A, AtA, n_threads=n_threads)
r = Array.create_from(A.transpose() * b - AtA * x)
if precond:
return solve_linear_diag_precond(AtA, b, x, r, n_iterations,
progress, stop)
v = sfix.Array(A.sizes[1])
v.assign_all(0)
r = Array.create_from(A.transpose() * b - AtA * x)
Av = sfix.Array(len(x))
@for_range(n_iterations)
def _(i):
@@ -2523,10 +2561,43 @@ def solve_linear(A, b, n_iterations, progress=False):
if progress:
print_ln('%s alpha=%s vr=%s v_norm=%s', i, alpha.reveal(),
vr.reveal(), v_norm.reveal())
if stop:
return (alpha > 0).reveal()
return x
def mr(A, n_iterations):
""" Iterative matrix inverse approximation. """
def solve_linear_diag_precond(A, b, x, r, n_iterations, progress=False,
stop=False):
m = 1 / A.diag()
mr = Array.create_from(m * r[:])
d = Array.create_from(mr)
@for_range(n_iterations)
def _(i):
Ad = A * d
d_norm = sfix.dot_product(d, Ad)
alpha = (d_norm == 0).if_else(0, sfix.dot_product(r, mr) / d_norm)
x[:] = x[:] + alpha * d[:]
r_norm = sfix.dot_product(r, mr)
r[:] = r[:] - alpha * Ad
tmp = m * r[:]
beta = (r_norm == 0).if_else(0, sfix.dot_product(r, tmp) / r_norm)
mr[:] = tmp
d[:] = tmp + beta * d
if progress:
print_ln('%s alpha=%s beta=%s r_norm=%s d_norm=%s', i,
alpha.reveal(), beta.reveal(), r_norm.reveal(),
d_norm.reveal())
if stop:
return (alpha > 0).reveal()
return x
def mr(A, n_iterations, stop=False):
""" Iterative matrix inverse approximation.
:param A: matrix to invert
:param n_iterations: maximum number of iterations
:param stop: whether to stop when converged (implies revealing)
"""
assert len(A.sizes) == 2
assert A.sizes[0] == A.sizes[1]
M = A.same_shape()
@@ -2536,5 +2607,18 @@ def mr(A, n_iterations):
e = sfix.Array(n)
e.assign_all(0)
e[i] = 1
M[i] = solve_linear(A, e, n_iterations)
M[i] = solve_linear(A, e, n_iterations, stop=stop)
return M.transpose()
def var(x):
""" Variance. """
mean = MemValue(type(x[0])(0))
@for_range_opt(len(x))
def _(i):
mean.iadd(x[i])
mean /= len(x)
res = MemValue(type(x[0])(0))
@for_range_opt(len(x))
def _(i):
res.iadd((x[i] - mean.read()) ** 2)
return res.read()

View File

@@ -690,8 +690,9 @@ class Tape:
def expand_cisc(self):
new_instructions = []
if self.parent.program.options.keep_cisc:
if self.parent.program.options.keep_cisc != None:
skip = ['LTZ', 'Trunc']
skip += self.parent.program.options.keep_cisc.split(',')
else:
skip = []
for inst in self.instructions:

View File

@@ -1161,7 +1161,7 @@ class cint(_clear, _int):
cond_print_str(self, string)
def output_if(self, cond):
cond_print_plain(self.conv(cond), self, cint(0))
cond_print_plain(self.conv(cond), self, cint(0, size=self.size))
class cgf2n(_clear, _gf2n):
@@ -1922,14 +1922,7 @@ class _secret(_register, _secret_structure):
r = self.get_dabit()
movs(self, r[0].bit_xor((r[1] ^ val).reveal().to_regint_by_bit()))
elif isinstance(val, sbitvec):
assert(sum(x.n for x in val.v) == self.size)
for val_part, base in zip(val, range(0, self.size, 64)):
left = min(64, self.size - base)
r = self.get_dabit(size=left)
v = regint(size=left)
bitdecint_class(regint((r[1] ^ val_part).reveal()), *v)
part = r[0].bit_xor(v)
vmovs(left, self.get_vector(base, left), part)
movs(self, sint.bit_compose(val))
else:
self.load_clear(self.clear_type(val))
@@ -1939,6 +1932,8 @@ class _secret(_register, _secret_structure):
:param bits: iterable of any type convertible to sint """
from Compiler.GC.types import sbits, sbitintvec
if isinstance(bits, sbits):
bits = bits.bit_decompose()
bits = list(bits)
if (program.use_edabit() or program.use_split()) and isinstance(bits[0], sbits):
if program.use_edabit():
@@ -2112,6 +2107,11 @@ class sint(_secret, _int):
thereof or sbits/sbitvec/sfix)
:param size: vector size (int), defaults to 1 or size of list
When converting :py:class:`~Compiler.GC.types.sbits`, the result is a
vector of bits, and when converting
:py:class:`~Compiler.GC.types.sbitvec`, the result is a vector of values
with bit length equal the length of the input.
"""
__slots__ = []
instruction_type = 'modp'
@@ -2278,6 +2278,17 @@ class sint(_secret, _int):
else:
return res
@vectorized_classmethod
def write_to_socket(cls, client_id, values,
message_type=ClientMessageType.NoType):
""" Send a list of shares and MAC shares to a client socket.
:param client_id: regint
:param values: list of sint
"""
writesockets(client_id, message_type, values[0].size, *values)
@vectorize
def write_share_to_socket(self, client_id, message_type=ClientMessageType.NoType):
""" Send only share to socket """
@@ -2339,6 +2350,7 @@ class sint(_secret, _int):
@vectorize_init
def __init__(self, val=None, size=None):
from .GC.types import sbitvec
if isinstance(val, personal):
size = val._v.size
super(sint, self).__init__('s', size=size)
@@ -2346,6 +2358,8 @@ class sint(_secret, _int):
elif isinstance(val, _fix):
super(sint, self).__init__('s', size=val.v.size)
self.load_other(val.v.round(val.k, val.f))
elif isinstance(val, sbitvec):
super(sint, self).__init__('s', val=val, size=val[0].n)
else:
super(sint, self).__init__('s', val=val, size=size)
@@ -3747,13 +3761,14 @@ class cfix(_number, _structure):
other = parse_type(other, self.k, self.f)
return other / self
@vectorize
def print_plain(self):
""" Clear fixed-point output. """
print_float_plain(cint.conv(self.v), cint(-self.f), \
cint(0), cint(0), cint(0))
def output_if(self, cond):
cond_print_plain(cint.conv(cond), self.v, cint(-self.f))
cond_print_plain(cint.conv(cond), self.v, cint(-self.f, size=self.size))
@vectorize
def binary_output(self, player=None):
@@ -4028,7 +4043,8 @@ class _fix(_single):
print('Nearest rounding instead of proabilistic '
'for fixed-point computation')
cls.round_nearest = True
if adapt_ring and program.options.ring:
if adapt_ring and program.options.ring \
and 'fix_ring' not in program.args:
need = 2 ** int(math.ceil(math.log(2 * cls.k, 2)))
if need != int(program.options.ring):
print('Changing computation modulus to 2^%d' % need)
@@ -4094,6 +4110,8 @@ class _fix(_single):
elif isinstance(_v, (MemValue, MemFix)):
#this is a memvalue object
self.v = type(self)(_v.read()).v
elif isinstance(_v, (list, tuple)):
self.v = self.int_type(list(self.conv(x).v for x in _v))
else:
raise CompilerError('cannot convert %s to sfix' % _v)
if not isinstance(self.v, self.int_type):
@@ -4250,7 +4268,7 @@ class sfix(_fix):
return cls._new(cls.int_type.get_raw_input_from(player))
@vectorized_classmethod
def get_random(cls, lower, upper):
def get_random(cls, lower, upper, symmetric=True):
""" Uniform secret random number around centre of bounds.
Actual range can be smaller but never larger.
@@ -4261,8 +4279,19 @@ class sfix(_fix):
log_range = int(math.log(upper - lower, 2))
n_bits = log_range + cls.f
average = lower + 0.5 * (upper - lower)
lower = average - 0.5 * 2 ** log_range
return cls._new(cls.int_type.get_random_int(n_bits)) + lower
real_range = (2 ** (n_bits) - 1) / 2 ** cls.f
lower = average - 0.5 * real_range
real_lower = round(lower * 2 ** cls.f) / 2 ** cls.f
r = cls._new(cls.int_type.get_random_int(n_bits)) + lower
if symmetric:
lowest = math.floor(lower * 2 ** cls.f) / 2 ** cls.f
print('randomness range [%f,%f], fringes half the probability' % \
(lowest, lowest + 2 ** log_range))
return cls.int_type.get_random_bit().if_else(r, -r + 2 * average)
else:
print('randomness range [%f,%f], %d bits' % \
(real_lower, real_lower + real_range, n_bits))
return r
@classmethod
def direct_matrix_mul(cls, A, B, n, m, l, reduce=True, indices=None):
@@ -4985,6 +5014,7 @@ class cfloat(object):
""" Helper class for printing revealed sfloats. """
__slots__ = ['v', 'p', 'z', 's', 'nan']
@vectorize_init
def __init__(self, v, p=None, z=None, s=None, nan=0):
""" Parameters as with :py:class:`sfloat` but public. """
if s is None:
@@ -4993,6 +5023,11 @@ class cfloat(object):
parts = [cint.conv(x) for x in (v, p, z, s, nan)]
self.v, self.p, self.z, self.s, self.nan = parts
@property
def size(self):
return self.v.size
@vectorize
def print_float_plain(self):
""" Output. """
print_float_plain(self.v, self.p, self.z, self.s, self.nan)
@@ -5389,15 +5424,7 @@ class Array(_vectorizable):
def shuffle(self):
""" Insecure shuffle in place. """
if self.value_type == regint:
self.assign(self.get_vector().shuffle())
else:
@library.for_range(len(self))
def _(i):
j = regint.get_random(64) % (len(self) - i)
tmp = self[i]
self[i] = self[i + j]
self[i + j] = tmp
self.assign_vector(self.get(regint.inc(len(self)).shuffle()))
def reveal(self):
""" Reveal the whole array.
@@ -5411,6 +5438,13 @@ class Array(_vectorizable):
reveal_nested = reveal_list
def print_reveal_nested(self, end='\n'):
""" Reveal and print as list.
:param end: string to print after (default: line break)
"""
library.print_str('%s' + end, self.get_vector().reveal())
def reveal_to_binary_output(self, player=None):
""" Reveal to binary output if supported by type.
@@ -5449,6 +5483,8 @@ sgf2n.dynamic_array = Array
class SubMultiArray(_vectorizable):
""" Multidimensional array functionality. Don't construct this
directly, use :py:class:`MultiArray` instead. """
check_indices = True
def __init__(self, sizes, value_type, address, index, debug=None):
self.sizes = tuple(sizes)
self.value_type = _get_type(value_type)
@@ -5458,7 +5494,6 @@ class SubMultiArray(_vectorizable):
self.address = None
self.sub_cache = {}
self.debug = debug
self.check_indices = True
if debug:
library.print_ln_if(self.address + reduce(operator.mul, self.sizes) * self.value_type.n_elements() > program.allocated_mem[self.value_type.reg_type], 'AOF%d:' % len(self.sizes) + self.debug)
@@ -5467,8 +5502,6 @@ class SubMultiArray(_vectorizable):
:param index: public (regint/cint/int)
:return: :py:class:`Array` if one-dimensional, :py:class:`SubMultiArray` otherwise"""
if util.is_constant(index) and index >= self.sizes[0]:
raise StopIteration
if isinstance(index, slice) and index == slice(None):
return self.get_vector()
key = program.curr_block, str(index)
@@ -5490,7 +5523,9 @@ class SubMultiArray(_vectorizable):
self.sub_cache[key] = \
SubMultiArray(self.sizes[1:], self.value_type, \
self.address, index, debug=self.debug)
return self.sub_cache[key]
res = self.sub_cache[key]
res.check_indices = self.check_indices
return res
def __setitem__(self, index, other):
""" Part assignment.
@@ -5505,6 +5540,9 @@ class SubMultiArray(_vectorizable):
""" Size of top dimension. """
return self.sizes[0]
def __iter__(self):
return (self[i] for i in range(len(self)))
def assign_all(self, value):
""" Assign the same value to all entries.
@@ -5877,6 +5915,38 @@ class SubMultiArray(_vectorizable):
self.address, other.address, None, 1, other.sizes[1],
reduce=reduce, indices=indices)
def trans_mul_to(self, other, res, n_threads=None):
"""
Matrix multiplication with the transpose of :py:obj:`self`
in the virtual machine.
:param self: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray`
:param other: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray`
:param res: matrix of matching dimension to store result
:param n_threads: number of threads (default: single thread)
"""
@library.for_range_multithread(n_threads, 1, self.sizes[1])
def _(i):
indices = [regint(i), regint.inc(self.sizes[0])]
indices += [regint.inc(i) for i in other.sizes]
res[i] = self.direct_trans_mul(other, indices=indices)
def mul_trans_to(self, other, res, n_threads=None):
"""
Matrix multiplication with the transpose of :py:obj:`other`
in the virtual machine.
:param self: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray`
:param other: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray`
:param res: matrix of matching dimension to store result
:param n_threads: number of threads (default: single thread)
"""
@library.for_range_multithread(n_threads, 1, self.sizes[0])
def _(i):
indices = [regint(i), regint.inc(self.sizes[1])]
indices += [regint.inc(i) for i in reversed(other.sizes)]
res[i] = self.direct_mul_trans(other, indices=indices)
def direct_mul_to_matrix(self, other):
""" Matrix multiplication in the virtual machine.
@@ -5992,6 +6062,13 @@ class SubMultiArray(_vectorizable):
assert self.sizes[0] == self.sizes[1]
return sum(self[i][i] for i in range(self.sizes[0]))
def diag(self):
""" Matrix diagonal. """
assert len(self.sizes) == 2
assert self.sizes[0] == self.sizes[1]
n = self.sizes[0]
return self.array.get(regint.inc(n, 0, n + 1))
def reveal_list(self):
""" Reveal as list. """
return list(self.get_vector().reveal())
@@ -6007,6 +6084,21 @@ class SubMultiArray(_vectorizable):
return [f(sizes[1:]) for i in range(sizes[0])]
return f(self.sizes)
def print_reveal_nested(self, end='\n'):
""" Reveal and print as nested list.
:param end: string to print after (default: line break)
"""
if self.total_size() < program.options.budget:
library.print_str('%s' + end, self.reveal_nested())
else:
library.print_str('[')
@library.for_range(len(self) - 1)
def _(i):
self[i].print_reveal_nested(end=', ')
self[len(self) - 1].print_reveal_nested(end='')
library.print_str(']' + end)
def reveal_to_binary_output(self, player=None):
""" Reveal to binary output if supported by type.
@@ -6042,6 +6134,10 @@ class MultiArray(SubMultiArray):
a[2][:] = a[0][:] * a[1][:]
"""
@staticmethod
def disable_index_checks():
SubMultiArray.check_indices = False
def __init__(self, sizes, value_type, debug=None, address=None, alloc=True):
if isinstance(address, Array):
self.array = address

View File

@@ -3,6 +3,8 @@
*
*/
#define NO_MIXED_CIRCUITS
#include "Networking/Server.h"
#include "Networking/CryptoPlayer.h"
#include "Math/gfp.h"
@@ -49,8 +51,6 @@ int main(int argc, const char** argv)
ArithmeticProcessor _({}, 0);
BaseMachine machine;
machine.ot_setups.push_back({P, false});
GC::ShareThread<typename pShare::bit_type> thread(N,
OnlineOptions::singleton, P, {}, usage);
SubProcessor<pShare> proc(_, MCp, prep, P);
pShare sk, __;

View File

@@ -3,6 +3,8 @@
*
*/
#define NO_MIXED_CIRCUITS
#include "GC/TinierSecret.h"
#include "GC/TinyMC.h"
#include "GC/VectorInput.h"

View File

@@ -106,8 +106,6 @@ void run(int argc, const char** argv)
typename pShare::Direct_MC MCp(keyp);
ArithmeticProcessor _({}, 0);
typename pShare::TriplePrep sk_prep(0, usage);
GC::ShareThread<typename pShare::bit_type> thread(N,
OnlineOptions::singleton, P, {}, usage);
SubProcessor<pShare> sk_proc(_, MCp, sk_prep, P);
pShare sk, __;
// synchronize

View File

@@ -6,6 +6,7 @@
#include "FHEOffline/PairwiseMachine.h"
#include "Tools/benchmarking.h"
#include "Protocols/fake-stuff.h"
#include "Tools/Bundle.h"
#include "Protocols/fake-stuff.hpp"

View File

@@ -167,6 +167,8 @@ string open_prep_file(ofstream& outf, string data_type, int my_num, int thread_n
throw runtime_error("cannot create directory " + dir);
string file = prep_filename<T>(data_type, my_num, thread_num, initial, dir);
outf.open(file.c_str(),ios::out | ios::binary | (clear ? ios::trunc : ios::app));
if (clear)
file_signature<Share<T>>().output(outf);
if (outf.fail()) { throw file_error(file); }
return file;
}
@@ -516,6 +518,7 @@ InputProducer<FD>::InputProducer(const Player& P, int thread_num,
if (thread_num)
file << "-" << thread_num;
outf[j].open(file.str().c_str(), ios::out | ios::binary);
file_signature<Share<T>>().output(outf[j]);
if (outf[j].fail())
{
throw file_error(file.str());

View File

@@ -24,11 +24,4 @@ AtlasShare::AtlasShare(const AtlasSecret& other) :
{
}
void AtlasShare::random()
{
AtlasSecret tmp;
this->get_party().DataF.get_one(DATA_BIT, tmp);
*this = tmp.get_reg(0);
}
}

View File

@@ -63,8 +63,6 @@ public:
{
*this = input;
}
void random();
};
}

View File

@@ -20,17 +20,17 @@ class CcdPrep : public BufferPrep<T>
{
typename T::part_type::LivePrep part_prep;
SubProcessor<typename T::part_type>* part_proc;
ShareThread<T>& thread;
public:
CcdPrep(DataPositions& usage, ShareThread<T>& thread) :
BufferPrep<T>(usage), part_prep(usage, thread), part_proc(0),
thread(thread)
static const bool use_part = true;
CcdPrep(DataPositions& usage) :
BufferPrep<T>(usage), part_prep(usage), part_proc(0)
{
}
CcdPrep(SubProcessor<T>*, DataPositions& usage) :
CcdPrep(usage, ShareThread<T>::s())
CcdPrep(usage)
{
}

View File

@@ -23,6 +23,7 @@ CcdPrep<T>::~CcdPrep()
template<class T>
void CcdPrep<T>::set_protocol(typename T::Protocol& protocol)
{
auto& thread = ShareThread<T>::s();
assert(thread.MC);
part_proc = new SubProcessor<typename T::part_type>(
thread.MC->get_part_MC(), part_prep, protocol.get_part().P);

View File

@@ -74,13 +74,6 @@ public:
*this = input;
}
void random()
{
CcdSecret<T> tmp;
ShareThread<CcdSecret<T>>::s().DataF.get_one(DATA_BIT, tmp);
*this = tmp.get_reg(0);
}
This& operator^=(const This& other)
{
*this += other;

View File

@@ -136,6 +136,8 @@ public:
void andrs(int n, const FakeSecret& x, const FakeSecret& y)
{ *this = BitVec(x.a * (y.a & 1)).mask(n); }
void xor_bit(int i, FakeSecret bit) { *this ^= bit << i; }
void invert(int n, const FakeSecret& x) { *this = BitVec(~x.a).mask(n); }
void random_bit() { a = random() % 2; }

View File

@@ -79,13 +79,6 @@ public:
*this = input;
}
void random()
{
MaliciousCcdSecret<T> tmp;
ShareThread<MaliciousCcdSecret<T>>::s().DataF.get_one(DATA_BIT, tmp);
*this = tmp.get_reg(0);
}
This& operator^=(const This& other)
{
*this += other;

View File

@@ -7,12 +7,12 @@
#define GC_NOSHARE_H_
#include "Processor/DummyProtocol.h"
#include "BMR/Register.h"
#include "Tools/SwitchableOutput.h"
#include "Protocols/ShareInterface.h"
class InputArgs;
class ArithmeticProcessor;
class BlackHole;
class SwitchableOutput;
namespace GC
{
@@ -110,7 +110,7 @@ public:
typedef NoShare small_type;
typedef BlackHole out_type;
typedef SwitchableOutput out_type;
static const bool is_real = false;
@@ -124,6 +124,11 @@ public:
return "no";
}
static void specification(octetStream&)
{
fail();
}
static int size()
{
return 0;
@@ -172,6 +177,8 @@ public:
NoShare get_bit(int) const { fail(); return {}; }
void xor_bit(int, NoShare) const { fail(); }
void invert(int, NoShare) { fail(); }
NoShare mask(int) const { fail(); return {}; }

View File

@@ -27,7 +27,7 @@ public:
static int check_args(const vector<int>& args, int n);
template<class U>
static void check_input(const U& in, int n_bits);
static void check_input(const U& in, const int* params);
Machine<T>* machine;
Memories<T>& memories;

View File

@@ -89,15 +89,15 @@ U GC::Processor<T>::get_long_input(const int* params,
else
res = input_proc.get_input<FixInput_<U>>(interactive,
&params[1]).items[0];
int n_bits = *params;
check_input(res, n_bits);
check_input(res, params);
return res;
}
template<class T>
template<class U>
void GC::Processor<T>::check_input(const U& in, int n_bits)
void GC::Processor<T>::check_input(const U& in, const int* params)
{
int n_bits = *params;
auto test = in >> (n_bits - 1);
if (n_bits == 1)
{
@@ -106,9 +106,17 @@ void GC::Processor<T>::check_input(const U& in, int n_bits)
}
else if (not (test == 0 or test == -1))
{
throw runtime_error(
"input too large for a " + std::to_string(n_bits)
+ "-bit signed integer: " + to_string(in));
if (params[1] == 0)
throw runtime_error(
"input out of range for a " + std::to_string(n_bits)
+ "-bit signed integer: " + to_string(in));
else
throw runtime_error(
"input out of range for a " + to_string(n_bits)
+ "-bit fixed-point number with "
+ to_string(params[1]) + "-bit precision: "
+ to_string(
mpf_class(bigint(in)) * exp2(-params[1])));
}
}

View File

@@ -22,7 +22,6 @@ class RepPrep : public PersonalPrep<T>, ShiftableTripleBuffer<T>
ReplicatedBase* protocol;
public:
RepPrep(DataPositions& usage, ShareThread<T>& thread);
RepPrep(DataPositions& usage, int input_player = PersonalPrep<T>::SECURE);
~RepPrep();

View File

@@ -16,13 +16,6 @@
namespace GC
{
template<class T>
RepPrep<T>::RepPrep(DataPositions& usage, ShareThread<T>& thread) :
RepPrep<T>(usage)
{
(void) thread;
}
template<class T>
RepPrep<T>::RepPrep(DataPositions& usage, int input_player) :
PersonalPrep<T>(usage, input_player), protocol(0)

View File

@@ -31,6 +31,11 @@ public:
{
tainted = true;
}
bool is_tainted()
{
return tainted;
}
};
} /* namespace GC */

View File

@@ -84,7 +84,6 @@ public:
template <class U>
static void store_clear_in_dynamic(U& mem, const vector<ClearWriteAccess>& accesses)
{ T::store_clear_in_dynamic(mem, accesses); }
static void output(T& reg);
template<class U, class V>
static void load(vector< ReadAccess<V> >& accesses, const U& mem);
@@ -113,7 +112,7 @@ public:
{ T::inputbvec(processor, input_proc, args); }
template<class U>
static void reveal_inst(Processor<U>& processor, const vector<int>& args)
{ processor.reveal(args); }
{ T::reveal_inst(processor, args); }
template<class U>
static void trans(Processor<U>& processor, int n_inputs, const vector<int>& args);
@@ -148,7 +147,6 @@ public:
}
void invert(int n, const Secret<T>& x);
void and_(int n, const Secret<T>& x, const Secret<T>& y, bool repeat);
void andrs(int n, const Secret<T>& x, const Secret<T>& y) { and_(n, x, y, true); }
template <class U>
void reveal(size_t n_bits, U& x);

View File

@@ -119,12 +119,6 @@ void Secret<T>::store(U& mem,
T::store(mem, accesses);
}
template <class T>
void Secret<T>::output(T& reg)
{
reg.output();
}
template <class T>
Secret<T>::Secret()
{

View File

@@ -15,10 +15,6 @@ namespace GC
class SemiHonestRepPrep : public RepPrep<SemiHonestRepSecret>
{
public:
SemiHonestRepPrep(DataPositions& usage, ShareThread<SemiHonestRepSecret>&) :
RepPrep<SemiHonestRepSecret>(usage)
{
}
SemiHonestRepPrep(DataPositions& usage, bool = false) :
RepPrep<SemiHonestRepSecret>(usage)
{

View File

@@ -16,11 +16,6 @@
namespace GC
{
SemiPrep::SemiPrep(DataPositions& usage, ShareThread<SemiSecret>&) :
SemiPrep(usage)
{
}
SemiPrep::SemiPrep(DataPositions& usage, bool) :
BufferPrep<SemiSecret>(usage), triple_generator(0)
{

View File

@@ -26,7 +26,6 @@ class SemiPrep : public BufferPrep<SemiSecret>, ShiftableTripleBuffer<SemiSecret
SeededPRNG secure_prng;
public:
SemiPrep(DataPositions& usage, ShareThread<SemiSecret>& thread);
SemiPrep(DataPositions& usage, bool = true);
~SemiPrep();

View File

@@ -73,6 +73,9 @@ public:
void xor_(int n, const SemiSecret& x, const SemiSecret& y)
{ *this = BitVec(x ^ y).mask(n); }
void xor_bit(int i, const SemiSecret& bit)
{ *this ^= bit << i; }
void reveal(size_t n_bits, Clear& x);
SemiSecret lsb()

View File

@@ -161,6 +161,9 @@ public:
This get_bit(int i)
{ return (*this >> i) & 1; }
void xor_bit(int i, const This& bit)
{ *this ^= bit << i; }
};
template<class U>

View File

@@ -32,9 +32,9 @@ public:
Preprocessing<T>& DataF;
ShareThread(const Names& N, OnlineOptions& opts, DataPositions& usage);
ShareThread(const Names& N, OnlineOptions& opts, Player& P,
typename T::mac_key_type mac_key, DataPositions& usage);
ShareThread(Preprocessing<T>& prep);
ShareThread(Preprocessing<T>& prep, Player& P,
typename T::mac_key_type mac_key);
virtual ~ShareThread();
virtual typename T::MC* new_mc(typename T::mac_key_type mac_key)
@@ -54,6 +54,7 @@ public:
DataPositions usage;
StandaloneShareThread(int i, ThreadMaster<T>& master);
~StandaloneShareThread();
void pre_run();
void post_run() { ShareThread<T>::post_run(); }

View File

@@ -18,26 +18,28 @@ namespace GC
template<class T>
StandaloneShareThread<T>::StandaloneShareThread(int i, ThreadMaster<T>& master) :
ShareThread<T>(master.N, master.opts, usage), Thread<T>(i, master)
ShareThread<T>(*Preprocessing<T>::get_new(master.opts.live_prep,
master.N, usage)),
Thread<T>(i, master)
{
}
template<class T>
ShareThread<T>::ShareThread(const Names& N, OnlineOptions& opts, DataPositions& usage) :
P(0), MC(0), protocol(0), DataF(
opts.live_prep ?
*static_cast<Preprocessing<T>*>(new typename T::LivePrep(
usage, *this)) :
*static_cast<Preprocessing<T>*>(new BitPrepFiles<T>(N,
get_prep_sub_dir<T>(PREP_DIR, N.num_players()),
usage, BaseMachine::thread_num)))
StandaloneShareThread<T>::~StandaloneShareThread()
{
delete &this->DataF;
}
template<class T>
ShareThread<T>::ShareThread(Preprocessing<T>& prep) :
P(0), MC(0), protocol(0), DataF(prep)
{
}
template<class T>
ShareThread<T>::ShareThread(const Names& N, OnlineOptions& opts, Player& P,
typename T::mac_key_type mac_key, DataPositions& usage) :
ShareThread(N, opts, usage)
ShareThread<T>::ShareThread(Preprocessing<T>& prep, Player& P,
typename T::mac_key_type mac_key) :
ShareThread(prep)
{
pre_run(P, mac_key);
}
@@ -45,7 +47,6 @@ ShareThread<T>::ShareThread(const Names& N, OnlineOptions& opts, Player& P,
template<class T>
ShareThread<T>::~ShareThread()
{
delete &DataF;
if (MC)
delete MC;
if (protocol)
@@ -76,12 +77,6 @@ void ShareThread<T>::post_run()
{
protocol->check();
MC->Check(*this->P);
#ifndef INSECURE
#ifdef VERBOSE
cerr << "Removing used pre-processed data" << endl;
#endif
DataF.prune();
#endif
}
template<class T>

View File

@@ -58,10 +58,8 @@ void Thread<T>::run()
P = new PlainPlayer(N, id);
processor.open_input_file(N.my_num(), thread_num,
master.opts.cmd_private_input_file);
processor.out.activate(N.my_num() == 0 or master.opts.interactive);
processor.setup_redirection(P->my_num(), thread_num, master.opts);
if (processor.stdout_redirect_file.is_open())
processor.out.redirect_to_file(processor.stdout_redirect_file);
processor.setup_redirection(P->my_num(), thread_num, master.opts,
processor.out);
done.push(0);
pre_run();

View File

@@ -58,6 +58,16 @@ Thread<T>* ThreadMaster<T>::new_thread(int i)
template<class T>
void ThreadMaster<T>::run()
{
#ifndef INSECURE
if (not opts.live_prep)
{
cerr
<< "Preprocessing from file not supported by binary virtual machines"
<< endl;
exit(1);
}
#endif
P = new PlainPlayer(N, "main");
machine.load_schedule(progname);

View File

@@ -106,11 +106,6 @@ public:
party.MC->get_alphai());
}
void random()
{
*this = get_party().DataF.get_part().get_bit();
}
This lsb() const
{
return *this;

View File

@@ -25,7 +25,6 @@ class TinierSharePrep : public PersonalPrep<T>
MascotParams params;
typedef typename T::whole_type secret_type;
ShareThread<secret_type>& thread;
void buffer_triples();
void buffer_squares() { throw not_implemented(); }
@@ -39,8 +38,6 @@ class TinierSharePrep : public PersonalPrep<T>
void init_real(Player& P);
public:
TinierSharePrep(DataPositions& usage, ShareThread<secret_type>& thread,
int input_player = PersonalPrep<T>::SECURE);
TinierSharePrep(DataPositions& usage, int input_player =
PersonalPrep<T>::SECURE);
TinierSharePrep(SubProcessor<T>*, DataPositions& usage);

View File

@@ -15,16 +15,8 @@ namespace GC
template<class T>
TinierSharePrep<T>::TinierSharePrep(DataPositions& usage, int input_player) :
TinierSharePrep<T>(usage, ShareThread<secret_type>::s(), input_player)
{
}
template<class T>
TinierSharePrep<T>::TinierSharePrep(DataPositions& usage,
ShareThread<secret_type>& thread, int input_player) :
PersonalPrep<T>(usage, input_player), triple_generator(0),
real_triple_generator(0),
thread(thread)
real_triple_generator(0)
{
}
@@ -87,6 +79,7 @@ void TinierSharePrep<T>::buffer_inputs(int player)
template<class T>
void GC::TinierSharePrep<T>::buffer_bits()
{
auto& thread = ShareThread<secret_type>::s();
this->bits.push_back(
BufferPrep<T>::get_random_from_inputs(thread.P->num_players()));
}

View File

@@ -14,6 +14,7 @@ template<class T>
void TinierSharePrep<T>::init_real(Player& P)
{
assert(real_triple_generator == 0);
auto& thread = ShareThread<secret_type>::s();
real_triple_generator = new typename T::whole_type::TripleGenerator(
BaseMachine::s().fresh_ot_setup(), P.N, -1,
OnlineOptions::singleton.batch_size, 1, params,
@@ -24,6 +25,7 @@ void TinierSharePrep<T>::init_real(Player& P)
template<class T>
void TinierSharePrep<T>::buffer_secret_triples()
{
auto& thread = ShareThread<secret_type>::s();
auto& triple_generator = real_triple_generator;
assert(triple_generator != 0);
params.generateBits = false;

View File

@@ -58,6 +58,11 @@ public:
return part_type::size() * default_length;
}
static void specification(octetStream& os)
{
T::specification(os);
}
static void read_or_generate_mac_key(string directory, const Player& P,
mac_key_type& key)
{
@@ -150,6 +155,17 @@ public:
return this->get_reg(i);
}
void xor_bit(size_t i, const T& bit)
{
if (i < this->get_regs().size())
XOR(this->get_reg(i), this->get_reg(i), bit);
else
{
this->resize_regs(i + 1);
this->get_reg(i) = bit;
}
}
void output(ostream& s, bool human = true) const
{
assert(this->get_regs().size() == default_length);
@@ -179,6 +195,12 @@ public:
{
inputter.finalize(from, n_bits).mask(*this, n_bits);
}
void random_bit()
{
auto& thread = GC::ShareThread<typename T::whole_type>::s();
*this = thread.DataF.get_part().get_bit();
}
};
template<int S>

View File

@@ -74,13 +74,6 @@ public:
*this = super::constant(input, party.P->my_num(),
party.MC->get_alphai());
}
void random()
{
TinySecret<S> tmp;
this->get_party().DataF.get_one(DATA_BIT, tmp);
*this = tmp.get_reg(0);
}
};
} /* namespace GC */

View File

@@ -158,8 +158,8 @@ GeneratorThread* TripleMachine::new_generator(OTTripleSetup& setup, int i,
{
if (output and i == 0)
{
T::clear::template generate_setup<T>(PREP_DIR, nplayers, 128);
prep_data_dir = get_prep_sub_dir<T>(PREP_DIR, nplayers);
T::clear::write_setup(prep_data_dir);
write_mac_key(prep_data_dir, my_num, nplayers, mac_key);
}

View File

@@ -59,6 +59,9 @@ int main(int argc, const char** argv)
online_opts.live_prep, online_opts).run(); \
break;
X(64) X(128) X(256) X(192) X(384) X(512)
#ifdef RING_SIZE
X(RING_SIZE)
#endif
#undef X
default:
cerr << "Not compiled for " << R << "-bit rings" << endl;

View File

@@ -3,6 +3,8 @@
*
*/
#define NO_MIXED_CIRCUITS
#include "BMR/RealProgramParty.hpp"
#include "Machines/SPDZ.hpp"

View File

@@ -97,16 +97,15 @@ replicated: rep-field rep-ring rep-bin
spdz2k: spdz2k-party.x ot-offline.x Check-Offline-Z2k.x galois-degree.x Fake-Offline.x
mascot: mascot-party.x spdz2k mama-party.x
tldr:
-echo ARCH = -march=native >> CONFIG.mine
$(MAKE) mascot-party.x
ifeq ($(OS), Darwin)
tldr: mac-setup
else
tldr: mpir
tldr: mpir linux-machine-setup
endif
tldr:
$(MAKE) mascot-party.x
ifeq ($(MACHINE), aarch64)
tldr: simde/simde
endif
@@ -144,8 +143,6 @@ static-release: static-dir $(patsubst Machines/%.cpp, static/%.x, $(wildcard Mac
Fake-ECDSA.x: ECDSA/Fake-ECDSA.cpp ECDSA/P256Element.o $(COMMON) Processor/PrepBase.o
$(CXX) -o $@ $^ $(CFLAGS) $(LDLIBS)
Check-Offline.x: $(PROCESSOR)
ot.x: $(OT) $(COMMON) Machines/OText_main.o Machines/OTMachine.o $(LIBSIMPLEOT)
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
@@ -285,11 +282,21 @@ mpir: mpir-setup
-echo MY_CFLAGS += -I./local/include >> CONFIG.mine
-echo MY_LDLIBS += -Wl,-rpath -Wl,$(CURDIR)/local/lib -L$(CURDIR)/local/lib >> CONFIG.mine
mac-setup:
mac-setup: mac-machine-setup
brew install openssl boost libsodium mpir yasm ntl
-echo MY_CFLAGS += -I/usr/local/opt/openssl/include >> CONFIG.mine
-echo MY_LDLIBS += -L/usr/local/opt/openssl/lib >> CONFIG.mine
-echo USE_NTL = 1 >> CONFIG.mine
-echo MY_CFLAGS += -I/usr/local/opt/openssl/include -I/opt/homebrew/opt/openssl/include -I/opt/homebrew/include >> CONFIG.mine
-echo MY_LDLIBS += -L/usr/local/opt/openssl/lib -L/opt/homebrew/lib -L/opt/homebrew/opt/openssl/lib >> CONFIG.mine
# -echo USE_NTL = 1 >> CONFIG.mine
ifeq ($(MACHINE), aarch64)
mac-machine-setup:
-echo ARCH = >> CONFIG.mine
linux-machine-setup:
-echo ARCH = -march=armv8.2-a+crypto >> CONFIG.mine
else
mac-machine-setup:
linux-machine-setup:
endif
simde/simde:
git submodule update --init simde

View File

@@ -49,6 +49,11 @@ public:
return string(1, T::type_char());
}
static void specification(octetStream& os)
{
T::specification(os);
}
template<class U, class V>
static FixedVec Mul(const FixedVec<U, L>& a, const V& b)
{

View File

@@ -35,6 +35,8 @@ public:
static int length() { return N_BITS; }
static string type_string() { return "integer"; }
static void specification(octetStream& os);
static void init_default(int lgp) { (void)lgp; }
static bool allows(Dtype type) { return type <= DATA_BIT; }
@@ -126,6 +128,8 @@ class Integer : public IntBase<long>
Integer(const bigint& x) { *this = (x > 0) ? x.get_ui() : -x.get_ui(); }
template<int K>
Integer(const Z2<K>& x) : Integer(x.get_limb(0)) {}
template<int K>
Integer(const SignedZ2<K>& x);
template<int X, int L>
Integer(const gfp_<X, L>& x);
Integer(int128 x) : Integer(x.get_lower()) {}

View File

@@ -8,6 +8,12 @@
template<class T>
const int IntBase<T>::N_BITS;
template<class T>
inline void IntBase<T>::specification(octetStream& os)
{
os.store(sizeof(T));
}
template<class T>
void IntBase<T>::output(ostream& s,bool human) const
{
@@ -42,6 +48,15 @@ void Integer::reqbl(int n)
}
}
template<int K>
Integer::Integer(const SignedZ2<K>& x)
{
if (K < N_BITS and x.negative())
a = -(-x).get_limb(0);
else
a = x.get_limb(0);
}
inline
Integer::Integer(const Integer& x, int n_bits)
{

View File

@@ -139,6 +139,14 @@ void gf2n_<U>::init_multiplication()
}
template<class U>
void gf2n_<U>::specification(octetStream& os)
{
os.store(sizeof(U));
os.store(degree());
}
/* Takes 8bit x and y and returns the 16 bit product in c1 and c0
ans = (c1<<8)^c0
where c1 and c0 are 8 bit

View File

@@ -76,6 +76,8 @@ protected:
static string type_short() { return "2"; }
static string type_string() { return "gf2n_"; }
static void specification(octetStream& os);
static int size() { return sizeof(a); }
static int size_in_bits() { return sizeof(a) * 8; }

View File

@@ -171,6 +171,29 @@ class gf2n_long : public gf2n_<int128>
}
};
#if defined(__aarch64__) && defined(__clang__)
inline __m128i my_slli(int128 x, int i)
{
if (i < 64)
return int128(x.get_upper() << i, x.get_lower() << i).a;
else
return int128().a;
}
inline __m128i my_srli(int128 x, int i)
{
if (i < 64)
return int128(x.get_upper() >> i, x.get_lower() >> i).a;
else
return int128().a;
}
#undef _mm_slli_epi64
#undef _mm_srli_epi64
#define _mm_slli_epi64 my_slli
#define _mm_srli_epi64 my_srli
#endif
inline int128 int128::operator<<(const int& other) const
{
int128 res(_mm_slli_epi64(a, other));

View File

@@ -15,7 +15,7 @@ Zp_Data gfpvar_<X, L>::ZpD;
template<int X, int L>
string gfpvar_<X, L>::type_string()
{
return "gfpvar";
return "gfp";
}
template<int X, int L>
@@ -30,6 +30,12 @@ char gfpvar_<X, L>::type_char()
return 'p';
}
template<int X, int L>
void gfpvar_<X, L>::specification(octetStream& os)
{
os.store(pr());
}
template<int X, int L>
int gfpvar_<X, L>::length()
{

View File

@@ -44,6 +44,8 @@ public:
static string type_short();
static char type_char();
static void specification(octetStream& os);
static int length();
static int size();
static int size_in_bits();

View File

@@ -14,21 +14,43 @@ void check_ssl_file(string filename)
"You can use `Scripts/setup-ssl.sh <nparties>`.");
}
void ssl_error(string side, string pronoun, string other, string server)
void ssl_error(string side, string other, string me)
{
cerr << side << "-side handshake with " << other
<< " failed. Make sure " << pronoun
<< " have the necessary certificate (" << PREP_DIR << server
<< ".pem in the default configuration),"
<< " failed. Make sure both sides "
<< " have the necessary certificate (" << PREP_DIR << me
<< ".pem in the default configuration on their side and "
<< PREP_DIR << other << ".pem on ours),"
<< " and run `c_rehash <directory>` on its location." << endl
<< "The certificates should be the same on every host. "
<< "Also make sure that it's still valid. Certificates generated "
<< "with `Scripts/setup-ssl.sh` expire after a month." << endl;
cerr << "See also "
"https://mp-spdz.readthedocs.io/en/latest/troubleshooting.html"
"#handshake-failures" << endl;
string ids[2];
ids[side == "Client"] = other;
ids[side != "Client"] = me;
cerr << "Signature (should match the other side): ";
for (int i = 0; i < 2; i++)
{
auto filename = PREP_DIR + ids[i] + ".pem";
ifstream cert(filename);
stringstream buffer;
buffer << cert.rdbuf();
if (buffer.str().empty())
cerr << "<'" << filename << "' not found>";
else
cerr << octetStream(buffer.str()).hash();
if (i == 0)
cerr << "/";
}
cerr << endl;
}
CryptoPlayer::CryptoPlayer(const Names& Nms, const string& id_base) :
MultiPlayer<ssl_socket*>(Nms), plaintext_player(Nms, id_base),
other_player(Nms, id_base + "recv"),
MultiPlayer<ssl_socket*>(Nms),
ctx("P" + to_string(my_num()))
{
sockets.resize(num_players());
@@ -36,6 +58,16 @@ CryptoPlayer::CryptoPlayer(const Names& Nms, const string& id_base) :
senders.resize(num_players());
receivers.resize(num_players());
vector<int> plaintext_sockets[2];
for (int i = 0; i < 2; i++)
{
PlainPlayer player(Nms, id_base + (i ? "recv" : ""));
plaintext_sockets[i] = player.sockets;
close_client_socket(player.socket(my_num()));
player.sockets.clear();
}
for (int i = 0; i < (int)sockets.size(); i++)
{
if (i == my_num())
@@ -47,9 +79,9 @@ CryptoPlayer::CryptoPlayer(const Names& Nms, const string& id_base) :
continue;
}
sockets[i] = new ssl_socket(io_service, ctx, plaintext_player.socket(i),
sockets[i] = new ssl_socket(io_service, ctx, plaintext_sockets[0][i],
"P" + to_string(i), "P" + to_string(my_num()), i < my_num());
other_sockets[i] = new ssl_socket(io_service, ctx, other_player.socket(i),
other_sockets[i] = new ssl_socket(io_service, ctx, plaintext_sockets[1][i],
"P" + to_string(i), "P" + to_string(my_num()), i < my_num());
senders[i] = new Sender<ssl_socket*>(i < my_num() ? sockets[i] : other_sockets[i]);
@@ -64,10 +96,6 @@ CryptoPlayer::CryptoPlayer(const Names& Nms, int id_base) :
CryptoPlayer::~CryptoPlayer()
{
close_client_socket(plaintext_player.socket(my_num()));
close_client_socket(other_player.socket(my_num()));
plaintext_player.sockets.clear();
other_player.sockets.clear();
for (int i = 0; i < num_players(); i++)
{
delete sockets[i];

View File

@@ -20,7 +20,6 @@
*/
class CryptoPlayer : public MultiPlayer<ssl_socket*>
{
PlainPlayer plaintext_player, other_player;
ssl_ctx ctx;
boost::asio::io_service io_service;

View File

@@ -373,6 +373,7 @@ protected:
T send_to_self_socket;
T socket_to_send(int player) const { return player == player_no ? send_to_self_socket : sockets[player]; }
T socket(int i) const { return sockets[i]; }
friend class CryptoPlayer;
@@ -381,8 +382,6 @@ public:
virtual ~MultiPlayer();
T socket(int i) const { return sockets[i]; }
// Send/Receive data to/from player i
void send_long(int i, long a) const;
long receive_long(int i) const;

View File

@@ -109,7 +109,9 @@ void set_up_client_socket(int& mysocket,const char* hostname,int Portnum)
throw runtime_error(
string() + "cannot connect from " + my_name + " to " + hostname + ":"
+ to_string(Portnum) + " after " + to_string(attempts)
+ " attempts in one minute because " + strerror(connect_errno));
+ " attempts in one minute because " + strerror(connect_errno) + ". "
"https://mp-spdz.readthedocs.io/en/latest/troubleshooting.html#"
"connection-failures has more information on port requirements.");
}
freeaddrinfo(ai);

View File

@@ -16,7 +16,7 @@
typedef boost::asio::io_service ssl_service;
void check_ssl_file(string filename);
void ssl_error(string side, string pronoun, string other, string server);
void ssl_error(string side, string other, string server);
class ssl_ctx : public boost::asio::ssl::context
{
@@ -55,7 +55,7 @@ public:
handshake(ssl_socket::client);
} catch (...)
{
ssl_error("Client", "we", other, other);
ssl_error("Client", other, me);
throw;
}
else
@@ -65,7 +65,7 @@ public:
handshake(ssl_socket::server);
} catch (...)
{
ssl_error("Server", "they", other, me);
ssl_error("Server", other, me);
throw;
}

View File

@@ -187,7 +187,13 @@ void NPartyTripleGenerator<T>::generate()
if (thread_num != 0)
ss << "-" << thread_num;
if (machine.output)
{
outputFile.open(ss.str().c_str());
if (machine.generateMACs or not T::clear::invertible)
file_signature<T>().output(outputFile);
else
file_signature<typename T::clear>().output(outputFile);
}
if (machine.generateBits)
generateBits();

View File

@@ -89,6 +89,13 @@ DataPositions DataPositions::operator-(const DataPositions& other) const
return res;
}
DataPositions DataPositions::operator+(const DataPositions& other) const
{
DataPositions res = *this;
res.increase(other);
return res;
}
void DataPositions::print_cost() const
{
ifstream file("cost");

View File

@@ -19,6 +19,11 @@ using namespace std;
template<class T> class dabit;
namespace GC
{
template<class T> class ShareThread;
}
class DataTag
{
int t[4];
@@ -74,6 +79,7 @@ public:
void increase(const DataPositions& delta);
DataPositions& operator-=(const DataPositions& delta);
DataPositions operator-(const DataPositions& delta) const;
DataPositions operator+(const DataPositions& delta) const;
void print_cost() const;
bool empty() const;
bool any_more(const DataPositions& other) const;
@@ -84,10 +90,15 @@ template<class sint, class sgf2n> class Data_Files;
template<class sint, class sgf2n> class Machine;
template<class T> class SubProcessor;
/**
* Abstract base class for preprocessing
*/
template<class T>
class Preprocessing : public PrepBase
{
protected:
static const bool use_part = false;
DataPositions& usage;
map<pair<bool, int>, vector<edabitvec<T>>> edabits;
@@ -114,6 +125,8 @@ public:
template<class U, class V>
static Preprocessing<T>* get_new(Machine<U, V>& machine, DataPositions& usage,
SubProcessor<T>* proc);
static Preprocessing<T>* get_new(bool live_prep, const Names& N,
DataPositions& usage);
static Preprocessing<T>* get_live_prep(SubProcessor<T>* proc,
DataPositions& usage);
@@ -144,11 +157,15 @@ public:
void get_input(T& a, typename T::open_type& x, int i);
void get(vector<T>& S, DataTag tag, const vector<int>& regs, int vector_size);
/// Get fresh random multiplication triple
virtual array<T, 3> get_triple(int n_bits);
virtual array<T, 3> get_triple_no_count(int n_bits);
/// Get fresh random bit
virtual T get_bit();
/// Get fresh random value in domain
virtual T get_random();
virtual void get_dabit(T&, typename T::bit_type&);
/// Store fresh daBit in ``a`` (arithmetic part) and ``b`` (binary part)
virtual void get_dabit(T& a, typename T::bit_type& b);
virtual void get_dabit_no_count(T&, typename T::bit_type&) { throw runtime_error("no daBit"); }
virtual void get_edabits(bool strict, size_t size, T* a,
vector<typename T::bit_type>& Sb, const vector<int>& regs)
@@ -156,6 +173,7 @@ public:
template<int>
void get_edabit_no_count(bool, int n_bits, edabit<T>& eb);
template<int>
/// Get fresh edaBit chunk
edabitvec<T> get_edabitvec(bool strict, int n_bits);
virtual void buffer_edabits_with_queues(bool, int) { throw runtime_error("no edaBits"); }
@@ -270,13 +288,14 @@ class Data_Files
Preprocessing<sint>& DataFp;
Preprocessing<sgf2n>& DataF2;
Preprocessing<typename sint::bit_type>& DataFb;
Data_Files(Machine<sint, sgf2n>& machine, SubProcessor<sint>* procp = 0,
SubProcessor<sgf2n>* proc2 = 0);
Data_Files(const Names& N);
~Data_Files();
DataPositions tellg();
DataPositions tellg() { return usage; }
void seekg(DataPositions& pos);
void skip(const DataPositions& pos);
void prune();
@@ -289,7 +308,7 @@ class Data_Files
void reset_usage() { usage.reset(); skipped.reset(); }
NamedCommStats comm_stats() { return DataFp.comm_stats() + DataF2.comm_stats(); }
NamedCommStats comm_stats();
};
template<class T> inline
@@ -407,6 +426,13 @@ inline void Data_Files<sint, sgf2n>::purge()
{
DataFp.purge();
DataF2.purge();
DataFb.purge();
}
template<class sint, class sgf2n>
NamedCommStats Data_Files<sint, sgf2n>::comm_stats()
{
return DataFp.comm_stats() + DataF2.comm_stats() + DataFb.comm_stats();
}
#endif

View File

@@ -5,6 +5,7 @@
#include "Processor/Processor.h"
#include "Protocols/dabit.h"
#include "Math/Setup.h"
#include "GC/BitPrepFiles.h"
#include "Protocols/MascotPrep.hpp"
@@ -28,6 +29,19 @@ Preprocessing<T>* Preprocessing<T>::get_new(
machine.template prep_dir_prefix<T>(), usage);
}
template<class T>
Preprocessing<T>* Preprocessing<T>::get_new(
bool live_prep, const Names& N,
DataPositions& usage)
{
if (live_prep)
return new typename T::LivePrep(usage);
else
return new GC::BitPrepFiles<T>(N,
get_prep_sub_dir<T>(PREP_DIR, N.num_players()), usage,
BaseMachine::thread_num);
}
template<class T>
Sub_Data_Files<T>::Sub_Data_Files(const Names& N, DataPositions& usage,
int thread_num) :
@@ -96,7 +110,7 @@ Sub_Data_Files<T>::Sub_Data_Files(int my_num, int num_players,
dabit_buffer.setup(
PrepBase::get_filename(prep_data_dir, DATA_DABIT,
type_short, my_num, thread_num), 1, type_string,
type_short, my_num, thread_num), dabit<T>::size(), type_string,
DataPositions::dtype_names[DATA_DABIT]);
input_buffers.resize(num_players);
@@ -106,7 +120,7 @@ Sub_Data_Files<T>::Sub_Data_Files(int my_num, int num_players,
type_short, i, my_num, thread_num);
if (i == my_num)
my_input_buffers.setup(filename,
T::size() * 3 / 2, type_string);
T::size() + T::clear::size(), type_string);
else
input_buffers[i].setup(filename,
T::size(), type_string);
@@ -122,7 +136,10 @@ Data_Files<sint, sgf2n>::Data_Files(Machine<sint, sgf2n>& machine, SubProcessor<
SubProcessor<sgf2n>* proc2) :
usage(machine.get_N().num_players()),
DataFp(*Preprocessing<sint>::get_new(machine, usage, procp)),
DataF2(*Preprocessing<sgf2n>::get_new(machine, usage, proc2))
DataF2(*Preprocessing<sgf2n>::get_new(machine, usage, proc2)),
DataFb(
*Preprocessing<typename sint::bit_type>::get_new(machine.live_prep,
machine.get_N(), usage))
{
}
@@ -130,7 +147,8 @@ template<class sint, class sgf2n>
Data_Files<sint, sgf2n>::Data_Files(const Names& N) :
usage(N.num_players()),
DataFp(*new Sub_Data_Files<sint>(N, usage)),
DataF2(*new Sub_Data_Files<sgf2n>(N, usage))
DataF2(*new Sub_Data_Files<sgf2n>(N, usage)),
DataFb(*new Sub_Data_Files<typename sint::bit_type>(N, usage))
{
}
@@ -150,6 +168,7 @@ Data_Files<sint, sgf2n>::~Data_Files()
DataF2.data_sent() * 1e-6 << " MB" << endl;
#endif
delete &DataF2;
delete &DataFb;
}
template<class T>
@@ -166,6 +185,12 @@ Sub_Data_Files<T>::~Sub_Data_Files()
template<class T>
void Sub_Data_Files<T>::seekg(DataPositions& pos)
{
if (T::LivePrep::use_part)
{
get_part().seekg(pos);
return;
}
DataFieldType field_type = T::clear::field_type();
for (int dtype = 0; dtype < N_DTYPE; dtype++)
if (T::clear::allows(Dtype(dtype)))
@@ -181,6 +206,7 @@ void Sub_Data_Files<T>::seekg(DataPositions& pos)
setup_extended(it->first);
extended[it->first].seekg(it->second);
}
dabit_buffer.seekg(pos.files[field_type][DATA_DABIT]);
}
template<class sint, class sgf2n>
@@ -188,6 +214,7 @@ void Data_Files<sint, sgf2n>::seekg(DataPositions& pos)
{
DataFp.seekg(pos);
DataF2.seekg(pos);
DataFb.seekg(pos);
usage = pos;
}
@@ -210,6 +237,9 @@ void Sub_Data_Files<T>::prune()
input_buffers[j].prune();
for (auto it : extended)
it.second.prune();
dabit_buffer.prune();
if (part != 0)
part->prune();
}
template<class sint, class sgf2n>
@@ -217,6 +247,7 @@ void Data_Files<sint, sgf2n>::prune()
{
DataFp.prune();
DataF2.prune();
DataFb.prune();
}
template<class T>
@@ -229,6 +260,7 @@ void Sub_Data_Files<T>::purge()
input_buffers[j].purge();
for (auto it : extended)
it.second.purge();
dabit_buffer.purge();
}
template<class T>
@@ -280,11 +312,12 @@ void Sub_Data_Files<T>::buffer_edabits_with_queues(bool strict, int n_bits,
ifstream* f = new ifstream(filename);
if (f->fail())
throw runtime_error("cannot open " + filename);
check_file_signature<T>(*f, filename);
edabit_buffers[n_bits] = f;
}
auto& buffer = *edabit_buffers[n_bits];
if (buffer.peek() == EOF)
buffer.seekg(0);
buffer.seekg(file_signature<T>().get_length());
edabitvec<T> eb;
eb.input(n_bits, buffer);
this->edabits[{strict, n_bits}].push_back(eb);

View File

@@ -15,6 +15,9 @@ using namespace std;
class ArithmeticProcessor;
/**
* Abstract base for input protocols
*/
template<class T>
class InputBase
{
@@ -45,18 +48,28 @@ public:
InputBase(SubProcessor<T>* proc);
virtual ~InputBase();
/// Initialize input round for ``player``
virtual void reset(int player) = 0;
/// Initialize input round for all players
void reset_all(Player& P);
/// Schedule input from me
virtual void add_mine(const typename T::open_type& input, int n_bits = -1) = 0;
/// Schedule input from other player
virtual void add_other(int player, int n_bits = -1) = 0;
/// Schedule input from all players
void add_from_all(const clear& input);
/// Send my inputs
virtual void send_mine() = 0;
/// Run input protocol for all players
virtual void exchange();
/// Get share for next input of mine
virtual T finalize_mine() = 0;
/// Store share for next input from ``player`` from buffer ``o`` in ``target``
virtual void finalize_other(int player, T& target, octetStream& o, int n_bits = -1) = 0;
/// Get share for next input from ``player`
virtual T finalize(int player, int n_bits = -1);
void raw_input(SubProcessor<T>& proc, const vector<int>& args, int size);

View File

@@ -19,6 +19,11 @@ struct InputTuple
static string type_string()
{ return T::type_string(); }
static void specification(octetStream& os)
{
T::specification(os);
}
InputTuple() {}
InputTuple(const T& share, const typename T::open_type& value) : share(share), value(value) {}

View File

@@ -14,6 +14,7 @@ using namespace std;
template<class sint, class sgf2n> class Machine;
template<class sint, class sgf2n> class Processor;
class ArithmeticProcessor;
class SwitchableOutput;
/*
* Opcode constants
@@ -306,12 +307,6 @@ enum RegType {
MAX_REG_TYPE,
};
enum SecrecyType {
SECRET,
CLEAR,
MAX_SECRECY_TYPE
};
template<class sint, class sgf2n>
struct TempVars {
typename sgf2n::clear ans2;
@@ -387,6 +382,10 @@ public:
void shuffle(ArithmeticProcessor& Proc) const;
void bitdecint(ArithmeticProcessor& Proc) const;
template<class T>
void print(SwitchableOutput& out, T* v, T* p = 0, T* s = 0, T* z = 0,
T* nan = 0) const;
};
#endif

View File

@@ -328,6 +328,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
// write to external client, input is : opcode num_args, client_id, message_type, var1, var2 ...
case WRITESOCKETC:
case WRITESOCKETS:
case WRITESOCKETSHARE:
case WRITESOCKETINT:
num_var_args = get_int(s) - 3;
@@ -336,8 +337,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
n = get_int(s);
get_vector(num_var_args, start, s);
break;
case WRITESOCKETS:
throw runtime_error("sending MACs to client not supported any more");
case READCLIENTPUBLICKEY:
case INITSECURESOCKET:
case RESPSECURESOCKET:
@@ -1070,31 +1069,19 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
}
break;
case PRINTREGPLAIN:
{
Proc.out << Proc.read_Cp(r[0]) << flush;
}
break;
print(Proc.out, &Proc.read_Cp(r[0]));
return;
case CONDPRINTPLAIN:
if (not Proc.read_Cp(r[0]).is_zero())
{
auto v = Proc.read_Cp(r[1]);
auto p = Proc.read_Cp(r[2]);
if (p.is_zero())
Proc.out << v << flush;
else
Proc.out << bigint::get_float(v, p, {}, {}) << flush;
print(Proc.out, &Proc.read_Cp(r[1]), &Proc.read_Cp(r[2]));
}
break;
return;
case PRINTFLOATPLAIN:
{
auto nan = Proc.read_Cp(start[4]);
typename sint::clear v = Proc.read_Cp(start[0]);
typename sint::clear p = Proc.read_Cp(start[1]);
typename sint::clear z = Proc.read_Cp(start[2]);
typename sint::clear s = Proc.read_Cp(start[3]);
bigint::output_float(Proc.out, bigint::get_float(v, p, z, s), nan);
}
break;
print(Proc.out, &Proc.read_Cp(start[0]), &Proc.read_Cp(start[1]),
&Proc.read_Cp(start[2]), &Proc.read_Cp(start[3]),
&Proc.read_Cp(start[4]));
return;
case CONDPRINTSTR:
if (not Proc.read_Cp(r[0]).is_zero())
{
@@ -1124,9 +1111,7 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
Proc.machine.stop(n);
break;
case RUN_TAPE:
Proc.DataF.skip(
Proc.machine.run_tapes(start, &Proc.DataF.DataFp,
&Proc.share_thread.DataF));
Proc.machine.run_tapes(start, Proc.DataF);
break;
case JOIN_TAPE:
Proc.machine.join_tape(r[0]);
@@ -1186,15 +1171,19 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
Proc.read_socket_private(Proc.read_Ci(r[0]), start, n, true);
break;
case WRITESOCKETINT:
Proc.write_socket(INT, Proc.read_Ci(r[0]), r[1], start, n);
Proc.write_socket(INT, false, Proc.read_Ci(r[0]), r[1], start, n);
break;
case WRITESOCKETC:
Proc.write_socket(CINT, Proc.read_Ci(r[0]), r[1], start, n);
Proc.write_socket(CINT, false, Proc.read_Ci(r[0]), r[1], start, n);
break;
case WRITESOCKETS:
// Send shares + MACs
Proc.write_socket(SINT, true, Proc.read_Ci(r[0]), r[1], start, n);
break;
case WRITESOCKETSHARE:
// Send only shares, no MACs
// N.B. doesn't make sense to have a corresponding read instruction for this
Proc.write_socket(SINT, Proc.read_Ci(r[0]), r[1], start, n);
Proc.write_socket(SINT, false, Proc.read_Ci(r[0]), r[1], start, n);
break;
case WRITEFILESHARE:
// Write shares to file system
@@ -1323,4 +1312,29 @@ void Program::execute(Processor<sint, sgf2n>& Proc) const
}
}
template<class T>
void Instruction::print(SwitchableOutput& out, T* v, T* p, T* s, T* z, T* nan) const
{
if (size > 1)
out << "[";
for (int i = 0; i < size; i++)
{
if (p == 0)
out << v[i];
else if (s == 0)
out << bigint::get_float(v[i], p[i], {}, {});
else
{
assert(z != 0);
assert(nan != 0);
bigint::output_float(out, bigint::get_float(v[i], p[i], s[i], z[i]),
nan[i]);
}
if (i < size - 1)
out << ", ";
}
if (size > 1)
out << "]";
}
#endif

View File

@@ -42,9 +42,6 @@ class Machine : public BaseMachine
typename sgf2n::mac_key_type alpha2i;
typename sint::bit_type::mac_key_type alphabi;
// Keep record of used offline data
DataPositions pos;
Player* P;
void load_program(const string& threadname, const string& filename);
@@ -83,8 +80,8 @@ class Machine : public BaseMachine
const Names& get_N() { return N; }
DataPositions run_tapes(const vector<int> &args, Preprocessing<sint> *prep,
Preprocessing<typename sint::bit_type> *bit_prep);
DataPositions run_tapes(const vector<int> &args,
Data_Files<sint, sgf2n>& DataF);
void fill_buffers(int thread_number, int tape_number,
Preprocessing<sint> *prep,
Preprocessing<typename sint::bit_type> *bit_prep);
@@ -93,7 +90,8 @@ class Machine : public BaseMachine
Preprocessing<sint> *prep, true_type);
template<int = 0>
void fill_matmul(int, int, Preprocessing<sint>*, false_type) {}
DataPositions run_tape(int thread_number, int tape_number, int arg);
DataPositions run_tape(int thread_number, int tape_number, int arg,
const DataPositions& pos);
DataPositions join_tape(int thread_number);
void run();

View File

@@ -92,9 +92,6 @@ Machine<sint, sgf2n>::Machine(int my_number, Names& playerNames,
exit(1);
}
// Keep record of used offline data
pos.set_num_players(N.num_players());
load_schedule(progname_str);
// remove persistence if necessary
@@ -161,14 +158,16 @@ void Machine<sint, sgf2n>::load_program(const string& threadname,
template<class sint, class sgf2n>
DataPositions Machine<sint, sgf2n>::run_tapes(const vector<int>& args,
Preprocessing<sint>* prep, Preprocessing<typename sint::bit_type>* bit_prep)
Data_Files<sint, sgf2n>& DataF)
{
assert(args.size() % 3 == 0);
for (unsigned i = 0; i < args.size(); i += 3)
fill_buffers(args[i], args[i + 1], prep, bit_prep);
fill_buffers(args[i], args[i + 1], &DataF.DataFp, &DataF.DataFb);
DataPositions res(N.num_players());
for (unsigned i = 0; i < args.size(); i += 3)
res.increase(run_tape(args[i], args[i + 1], args[i + 2]));
res.increase(
run_tape(args[i], args[i + 1], args[i + 2], DataF.tellg() + res));
DataF.skip(res);
return res;
}
@@ -281,7 +280,7 @@ void Machine<sint, sgf2n>::fill_matmul(int thread_number, int tape_number,
template<class sint, class sgf2n>
DataPositions Machine<sint, sgf2n>::run_tape(int thread_number, int tape_number,
int arg)
int arg, const DataPositions& pos)
{
if (size_t(thread_number) >= tinfo.size())
throw overflow("invalid thread number", thread_number, tinfo.size());
@@ -294,7 +293,7 @@ DataPositions Machine<sint, sgf2n>::run_tape(int thread_number, int tape_number,
if (progs[tape_number].usage_unknown())
{
#ifndef INSECURE
if (not opts.live_prep)
if (not opts.live_prep and thread_number != 0)
{
cerr << "Internally called tape " << tape_number <<
" has unknown offline data usage" << endl;
@@ -328,7 +327,7 @@ void Machine<sint, sgf2n>::run()
timer[0].start();
// run main tape
pos.increase(run_tape(0, 0, 0));
run_tape(0, 0, 0, N.num_players());
join_tape(0);
print_compiler();
@@ -341,8 +340,8 @@ void Machine<sint, sgf2n>::run()
queues[i]->schedule(-1);
}
// reset to sum actual usage
pos.reset();
// sum actual usage
DataPositions pos(N.num_players());
#ifdef DEBUG_THREADS
cerr << "Waiting for all clients to finish" << endl;

View File

@@ -15,22 +15,29 @@ template<class T> istream& operator>>(istream& s,Memory<T>& M);
#include "Processor/Program.h"
#include "Tools/CheckVector.h"
template<class T>
class MemoryPart : public CheckVector<T>
{
public:
void minimum_size(size_t size);
};
template<class T>
class Memory
{
public:
CheckVector<T> MS;
CheckVector<typename T::clear> MC;
MemoryPart<T> MS;
MemoryPart<typename T::clear> MC;
void resize_s(int sz)
void resize_s(size_t sz)
{ MS.resize(sz); }
void resize_c(int sz)
void resize_c(size_t sz)
{ MC.resize(sz); }
unsigned size_s()
size_t size_s()
{ return MS.size(); }
unsigned size_c()
size_t size_c()
{ return MC.size(); }
template<class U>
@@ -40,23 +47,23 @@ class Memory
throw overflow("memory", i, M.size());
}
const typename T::clear& read_C(int i) const
const typename T::clear& read_C(size_t i) const
{
check_index(MC, i);
return MC[i];
}
const T& read_S(int i) const
const T& read_S(size_t i) const
{
check_index(MS, i);
return MS[i];
}
void write_C(unsigned int i,const typename T::clear& x)
void write_C(size_t i,const typename T::clear& x)
{
check_index(MC, i);
MC[i]=x;
}
void write_S(unsigned int i,const T& x)
void write_S(size_t i,const T& x)
{
check_index(MS, i);
MS[i]=x;

View File

@@ -8,27 +8,23 @@ void Memory<T>::minimum_size(RegType secret_type, RegType clear_type,
const Program &program, const string& threadname)
{
(void) threadname;
unsigned sizes[MAX_SECRECY_TYPE];
sizes[SECRET]= program.direct_mem(secret_type);
sizes[CLEAR] = program.direct_mem(clear_type);
if (sizes[SECRET] > size_s())
{
#ifdef DEBUG_MEMORY
cerr << threadname << " needs more secret " << T::type_string() << " memory, resizing to "
<< sizes[SECRET] << endl;
#endif
resize_s(sizes[SECRET]);
}
if (sizes[CLEAR] > size_c())
{
#ifdef DEBUG_MEMORY
cerr << threadname << " needs more clear " << T::type_string() << " memory, resizing to "
<< sizes[CLEAR] << endl;
#endif
resize_c(sizes[CLEAR]);
}
MS.minimum_size(program.direct_mem(secret_type));
MC.minimum_size(program.direct_mem(clear_type));
}
template<class T>
void MemoryPart<T>::minimum_size(size_t size)
{
try
{
if (size > this->size())
this->resize(size);
}
catch (bad_alloc&)
{
throw insufficient_memory(size, T::type_string());
}
}
template<class T>
ostream& operator<<(ostream& s,const Memory<T>& M)

View File

@@ -8,6 +8,7 @@
#include "OfflineMachine.h"
#include "Protocols/mac_key.hpp"
#include "Tools/Buffer.h"
template<class W>
template<class V>
@@ -39,8 +40,8 @@ int OfflineMachine<W>::run()
T::bit_type::mac_key_type::init_field();
auto binary_mac_key = read_generate_write_mac_key<
typename T::bit_type::part_type>(P);
GC::ShareThread<typename T::bit_type> thread(playerNames,
OnlineOptions::singleton, P, binary_mac_key, usage);
typename T::bit_type::LivePrep bit_prep(usage);
GC::ShareThread<typename T::bit_type> thread(bit_prep, P, binary_mac_key);
generate<T>();
generate<typename T::bit_type::part_type>();
@@ -74,6 +75,7 @@ void OfflineMachine<W>::generate()
if (my_usage > 0)
{
ofstream out(filename, iostream::out | iostream::binary);
file_signature<T>().output(out);
if (i == DATA_DABIT)
{
for (long long j = 0;
@@ -108,6 +110,7 @@ void OfflineMachine<W>::generate()
if (n_inputs > 0)
{
ofstream out(filename, iostream::out | iostream::binary);
file_signature<T>().output(out);
InputTuple<T> tuple;
for (long long j = 0;
j < DIV_CEIL(n_inputs, BUFFER_SIZE) * BUFFER_SIZE; j++)
@@ -138,6 +141,7 @@ void OfflineMachine<W>::generate()
if (total > 0)
{
ofstream out(filename, ios::binary);
file_signature<T>().output(out);
for (int i = 0; i < DIV_CEIL(total, batch) * batch; i++)
preprocessing.template get_edabitvec<0>(true, n_bits).output(n_bits,
out);

View File

@@ -26,7 +26,7 @@ class thread_info
static void* Main_Func(void *ptr);
static void purge_preprocessing(const Names& N);
static void purge_preprocessing(const Names& N, int thread_num);
template<class T>
static void print_usage(ostream& o, const vector<T>& regs,

View File

@@ -352,7 +352,7 @@ void* thread_info<sint, sgf2n>::Main_Func(void* ptr)
catch (...)
{
thread_info<sint, sgf2n>* ti = (thread_info<sint, sgf2n>*)ptr;
ti->purge_preprocessing(ti->machine->get_N());
ti->purge_preprocessing(ti->machine->get_N(), ti->thread_num);
throw;
}
#endif
@@ -361,13 +361,17 @@ void* thread_info<sint, sgf2n>::Main_Func(void* ptr)
template<class sint, class sgf2n>
void thread_info<sint, sgf2n>::purge_preprocessing(const Names& N)
void thread_info<sint, sgf2n>::purge_preprocessing(const Names& N, int thread_num)
{
cerr << "Purging preprocessed data because something is wrong" << endl;
try
{
Data_Files<sint, sgf2n> df(N);
df.purge();
DataPositions pos;
Sub_Data_Files<typename sint::bit_type> bit_df(N, pos, thread_num);
bit_df.get_part();
bit_df.purge();
}
catch(...)
{

View File

@@ -249,7 +249,7 @@ int OnlineMachine::run()
catch(...)
{
if (not online_opts.live_prep)
thread_info<T, U>::purge_preprocessing(playerNames);
thread_info<T, U>::purge_preprocessing(playerNames, 0);
throw;
}
#endif

View File

@@ -36,13 +36,9 @@ OnlineOptions::OnlineOptions() : playerno(-1)
}
OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
const char** argv, int default_batch_size, bool default_live_prep,
bool variable_prime_length) :
const char** argv, false_type) :
OnlineOptions()
{
if (default_batch_size <= 0)
default_batch_size = batch_size;
opt.syntax = std::string(argv[0]) + " [OPTIONS] [<playerno>] <progname>";
opt.add(
@@ -78,6 +74,58 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
"--output-file" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"This player's number (required if not given before program name)", // Help description.
"-p", // Flag token.
"--player" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Verbose output", // Help description.
"-v", // Flag token.
"--verbose" // Flag token.
);
opt.add(
"4", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Batch size for sacrifice (3-5, default: 4)", // Help description.
"-B", // Flag token.
"--bucket-size" // Flag token.
);
opt.parse(argc, argv);
interactive = opt.isSet("-I");
opt.get("-IF")->getString(cmd_private_input_file);
opt.get("-OF")->getString(cmd_private_output_file);
opt.get("--bucket-size")->getInt(bucket_size);
#ifndef VERBOSE
verbose = opt.isSet("--verbose");
#endif
opt.resetArgs();
}
OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
const char** argv, int default_batch_size, bool default_live_prep,
bool variable_prime_length) :
OnlineOptions(opt, argc, argv, false_type())
{
if (default_batch_size <= 0)
default_batch_size = batch_size;
string default_lgp = to_string(lgp);
if (variable_prime_length)
{
@@ -121,15 +169,6 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
"-L", // Flag token.
"--live-preprocessing" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"This player's number (required if not given before program name)", // Help description.
"-p", // Flag token.
"--player" // Flag token.
);
opt.add(
to_string(default_batch_size).c_str(), // Default.
@@ -170,28 +209,9 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
"-d", // Flag token.
"--direct" // Flag token.
);
opt.add(
"4", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Batch size for sacrifice (3-5, default: 4)", // Help description.
"-B", // Flag token.
"--bucket-size" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Verbose output", // Help description.
"-v", // Flag token.
"--verbose" // Flag token.
);
opt.parse(argc, argv);
interactive = opt.isSet("-I");
if (variable_prime_length)
{
opt.get("--lgp")->getInt(lgp);
@@ -208,17 +228,8 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc,
opt.get("--memory")->getString(memtype);
bits_from_squares = opt.isSet("-Q");
opt.get("-IF")->getString(cmd_private_input_file);
opt.get("-OF")->getString(cmd_private_output_file);
direct = opt.isSet("--direct");
opt.get("--bucket-size")->getInt(bucket_size);
#ifndef VERBOSE
verbose = opt.isSet("--verbose");
#endif
opt.resetArgs();
}

View File

@@ -31,6 +31,8 @@ public:
bool verbose;
OnlineOptions();
OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv,
false_type);
OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv,
int default_batch_size = 0, bool default_live_prep = true,
bool variable_prime_length = false);

View File

@@ -9,15 +9,8 @@
string PrepBase::get_suffix(int thread_num)
{
#ifdef INSECURE
(void) thread_num;
return "";
#else
if (thread_num >= 0)
return "-T" + to_string(thread_num);
else
return "";
#endif
}
string PrepBase::get_filename(const string& prep_data_dir,

View File

@@ -31,7 +31,7 @@ class SubProcessor
DataPositions bit_usage;
void resize(int size) { C.resize(size); S.resize(size); }
void resize(size_t size) { C.resize(size); S.resize(size); }
template<class sint, class sgf2n> friend class Processor;
template<class U> friend class SPDZ;
@@ -64,10 +64,10 @@ public:
void muls(const vector<int>& reg, int size);
void mulrs(const vector<int>& reg);
void dotprods(const vector<int>& reg, int size);
void matmuls(const vector<T>& source, const Instruction& instruction, int a,
int b);
void matmulsm(const CheckVector<T>& source, const Instruction& instruction, int a,
int b);
void matmuls(const vector<T>& source, const Instruction& instruction, size_t a,
size_t b);
void matmulsm(const CheckVector<T>& source, const Instruction& instruction, size_t a,
size_t b);
void conv2ds(const Instruction& instruction);
void input_personal(const vector<int>& args);
@@ -82,12 +82,12 @@ public:
return C;
}
T& get_S_ref(int i)
T& get_S_ref(size_t i)
{
return S[i];
}
typename T::clear& get_C_ref(int i)
typename T::clear& get_C_ref(size_t i)
{
return C[i];
}
@@ -136,11 +136,11 @@ public:
return thread_num;
}
const long& read_Ci(int i) const
const long& read_Ci(size_t i) const
{ return Ci[i]; }
long& get_Ci_ref(int i)
long& get_Ci_ref(size_t i)
{ return Ci[i]; }
void write_Ci(int i,const long& x)
void write_Ci(size_t i, const long& x)
{ Ci[i]=x; }
CheckVector<long>& get_Ci()
{ return Ci; }
@@ -190,30 +190,30 @@ class Processor : public ArithmeticProcessor
const Program& program);
~Processor();
const typename sgf2n::clear& read_C2(int i) const
const typename sgf2n::clear& read_C2(size_t i) const
{ return Proc2.C[i]; }
const sgf2n& read_S2(int i) const
const sgf2n& read_S2(size_t i) const
{ return Proc2.S[i]; }
typename sgf2n::clear& get_C2_ref(int i)
typename sgf2n::clear& get_C2_ref(size_t i)
{ return Proc2.C[i]; }
sgf2n& get_S2_ref(int i)
sgf2n& get_S2_ref(size_t i)
{ return Proc2.S[i]; }
void write_C2(int i,const typename sgf2n::clear& x)
void write_C2(size_t i,const typename sgf2n::clear& x)
{ Proc2.C[i]=x; }
void write_S2(int i,const sgf2n& x)
void write_S2(size_t i,const sgf2n& x)
{ Proc2.S[i]=x; }
const typename sint::clear& read_Cp(int i) const
const typename sint::clear& read_Cp(size_t i) const
{ return Procp.C[i]; }
const sint & read_Sp(int i) const
const sint & read_Sp(size_t i) const
{ return Procp.S[i]; }
typename sint::clear& get_Cp_ref(int i)
typename sint::clear& get_Cp_ref(size_t i)
{ return Procp.C[i]; }
sint & get_Sp_ref(int i)
sint & get_Sp_ref(size_t i)
{ return Procp.S[i]; }
void write_Cp(int i,const typename sint::clear& x)
void write_Cp(size_t i,const typename sint::clear& x)
{ Procp.C[i]=x; }
void write_Sp(int i,const sint & x)
void write_Sp(size_t i,const sint & x)
{ Procp.S[i]=x; }
void check();
@@ -229,8 +229,8 @@ class Processor : public ArithmeticProcessor
// Access to external client sockets for reading clear/shared data
void read_socket_ints(int client_id, const vector<int>& registers, int size);
void write_socket(const RegType reg_type, int socket_id, int message_type,
const vector<int>& registers, int size);
void write_socket(const RegType reg_type, bool send_macs, int socket_id,
int message_type, const vector<int>& registers, int size);
void read_socket_vector(int client_id, const vector<int>& registers,
int size);

View File

@@ -70,7 +70,7 @@ Processor<sint, sgf2n>::Processor(int thread_num,Player& P,
const Program& program)
: ArithmeticProcessor(machine.opts, thread_num),DataF(machine, &Procp, &Proc2),P(P),
MC2(MC2),MCp(MCp),machine(machine),
share_thread(machine.get_N(), machine.opts, P, machine.get_bit_mac_key(), DataF.usage),
share_thread(DataF.DataFb, P, machine.get_bit_mac_key()),
Procb(machine.bit_memories),
Proc2(*this,MC2,DataF.DataF2,P),Procp(*this,MCp,DataF.DataFp,P),
privateOutput2(Proc2),privateOutputp(Procp),
@@ -94,21 +94,8 @@ Processor<sint, sgf2n>::Processor(int thread_num,Player& P,
secure_prng.ReSeed();
shared_prng.SeedGlobally(P, false);
// only output on party 0 if not interactive
bool always_stdout = machine.opts.cmd_private_output_file == ".";
bool output = P.my_num() == 0 or machine.opts.interactive or always_stdout;
out.activate(output);
Procb.out.activate(output);
if (not always_stdout)
setup_redirection(P.my_num(), thread_num, opts);
if (stdout_redirect_file.is_open())
{
out.redirect_to_file(stdout_redirect_file);
Procb.out.redirect_to_file(stdout_redirect_file);
}
setup_redirection(P.my_num(), thread_num, opts, out);
Procb.out = out;
}
@@ -266,8 +253,9 @@ void Processor<sint, sgf2n>::split(const Instruction& instruction)
// If message_type is > 0, send message_type in bytes 0 - 3, to allow an external client to
// determine the data structure being sent in a message.
template<class sint, class sgf2n>
void Processor<sint, sgf2n>::write_socket(const RegType reg_type, int socket_id,
int message_type, const vector<int>& registers, int size)
void Processor<sint, sgf2n>::write_socket(const RegType reg_type,
bool send_macs, int socket_id, int message_type,
const vector<int>& registers, int size)
{
int m = registers.size();
socket_stream.reset_write_head();
@@ -283,9 +271,12 @@ void Processor<sint, sgf2n>::write_socket(const RegType reg_type, int socket_id,
{
if (reg_type == SINT)
{
// Send vector of secret shares
get_Sp_ref(registers[i] + j).pack(socket_stream,
sint::get_rec_factor(P.my_num(), P.num_players()));
// Send vector of secret shares and optionally macs
if (send_macs)
get_Sp_ref(registers[i] + j).pack(socket_stream);
else
get_Sp_ref(registers[i] + j).pack(socket_stream,
sint::get_rec_factor(P.my_num(), P.num_players()));
}
else if (reg_type == CINT)
{
@@ -522,7 +513,7 @@ void SubProcessor<T>::dotprods(const vector<int>& reg, int size)
template<class T>
void SubProcessor<T>::matmuls(const vector<T>& source,
const Instruction& instruction, int a, int b)
const Instruction& instruction, size_t a, size_t b)
{
auto& dim = instruction.get_start();
auto A = source.begin() + a;
@@ -549,7 +540,7 @@ void SubProcessor<T>::matmuls(const vector<T>& source,
template<class T>
void SubProcessor<T>::matmulsm(const CheckVector<T>& source,
const Instruction& instruction, int a, int b)
const Instruction& instruction, size_t a, size_t b)
{
auto& dim = instruction.get_start();
auto C = S.begin() + (instruction.get_r(0));

View File

@@ -5,6 +5,11 @@
#include "ProcessorBase.hpp"
ProcessorBase::ProcessorBase() :
input_counter(0), arg(0)
{
}
string ProcessorBase::get_parameterized_filename(int my_num, int thread_num, const string& prefix)
{
string filename = prefix + "-P" + to_string(my_num) + "-" + to_string(thread_num);
@@ -22,12 +27,18 @@ void ProcessorBase::open_input_file(int my_num, int thread_num,
}
void ProcessorBase::setup_redirection(int my_num, int thread_num,
OnlineOptions& opts)
OnlineOptions& opts, SwitchableOutput& out)
{
if (not opts.cmd_private_output_file.empty())
// only output on party 0 if not interactive
bool always_stdout = opts.cmd_private_output_file == ".";
bool output = my_num == 0 or opts.interactive or always_stdout;
out.activate(output);
if (not (opts.cmd_private_output_file.empty() or always_stdout))
{
const string stdout_filename = get_parameterized_filename(my_num,
thread_num, opts.cmd_private_output_file);
stdout_redirect_file.open(stdout_filename.c_str(), ios_base::out);
out.redirect_to_file(stdout_redirect_file);
}
}

View File

@@ -12,6 +12,7 @@
using namespace std;
#include "Tools/ExecutionStats.h"
#include "Tools/SwitchableOutput.h"
#include "OnlineOptions.h"
class ProcessorBase
@@ -21,6 +22,7 @@ class ProcessorBase
ifstream input_file;
string input_filename;
size_t input_counter;
protected:
// Optional argument to tape
@@ -34,6 +36,8 @@ public:
ofstream stdout_redirect_file;
ProcessorBase();
void pushi(long x) { stacki.push(x); }
void popi(long& x) { x = stacki.top(); stacki.pop(); }
@@ -55,7 +59,8 @@ public:
template<class T>
T get_input(istream& is, const string& input_filename, const int* params);
void setup_redirection(int my_nu, int thread_num, OnlineOptions& opts);
void setup_redirection(int my_nu, int thread_num, OnlineOptions& opts,
SwitchableOutput& out);
};
#endif /* PROCESSOR_PROCESSORBASE_H_ */

View File

@@ -42,8 +42,9 @@ T ProcessorBase::get_input(istream& input_file, const string& input_filename, co
res.read(input_file, params);
if (input_file.fail())
{
throw input_error(T::NAME, input_filename, input_file);
throw input_error(T::NAME, input_filename, input_file, input_counter);
}
input_counter++;
return res;
}

View File

@@ -42,7 +42,7 @@ RingMachine<U, V, W>::RingMachine(int argc, const char** argv,
case L: \
machine.template run<U<L>, V<gf2n>>(); \
break;
X(64) X(72) X(128)
X(64) X(72) X(128) X(192)
#ifdef RING_SIZE
X(RING_SIZE)
#endif

View File

@@ -8,6 +8,7 @@ import util
program.options_from_args()
sfix.set_precision_from_args(program)
MultiArray.disable_index_checks()
n_examples = 11791
n_test = 1991

View File

@@ -10,6 +10,7 @@ import util
program.options_from_args()
sfix.set_precision_from_args(program, adapt_ring=True)
MultiArray.disable_index_checks()
if 'profile' in program.args:
print('Compiling for profiling')

View File

@@ -8,6 +8,7 @@ import util
program.options_from_args()
sfix.set_precision_from_args(program, adapt_ring=True)
MultiArray.disable_index_checks()
if 'profile' in program.args:
print('Compiling for profiling')

View File

@@ -8,6 +8,7 @@ import util
program.options_from_args()
sfix.set_precision_from_args(program, adapt_ring=True)
MultiArray.disable_index_checks()
if 'profile' in program.args:
print('Compiling for profiling')

View File

@@ -8,6 +8,7 @@ import util
program.options_from_args()
sfix.set_precision_from_args(program, True)
MultiArray.disable_index_checks()
if 'profile' in program.args:
print('Compiling for profiling')

View File

@@ -30,6 +30,8 @@ layers[0].X.input_from(0)
for layer in layers:
layer.input_from(0, raw='raw' in program.args)
sint(0).reveal().store_in_mem(0)
start_timer(1)
opt.forward(1, keep_intermediate=False)
stop_timer(1)

View File

@@ -8,6 +8,10 @@
#include "Replicated.h"
/**
* ATLAS protocol (simple version).
* Uses double sharings to reduce degree of Shamir secret sharing.
*/
template<class T>
class Atlas : public ProtocolBase<T>
{

View File

@@ -8,6 +8,9 @@
#include "ReplicatedPrep.h"
/**
* ATLAS preprocessing.
*/
template<class T>
class AtlasPrep : public ReplicatedPrep<T>
{
@@ -21,6 +24,7 @@ public:
{
}
/// Input tuples from random sharings
void buffer_inputs(int player)
{
assert(this->protocol and this->proc);

View File

@@ -17,6 +17,9 @@ template<class T> class SubProcessor;
template<class T> class MAC_Check_Base;
class Player;
/**
* Beaver multiplication
*/
template<class T>
class Beaver : public ProtocolBase<T>
{

View File

@@ -8,6 +8,9 @@
#include "FHEOffline/SimpleGenerator.h"
/**
* HighGear/ChaiGear preprocessing
*/
template<class T>
class ChaiGearPrep : public MaliciousRingPrep<T>
{

View File

@@ -11,6 +11,9 @@
class PairwiseMachine;
template<class FD> class PairwiseGenerator;
/**
* LowGear/CowGear preprocessing
*/
template<class T>
class CowGearPrep : public MaliciousRingPrep<T>
{

Some files were not shown because too many files have changed in this diff Show More