Maintenance.

This commit is contained in:
Marcel Keller
2022-01-11 16:04:59 +11:00
parent cdb0c0f898
commit e07d9bf2a3
216 changed files with 2410 additions and 1117 deletions

2
.gitmodules vendored
View File

@@ -3,7 +3,7 @@
url = https://github.com/mkskeller/SimpleOT
[submodule "mpir"]
path = mpir
url = git://github.com/wbhart/mpir.git
url = https://github.com/wbhart/mpir
[submodule "Programs/Circuits"]
path = Programs/Circuits
url = https://github.com/mkskeller/bristol-fashion

View File

@@ -259,7 +259,7 @@ ProgramParty::~ProgramParty()
reset();
if (P)
{
cerr << "Data sent: " << 1e-6 * P->comm_stats.total_data() << " MB" << endl;
cerr << "Data sent: " << 1e-6 * P->total_comm().total_data() << " MB" << endl;
delete P;
}
delete[] eval_threads;

View File

@@ -175,7 +175,7 @@ void GarbleInputter<T>::exchange()
assert(party.P != 0);
assert(party.MC != 0);
auto& protocol = party.shared_proc->protocol;
protocol.init_mul(party.shared_proc);
protocol.init_mul();
for (auto& tuple : tuples)
protocol.prepare_mul(tuple.first->mask,
T::constant(1, party.P->my_num(), party.mac_key)

View File

@@ -155,7 +155,7 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
while (next != GC::DONE_BREAK);
MC->Check(*P);
data_sent = P->comm_stats.total_data() + prep->data_sent();
data_sent = P->total_comm().sent;
this->machine.write_memory(this->N.my_num());
}
@@ -173,7 +173,8 @@ void RealProgramParty<T>::garble()
garble_jobs.clear();
garble_inputter->reset_all(*P);
auto& protocol = *garble_protocol;
protocol.init_mul(shared_proc);
protocol.init(*prep, shared_proc->MC);
protocol.init_mul();
next = this->first_phase(program, garble_processor, this->garble_machine);
@@ -181,7 +182,8 @@ void RealProgramParty<T>::garble()
protocol.exchange();
typename T::Protocol second_protocol(*P);
second_protocol.init_mul(shared_proc);
second_protocol.init(*prep, shared_proc->MC);
second_protocol.init_mul();
for (auto& job : garble_jobs)
job.middle_round(*this, second_protocol);

View File

@@ -293,6 +293,9 @@ public:
template<class U>
static void convcbit2s(GC::Processor<U>&, const BaseInstruction&)
{ throw runtime_error("convcbit2s not implemented"); }
template<class U>
static void andm(GC::Processor<U>&, const BaseInstruction&)
{ throw runtime_error("andm not implemented"); }
// most BMR phases don't need actual input
template<class T>

View File

@@ -42,6 +42,12 @@ BaseTrustedParty::BaseTrustedParty()
_received_gc_received = 0;
n_received = 0;
randomfd = open("/dev/urandom", O_RDONLY);
done_filling = false;
}
BaseTrustedParty::~BaseTrustedParty()
{
close(randomfd);
}
TrustedProgramParty::TrustedProgramParty(int argc, char** argv) :

View File

@@ -20,7 +20,7 @@ public:
vector<SendBuffer> msg_input_masks;
BaseTrustedParty();
virtual ~BaseTrustedParty() {}
virtual ~BaseTrustedParty();
/* From NodeUpdatable class */
virtual void NodeReady();
@@ -104,7 +104,6 @@ private:
void add_all_keys(const Register& reg, bool external);
};
inline void BaseTrustedParty::add_keys(const Register& reg)
{
for(int p = 0; p < get_n_parties(); p++)

View File

@@ -1,5 +1,18 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
## 0.2.9 (Jan 11, 2021)
- Disassembler
- Run-time parameter for probabilistic truncation error
- Probabilistic truncation for some protocols computing modulo a prime
- Simplified C++ interface
- Comparison as in [ACCO](https://dl.acm.org/doi/10.1145/3474123.3486757)
- More general scalar-vector multiplication
- Complete memory support for clear bits
- Extended clear bit functionality with Yao's garbled circuits
- Allow preprocessing information to be supplied via named pipes
- In-place operations for containers
## 0.2.8 (Nov 4, 2021)
- Tested on Apple laptop with ARM chip

View File

@@ -112,10 +112,16 @@ class bits(Tape.Register, _structure, _bit):
return cls.load_dynamic_mem(address)
else:
for i in range(res.size):
cls.load_inst[util.is_constant(address)](res[i], address + i)
cls.mem_op(cls.load_inst, res[i], address + i)
return res
def store_in_mem(self, address):
self.store_inst[isinstance(address, int)](self, address)
self.mem_op(self.store_inst, self, address)
@staticmethod
def mem_op(inst, reg, address):
direct = isinstance(address, int)
if not direct:
address = regint.conv(address)
inst[direct](reg, address)
@classmethod
def new(cls, value=None, n=None):
if util.is_constant(value):

View File

@@ -77,13 +77,16 @@ def LTZ(s, a, k, kappa):
k: bit length of a
"""
movs(s, program.non_linear.ltz(a, k, kappa))
def LtzRing(a, k):
from .types import sint, _bitint
from .GC.types import sbitvec
if program.use_split():
summands = a.split_to_two_summands(k)
carry = CarryOutRawLE(*reversed(list(x[:-1] for x in summands)))
msb = carry ^ summands[0][-1] ^ summands[1][-1]
movs(s, sint.conv(msb))
return sint.conv(msb)
return
elif program.options.ring:
from . import floatingpoint
@@ -96,11 +99,7 @@ def LTZ(s, a, k, kappa):
a = r_bin[0].bit_decompose_clear(c_prime, m)
b = r_bin[:m]
u = CarryOutRaw(a[::-1], b[::-1])
movs(s, sint.conv(r_bin[m].bit_xor(c_prime >> m).bit_xor(u)))
return
t = sint()
Trunc(t, a, k, k - 1, kappa, True)
subsfi(s, t, 0)
return sint.conv(r_bin[m].bit_xor(c_prime >> m).bit_xor(u))
def LessThanZero(a, k, kappa):
from . import types

View File

@@ -82,7 +82,7 @@ def run(args, options):
prog.finalize()
if prog.req_num:
print('Program requires:')
print('Program requires at most:')
for x in prog.req_num.pretty():
print(x)

View File

@@ -12,4 +12,7 @@ class ArgumentError(CompilerError):
""" Exception raised for errors in instruction argument parsing. """
def __init__(self, arg, msg):
self.arg = arg
self.msg = msg
self.msg = msg
class VectorMismatch(CompilerError):
pass

View File

@@ -392,7 +392,7 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False):
for i in range(1,l):
ci[i] = c % two_power(i)
c_dprime = sum(ci[i]*(x[i-1] - x[i]) for i in range(1,l))
lts(d, c_dprime, r_prime, l, kappa)
d = program.Program.prog.non_linear.ltz(c_dprime - r_prime, l, kappa)
if compute_modulo:
b = c_dprime - r_prime + pow2m * d
return b, pow2m
@@ -629,12 +629,14 @@ def BITLT(a, b, bit_length):
# - From the paper
# Multiparty Computation for Interval, Equality, and Comparison without
# Bit-Decomposition Protocol
def BitDecFull(a, maybe_mixed=False):
def BitDecFull(a, n_bits=None, maybe_mixed=False):
from .library import get_program, do_while, if_, break_point
from .types import sint, regint, longint, cint
p = get_program().prime
assert p
bit_length = p.bit_length()
n_bits = n_bits or bit_length
assert n_bits <= bit_length
logp = int(round(math.log(p, 2)))
if abs(p - 2 ** logp) / p < 2 ** -get_program().security:
# inspired by Rabbit (https://eprint.iacr.org/2021/119)
@@ -677,12 +679,12 @@ def BitDecFull(a, maybe_mixed=False):
czero = (c==0)
q = bbits[0].long_one() - comparison.BitLTL_raw(bbits, t)
fbar = [bbits[0].clear_type.conv(cint(x))
for x in ((1<<bit_length)+c-p).bit_decompose(bit_length)]
fbard = bbits[0].bit_decompose_clear(cmodp, bit_length)
g = [q.if_else(fbar[i], fbard[i]) for i in range(bit_length)]
for x in ((1<<bit_length)+c-p).bit_decompose(n_bits)]
fbard = bbits[0].bit_decompose_clear(cmodp, n_bits)
g = [q.if_else(fbar[i], fbard[i]) for i in range(n_bits)]
h = bbits[0].bit_adder(bbits, g)
abits = [bbits[0].clear_type(cint(czero)).if_else(bbits[i], h[i])
for i in range(bit_length)]
for i in range(n_bits)]
if maybe_mixed:
return abits
else:

View File

@@ -442,7 +442,7 @@ class join_tape(base.Instruction):
arg_format = ['int']
class crash(base.IOInstruction):
""" Crash runtime if the register's value is > 0.
""" Crash runtime if the value in the register is not zero.
:param: Crash condition (regint)"""
code = base.opcodes['CRASH']
@@ -1275,7 +1275,7 @@ class prep(base.Instruction):
field_type = 'modp'
def add_usage(self, req_node):
req_node.increment((self.field_type, self.args[0]), 1)
req_node.increment((self.field_type, self.args[0]), self.get_size())
def has_var_args(self):
return True
@@ -2407,19 +2407,6 @@ class sqrs(base.CISC):
subml(self.args[0], s[5], c[1])
@base.gf2n
@base.vectorize
class lts(base.CISC):
""" Secret comparison $s_i = (s_j < s_k)$. """
__slots__ = []
arg_format = ['sw', 's', 's', 'int', 'int']
def expand(self):
from .types import sint
a = sint()
subs(a, self.args[1], self.args[2])
comparison.LTZ(self.args[0], a, self.args[3], self.args[4])
# placeholder for documentation
class cisc:
""" Meta instruction for emulation. This instruction is only generated

View File

@@ -4,6 +4,8 @@ import time
import inspect
import functools
import copy
import sys
import struct
from Compiler.exceptions import *
from Compiler.config import *
from Compiler import util
@@ -299,11 +301,12 @@ def vectorize(instruction, global_dict=None):
vectorized_name = 'v' + instruction.__name__
Vectorized_Instruction.__name__ = vectorized_name
global_dict[vectorized_name] = Vectorized_Instruction
if 'sphinx.extension' in sys.modules:
return instruction
global_dict[instruction.__name__ + '_class'] = instruction
instruction.__doc__ = ''
# exclude GF(2^n) instructions from documentation
if instruction.code and instruction.code >> 8 == 1:
maybe_vectorized_instruction.__doc__ = ''
maybe_vectorized_instruction.arg_format = instruction.arg_format
return maybe_vectorized_instruction
@@ -389,8 +392,11 @@ def gf2n(instruction):
else:
global_dict[GF2N_Instruction.__name__] = GF2N_Instruction
if 'sphinx.extension' in sys.modules:
return instruction
global_dict[instruction.__name__ + '_class'] = instruction_cls
instruction_cls.__doc__ = ''
maybe_gf2n_instruction.arg_format = instruction.arg_format
return maybe_gf2n_instruction
#return instruction
@@ -661,6 +667,12 @@ class RegisterArgFormat(ArgFormat):
assert arg.i >= 0
return int_to_bytes(arg.i)
def __init__(self, f):
self.i = struct.unpack('>I', f.read(4))[0]
def __str__(self):
return self.reg_type + str(self.i)
class ClearModpAF(RegisterArgFormat):
reg_type = RegType.ClearModp
@@ -686,6 +698,12 @@ class IntArgFormat(ArgFormat):
def encode(cls, arg):
return int_to_bytes(arg)
def __init__(self, f):
self.i = struct.unpack('>i', f.read(4))[0]
def __str__(self):
return str(self.i)
class ImmediateModpAF(IntArgFormat):
@classmethod
def check(cls, arg):
@@ -722,6 +740,13 @@ class String(ArgFormat):
def encode(cls, arg):
return bytearray(arg, 'ascii') + b'\0' * (cls.length - len(arg))
def __init__(self, f):
tmp = f.read(16)
self.str = str(tmp[0:tmp.find(b'\0')], 'ascii')
def __str__(self):
return self.str
ArgFormats = {
'c': ClearModpAF,
's': SecretModpAF,
@@ -890,6 +915,54 @@ class Instruction(object):
def __repr__(self):
return self.__class__.__name__ + '(' + self.get_pre_arg() + ','.join(str(a) for a in self.args) + ')'
class ParsedInstruction:
reverse_opcodes = {}
def __init__(self, f):
cls = type(self)
from Compiler import instructions
from Compiler.GC import instructions as gc_inst
if not cls.reverse_opcodes:
for module in instructions, gc_inst:
for x, y in inspect.getmodule(module).__dict__.items():
if inspect.isclass(y) and y.__name__[0] != 'v':
try:
cls.reverse_opcodes[y.code] = y
except AttributeError:
pass
read = lambda: struct.unpack('>I', f.read(4))[0]
full_code = read()
code = full_code % (1 << Instruction.code_length)
self.size = full_code >> Instruction.code_length
self.type = cls.reverse_opcodes[code]
t = self.type
name = t.__name__
try:
n_args = len(t.arg_format)
self.var_args = False
except:
n_args = read()
self.var_args = True
try:
arg_format = iter(t.arg_format)
except:
if name == 'cisc':
arg_format = itertools.chain(['str'], itertools.repeat('int'))
else:
arg_format = itertools.repeat('int')
self.args = [ArgFormats[next(arg_format)](f)
for i in range(n_args)]
def __str__(self):
name = self.type.__name__
res = name + ' '
if self.size > 1:
res = 'v' + res + str(self.size) + ', '
if self.var_args:
res += str(len(self.args)) + ', '
res += ', '.join(str(arg) for arg in self.args)
return res
class VarArgsInstruction(Instruction):
def has_var_args(self):
return True

View File

@@ -219,6 +219,9 @@ def crash(condition=None):
:param condition: crash if true (default: true)
"""
if isinstance(condition, localint):
# allow crash on local values
condition = condition._v
if condition == None:
condition = regint(1)
instructions.crash(regint.conv(condition))
@@ -1347,6 +1350,8 @@ def while_loop(loop_body, condition, arg, g=None):
arg = regint(arg)
def loop_fn():
result = loop_body(arg)
if isinstance(result, MemValue):
result = result.read()
result.link(arg)
cont = condition(result)
return cont
@@ -1531,6 +1536,8 @@ def if_(condition):
def if_e(condition):
"""
Conditional execution with else block.
Use :py:class:`~Compiler.types.MemValue` to assign values that
live beyond.
:param condition: regint/cint/int
@@ -1538,12 +1545,13 @@ def if_e(condition):
.. code::
y = MemValue(0)
@if_e(x > 0)
def _():
...
y.write(1)
@else_
def _():
...
y.write(0)
"""
try:
condition = bool(condition)
@@ -1647,11 +1655,18 @@ def get_player_id():
return res
def listen_for_clients(port):
""" Listen for clients on specific port. """
""" Listen for clients on specific port base.
:param port: port base (int/regint/cint)
"""
instructions.listen(regint.conv(port))
def accept_client_connection(port):
""" Listen for clients on specific port. """
""" Accept client connection on specific port base.
:param port: port base (int/regint/cint)
:returns: client id
"""
res = regint()
instructions.acceptclientconnection(res, regint.conv(port))
return res

View File

@@ -1810,6 +1810,7 @@ class Optimizer:
self.print_loss_reduction = False
self.i_epoch = MemValue(0)
self.stopped_on_loss = MemValue(0)
self.stopped_on_low_loss = MemValue(0)
@property
def layers(self):
@@ -1932,6 +1933,7 @@ class Optimizer:
""" Run training.
:param batch_size: batch size (defaults to example size of first layer)
:param stop_on_loss: stop when loss falls below this (default: 0)
"""
if self.n_epochs == 0:
return
@@ -2013,6 +2015,7 @@ class Optimizer:
if self.tol > 0:
res *= (1 - (loss_sum >= 0) * \
(loss_sum < self.tol * n_per_epoch)).reveal()
self.stopped_on_low_loss.write(1 - res)
return res
def reveal_correctness(self, data, truth, batch_size):
@@ -2138,6 +2141,7 @@ class Optimizer:
if depreciation:
self.gamma.imul(depreciation)
print_ln('reducing learning rate to %s', self.gamma)
return 1 - self.stopped_on_low_loss
if 'model_output' in program.args:
self.output_weights()
@@ -2386,6 +2390,7 @@ class keras:
return list(self.opt.thetas)
def build(self, input_shape, batch_size=128):
data_input_shape = input_shape
if self.opt != None and \
input_shape == self.opt.layers[0].X.sizes and \
batch_size <= self.batch_size and \
@@ -2458,9 +2463,10 @@ class keras:
else:
raise Exception(layer[0] + ' not supported')
if layers[-1].d_out == 1:
layers.append(Output(input_shape[0]))
layers.append(Output(data_input_shape[0]))
else:
layers.append(MultiOutput(input_shape[0], layers[-1].d_out))
layers.append(
MultiOutput(data_input_shape[0], layers[-1].d_out))
if self.optimizer[1]:
raise Exception('use keyword arguments for optimizer')
opt = self.optimizer[0]
@@ -2504,7 +2510,7 @@ class keras:
if x.total_size() != self.opt.layers[0].X.total_size():
raise Exception('sample data size mismatch')
if y.total_size() != self.opt.layers[-1].Y.total_size():
print (y, layers[-1].Y)
print (y, self.opt.layers[-1].Y)
raise Exception('label size mismatch')
if validation_data == None:
validation_data = None, None

View File

@@ -1,7 +1,7 @@
from .comparison import *
from .floatingpoint import *
from .types import *
from . import comparison
from . import comparison, program
class NonLinear:
kappa = None
@@ -30,6 +30,15 @@ class NonLinear:
def trunc_pr(self, a, k, m, signed=True):
if isinstance(a, types.cint):
return shift_two(a, m)
prog = program.Program.prog
if prog.use_trunc_pr:
if signed and prog.use_trunc_pr != -1:
a += (1 << (k - 1))
res = sint()
trunc_pr(res, a, k, m)
if signed and prog.use_trunc_pr != -1:
res -= (1 << (k - m - 1))
return res
return self._trunc_pr(a, k, m, signed)
def trunc_round_nearest(self, a, k, m, signed):
@@ -44,6 +53,9 @@ class NonLinear:
return a
return self._trunc(a, k, m, signed)
def ltz(self, a, k, kappa=None):
return -self.trunc(a, k, k - 1, kappa, True)
class Masking(NonLinear):
def eqz(self, a, k):
c, r = self._mask(a, k)
@@ -100,42 +112,44 @@ class KnownPrime(NonLinear):
def _mod2m(self, a, k, m, signed):
if signed:
a += cint(1) << (k - 1)
return sint.bit_compose(self.bit_dec(a, k, k, True)[:m])
return sint.bit_compose(self.bit_dec(a, k, m, True))
def _trunc_pr(self, a, k, m, signed):
# nearest truncation
return self.trunc_round_nearest(a, k, m, signed)
def _trunc(self, a, k, m, signed=None):
if signed:
a += cint(1) << (k - 1)
res = sint.bit_compose(self.bit_dec(a, k, k, True)[m:])
if signed:
res -= cint(1) << (k - 1 - m)
return res
return TruncZeros(a - self._mod2m(a, k, m, signed), k, m, signed)
def trunc_round_nearest(self, a, k, m, signed):
a += cint(1) << (m - 1)
if signed:
a += cint(1) << (k - 1)
k += 1
res = sint.bit_compose(self.bit_dec(a, k, k, True)[m:])
res = self._trunc(a, k, m, False)
if signed:
res -= cint(1) << (k - m - 2)
return res
def bit_dec(self, a, k, m, maybe_mixed=False):
assert k < self.prime.bit_length()
bits = BitDecFull(a, maybe_mixed=maybe_mixed)
if len(bits) < m:
raise CompilerError('%d has fewer than %d bits' % (self.prime, m))
return bits[:m]
bits = BitDecFull(a, m, maybe_mixed=maybe_mixed)
assert len(bits) == m
return bits
def eqz(self, a, k):
# always signed
a += two_power(k)
return 1 - types.sintbit.conv(KORL(self.bit_dec(a, k, k, True)))
def ltz(self, a, k, kappa=None):
if k + 1 < self.prime.bit_length():
# https://dl.acm.org/doi/10.1145/3474123.3486757
# "negative" values wrap around when doubling, thus becoming odd
return self.mod2m(2 * a, k + 1, 1, False)
else:
return super(KnownPrime, self).ltz(a, k, kappa)
class Ring(Masking):
""" Non-linear functionality modulo a power of two known at compile time.
"""
@@ -172,3 +186,6 @@ class Ring(Masking):
return TruncRing(None, tmp + 1, k - m + 1, 1, signed)
else:
return super(Ring, self).trunc_round_nearest(a, k, m, signed)
def ltz(self, a, k, kappa=None):
return LtzRing(a, k)

View File

@@ -578,6 +578,15 @@ class Program(object):
self.warn_about_mem.append(False)
self.curr_block.warn_about_mem = False
@staticmethod
def read_tapes(schedule):
if not os.path.exists(schedule):
schedule = 'Programs/Schedules/%s.sch' % schedule
lines = open(schedule).readlines()
for tapename in lines[2].split(' '):
yield tapename.strip()
class Tape:
""" A tape contains a list of basic blocks, onto which instructions are added. """
def __init__(self, name, program):
@@ -1109,7 +1118,20 @@ class Tape:
else:
self.req_bit_length[t] = max(bit_length, self.req_bit_length)
class Register(object):
@staticmethod
def read_instructions(tapename):
tape = open('Programs/Bytecode/%s.bc' % tapename, 'rb')
while tape.peek():
yield inst_base.ParsedInstruction(tape)
class _no_truth(object):
__slots__ = []
def __bool__(self):
raise CompilerError('Cannot derive truth value from register, '
"consider using 'compile.py -l'")
class Register(_no_truth):
"""
Class for creating new registers. The register's index is automatically assigned
based on the block's reg_counter dictionary.
@@ -1233,10 +1255,6 @@ class Tape:
self.reg_type == RegType.ClearGF2N or \
self.reg_type == RegType.ClearInt
def __bool__(self):
raise CompilerError('Cannot derive truth value from register, '
"consider using 'compile.py -l'")
def __str__(self):
return self.reg_type + str(self.i)

View File

@@ -127,7 +127,7 @@ def vectorize(operation):
if (isinstance(args[0], Tape.Register) or isinstance(args[0], sfloat)) \
and not isinstance(args[0], bits) \
and args[0].size != self.size:
raise CompilerError('Different vector sizes of operands: %d/%d'
raise VectorMismatch('Different vector sizes of operands: %d/%d'
% (self.size, args[0].size))
set_global_vector_size(self.size)
try:
@@ -221,7 +221,7 @@ def inputmixed(*args):
else:
instructions.inputmixedreg(*(args[:-1] + (regint.conv(args[-1]),)))
class _number(object):
class _number(Tape._no_truth):
""" Number functionality. """
def square(self):
@@ -246,7 +246,11 @@ class _number(object):
elif is_one(other):
return self
else:
return self.mul(other)
try:
return self.mul(other)
except VectorMismatch:
# try reverse multiplication
return NotImplemented
__radd__ = __add__
__rmul__ = __mul__
@@ -320,7 +324,7 @@ class _number(object):
def popcnt_bits(bits):
return sum(bits)
class _int(object):
class _int(Tape._no_truth):
""" Integer functionality. """
@staticmethod
@@ -408,7 +412,7 @@ class _int(object):
def long_one():
return 1
class _bit(object):
class _bit(Tape._no_truth):
""" Binary functionality. """
def bit_xor(self, other):
@@ -474,7 +478,7 @@ class _gf2n(_bit):
def bit_not(self):
return self ^ 1
class _structure(object):
class _structure(Tape._no_truth):
""" Interface for type-dependent container types. """
MemValue = classmethod(lambda cls, value: MemValue(cls.conv(value)))
@@ -591,7 +595,7 @@ class _secret_structure(_structure):
res.input_from(player)
return res
class _vec(object):
class _vec(Tape._no_truth):
def link(self, other):
assert len(self.v) == len(other.v)
for x, y in zip(self.v, other.v):
@@ -726,7 +730,7 @@ class _register(Tape.Register, _number, _structure):
assert self.size == 1
res = type(self)(size=size)
for i in range(size):
movs(res[i], self)
self.mov(res[i], self)
return res
class _clear(_register):
@@ -1010,9 +1014,10 @@ class cint(_clear, _int):
if bit_length <= 64:
return regint(self) < regint(other)
else:
sint.require_bit_length(bit_length + 1)
diff = self - other
diff += (1 << (bit_length - 1))
shifted = diff >> (bit_length - 1)
diff += 1 << bit_length
shifted = diff >> bit_length
res = 1 - regint(shifted & 1)
return res
@@ -1646,7 +1651,7 @@ class regint(_register, _int):
player = -1
intoutput(player, self)
class localint(object):
class localint(Tape._no_truth):
""" Local integer that must prevented from leaking into the secure
computation. Uses regint internally.
@@ -1669,7 +1674,7 @@ class localint(object):
__eq__ = lambda self, other: localint(self._v == other)
__ne__ = lambda self, other: localint(self._v != other)
class personal(object):
class personal(Tape._no_truth):
def __init__(self, player, value):
assert value is not NotImplemented
assert not isinstance(value, _secret)
@@ -2003,9 +2008,11 @@ class _secret(_register, _secret_structure):
size or one size 1 for a value-vector multiplication.
:param other: any compatible type """
if isinstance(other, _secret) and (1 in (self.size, other.size)) \
if isinstance(other, _register) and (1 in (self.size, other.size)) \
and (self.size, other.size) != (1, 1):
x, y = (other, self) if self.size < other.size else (self, other)
if not isinstance(other, _secret):
return y.expand_to_vector(x.size) * x
res = type(self)(size=x.size)
mulrs(res, x, y)
return res
@@ -2221,11 +2228,13 @@ class sint(_secret, _int):
@vectorized_classmethod
def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType):
""" Securely obtain shares of values input by a client.
This uses the triple-based input protocol introduced by
`Damgård et al. <http://eprint.iacr.org/2015/1006>`_
:param n: number of inputs (int)
:param client_id: regint
:param size: vector size (default 1)
:returns: list of sint
"""
# send shares of a triple to client
triples = list(itertools.chain(*(sint.get_random_triple() for i in range(n))))
@@ -2910,7 +2919,7 @@ for t in (sint, sgf2n):
sint.bit_type = sintbit
sgf2n.bit_type = sgf2n
class _bitint(object):
class _bitint(Tape._no_truth):
bits = None
log_rounds = False
linear_rounds = False
@@ -3521,6 +3530,7 @@ class cfix(_number, _structure):
@classmethod
def _new(cls, other, k=None, f=None):
assert not isinstance(other, (list, tuple))
res = cls(k=k, f=f)
res.v = cint.conv(other)
return res
@@ -3567,6 +3577,8 @@ class cfix(_number, _structure):
return len(self.v)
def __getitem__(self, index):
if isinstance(index, slice):
return [self._new(x, k=self.k, f=self.f) for x in self.v[index]]
return self._new(self.v[index], k=self.k, f=self.f)
@vectorize
@@ -3608,7 +3620,6 @@ class cfix(_number, _structure):
else:
return NotImplemented
@vectorize
def mul(self, other):
""" Clear fixed-point multiplication.
@@ -4045,7 +4056,8 @@ class _fix(_single):
'for fixed-point computation')
cls.round_nearest = True
if adapt_ring and program.options.ring \
and 'fix_ring' not in program.args:
and 'fix_ring' not in program.args \
and 2 * cls.k > int(program.options.ring):
need = 2 ** int(math.ceil(math.log(2 * cls.k, 2)))
if need != int(program.options.ring):
print('Changing computation modulus to 2^%d' % need)
@@ -4489,7 +4501,7 @@ class squant(_single):
def __neg__(self):
return self._new(-self.v + 2 * util.expand(self.Z, self.v.size))
class _unreduced_squant(object):
class _unreduced_squant(Tape._no_truth):
def __init__(self, v, params, res_params=None, n_summands=1):
self.v = v
self.params = params
@@ -5011,7 +5023,7 @@ class sfloat(_number, _secret_structure):
:return: cfloat """
return cfloat(self.v.reveal(), self.p.reveal(), self.z.reveal(), self.s.reveal())
class cfloat(object):
class cfloat(Tape._no_truth):
""" Helper class for printing revealed sfloats. """
__slots__ = ['v', 'p', 'z', 's', 'nan']

View File

@@ -52,10 +52,10 @@ void run(int argc, const char** argv)
P.unchecked_broadcast(bundle);
Timer timer;
timer.start();
auto stats = P.comm_stats;
auto stats = P.total_comm();
pShare sk = typename T<P256Element::Scalar>::Honest::Protocol(P).get_random();
cout << "Secret key generation took " << timer.elapsed() * 1e3 << " ms" << endl;
(P.comm_stats - stats).print(true);
(P.total_comm() - stats).print(true);
OnlineOptions::singleton.batch_size = (1 + pShare::Protocol::uses_triples) * n_tuples;
DataPositions usage;

View File

@@ -5,6 +5,8 @@
#define NO_MIXED_CIRCUITS
#define NO_SECURITY_CHECK
#include "GC/TinierSecret.h"
#include "GC/TinyMC.h"
#include "GC/VectorInput.h"

View File

@@ -113,10 +113,10 @@ void run(int argc, const char** argv)
P.unchecked_broadcast(bundle);
Timer timer;
timer.start();
auto stats = P.comm_stats;
auto stats = P.total_comm();
sk_prep.get_two(DATA_INVERSE, sk, __);
cout << "Secret key generation took " << timer.elapsed() * 1e3 << " ms" << endl;
(P.comm_stats - stats).print(true);
(P.total_comm() - stats).print(true);
OnlineOptions::singleton.batch_size = (1 + pShare::Protocol::uses_triples) * n_tuples;
typename pShare::TriplePrep prep(0, usage);

View File

@@ -41,8 +41,8 @@ void preprocessing(vector<EcTuple<T>>& tuples, int buffer_size,
timer.start();
Player& P = proc.P;
auto& prep = proc.DataF;
size_t start = P.sent + prep.data_sent();
auto stats = P.comm_stats + prep.comm_stats();
size_t start = P.total_comm().sent;
auto stats = P.total_comm();
auto& extra_player = P;
auto& protocol = proc.protocol;
@@ -77,7 +77,7 @@ void preprocessing(vector<EcTuple<T>>& tuples, int buffer_size,
MCc.POpen_Begin(opened_Rs, secret_Rs, extra_player);
if (prep_mul)
{
protocol.init_mul(&proc);
protocol.init_mul();
for (int i = 0; i < buffer_size; i++)
protocol.prepare_mul(inv_ks[i], sk);
protocol.start_exchange();
@@ -106,9 +106,9 @@ void preprocessing(vector<EcTuple<T>>& tuples, int buffer_size,
timer.stop();
cout << "Generated " << buffer_size << " tuples in " << timer.elapsed()
<< " seconds, throughput " << buffer_size / timer.elapsed() << ", "
<< 1e-3 * (P.sent + prep.data_sent() - start) / buffer_size
<< 1e-3 * (P.total_comm().sent - start) / buffer_size
<< " kbytes per tuple" << endl;
(P.comm_stats + prep.comm_stats() - stats).print(true);
(P.total_comm() - stats).print(true);
}
template<template<class U> class T>

View File

@@ -61,8 +61,7 @@ EcSignature sign(const unsigned char* message, size_t length,
(void) pk;
Timer timer;
timer.start();
size_t start = P.sent;
auto stats = P.comm_stats;
auto stats = P.total_comm();
EcSignature signature;
vector<P256Element> opened_R;
if (opts.R_after_msg)
@@ -71,7 +70,7 @@ EcSignature sign(const unsigned char* message, size_t length,
auto& protocol = proc->protocol;
if (proc)
{
protocol.init_mul(proc);
protocol.init_mul();
protocol.prepare_mul(sk, tuple.a);
protocol.start_exchange();
}
@@ -91,9 +90,9 @@ EcSignature sign(const unsigned char* message, size_t length,
auto rx = tuple.R.x();
signature.s = MC.open(
tuple.a * hash_to_scalar(message, length) + prod * rx, P);
auto diff = (P.total_comm() - stats);
cout << "Minimal signing took " << timer.elapsed() * 1e3 << " ms and sending "
<< (P.sent - start) << " bytes" << endl;
auto diff = (P.comm_stats - stats);
<< diff.sent << " bytes" << endl;
diff.print(true);
return signature;
}
@@ -139,11 +138,11 @@ void sign_benchmark(vector<EcTuple<T>>& tuples, T<P256Element::Scalar> sk,
P.unchecked_broadcast(bundle);
Timer timer;
timer.start();
auto stats = P.comm_stats;
auto stats = P.total_comm();
P256Element pk = MCc.open(sk, P);
MCc.Check(P);
cout << "Public key generation took " << timer.elapsed() * 1e3 << " ms" << endl;
(P.comm_stats - stats).print(true);
(P.total_comm() - stats).print(true);
for (size_t i = 0; i < min(10lu, tuples.size()); i++)
{
@@ -154,13 +153,12 @@ void sign_benchmark(vector<EcTuple<T>>& tuples, T<P256Element::Scalar> sk,
Timer timer;
timer.start();
auto& check_player = MCp.get_check_player(P);
auto stats = check_player.comm_stats;
auto start = check_player.sent;
auto stats = check_player.total_comm();
MCp.Check(P);
MCc.Check(P);
auto diff = (check_player.total_comm() - stats);
cout << "Online checking took " << timer.elapsed() * 1e3 << " ms and sending "
<< (check_player.sent - start) << " bytes" << endl;
auto diff = (check_player.comm_stats - stats);
<< diff.sent << " bytes" << endl;
diff.print();
}
}

View File

@@ -8,6 +8,9 @@
#include "Networking/ssl_sockets.h"
/**
* Client-side interface
*/
class Client
{
vector<int> plain_sockets;
@@ -15,15 +18,37 @@ class Client
ssl_service io_service;
public:
/**
* Sockets for cleartext communication
*/
vector<ssl_socket*> sockets;
/**
* Specification of computation domain
*/
octetStream specification;
/**
* Start a new set of connections to computing parties.
* @param hostnames location of computing parties
* @param port_base port base
* @param my_client_id client identifier
*/
Client(const vector<string>& hostnames, int port_base, int my_client_id);
~Client();
/**
* Securely input private values.
* @param values vector of integer-like values
*/
template<class T>
void send_private_inputs(const vector<T>& values);
/**
* Securely receive output values.
* @param n number of values
* @returns vector of integer-like values
*/
template<class T>
vector<T> receive_outputs(int n);
};

View File

@@ -19,6 +19,8 @@ Scripts/<protocol>.sh bankers_bonus-1 &
./bankers-bonus-client.x 1 <nparties> 200 0 &
./bankers-bonus-client.x 2 <nparties> 50 1
```
`<protocol>` can be any arithmetic protocol (e.g., `mascot`) but not a
binary protocol (e.g., `yao`).
This should output that the winning id is 1. Note that the ids have to
be incremental, and the client with the highest id has to input 1 as
the last argument while the others have to input 0 there. Furthermore,
@@ -32,54 +34,21 @@ different hosts, you will have to distribute the `*.pem` files.
### Connection Setup
**listen**(*int port_num*)
Setup a socket server to listen for client connections. Runs in own thread so once created clients will be able to connect in the background.
*port_num* - the port number to listen on.
**acceptclientconnection**(*regint client_socket_id*, *int port_num*)
Picks the first available client socket connection. Blocks if none available.
*client_socket_id* - an identifier used to refer to the client socket.
*port_num* - the port number identifies the socket server to accept connections on.
1. [Listen for clients](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.library.listen_for_clients)
2. [Accept client connections](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.library.accept_client_connection)
3. [Close client connections](https://mp-spdz.readthedocs.io/en/latest/instructions.html#Compiler.instructions.closeclientconnection)
### Data Exchange
Only the sint methods are documented here, equivalent methods are available for the other data types **cfix**, **cint** and **regint**. See implementation details in [types.py](../Compiler/types.py).
Only the `sint` methods used in the example are documented here, equivalent methods are available for other data types. See [the reference](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#module-Compiler.types).
*[sint inputs]* **sint.read_from_socket**(*regint client_socket_id*, *int number_of_inputs*)
1. [Public value from client](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.regint.read_from_socket)
2. [Secret value from client](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sint.receive_from_client)
3. [Reveal secret value to clients](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sint.reveal_to_clients)
Read a share of an input from a client, blocking on the client send.
## Client-Side Interface
*client_socket_id* - an identifier used to refer to the client socket.
*number_of_inputs* - the number of inputs expected
*[inputs]* - returned list of shares of private input.
**sint.write_to_socket**(*regint client_socket_id*, *[sint values]*, *int message_type*)
Write shares of values including macs to an external client.
*client_socket_id* - an identifier used to refer to the client socket.
*[values]* - list of shares of values to send to client.
*message_type* - optional integer which will be sent in first 4 bytes of message, to indicate message type to client.
See also sint.write_shares_to_socket where macs can be explicitly included or excluded from the message.
*[sint inputs]* **sint.receive_from_client**(*int number_of_inputs*, *regint client_socket_id*, *int message_type*)
Receive shares of private inputs from a client, blocking on client send. This is an abstraction which first sends shares of random values to the client and then receives masked input from the client, using the input protocol introduced in [Confidential Benchmarking based on Multiparty Computation. Damgard et al.](http://eprint.iacr.org/2015/1006.pdf)
*number_of_inputs* - the number of inputs expected
*client_socket_id* - an identifier used to refer to the client socket.
*message_type* - optional integer which will be sent in first 4 bytes of message, to indicate message type to client.
*[inputs]* - returned list of shares of private input.
The example uses the `Client` class implemented in
`ExternalIO/Client.hpp` to handle the communication, see
https://mp-spdz.readthedocs.io/en/latest/io.html#reference for
documentation.

View File

@@ -33,9 +33,6 @@ class FHE_Params
int n_mults() const { return FFTData.size() - 1; }
// Rely on default copy assignment/constructor (not that they should
// ever be needed)
void set(const Ring& R,const vector<bigint>& primes);
void set(const vector<bigint>& primes);
void set_sec(int sec);

View File

@@ -178,12 +178,6 @@ int finalize_lengths(int& lg2p0, int& lg2p1, int n, int m, int* lg2pi,
return extra_slack;
}
/******************************************************************************
* Here onwards needs NTL
******************************************************************************/
@@ -345,6 +339,7 @@ ZZX Cyclotomic(int N)
return F;
}
#else
// simplified version powers of two
int phi_N(int N)
{
if (((N - 1) & N) != 0)

View File

@@ -1,8 +1,6 @@
#ifndef _NTL_Subs
#define _NTL_Subs
/* All these routines use NTL on the inside */
#include "FHE/Ring.h"
#include "FHE/FFT_Data.h"
#include "FHE/P2Data.h"
@@ -47,7 +45,7 @@ public:
};
// Main setup routine (need NTL if online_only is false)
// Main setup routine
void generate_setup(int nparties, int lgp, int lg2,
int sec, bool skip_2 = false, int slack = 0, bool round_up = false);
@@ -60,7 +58,6 @@ int generate_semi_setup(int plaintext_length, int sec,
int common_semi_setup(FHE_Params& params, int m, bigint p, int lgp0, int lgp1,
bool round_up);
// Everything else needs NTL
void init(Ring& Rg, int m, bool generate_poly);
void init(P2Data& P2D,const Ring& Rg);

View File

@@ -114,7 +114,6 @@ NoiseBounds::NoiseBounds(const bigint& p, int phi_m, int n, int sec, int slack,
cout << "n: " << n << endl;
cout << "sec: " << sec << endl;
cout << "sigma: " << this->sigma << endl;
cout << "h: " << h << endl;
cout << "B_clean size: " << numBits(B_clean) << endl;
cout << "B_scale size: " << numBits(B_scale) << endl;
cout << "B_KS size: " << numBits(B_KS) << endl;

View File

@@ -401,19 +401,29 @@ void Ring_Element::change_rep(RepType r)
bool Ring_Element::equals(const Ring_Element& a) const
{
if (element.empty() and a.element.empty())
return true;
else if (element.empty() or a.element.empty())
throw not_implemented();
if (rep!=a.rep) { throw rep_mismatch(); }
if (*FFTD!=*a.FFTD) { throw pr_mismatch(); }
if (is_zero() or a.is_zero())
return is_zero() and a.is_zero();
for (int i=0; i<(*FFTD).phi_m(); i++)
{ if (!areEqual(element[i],a.element[i],(*FFTD).get_prD())) { return false; } }
return true;
}
bool Ring_Element::is_zero() const
{
if (element.empty())
return true;
for (auto& x : element)
if (not ::isZero(x, FFTD->get_prD()))
return false;
return true;
}
ConversionIterator Ring_Element::get_iterator() const
{
if (rep != polynomial)
@@ -560,6 +570,8 @@ void Ring_Element::check(const FFT_Data& FFTD) const
{
if (&FFTD != this->FFTD)
throw params_mismatch();
if (is_zero())
throw runtime_error("element is zero");
}

View File

@@ -95,6 +95,7 @@ class Ring_Element
void randomize(PRNG& G,bool Diag=false);
bool equals(const Ring_Element& a) const;
bool is_zero() const;
// This is a NOP in cases where we cannot do a FFT
void change_rep(RepType r);

View File

@@ -175,7 +175,7 @@ size_t PairwiseGenerator<FD>::report_size(ReportType type)
template <class FD>
size_t PairwiseGenerator<FD>::report_sent()
{
return P.sent;
return P.total_comm().sent;
}
template <class FD>

View File

@@ -71,7 +71,7 @@ public:
void run(bool exhaust);
size_t report_size(ReportType type);
void report_size(ReportType type, MemoryUsage& res);
size_t report_sent() { return P.sent; }
size_t report_sent() { return P.total_comm().sent; }
};
#endif /* FHEOFFLINE_SIMPLEGENERATOR_H_ */

View File

@@ -96,7 +96,7 @@ void BitAdder::add(vector<vector<T> >& res,
b[j] = summands[i][1][input_begin + j];
}
protocol.init_mul(&proc);
protocol.init_mul();
for (size_t j = 0; j < n_items; j++)
{
res[begin + j][i] = a[j] + b[j] + carries[j];

View File

@@ -91,11 +91,6 @@ public:
(typename T::clear(tmp.get_bit(0)) << i);
}
}
NamedCommStats comm_stats()
{
return part_prep.comm_stats();
}
};
} /* namespace GC */

View File

@@ -25,6 +25,14 @@ void CcdPrep<T>::set_protocol(typename T::Protocol& protocol)
{
auto& thread = ShareThread<T>::s();
assert(thread.MC);
if (part_proc)
{
assert(&part_proc->MC == &thread.MC->get_part_MC());
assert(&part_proc->P == &protocol.get_part().P);
return;
}
part_proc = new SubProcessor<typename T::part_type>(
thread.MC->get_part_MC(), part_prep, protocol.get_part().P);
}

View File

@@ -27,6 +27,7 @@ public:
typedef ShamirInput<This> Input;
typedef ShamirMC<This> MAC_Check;
typedef Shamir<This> Protocol;
typedef This small_type;

View File

@@ -108,6 +108,9 @@ public:
template<class U>
static void convcbit2s(GC::Processor<U>&, const BaseInstruction&)
{ throw runtime_error("convcbit2s not implemented"); }
template<class U>
static void andm(GC::Processor<U>&, const BaseInstruction&)
{ throw runtime_error("andm not implemented"); }
static FakeSecret input(GC::Processor<FakeSecret>& processor, const InputArgs& args);
static FakeSecret input(int from, word input, int n_bits);

View File

@@ -84,7 +84,7 @@ void Instruction::parse(istream& s, int pos)
ostringstream os;
os << "Code not defined for instruction " << showbase << hex << opcode << dec << endl;
os << "This virtual machine executes binary circuits only." << endl;
os << "Try compiling with '-B' or use only sbit* types." << endl;
os << "Use 'compile.py -B'." << endl;
throw Invalid_Instruction(os.str());
break;
}

View File

@@ -7,6 +7,7 @@
#define GC_NOSHARE_H_
#include "Processor/DummyProtocol.h"
#include "Processor/Instruction.h"
#include "Protocols/ShareInterface.h"
class InputArgs;
@@ -148,11 +149,14 @@ public:
static void trans(Processor<NoShare>&, Integer, const vector<int>&) { fail(); }
static void andm(GC::Processor<NoShare>&, const BaseInstruction&) { fail(); }
static NoShare constant(const GC::Clear&, int, mac_key_type, int = -1) { fail(); return {}; }
NoShare() {}
NoShare(int) { fail(); }
template<class T>
NoShare(T) { fail(); }
void load_clear(Integer, Integer) { fail(); }
void random_bit() { fail(); }

View File

@@ -9,6 +9,7 @@
#include "Protocols/Replicated.hpp"
#include "Protocols/MaliciousRepMC.hpp"
#include "Protocols/MalRepRingPrep.hpp"
#include "ShareSecret.hpp"
namespace GC
@@ -28,24 +29,19 @@ PostSacriBin::~PostSacriBin()
}
}
void PostSacriBin::init_mul(SubProcessor<T>* proc)
{
assert(proc != 0);
init_mul(proc->DataF, proc->MC);
}
void PostSacriBin::init_mul(Preprocessing<T>&, T::MC&)
void PostSacriBin::init_mul()
{
if ((int) inputs.size() >= OnlineOptions::singleton.batch_size)
check();
honest.init_mul();
}
PostSacriBin::T::clear PostSacriBin::prepare_mul(const T& x, const T& y, int n)
void PostSacriBin::prepare_mul(const T& x, const T& y, int n)
{
if (n == -1)
n = T::default_length;
honest.prepare_mul(x, y, n);
inputs.push_back({{x.mask(n), y.mask(n)}});
return {};
}
void PostSacriBin::exchange()
@@ -55,6 +51,8 @@ void PostSacriBin::exchange()
PostSacriBin::T PostSacriBin::finalize_mul(int n)
{
if (n == -1)
n = T::default_length;
auto res = honest.finalize_mul(n);
outputs.push_back({res, n});
return res;

View File

@@ -38,9 +38,8 @@ public:
PostSacriBin(Player& P);
~PostSacriBin();
void init_mul(Preprocessing<T>&, T::MC&);
void init_mul(SubProcessor<T>* proc);
T::clear prepare_mul(const T& x, const T& y, int n = -1);
void init_mul();
void prepare_mul(const T& x, const T& y, int n = -1);
void exchange();
T finalize_mul(int n = -1);

View File

@@ -3,6 +3,9 @@
*
*/
#ifndef GC_REPPREP_HPP_
#define GC_REPPREP_HPP_
#include "RepPrep.h"
#include "ShareThread.h"
#include "Processor/OnlineOptions.h"
@@ -98,3 +101,5 @@ void RepPrep<T>::buffer_inputs(int player)
}
} /* namespace GC */
#endif

View File

@@ -126,6 +126,9 @@ public:
template<class U>
static void convcbit2s(Processor<U>& processor, const BaseInstruction& instruction)
{ T::convcbit2s(processor, instruction); }
template<class U>
static void andm(Processor<U>& processor, const BaseInstruction& instruction)
{ T::andm(processor, instruction); }
Secret();
Secret(const Integer& x) { *this = x; }

View File

@@ -24,12 +24,15 @@ SemiPrep::SemiPrep(DataPositions& usage, bool) :
void SemiPrep::set_protocol(Beaver<SemiSecret>& protocol)
{
if (triple_generator)
{
assert(&triple_generator->get_player() == &protocol.P);
return;
}
(void) protocol;
params.set_passive();
triple_generator = new SemiSecret::TripleGenerator(
BaseMachine::s().fresh_ot_setup(),
BaseMachine::fresh_ot_setup(protocol.P),
protocol.P.N, -1, OnlineOptions::singleton.batch_size,
1, params, {}, &protocol.P);
triple_generator->multi_threaded = false;
@@ -61,12 +64,4 @@ void SemiPrep::buffer_bits()
}
}
NamedCommStats SemiPrep::comm_stats()
{
if (triple_generator)
return triple_generator->comm_stats();
else
return {};
}
} /* namespace GC */

View File

@@ -44,6 +44,8 @@ public:
array<SemiSecret, 3> get_triple_no_count(int n_bits)
{
if (n_bits == -1)
n_bits = SemiSecret::default_length;
return ShiftableTripleBuffer<SemiSecret>::get_triple_no_count(n_bits);
}
@@ -51,8 +53,6 @@ public:
{
throw not_implemented();
}
NamedCommStats comm_stats();
};
} /* namespace GC */

View File

@@ -78,6 +78,8 @@ public:
template<class T>
static void convcbit2s(Processor<T>& processor, const BaseInstruction& instruction)
{ processor.convcbit2s(instruction); }
static void andm(Processor<U>& processor, const BaseInstruction& instruction)
{ processor.andm(instruction); }
static BitVec get_mask(int n) { return n >= 64 ? -1 : ((1L << n) - 1); }

View File

@@ -47,7 +47,7 @@ void ShareSecret<U>::invert(int n, const U& x)
{
U ones;
ones.load_clear(64, -1);
static_cast<U&>(*this) = U(x ^ ones) & get_mask(n);
reinterpret_cast<U&>(*this) = U(x + ones) & get_mask(n);
}
template<class U>
@@ -92,8 +92,12 @@ template<class U>
void ShareSecret<U>::store_clear_in_dynamic(Memory<U>& mem,
const vector<ClearWriteAccess>& accesses)
{
auto& thread = ShareThread<U>::s();
assert(thread.P);
assert(thread.MC);
for (auto access : accesses)
mem[access.address] = access.value;
mem[access.address] = U::constant(access.value, thread.P->my_num(),
thread.MC->get_alphai());
}
template<class U>
@@ -330,7 +334,7 @@ void ShareSecret<U>::random_bit()
template<class U>
U& GC::ShareSecret<U>::operator=(const U& other)
{
U& real_this = static_cast<U&>(*this);
U& real_this = reinterpret_cast<U&>(*this);
real_this = other;
return real_this;
}

View File

@@ -58,9 +58,6 @@ public:
void pre_run();
void post_run() { ShareThread<T>::post_run(); }
NamedCommStats comm_stats()
{ return Thread<T>::comm_stats() + this->DataF.comm_stats(); }
};
template<class T>

View File

@@ -63,6 +63,7 @@ void ShareThread<T>::pre_run(Player& P, typename T::mac_key_type mac_key)
protocol = new typename T::Protocol(*this->P);
MC = this->new_mc(mac_key);
DataF.set_protocol(*this->protocol);
this->protocol->init(DataF, *MC);
}
template<class T>
@@ -85,7 +86,7 @@ void ShareThread<T>::and_(Processor<T>& processor,
{
auto& protocol = this->protocol;
processor.check_args(args, 4);
protocol->init_mul(DataF, *this->MC);
protocol->init_mul();
T x_ext, y_ext;
for (size_t i = 0; i < args.size(); i += 4)
{

View File

@@ -55,8 +55,6 @@ public:
void join_tape();
void finish();
virtual NamedCommStats comm_stats();
};
template<class T>

View File

@@ -96,13 +96,6 @@ void Thread<T>::finish()
pthread_join(thread, 0);
}
template<class T>
NamedCommStats Thread<T>::comm_stats()
{
assert(P);
return P->comm_stats;
}
} /* namespace GC */

View File

@@ -95,11 +95,11 @@ void ThreadMaster<T>::run()
post_run();
NamedCommStats stats = P->comm_stats;
NamedCommStats stats = P->total_comm();
ExecutionStats exe_stats;
for (auto thread : threads)
{
stats += thread->P->comm_stats;
stats += thread->P->total_comm();
exe_stats += thread->processor.stats;
delete thread;
}

View File

@@ -44,8 +44,6 @@ public:
~TinierSharePrep();
void set_protocol(typename T::Protocol& protocol);
NamedCommStats comm_stats();
};
}

View File

@@ -8,7 +8,7 @@
#include "TinierSharePrep.h"
#include "PersonalPrep.hpp"
#include "PersonalPrep.h"
namespace GC
{
@@ -39,14 +39,17 @@ template<class T>
void TinierSharePrep<T>::set_protocol(typename T::Protocol& protocol)
{
if (triple_generator)
{
assert(&triple_generator->get_player() == &protocol.P);
return;
}
params.generateMACs = true;
params.amplify = false;
params.check = false;
auto& thread = ShareThread<typename T::whole_type>::s();
triple_generator = new typename T::TripleGenerator(
BaseMachine::s().fresh_ot_setup(), protocol.P.N, -1,
BaseMachine::fresh_ot_setup(protocol.P), protocol.P.N, -1,
OnlineOptions::singleton.batch_size, 1,
params, thread.MC->get_alphai(), &protocol.P);
triple_generator->multi_threaded = false;
@@ -84,17 +87,6 @@ void GC::TinierSharePrep<T>::buffer_bits()
BufferPrep<T>::get_random_from_inputs(thread.P->num_players()));
}
template<class T>
NamedCommStats TinierSharePrep<T>::comm_stats()
{
NamedCommStats res;
if (triple_generator)
res += triple_generator->comm_stats();
if (real_triple_generator)
res += real_triple_generator->comm_stats();
return res;
}
}
#endif

View File

@@ -16,7 +16,7 @@ void TinierSharePrep<T>::init_real(Player& P)
assert(real_triple_generator == 0);
auto& thread = ShareThread<secret_type>::s();
real_triple_generator = new typename T::whole_type::TripleGenerator(
BaseMachine::s().fresh_ot_setup(), P.N, -1,
BaseMachine::fresh_ot_setup(P), P.N, -1,
OnlineOptions::singleton.batch_size, 1, params,
thread.MC->get_alphai(), &P);
real_triple_generator->multi_threaded = false;

View File

@@ -36,6 +36,8 @@ public:
void add_mine(const typename T::open_type& input, int n_bits)
{
if (n_bits == -1)
n_bits = T::default_length;
for (int i = 0; i < n_bits; i++)
part_input.add_mine(input.get_bit(i));
input_lengths.push_back(n_bits);
@@ -43,6 +45,8 @@ public:
void add_other(int player, int n_bits)
{
if (n_bits == -1)
n_bits = T::default_length;
for (int i = 0; i < n_bits; i++)
part_input.add_other(player);
}
@@ -69,6 +73,8 @@ public:
void finalize_other(int player, T& target, octetStream&, int n_bits)
{
if (n_bits == -1)
n_bits = T::default_length;
target.resize_regs(n_bits);
for (int i = 0; i < n_bits; i++)
part_input.finalize_other(player, target.get_reg(i),

View File

@@ -21,9 +21,10 @@ public:
VectorProtocol(Player& P);
void init_mul(SubProcessor<T>* proc);
void init_mul(Preprocessing<T>& prep, typename T::MAC_Check& MC);
typename T::clear prepare_mul(const T& x, const T& y, int n = -1);
void init(Preprocessing<T>& prep, typename T::MAC_Check& MC);
void init_mul();
void prepare_mul(const T& x, const T& y, int n = -1);
void exchange();
void finalize_mult(T& res, int n = -1);
T finalize_mul(int n = -1);

View File

@@ -18,26 +18,26 @@ VectorProtocol<T>::VectorProtocol(Player& P) :
}
template<class T>
void VectorProtocol<T>::init_mul(SubProcessor<T>* proc)
{
assert(proc);
init_mul(proc->DataF, proc->MC);
}
template<class T>
void VectorProtocol<T>::init_mul(Preprocessing<T>& prep,
void VectorProtocol<T>::init(Preprocessing<T>& prep,
typename T::MAC_Check& MC)
{
part_protocol.init_mul(prep.get_part(), MC.get_part_MC());
part_protocol.init(prep.get_part(), MC.get_part_MC());
}
template<class T>
typename T::clear VectorProtocol<T>::prepare_mul(const T& x,
void VectorProtocol<T>::init_mul()
{
part_protocol.init_mul();
}
template<class T>
void VectorProtocol<T>::prepare_mul(const T& x,
const T& y, int n)
{
if (n == -1)
n = T::default_length;
for (int i = 0; i < n; i++)
part_protocol.prepare_mul(x.get_reg(i), y.get_reg(i), 1);
return {};
}
template<class T>
@@ -57,6 +57,8 @@ T VectorProtocol<T>::finalize_mul(int n)
template<class T>
void VectorProtocol<T>::finalize_mult(T& res, int n)
{
if (n == -1)
n = T::default_length;
res.resize_regs(n);
for (int i = 0; i < n; i++)
res.get_reg(i) = part_protocol.finalize_mul(1);

View File

@@ -46,6 +46,7 @@
X(NOTCB, processor.notcb(INST)) \
X(ANDRS, T::andrs(PROC, EXTRA)) \
X(ANDS, T::ands(PROC, EXTRA)) \
X(ANDM, T::andm(PROC, instruction)) \
X(ADDCB, C0 = PC1 + PC2) \
X(ADDCBI, C0 = PC1 + int(IMM)) \
X(MULCBI, C0 = PC1 * int(IMM)) \
@@ -76,7 +77,6 @@
#define COMBI_INSTRUCTIONS BIT_INSTRUCTIONS \
X(INPUTB, T::inputb(PROC, Proc, EXTRA)) \
X(INPUTBVEC, T::inputbvec(PROC, Proc, EXTRA)) \
X(ANDM, processor.andm(instruction)) \
X(CONVSINT, S0.load_clear(IMM, Proc.read_Ci(REG1))) \
X(CONVCINT, C0 = Proc.read_Ci(REG1)) \
X(CONVCBIT, Proc.write_Ci(R0, PC1.get())) \

View File

@@ -1,5 +1,5 @@
CSIRO Open Source Software Licence Agreement (variation of the BSD / MIT License)
Copyright (c) 2021, Commonwealth Scientific and Industrial Research Organisation (CSIRO) ABN 41 687 119 230.
Copyright (c) 2022, Commonwealth Scientific and Industrial Research Organisation (CSIRO) ABN 41 687 119 230.
All rights reserved. CSIRO is willing to grant you a licence to this MP-SPDZ sofware on the following terms, except where otherwise indicated for third party material.
Redistribution and use of this software in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.

16
Machines/Atlas.hpp Normal file
View File

@@ -0,0 +1,16 @@
/*
* Atlas.hpp
*
*/
#ifndef MACHINES_ATLAS_HPP_
#define MACHINES_ATLAS_HPP_
#include "Protocols/AtlasShare.h"
#include "Protocols/AtlasPrep.h"
#include "GC/AtlasSecret.h"
#include "ShamirMachine.hpp"
#include "Protocols/Atlas.hpp"
#endif /* MACHINES_ATLAS_HPP_ */

View File

@@ -4,6 +4,7 @@
*/
#include "Protocols/MalRepRingPrep.h"
#include "Protocols/ReplicatedPrep2k.h"
#include "Processor/Data_Files.hpp"
#include "Processor/Instruction.hpp"

17
Machines/Rep4.hpp Normal file
View File

@@ -0,0 +1,17 @@
/*
* Rep4.hpp
*
*/
#ifndef MACHINES_REP4_HPP_
#define MACHINES_REP4_HPP_
#include "GC/Rep4Secret.h"
#include "Protocols/Rep4Share2k.h"
#include "Protocols/Rep4Prep.h"
#include "Protocols/Rep4.hpp"
#include "Protocols/Rep4MC.hpp"
#include "Protocols/Rep4Input.hpp"
#include "Protocols/Rep4Prep.hpp"
#endif /* MACHINES_REP4_HPP_ */

View File

@@ -21,13 +21,15 @@
#include "GC/TinierSecret.h"
#include "GC/TinyMC.h"
#include "GC/VectorInput.h"
#include "GC/VectorProtocol.h"
#include "GC/ShareParty.hpp"
#include "GC/ShareParty.h"
#include "GC/Secret.hpp"
#include "GC/TinyPrep.hpp"
#include "GC/ShareSecret.hpp"
#include "GC/TinierSharePrep.hpp"
#include "GC/CcdPrep.hpp"
#include "GC/ShareSecret.h"
#include "GC/TinierSharePrep.h"
#include "GC/CcdPrep.h"
#include "GC/VectorProtocol.hpp"
#include "Math/gfp.hpp"

View File

@@ -23,9 +23,10 @@
#include "Protocols/MascotPrep.hpp"
#include "Protocols/Spdz2kPrep.hpp"
#include "GC/ShareParty.hpp"
#include "GC/ShareSecret.hpp"
#include "GC/ShareParty.h"
#include "GC/ShareSecret.h"
#include "GC/Secret.hpp"
#include "GC/TinyPrep.hpp"
#include "GC/TinierSharePrep.hpp"
#include "GC/CcdPrep.hpp"
#include "GC/TinierSharePrep.h"
#include "GC/CcdPrep.h"
#include "GC/VectorProtocol.hpp"

View File

@@ -18,3 +18,4 @@
#include "Protocols/MAC_Check.hpp"
#include "Protocols/SemiMC.hpp"
#include "Protocols/Beaver.hpp"
#include "Protocols/MalRepRingPrep.hpp"

15
Machines/Semi2k.hpp Normal file
View File

@@ -0,0 +1,15 @@
/*
* Semi2.hpp
*
*/
#ifndef MACHINES_SEMI2K_HPP_
#define MACHINES_SEMI2K_HPP_
#include "Protocols/Semi2kShare.h"
#include "Protocols/SemiPrep2k.h"
#include "Semi.hpp"
#include "Protocols/RepRingOnlyEdabitPrep.hpp"
#endif /* MACHINES_SEMI2K_HPP_ */

View File

@@ -27,6 +27,7 @@
#include "Protocols/Beaver.hpp"
#include "Protocols/Spdz2kPrep.hpp"
#include "Protocols/ReplicatedPrep.hpp"
#include "Protocols/MalRepRingPrep.hpp"
#include "GC/ShareSecret.hpp"
#include "GC/VectorProtocol.hpp"
#include "GC/Secret.hpp"

23
Machines/Tinier.cpp Normal file
View File

@@ -0,0 +1,23 @@
/*
* Tinier.cpp
*
*/
#include "GC/TinyMC.h"
#include "GC/TinierSecret.h"
#include "GC/VectorInput.h"
#include "GC/ShareParty.hpp"
#include "GC/Secret.hpp"
#include "GC/TinyPrep.hpp"
#include "GC/ShareSecret.hpp"
#include "GC/TinierSharePrep.hpp"
#include "GC/CcdPrep.hpp"
#include "GC/PersonalPrep.hpp"
//template class GC::ShareParty<GC::TinierSecret<gf2n_short>>;
template class GC::CcdPrep<GC::TinierSecret<gf2n_short>>;
template class Preprocessing<GC::TinierSecret<gf2n_short>>;
template class GC::TinierSharePrep<GC::TinierShare<gf2n_short>>;
template class GC::ShareSecret<GC::TinierSecret<gf2n_short>>;
template class TripleShuffleSacrifice<GC::TinierSecret<gf2n_short>>;

View File

@@ -3,12 +3,7 @@
*
*/
#include "Protocols/AtlasShare.h"
#include "Protocols/AtlasPrep.h"
#include "GC/AtlasSecret.h"
#include "ShamirMachine.hpp"
#include "Protocols/Atlas.hpp"
#include "Atlas.hpp"
int main(int argc, const char** argv)
{

View File

@@ -10,11 +10,13 @@
#include "Processor/RingOptions.h"
#include "Processor/Machine.hpp"
#include "Processor/OnlineOptions.hpp"
#include "Math/Z2k.hpp"
#include "Protocols/Replicated.hpp"
#include "Protocols/ShuffleSacrifice.hpp"
#include "Protocols/ReplicatedPrep.hpp"
#include "Protocols/FakeShare.hpp"
#include "Protocols/MalRepRingPrep.hpp"
int main(int argc, const char** argv)
{
@@ -22,7 +24,7 @@ int main(int argc, const char** argv)
Names N;
ez::ezOptionParser opt;
RingOptions ring_opts(opt, argc, argv);
online_opts = {opt, argc, argv};
online_opts = {opt, argc, argv, FakeShare<SignedZ2<64>>()};
opt.parse(argc, argv);
opt.syntax = string(argv[0]) + " <progname>";
@@ -44,9 +46,7 @@ int main(int argc, const char** argv)
#ifdef ROUND_NEAREST_IN_EMULATION
cerr << "Using nearest rounding instead of probabilistic truncation" << endl;
#else
#ifdef RISKY_TRUNCATION_IN_EMULATION
cerr << "Using risky truncation" << endl;
#endif
online_opts.set_trunc_error(opt);
#endif
int R = ring_opts.ring_size_from_opts_or_schedule(progname);

View File

@@ -24,6 +24,7 @@
#include "Protocols/SemiMC.hpp"
#include "Protocols/Beaver.hpp"
#include "Protocols/Hemi.hpp"
#include "Protocols/MalRepRingPrep.hpp"
#include "GC/ShareSecret.hpp"
#include "GC/SemiHonestRepPrep.h"
#include "Math/gfp.hpp"

View File

@@ -8,6 +8,7 @@
#include "Processor/OnlineMachine.hpp"
#include "Processor/Machine.hpp"
#include "Protocols/Replicated.hpp"
#include "Protocols/MalRepRingPrep.hpp"
#include "Math/gfp.hpp"
#include "Math/Z2k.hpp"

View File

@@ -22,6 +22,7 @@
#include "Protocols/MAC_Check.hpp"
#include "Protocols/SemiMC.hpp"
#include "Protocols/Beaver.hpp"
#include "Protocols/MalRepRingPrep.hpp"
#include "GC/ShareSecret.hpp"
#include "GC/SemiHonestRepPrep.h"
#include "Math/gfp.hpp"

View File

@@ -25,6 +25,8 @@ MINI_OT = OT/OTTripleSetup.o OT/BaseOT.o $(LIBSIMPLEOT)
VMOBJS = $(PROCESSOR) $(COMMONOBJS) GC/square64.o GC/Instruction.o OT/OTTripleSetup.o OT/BaseOT.o $(LIBSIMPLEOT)
VM = $(MINI_OT) $(SHAREDLIB)
COMMON = $(SHAREDLIB)
TINIER = Machines/Tinier.o $(OT)
SPDZ = Machines/SPDZ.o $(TINIER)
LIB = libSPDZ.a
@@ -117,7 +119,7 @@ sy: sy-rep-field-party.x sy-rep-ring-party.x sy-shamir-party.x
ecdsa: $(patsubst ECDSA/%.cpp,%.x,$(wildcard ECDSA/*-ecdsa-party.cpp)) Fake-ECDSA.x
ecdsa-static: static-dir $(patsubst ECDSA/%.cpp,static/%.x,$(wildcard ECDSA/*-ecdsa-party.cpp))
$(LIBRELEASE): Protocols/MalRepRingOptions.o $(PROCESSOR) $(COMMONOBJS) $(OT) $(GC)
$(LIBRELEASE): Protocols/MalRepRingOptions.o $(PROCESSOR) $(COMMONOBJS) $(TINIER) $(GC)
$(AR) -csr $@ $^
CFLAGS += -fPIC
@@ -203,16 +205,16 @@ ps-rep-bin-party.x: GC/PostSacriBin.o
semi-bin-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o
tiny-party.x: $(OT)
tinier-party.x: $(OT)
spdz2k-party.x: $(OT) $(patsubst %.cpp,%.o,$(wildcard Machines/SPDZ2*.cpp))
spdz2k-party.x: $(TINIER) $(patsubst %.cpp,%.o,$(wildcard Machines/SPDZ2*.cpp))
static/spdz2k-party.x: $(patsubst %.cpp,%.o,$(wildcard Machines/SPDZ2*.cpp))
semi-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o
semi2k-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o
hemi-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT)
soho-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT)
cowgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(OT)
chaigear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(OT)
lowgear-party.x: $(FHEOFFLINE) $(OT) Protocols/CowGearOptions.o Protocols/LowGearKeyGen.o
highgear-party.x: $(FHEOFFLINE) $(OT) Protocols/CowGearOptions.o Protocols/HighGearKeyGen.o
cowgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(TINIER)
chaigear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(TINIER)
lowgear-party.x: $(FHEOFFLINE) $(TINIER) Protocols/CowGearOptions.o Protocols/LowGearKeyGen.o
highgear-party.x: $(FHEOFFLINE) $(TINIER) Protocols/CowGearOptions.o Protocols/HighGearKeyGen.o
atlas-party.x: GC/AtlasSecret.o
static/hemi-party.x: $(FHEOBJS)
static/soho-party.x: $(FHEOBJS)
@@ -220,10 +222,10 @@ static/cowgear-party.x: $(FHEOBJS)
static/chaigear-party.x: $(FHEOBJS)
static/lowgear-party.x: $(FHEOBJS) Protocols/CowGearOptions.o Protocols/LowGearKeyGen.o
static/highgear-party.x: $(FHEOBJS) Protocols/CowGearOptions.o Protocols/HighGearKeyGen.o
mascot-party.x: Machines/SPDZ.o $(OT)
static/mascot-party.x: Machines/SPDZ.o
Player-Online.x: Machines/SPDZ.o $(OT)
mama-party.x: $(OT)
mascot-party.x: $(SPDZ)
static/mascot-party.x: $(SPDZ)
Player-Online.x: $(SPDZ)
mama-party.x: $(TINIER)
ps-rep-ring-party.x: Protocols/MalRepRingOptions.o
malicious-rep-ring-party.x: Protocols/MalRepRingOptions.o
sy-rep-ring-party.x: Protocols/MalRepRingOptions.o
@@ -236,8 +238,10 @@ emulate.x: GC/FakeSecret.o
semi-bmr-party.x: GC/SemiPrep.o GC/SemiSecret.o $(OT)
real-bmr-party.x: $(OT)
paper-example.x: $(VM) $(OT) $(FHEOFFLINE)
mascot-offline.x: $(VM) $(OT)
cowgear-offline.x: $(OT) $(FHEOFFLINE)
binary-example.x: $(VM) $(OT) GC/PostSacriBin.o GC/SemiPrep.o GC/SemiSecret.o GC/AtlasSecret.o
mixed-example.x: $(VM) $(OT) GC/PostSacriBin.o GC/SemiPrep.o GC/SemiSecret.o GC/AtlasSecret.o Machines/Tinier.o
mascot-offline.x: $(VM) $(TINIER)
cowgear-offline.x: $(TINIER) $(FHEOFFLINE)
static/rep-bmr-party.x: $(BMR)
static/mal-rep-bmr-party.x: $(BMR)
static/shamir-bmr-party.x: $(BMR)

View File

@@ -26,6 +26,7 @@ public:
static const false_type invertible;
static const true_type characteristic_two;
static const true_type binary;
static char type_char() { return 'B'; }
static string type_short() { return "B"; }
@@ -64,8 +65,21 @@ public:
void pack(octetStream& os) const { os.store_int<sizeof(T)>(this->a); }
void unpack(octetStream& os) { this->a = os.get_int<sizeof(T)>(); }
void pack(octetStream& os, int n) const { os.store_int(super::mask(n).get(), DIV_CEIL(n, 8)); }
void unpack(octetStream& os, int n) { this->a = os.get_int(DIV_CEIL(n, 8)); }
void pack(octetStream& os, int n) const
{
if (n == -1)
pack(os);
else
os.store_int(super::mask(n).get(), DIV_CEIL(n, 8));
}
void unpack(octetStream& os, int n)
{
if (n == -1)
unpack(os);
else
this->a = os.get_int(DIV_CEIL(n, 8));
}
static BitVec_ unpack_new(octetStream& os, int n = n_bits)
{
@@ -81,5 +95,7 @@ template<class T>
const false_type BitVec_<T>::invertible;
template<class T>
const true_type BitVec_<T>::characteristic_two;
template<class T>
const true_type BitVec_<T>::binary;
#endif /* MATH_BITVEC_H_ */

View File

@@ -36,8 +36,9 @@ void read_setup(const string& dir_prefix, int lgp = -1)
{
if (lgp > 0)
{
cerr << "No modulus found in " << filename << ", generating " << lgp
<< "-bit prime" << endl;
if (OnlineOptions::singleton.verbose)
cerr << "No modulus found in " << filename << ", generating "
<< lgp << "-bit prime" << endl;
T::init_default(lgp);
}
else

View File

@@ -20,6 +20,7 @@ public:
static const false_type characteristic_two;
static const false_type prime_field;
static const false_type invertible;
static const false_type binary;
template<class T>
static void init(bool mont = true) { (void) mont; }

View File

@@ -47,6 +47,7 @@ public:
static int size_in_limbs() { return N_WORDS; }
static int size_in_bits() { return size() * 8; }
static int length() { return size_in_bits(); }
static int n_bits() { return N_BITS; }
static int t() { return 0; }
static char type_char() { return 'R'; }
@@ -100,6 +101,8 @@ public:
int bit_length() const;
Z2 mask(int) const { return *this; }
Z2<K> operator+(const Z2<K>& other) const;
Z2<K> operator-(const Z2<K>& other) const;

View File

@@ -86,6 +86,42 @@ void Zp_Data::Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y,int t
{ inline_mpn_copyi(z,ans+t,t); }
}
void Zp_Data::Mont_Mult_switch(mp_limb_t* z, const mp_limb_t* x,
const mp_limb_t* y) const
{
switch (t)
{
#ifdef __BMI2__
#define CASE(N) \
case N: \
Mont_Mult_<N>(z, x, y); \
break;
CASE(1)
CASE(2)
#if MAX_MOD_SZ >= 4
CASE(3)
CASE(4)
#endif
#if MAX_MOD_SZ >= 5
CASE(5)
#endif
#if MAX_MOD_SZ >= 6
CASE(6)
#endif
#if MAX_MOD_SZ >= 10
CASE(7)
CASE(8)
CASE(9)
CASE(10)
#endif
#undef CASE
#endif
default:
Mont_Mult_variable(z, x, y);
break;
}
}
ostream& operator<<(ostream& s,const Zp_Data& ZpD)

View File

@@ -40,6 +40,7 @@ class Zp_Data
template <int T>
void Mont_Mult_(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const;
void Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const;
void Mont_Mult_switch(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const;
void Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y, int t) const;
void Mont_Mult_variable(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const
{ Mont_Mult(z, x, y, t); }
@@ -242,37 +243,11 @@ inline void Zp_Data::Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t*
{
if (not cpu_has_bmi2())
return Mont_Mult_variable(z, x, y);
switch (t)
{
#ifdef __BMI2__
#define CASE(N) \
case N: \
Mont_Mult_<N>(z, x, y); \
break;
CASE(1)
CASE(2)
#if MAX_MOD_SZ >= 4
CASE(3)
CASE(4)
return Mont_Mult_switch(z, x, y);
#else
return Mont_Mult_variable(z, x, y);
#endif
#if MAX_MOD_SZ >= 5
CASE(5)
#endif
#if MAX_MOD_SZ >= 6
CASE(6)
#endif
#if MAX_MOD_SZ >= 10
CASE(7)
CASE(8)
CASE(9)
CASE(10)
#endif
#undef CASE
#endif
default:
Mont_Mult_variable(z, x, y);
break;
}
}
inline void Zp_Data::Mont_Mult_max(mp_limb_t* z, const mp_limb_t* x,

View File

@@ -11,7 +11,6 @@ using namespace std;
#include "Math/Bit.h"
#include "Math/Setup.h"
#include "Tools/random.h"
#include "GC/NoShare.h"
#include "Processor/OnlineOptions.h"
#include "Math/modp.hpp"
@@ -101,6 +100,7 @@ class gfp_ : public ValueInterface
static int size() { return t() * sizeof(mp_limb_t); }
static int size_in_bits() { return 8 * size(); }
static int length() { return ZpD.pr_bit_length; }
static int n_bits() { return length() - 1; }
static void reqbl(int n);

View File

@@ -5,6 +5,7 @@
#include "CryptoPlayer.h"
#include "Math/Setup.h"
#include "Tools/Bundle.h"
void check_ssl_file(string filename)
{
@@ -124,12 +125,14 @@ CryptoPlayer::~CryptoPlayer()
void CryptoPlayer::send_to_no_stats(int other, const octetStream& o) const
{
assert(other != my_num());
senders[other]->request(o);
senders[other]->wait(o);
}
void CryptoPlayer::receive_player_no_stats(int other, octetStream& o) const
{
assert(other != my_num());
receivers[other]->request(o);
receivers[other]->wait(o);
}
@@ -137,6 +140,7 @@ void CryptoPlayer::receive_player_no_stats(int other, octetStream& o) const
void CryptoPlayer::exchange_no_stats(int other, const octetStream& to_send,
octetStream& to_receive) const
{
assert(other != my_num());
if (&to_send == &to_receive)
{
MultiPlayer<ssl_socket*>::exchange_no_stats(other, to_send, to_receive);
@@ -153,6 +157,7 @@ void CryptoPlayer::exchange_no_stats(int other, const octetStream& to_send,
void CryptoPlayer::pass_around_no_stats(const octetStream& to_send,
octetStream& to_receive, int offset) const
{
assert(get_player(offset) != my_num());
if (&to_send == &to_receive)
{
MultiPlayer<ssl_socket*>::pass_around_no_stats(to_send, to_receive, offset);

View File

@@ -14,12 +14,14 @@
using namespace std;
void Names::init(int player,int pnb,int my_port,const char* servername)
void Names::init(int player, int pnb, int my_port, const char* servername,
bool setup_socket)
{
player_no=player;
portnum_base=pnb;
setup_names(servername, my_port);
setup_server();
if (setup_socket)
setup_server();
}
Names::Names(int player, int nplayers, const string& servername, int pnb,
@@ -124,7 +126,7 @@ void Names::setup_names(const char *servername, int my_port)
my_port = default_port(player_no);
int socket_num;
int pn = portnum_base - 1;
int pn = portnum_base;
set_up_client_socket(socket_num, servername, pn);
octetStream("P" + to_string(player_no)).Send(socket_num);
#ifdef DEBUG_NETWORKING
@@ -132,15 +134,11 @@ void Names::setup_names(const char *servername, int my_port)
#endif
// Send my name
octet my_name[512];
memset(my_name,0,512*sizeof(octet));
sockaddr_in address;
socklen_t size = sizeof address;
getsockname(socket_num, (sockaddr*)&address, &size);
char* name = inet_ntoa(address.sin_addr);
// max length of IP address with ending 0
strncpy((char*)my_name, name, 16);
send(socket_num,my_name,512);
char* my_name = inet_ntoa(address.sin_addr);
octetStream(my_name).Send(socket_num);
send(socket_num,(octet*)&my_port,4);
#ifdef DEBUG_NETWORKING
fprintf(stderr, "My Name = %s\n",my_name);
@@ -158,9 +156,10 @@ void Names::setup_names(const char *servername, int my_port)
names.resize(nplayers);
ports.resize(nplayers);
for (i=0; i<nplayers; i++)
{ octet tmp[512];
receive(socket_num,tmp,512);
names[i]=(char*)tmp;
{
octetStream os;
os.Receive(socket_num);
names[i] = os.str();
receive(socket_num, (octet*)&ports[i], 4);
#ifdef VERBOSE
cerr << "Player " << i << " is running on machine " << names[i] << endl;
@@ -176,6 +175,12 @@ void Names::setup_server()
server->init();
}
void Names::set_server(ServerSocket* socket)
{
assert(not server);
server = socket;
}
Names::Names(const Names& other)
{
@@ -201,6 +206,7 @@ Player::Player(const Names& Nms) :
{
nplayers=Nms.nplayers;
player_no=Nms.player_no;
thread_stats.resize(nplayers);
}
@@ -243,6 +249,10 @@ MultiPlayer<T>::~MultiPlayer()
Player::~Player()
{
#ifdef VERBOSE
for (auto& x : thread_stats)
x.print();
#endif
}
PlayerBase::~PlayerBase()
@@ -685,7 +695,7 @@ void VirtualTwoPartyPlayer::send(octetStream& o) const
{
TimeScope ts(comm_stats["Sending one-to-one"].add(o));
P.send_to_no_stats(other_player, o);
sent += o.get_length();
comm_stats.sent += o.get_length();
}
void RealTwoPartyPlayer::receive(octetStream& o) const
@@ -729,12 +739,13 @@ void RealTwoPartyPlayer::exchange(octetStream& o) const
void VirtualTwoPartyPlayer::send_receive_player(vector<octetStream>& o) const
{
TimeScope ts(comm_stats["Exchanging one-to-one"].add(o[0]));
sent += o[0].get_length();
comm_stats.sent += o[0].get_length();
P.exchange_no_stats(other_player, o[0], o[1]);
}
VirtualTwoPartyPlayer::VirtualTwoPartyPlayer(Player& P, int other_player) :
TwoPartyPlayer(P.my_num()), P(P), other_player(other_player)
TwoPartyPlayer(P.my_num()), P(P), other_player(other_player), comm_stats(
P.thread_stats.at(other_player))
{
}
@@ -814,5 +825,13 @@ void NamedCommStats::print(bool newline)
cerr << endl;
}
NamedCommStats Player::total_comm() const
{
auto res = comm_stats;
for (auto& x : thread_stats)
res += x;
return res;
}
template class MultiPlayer<int>;
template class MultiPlayer<ssl_socket*> ;

View File

@@ -35,6 +35,7 @@ class Names
friend class Player;
friend class PlainPlayer;
friend class RealTwoPartyPlayer;
friend class Server;
vector<string> names;
vector<int> ports;
@@ -51,6 +52,8 @@ class Names
void setup_server();
void set_server(ServerSocket* socket);
public:
static const int DEFAULT_PORT = -1;
@@ -62,8 +65,10 @@ class Names
* @param my_port my port number (`DEFAULT_PORT` for default,
* which is base port number plus player number)
* @param servername location of server
* @param setup_socket whether to start listening
*/
void init(int player,int pnb,int my_port,const char* servername);
void init(int player, int pnb, int my_port, const char* servername,
bool setup_socket = true);
Names(int player,int pnb,int my_port,const char* servername) : Names()
{ init(player,pnb,my_port,servername); }
@@ -172,11 +177,12 @@ class PlayerBase
protected:
int player_no;
public:
size_t& sent;
mutable Timer timer;
mutable NamedCommStats comm_stats;
public:
mutable Timer timer;
PlayerBase(int player_no) : player_no(player_no), sent(comm_stats.sent) {}
virtual ~PlayerBase();
@@ -205,6 +211,8 @@ protected:
public:
const Names& N;
mutable vector<NamedCommStats> thread_stats;
Player(const Names& Nms);
virtual ~Player();
@@ -358,6 +366,8 @@ public:
virtual void request_receive(int i, octetStream& o) const { (void)i; (void)o; }
virtual void wait_receive(int i, octetStream& o) const
{ receive_player(i, o); }
NamedCommStats total_comm() const;
};
/**
@@ -500,6 +510,7 @@ class VirtualTwoPartyPlayer : public TwoPartyPlayer
{
Player& P;
int other_player;
NamedCommStats& comm_stats;
public:
VirtualTwoPartyPlayer(Player& P, int other_player);

View File

@@ -51,9 +51,17 @@ void Receiver<T>::run()
while (in.pop(os))
{
os->reset_write_head();
#ifdef VERBOSE_SSL
timer.start();
RunningTimer mytimer;
#endif
os->Receive(socket);
#ifdef VERBOSE_SSL
cout << "receiving " << os->get_length() * 1e-6 << " MB on " << socket
<< " took " << mytimer.elapsed() << ", total "
<< timer.elapsed() << endl;
timer.stop();
#endif
out.push(os);
}
}

View File

@@ -47,9 +47,17 @@ void Sender<T>::run()
const octetStream* os = 0;
while (in.pop(os))
{
// timer.start();
#ifdef VERBOSE_SSL
timer.start();
RunningTimer mytimer;
#endif
os->Send(socket);
// timer.stop();
#ifdef VERBOSE_SSL
cout << "sending " << os->get_length() * 1e-6 << " MB on " << socket
<< " took " << mytimer.elapsed() << ", total "
<< timer.elapsed() << endl;
timer.stop();
#endif
out.push(os);
}
}

View File

@@ -28,9 +28,7 @@ void Server::get_ip(int num)
inet_ntop(AF_INET6, &s->sin6_addr, ipstr, sizeof ipstr);
}
names[num]=new octet[512];
memset(names[num], 0, 512);
strncpy((char*)names[num], ipstr, INET6_ADDRSTRLEN);
names[num] = ipstr;
#ifdef DEBUG_NETWORKING
cerr << "Client IP address: " << names[num] << endl;
@@ -45,11 +43,11 @@ void Server::get_name(int num)
#endif
// Receive name sent by client (legacy) - not used here
octet my_name[512];
receive(socket_num[num],my_name,512);
octetStream os;
os.Receive(socket_num[num]);
receive(socket_num[num],(octet*)&ports[num],4);
#ifdef DEBUG_NETWORKING
cerr << "Player " << num << " sent (IP for info only) " << my_name << ":"
cerr << "Player " << num << " sent (IP for info only) " << os.str() << ":"
<< ports[num] << endl;
#endif
@@ -66,7 +64,7 @@ void Server::send_names(int num)
send(socket_num[num],nmachines,4);
for (int i=0; i<nmachines; i++)
{
send(socket_num[num],names[i],512);
octetStream(names[i]).Send(socket_num[num]);
send(socket_num[num],(octet*)&ports[i],4);
}
}
@@ -88,13 +86,20 @@ Server::Server(int argc,char **argv)
}
nmachines=atoi(argv[1]);
PortnumBase=atoi(argv[2]);
server_socket = 0;
}
Server::Server(int nmachines, int PortnumBase) :
nmachines(nmachines), PortnumBase(PortnumBase)
nmachines(nmachines), PortnumBase(PortnumBase), server_socket(0)
{
}
Server::~Server()
{
if (server_socket)
delete server_socket;
}
void Server::start()
{
int i;
@@ -107,7 +112,8 @@ void Server::start()
for (i=0; i<nmachines; i++) { socket_num[i]=-1; }
// port number one lower to avoid conflict with players
ServerSocket server(PortnumBase - 1);
server_socket = new ServerSocket(PortnumBase);
auto& server = *server_socket;
server.init();
// set up connections
@@ -130,7 +136,7 @@ void Server::start()
bool all_on_local = true, none_on_local = true;
for (i = 1; i < nmachines; i++)
{
bool on_local = string((char*)names[i]).compare("127.0.0.1");
bool on_local = names[i].compare("127.0.0.1");
all_on_local &= on_local;
none_on_local &= not on_local;
}
@@ -144,9 +150,6 @@ void Server::start()
for (i=0; i<nmachines; i++)
send_names(i);
for (i=0; i<nmachines; i++)
{ delete[] names[i]; }
for (int i = 0; i < nmachines; i++)
close(socket_num[i]);
}
@@ -162,7 +165,7 @@ Server* Server::start_networking(Names& N, int my_num, int nplayers,
{
#ifdef DEBUG_NETWORKING
cerr << "Starting networking for " << my_num << "/" << nplayers
<< " with server on " << hostname << ":" << (portnum - 1) << endl;
<< " with server on " << hostname << ":" << (portnum) << endl;
#endif
assert(my_num >= 0);
assert(my_num < nplayers);
@@ -172,12 +175,19 @@ Server* Server::start_networking(Names& N, int my_num, int nplayers,
{
pthread_create(&thread, 0, Server::start_in_thread,
server = new Server(nplayers, portnum));
}
N.init(my_num, portnum, my_port, hostname.c_str());
if (my_num == 0)
{
N.init(my_num, portnum, my_port, hostname.c_str(), false);
pthread_join(thread, 0);
N.set_server(server->get_socket());
delete server;
}
else
N.init(my_num, portnum, my_port, hostname.c_str());
return 0;
}
ServerSocket* Server::get_socket()
{
auto res = server_socket;
server_socket = 0;
return res;
}

View File

@@ -14,10 +14,11 @@ using namespace std;
class Server
{
vector<int> socket_num;
vector<octet*> names;
vector<string> names;
vector<int> ports;
int nmachines;
int PortnumBase;
ServerSocket* server_socket;
void get_ip(int num);
void get_name(int num);
@@ -31,7 +32,11 @@ public:
Server(int argc, char** argv);
Server(int nmachines, int PortnumBase);
~Server();
void start();
ServerSocket* get_socket();
};
#endif /* NETWORKING_SERVER_H_ */

View File

@@ -7,6 +7,7 @@
#define CRYPTO_SSL_SOCKETS_H_
#include "Tools/int.h"
#include "Tools/time-func.h"
#include "sockets.h"
#include "Math/Setup.h"
@@ -46,6 +47,10 @@ public:
string me, bool client) :
parent(io_service, ctx)
{
#ifdef DEBUG_NETWORKING
cerr << me << " setting up SSL to " << other << " as " <<
(client ? "client" : "server") << endl;
#endif
lowest_layer().assign(boost::asio::ip::tcp::v4(), plaintext_socket);
set_verify_mode(boost::asio::ssl::verify_peer);
set_verify_callback(boost::asio::ssl::rfc2818_verification(other));
@@ -82,8 +87,16 @@ template<>
inline void send(ssl_socket* socket, octet* data, size_t length)
{
size_t sent = 0;
#ifdef VERBOSE_SSL
RunningTimer timer;
#endif
while (sent < length)
{
sent += send_non_blocking(socket, data + sent, length - sent);
#ifdef VERBOSE_SSL
cout << "sent " << sent * 1e-6 << " MB at " << timer.elapsed() << endl;
#endif
}
}
template<>

View File

@@ -1,6 +1,7 @@
#include "OT/BaseOT.h"
#include "Tools/random.h"
#include "Tools/benchmarking.h"
#include "Tools/Bundle.h"
#include <stdio.h>
#include <iostream>
@@ -78,6 +79,23 @@ void send_if_ot_receiver(TwoPartyPlayer* P, vector<octetStream>& os, OT_ROLE rol
void BaseOT::exec_base(bool new_receiver_inputs)
{
Bundle<octetStream> bundle(*P);
#ifdef NO_AVX_OT
bundle.mine = string("OT without AVX");
#else
bundle.mine = string("OT with AVX");
#endif
try
{
bundle.compare(*P);
}
catch (mismatch_among_parties&)
{
cerr << "Parties compiled with different base OT algorithms" << endl;
cerr << "Set \"AVX_OT\" to the same value on all parties" << endl;
exit(1);
}
#ifdef NO_AVX_OT
#ifdef USE_RISTRETTO
typedef CurveElement Element;

View File

@@ -116,7 +116,7 @@ public:
mac_key_type get_mac_key() const { return mac_key; }
NamedCommStats comm_stats();
Player& get_player() { return globalPlayer; }
};
template<class T>
@@ -209,15 +209,4 @@ public:
void generateTriples();
};
template<class T>
NamedCommStats OTTripleGenerator<T>::comm_stats()
{
NamedCommStats res;
if (parentPlayer != &globalPlayer)
res = globalPlayer.comm_stats;
for (auto& player : players)
res += player->comm_stats;
return res;
}
#endif

View File

@@ -110,22 +110,31 @@ void BaseMachine::time()
void BaseMachine::start(int n)
{
cout << "Starting timer " << n << " at " << timer[n].elapsed()
<< " (" << timer[n].mb_sent() << " MB)"
<< " after " << timer[n].idle() << endl;
timer[n].start();
timer[n].start(total_comm());
}
void BaseMachine::stop(int n)
{
timer[n].stop();
cout << "Stopped timer " << n << " at " << timer[n].elapsed() << endl;
timer[n].stop(total_comm());
cout << "Stopped timer " << n << " at " << timer[n].elapsed() << " ("
<< timer[n].mb_sent() << " MB)" << endl;
}
void BaseMachine::print_timers()
{
cerr << "The following timing is ";
if (OnlineOptions::singleton.live_prep)
cerr << "in";
else
cerr << "ex";
cerr << "clusive preprocessing." << endl;
cerr << "Time = " << timer[0].elapsed() << " seconds " << endl;
timer.erase(0);
for (map<int,Timer>::iterator it = timer.begin(); it != timer.end(); it++)
cerr << "Time" << it->first << " = " << it->second.elapsed() << " seconds " << endl;
for (auto it = timer.begin(); it != timer.end(); it++)
cerr << "Time" << it->first << " = " << it->second.elapsed() << " seconds ("
<< it->second.mb_sent() << " MB)" << endl;
}
string BaseMachine::memory_filename(const string& type_short, int my_number)
@@ -170,3 +179,18 @@ bigint BaseMachine::prime_from_schedule(string progname)
else
return 0;
}
NamedCommStats BaseMachine::total_comm()
{
NamedCommStats res;
for (auto& queue : queues)
res += queue->get_comm_stats();
return res;
}
void BaseMachine::set_thread_comm(const NamedCommStats& stats)
{
auto queue = queues.at(BaseMachine::thread_num);
assert(queue);
queue->set_comm_stats(stats);
}

View File

@@ -7,6 +7,7 @@
#define PROCESSOR_BASEMACHINE_H_
#include "Tools/time-func.h"
#include "Tools/TimerWithComm.h"
#include "OT/OTTripleSetup.h"
#include "ThreadJob.h"
#include "ThreadQueues.h"
@@ -22,7 +23,7 @@ class BaseMachine
protected:
static BaseMachine* singleton;
std::map<int,Timer> timer;
std::map<int,TimerWithComm> timer;
string compiler;
string domain;
@@ -66,12 +67,18 @@ public:
virtual void reqbl(int) {}
OTTripleSetup fresh_ot_setup();
static OTTripleSetup fresh_ot_setup(Player& P);
NamedCommStats total_comm();
void set_thread_comm(const NamedCommStats& stats);
};
inline OTTripleSetup BaseMachine::fresh_ot_setup()
inline OTTripleSetup BaseMachine::fresh_ot_setup(Player& P)
{
return ot_setups.at(thread_num).get_fresh();
if (singleton and size_t(thread_num) < s().ot_setups.size())
return s().ot_setups.at(thread_num).get_fresh();
else
return OTTripleSetup(P, true);
}
#endif /* PROCESSOR_BASEMACHINE_H_ */

View File

@@ -38,7 +38,7 @@ void Binary_File_IO::read_from_file(const string filename, vector< T >& buffer,
int size_in_bytes = T::size() * buffer.size();
int n_read = 0;
char * read_buffer = new char[size_in_bytes];
char read_buffer[size_in_bytes];
inf.seekg(start_posn * T::size());
do
{

View File

@@ -89,6 +89,7 @@ template<class sint, class sgf2n> class Processor;
template<class sint, class sgf2n> class Data_Files;
template<class sint, class sgf2n> class Machine;
template<class T> class SubProcessor;
template<class T> class NoFilePrep;
/**
* Abstract base class for preprocessing
@@ -125,6 +126,7 @@ public:
template<class U, class V>
static Preprocessing<T>* get_new(Machine<U, V>& machine, DataPositions& usage,
SubProcessor<T>* proc);
template<int = 0>
static Preprocessing<T>* get_new(bool live_prep, const Names& N,
DataPositions& usage);
static Preprocessing<T>* get_live_prep(SubProcessor<T>* proc,
@@ -133,22 +135,21 @@ public:
Preprocessing(DataPositions& usage) : usage(usage), do_count(true) {}
virtual ~Preprocessing() {}
virtual void set_protocol(typename T::Protocol& protocol) = 0;
virtual void set_protocol(typename T::Protocol&) {};
virtual void set_proc(SubProcessor<T>* proc) { (void) proc; }
virtual void seekg(DataPositions& pos) { (void) pos; }
virtual void prune() {}
virtual void purge() {}
virtual size_t data_sent() { return comm_stats().sent; }
virtual NamedCommStats comm_stats() { return {}; }
virtual void get_three_no_count(Dtype dtype, T& a, T& b, T& c) = 0;
virtual void get_two_no_count(Dtype dtype, T& a, T& b) = 0;
virtual void get_one_no_count(Dtype dtype, T& a) = 0;
virtual void get_input_no_count(T& a, typename T::open_type& x, int i) = 0;
virtual void get_no_count(vector<T>& S, DataTag tag, const vector<int>& regs,
int vector_size) = 0;
virtual void get_three_no_count(Dtype, T&, T&, T&)
{ throw not_implemented(); }
virtual void get_two_no_count(Dtype, T&, T&) { throw not_implemented(); }
virtual void get_one_no_count(Dtype, T&) { throw not_implemented(); }
virtual void get_input_no_count(T&, typename T::open_type&, int)
{ throw not_implemented() ; }
virtual void get_no_count(vector<T>&, DataTag, const vector<int>&, int)
{ throw not_implemented(); }
void get(Dtype dtype, T* a);
void get_three(Dtype dtype, T& a, T& b, T& c);
@@ -191,6 +192,9 @@ class Sub_Data_Files : public Preprocessing<T>
{
template<class U> friend class Sub_Data_Files;
typedef typename conditional<T::LivePrep::use_part,
Sub_Data_Files<typename T::part_type>, NoFilePrep<typename T::part_type>>::type part_type;
static int tuple_length(int dtype);
BufferOwner<T, T> buffers[N_DTYPE];
@@ -205,7 +209,7 @@ class Sub_Data_Files : public Preprocessing<T>
const string prep_data_dir;
int thread_num;
Sub_Data_Files<typename T::part_type>* part;
part_type* part;
void buffer_edabits_with_queues(bool strict, int n_bits)
{ buffer_edabits_with_queues<0>(strict, n_bits, T::clear::characteristic_two); }
@@ -274,7 +278,7 @@ public:
void get_no_count(vector<T>& S, DataTag tag, const vector<int>& regs, int vector_size);
void get_dabit_no_count(T& a, typename T::bit_type& b);
Preprocessing<typename T::part_type>& get_part();
part_type& get_part();
};
template<class sint, class sgf2n>
@@ -307,8 +311,6 @@ class Data_Files
}
void reset_usage() { usage.reset(); skipped.reset(); }
NamedCommStats comm_stats();
};
template<class T> inline
@@ -418,6 +420,7 @@ T Preprocessing<T>::get_bit()
template<class T>
T Preprocessing<T>::get_random()
{
assert(not usage.inputs.empty());
return get_random_from_inputs(usage.inputs.size());
}
@@ -429,10 +432,4 @@ inline void Data_Files<sint, sgf2n>::purge()
DataFb.purge();
}
template<class sint, class sgf2n>
NamedCommStats Data_Files<sint, sgf2n>::comm_stats()
{
return DataFp.comm_stats() + DataF2.comm_stats() + DataFb.comm_stats();
}
#endif

View File

@@ -3,6 +3,7 @@
#include "Processor/Data_Files.h"
#include "Processor/Processor.h"
#include "Processor/NoFilePrep.h"
#include "Protocols/dabit.h"
#include "Math/Setup.h"
#include "GC/BitPrepFiles.h"
@@ -30,6 +31,7 @@ Preprocessing<T>* Preprocessing<T>::get_new(
}
template<class T>
template<int>
Preprocessing<T>* Preprocessing<T>::get_new(
bool live_prep, const Names& N,
DataPositions& usage)
@@ -156,17 +158,7 @@ Data_Files<sint, sgf2n>::Data_Files(const Names& N) :
template<class sint, class sgf2n>
Data_Files<sint, sgf2n>::~Data_Files()
{
#ifdef VERBOSE
if (DataFp.data_sent())
cerr << "Sent for " << sint::type_string() << " preprocessing threads: " <<
DataFp.data_sent() * 1e-6 << " MB" << endl;
#endif
delete &DataFp;
#ifdef VERBOSE
if (DataF2.data_sent())
cerr << "Sent for " << sgf2n::type_string() << " preprocessing threads: " <<
DataF2.data_sent() * 1e-6 << " MB" << endl;
#endif
delete &DataF2;
delete &DataFb;
}
@@ -264,6 +256,8 @@ void Sub_Data_Files<T>::purge()
for (auto it : extended)
it.second.purge();
dabit_buffer.purge();
if (part != 0)
part->purge();
}
template<class T>
@@ -329,10 +323,10 @@ void Sub_Data_Files<T>::buffer_edabits_with_queues(bool strict, int n_bits,
}
template<class T>
Preprocessing<typename T::part_type>& Sub_Data_Files<T>::get_part()
typename Sub_Data_Files<T>::part_type& Sub_Data_Files<T>::get_part()
{
if (part == 0)
part = new Sub_Data_Files<typename T::part_type>(my_num, num_players,
part = new part_type(my_num, num_players,
get_prep_sub_dir<typename T::part_type>(num_players), this->usage,
thread_num);
return *part;

Some files were not shown because too many files have changed in this diff Show More