Maintenance.

This commit is contained in:
Marcel Keller
2023-05-09 14:49:52 +10:00
parent c62ab2ca1e
commit 6cc3fccef0
135 changed files with 1658 additions and 1062 deletions

3
.gitmodules vendored
View File

@@ -1,9 +1,6 @@
[submodule "SimpleOT"]
path = deps/SimpleOT
url = https://github.com/mkskeller/SimpleOT
[submodule "mpir"]
path = deps/mpir
url = https://github.com/wbhart/mpir
[submodule "Programs/Circuits"]
path = Programs/Circuits
url = https://github.com/mkskeller/bristol-fashion

View File

@@ -43,8 +43,6 @@ class RealProgramParty : public ProgramPartySpec<T>
bool one_shot;
size_t data_sent;
public:
static RealProgramParty& s();

View File

@@ -154,7 +154,6 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
while (next != GC::DONE_BREAK);
MC->Check(*P);
data_sent = P->total_comm().sent;
if (online_opts.verbose)
P->total_comm().print();
@@ -216,7 +215,7 @@ RealProgramParty<T>::~RealProgramParty()
delete prep;
delete garble_inputter;
delete garble_protocol;
cout << "Data sent = " << data_sent * 1e-6 << " MB" << endl;
garble_machine.print_comm(*this->P, this->P->total_comm());
T::MAC_Check::teardown();
}

View File

@@ -62,11 +62,13 @@ private:
#endif
};
#else
class BaseKeyVector : public vector<Key>
class BaseKeyVector : public CheckVector<Key>
{
typedef CheckVector<Key> super;
public:
BaseKeyVector(int size = 0) : vector<Key>(size, Key(0)) {}
void resize(int size) { vector<Key>::resize(size, Key(0)); }
BaseKeyVector(int size = 0) : super(size, Key(0)) {}
void resize(int size) { super::resize(size, Key(0)); }
};
#endif
@@ -296,7 +298,8 @@ public:
static void andm(GC::Processor<U>&, const BaseInstruction&)
{ throw runtime_error("andm not implemented"); }
static void run_tapes(const vector<int>&) { throw not_implemented(); }
static void run_tapes(const vector<int>&)
{ throw runtime_error("multi-threading not implemented"); }
// most BMR phases don't need actual input
template<class T>

View File

@@ -1,5 +1,20 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
## 0.3.6 (May 9, 2023)
- More extensive benchmarking outputs
- Replace MPIR by GMP
- Secure reading of edaBits from files
- Semi-honest client communication
- Back-propagation for average pooling
- Parallelized convolution
- Probabilistic truncation as in ABY3
- More balanced communication in Shamir secret sharing
- Avoid unnecessary communication in Dealer protocol
- Linear solver using Cholesky decomposition
- Accept .py files for compilation
- Fixed security bug: proper accounting for random elements
## 0.3.5 (Feb 16, 2023)
- Easier-to-use machine learning interface

21
CONFIG
View File

@@ -35,15 +35,32 @@ ARM := $(shell uname -m | grep x86; echo $$?)
OS := $(shell uname -s)
ifeq ($(MACHINE), x86_64)
ifeq ($(OS), Linux)
ifeq ($(shell cat /proc/cpuinfo | grep -q avx2; echo $$?), 0)
AVX_OT = 1
else
AVX_OT = 0
endif
else
AVX_OT = 0
endif
else
ARCH =
AVX_OT = 0
endif
ifeq ($(OS), Darwin)
BREW_CFLAGS += -I/usr/local/opt/openssl/include -I`brew --prefix`/opt/openssl/include -I`brew --prefix`/include
BREW_LDLIBS += -L/usr/local/opt/openssl/lib -L`brew --prefix`/lib -L`brew --prefix`/opt/openssl/lib
endif
ifeq ($(OS), Linux)
ifeq ($(ARM), 1)
ifeq ($(shell cat /proc/cpuinfo | grep -q aes; echo $$?), 0)
ARCH = -march=armv8.2-a+crypto
endif
endif
endif
USE_KOS = 0
# allow to set compiler in CONFIG.mine
@@ -66,7 +83,8 @@ endif
# Default for MAX_MOD_SZ is 10, which suffices for all Overdrive protocols
# MOD = -DMAX_MOD_SZ=10 -DGFP_MOD_SZ=5
LDLIBS = -lmpirxx -lmpir -lsodium $(MY_LDLIBS)
LDLIBS = -lgmpxx -lgmp -lsodium $(MY_LDLIBS)
LDLIBS += $(BREW_LDLIBS)
LDLIBS += -Wl,-rpath -Wl,$(CURDIR)/local/lib -L$(CURDIR)/local/lib
LDLIBS += -lboost_system -lssl -lcrypto
@@ -88,6 +106,7 @@ BOOST = -lboost_thread $(MY_BOOST)
endif
CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -I$(ROOT)/deps -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SSL_DIR) $(SECURE) -std=c++11 -Werror
CFLAGS += $(BREW_CFLAGS)
CPPFLAGS = $(CFLAGS)
LD = $(CXX)

View File

@@ -17,8 +17,10 @@ import math
class SecretBitsAF(base.RegisterArgFormat):
reg_type = 'sb'
name = 'sbit'
class ClearBitsAF(base.RegisterArgFormat):
reg_type = 'cb'
name = 'cbit'
base.ArgFormats['sb'] = SecretBitsAF
base.ArgFormats['sbw'] = SecretBitsAF

View File

@@ -338,16 +338,19 @@ class Merger:
d[j] = d[i]
def read(reg, n):
last_read[reg] = n
for dup in reg.duplicates:
if last_def[dup] != -1:
if last_def[dup] not in (-1, n):
add_edge(last_def[dup], n)
last_read[reg] = n
def write(reg, n):
last_def[reg] = n
for dup in reg.duplicates:
if last_read[dup] not in (-1, n):
add_edge(last_read[dup], n)
if id(dup) in [id(x) for x in block.instructions[n].get_used()] and \
last_read[dup] not in (-1, n):
add_edge(last_read[dup], n)
last_def[reg] = n
def handle_mem_access(addr, reg_type, last_access_this_kind,
last_access_other_kind):
@@ -434,13 +437,6 @@ class Merger:
# if options.debug:
# col = colordict[instr.__class__.__name__]
# G.add_node(n, color=col, label=str(instr))
for reg in inputs:
if reg.vector and instr.is_vec():
for i in reg.vector:
read(i, n)
else:
read(reg, n)
for reg in outputs:
if reg.vector and instr.is_vec():
for i in reg.vector:
@@ -448,6 +444,13 @@ class Merger:
else:
write(reg, n)
for reg in inputs:
if reg.vector and instr.is_vec():
for i in reg.vector:
read(i, n)
else:
read(reg, n)
# will be merged
if isinstance(instr, TextInputInstruction):
keep_text_order(instr, n)
@@ -556,18 +559,6 @@ class Merger:
if unused_result:
eliminate(i)
count += 1
# remove unnecessary stack instructions
# left by optimization with budget
if isinstance(inst, popint_class) and \
(not G.degree(i) or (G.degree(i) == 1 and
isinstance(instructions[list(G[i])[0]], StackInstruction))) \
and \
inst.args[0].can_eliminate and \
len(G.pred[i]) == 1 and \
isinstance(instructions[list(G.pred[i])[0]], pushint_class):
eliminate(list(G.pred[i])[0])
eliminate(i)
count += 2
if count > 0 and self.block.parent.program.verbose:
print('Eliminated %d dead instructions, among which %d opens: %s' \
% (count, open_count, dict(stats)))

View File

@@ -50,6 +50,9 @@ def set_variant(options):
do_precomp = False
elif variant is not None:
raise CompilerError('Unknown comparison variant: %s' % variant)
if const_rounds and instructions_base.program.options.binary:
raise CompilerError(
'Comparison variant choice incompatible with binary circuits')
def ld2i(c, n):
""" Load immediate 2^n into clear GF(p) register c """

View File

@@ -22,6 +22,7 @@ class Compiler:
self.custom_args = custom_args
self.build_option_parser()
self.VARS = {}
self.root = os.path.dirname(__file__) + '/..'
def build_option_parser(self):
parser = OptionParser(usage=self.usage)
@@ -269,7 +270,7 @@ class Compiler:
self.prog = Program(self.args, self.options, name=name)
if self.execute:
if self.options.execute in \
("emulate", "ring", "rep-field", "semi2k"):
("emulate", "ring", "rep-field"):
self.prog.use_trunc_pr = True
if self.options.execute in ("ring",):
self.prog.use_split(3)
@@ -405,7 +406,7 @@ class Compiler:
infile = open(self.prog.infile)
# make compiler modules directly accessible
sys.path.insert(0, "Compiler")
sys.path.insert(0, "%s/Compiler" % self.root)
# create the tapes
exec(compile(infile.read(), infile.name, "exec"), self.VARS)
@@ -477,15 +478,15 @@ class Compiler:
def local_execution(self, args=[]):
executable = self.executable_from_protocol(self.options.execute)
if not os.path.exists(executable):
if not os.path.exists("%s/%s" % (self.root, executable)):
print("Creating binary for virtual machine...")
try:
subprocess.run(["make", executable], check=True)
subprocess.run(["make", executable], check=True, cwd=self.root)
except:
raise CompilerError(
"Cannot produce %s. " % executable + \
"Note that compilation requires a few GB of RAM.")
vm = 'Scripts/%s.sh' % self.options.execute
vm = "%s/Scripts/%s.sh" % (self.root, self.options.execute)
os.execl(vm, vm, self.prog.name, *args)
def remote_execution(self, args=[]):
@@ -496,7 +497,7 @@ class Compiler:
from fabric import Connection
import subprocess
print("Creating static binary for virtual machine...")
subprocess.run(["make", "static/%s" % vm], check=True)
subprocess.run(["make", "static/%s" % vm], check=True, cwd=self.root)
# transfer files
import glob
@@ -519,7 +520,7 @@ class Compiler:
"mkdir -p %s/{Player-Data,Programs/{Bytecode,Schedules}} " % \
dest)
# executable
connection.put("static/%s" % vm, dest)
connection.put("%s/static/%s" % (self.root, vm), dest)
# program
dest += "/"
connection.put("Programs/Schedules/%s.sch" % self.prog.name,

View File

@@ -289,7 +289,7 @@ def BitDecRingRaw(a, k, m):
def BitDecRing(a, k, m):
bits = BitDecRingRaw(a, k, m)
# reversing to reduce number of rounds
return [types.sint.conv(bit) for bit in reversed(bits)][::-1]
return [types.sintbit.conv(bit) for bit in reversed(bits)][::-1]
def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None):
instructions_base.set_global_vector_size(a.size)
@@ -306,7 +306,7 @@ def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None):
def BitDecField(a, k, m, kappa, bits_to_compute=None):
res = BitDecFieldRaw(a, k, m, kappa, bits_to_compute)
return [types.sint.conv(bit) for bit in res]
return [types.sintbit.conv(bit) for bit in res]
@instructions_base.ret_cisc

View File

@@ -356,7 +356,17 @@ class reqbl(base.Instruction):
code = base.opcodes['REQBL']
arg_format = ['int']
class active(base.Instruction):
""" Indicate whether program is compatible with malicious-security
protocols.
:param: 0 for no, 1 for yes
"""
code = base.opcodes['ACTIVE']
arg_format = ['int']
class time(base.IOInstruction):
""" Output time since start of computation. """
code = base.opcodes['TIME']
arg_format = []
@@ -2418,9 +2428,10 @@ class matmulsm(matmul_base):
super(matmulsm, self).add_usage(req_node)
req_node.increment(('matmul', tuple(self.args[3:6])), 1)
class conv2ds(base.DataInstruction):
class conv2ds(base.DataInstruction, base.VarArgsInstruction, base.Mergeable):
""" Secret 2D convolution.
:param: number of arguments to follow (int)
:param: result (sint vector in row-first order)
:param: inputs (sint vector in row-first order)
:param: weights (sint vector in row-first order)
@@ -2436,10 +2447,12 @@ class conv2ds(base.DataInstruction):
:param: padding height (int)
:param: padding width (int)
:param: batch size (int)
:param: repeat from result...
"""
code = base.opcodes['CONV2DS']
arg_format = ['sw','s','s','int','int','int','int','int','int','int','int',
'int','int','int','int']
arg_format = itertools.cycle(['sw','s','s','int','int','int','int','int',
'int','int','int','int','int','int','int'])
data_type = 'triple'
is_vec = lambda self: True
@@ -2450,14 +2463,16 @@ class conv2ds(base.DataInstruction):
assert args[2].size == args[7] * args[8] * args[11]
def get_repeat(self):
return self.args[3] * self.args[4] * self.args[7] * self.args[8] * \
self.args[11] * self.args[14]
args = self.args
return sum(args[i+3] * args[i+4] * args[i+7] * args[i+8] * \
args[i+11] * args[i+14] for i in range(0, len(args), 15))
def add_usage(self, req_node):
super(conv2ds, self).add_usage(req_node)
args = self.args
req_node.increment(('matmul', (1, args[7] * args[8] * args[11],
args[14] * args[3] * args[4])), 1)
for i in range(0, len(self.args), 15):
args = self.args[i:i + 15]
req_node.increment(('matmul', (1, args[7] * args[8] * args[11],
args[14] * args[3] * args[4])), 1)
@base.vectorize
class trunc_pr(base.VarArgsInstruction):

View File

@@ -66,6 +66,7 @@ opcodes = dict(
PLAYERID = 0xE4,
USE_EDABIT = 0xE5,
USE_MATMUL = 0x1F,
ACTIVE = 0xE9,
# Addition
ADDC = 0x20,
ADDS = 0x21,
@@ -700,18 +701,23 @@ class RegisterArgFormat(ArgFormat):
class ClearModpAF(RegisterArgFormat):
reg_type = RegType.ClearModp
name = 'cint'
class SecretModpAF(RegisterArgFormat):
reg_type = RegType.SecretModp
name = 'sint'
class ClearGF2NAF(RegisterArgFormat):
reg_type = RegType.ClearGF2N
name = 'cgf2n'
class SecretGF2NAF(RegisterArgFormat):
reg_type = RegType.SecretGF2N
name = 'sgf2n'
class ClearIntAF(RegisterArgFormat):
reg_type = RegType.ClearInt
name = 'regint'
class IntArgFormat(ArgFormat):
n_bits = 32

View File

@@ -1226,7 +1226,7 @@ def while_loop(loop_body, condition, arg=None, g=None):
result = loop_body(arg)
if isinstance(result, MemValue):
result = result.read()
result.link(arg)
arg.update(result)
return condition(result)
if not isinstance(pre_condition, (bool,int)) or pre_condition:
if_statement(pre_condition, lambda: do_while(loop_fn, g=g))

View File

@@ -372,6 +372,7 @@ class Output(NoVariableLayer):
n = self.X.sizes[0]
if Y is None:
Y = self.Y
assert isinstance(Y, Array)
n_correct = MemValue(0)
n_printed = MemValue(0)
@for_range_opt(n)
@@ -1109,14 +1110,7 @@ class Square(ElementWiseLayer):
f_prime = staticmethod(lambda x: cfix(2, size=x.size) * x)
prime_type = sfix
class MaxPool(NoVariableLayer):
""" Fixed-point MaxPool layer.
:param shape: input shape (tuple/list of four int)
:param strides: strides (tuple/list of four int, first and last must be 1)
:param ksize: kernel size (tuple/list of four int, first and last must be 1)
:param padding: :py:obj:`'VALID'` (default) or :py:obj:`'SAME'`
"""
class PoolBase(NoVariableLayer):
def __init__(self, shape, strides=(1, 2, 2, 1), ksize=(1, 2, 2, 1),
padding='VALID'):
assert len(shape) == 4
@@ -1152,38 +1146,6 @@ class MaxPool(NoVariableLayer):
(type(self).__name__, self.X.sizes, self.strides,
self.ksize, self.padding)
def forward(self, batch=None, training=False):
if batch is None:
batch = Array.create_from(regint(0))
def process(pool, bi, k, i, j):
def m(a, b):
c = a[0] > b[0]
l = [c * x for x in a[1]]
l += [(1 - c) * x for x in b[1]]
return c.if_else(a[0], b[0]), l
red = util.tree_reduce(m, [(x[0], [1] if training else [])
for x in pool])
self.Y[bi][i][j][k] = red[0]
for ii, x in enumerate(red[1]):
self.comparisons[bi][k][i][j][ii] = x
self.traverse(batch, process)
def backward(self, compute_nabla_X=True, batch=None):
if compute_nabla_X:
self.nabla_X.alloc()
self.nabla_X.assign_all(0)
break_point()
def process(pool, bi, k, i, j):
for (x, h_in, w_in, h, w), c \
in zip(pool, self.comparisons[bi][k][i][j]):
hh = h * h_in
ww = w * w_in
res = h_in * w_in * c * self.nabla_Y[bi][i][j][k]
get_program().protect_memory(True)
self.nabla_X[bi][hh][ww][k] += res
get_program().protect_memory(False)
self.traverse(batch, process)
def traverse(self, batch, process):
need_padding = [self.strides[i] * (self.Y.sizes[i] - 1) + self.ksize[i] >
self.X.sizes[i] for i in range(4)]
@@ -1221,6 +1183,47 @@ class MaxPool(NoVariableLayer):
h_in, w_in, h, w])
process(pool, bi, k, i, j)
class MaxPool(PoolBase):
""" Fixed-point MaxPool layer.
:param shape: input shape (tuple/list of four int)
:param strides: strides (tuple/list of four int, first and last must be 1)
:param ksize: kernel size (tuple/list of four int, first and last must be 1)
:param padding: :py:obj:`'VALID'` (default), :py:obj:`'SAME'`, integer, or
list/tuple of integers
"""
def forward(self, batch=None, training=False):
if batch is None:
batch = Array.create_from(regint(0))
def process(pool, bi, k, i, j):
def m(a, b):
c = a[0] > b[0]
l = [c * x for x in a[1]]
l += [(1 - c) * x for x in b[1]]
return c.if_else(a[0], b[0]), l
red = util.tree_reduce(m, [(x[0], [1] if training else [])
for x in pool])
self.Y[bi][i][j][k] = red[0]
for ii, x in enumerate(red[1]):
self.comparisons[bi][k][i][j][ii] = x
self.traverse(batch, process)
def backward(self, compute_nabla_X=True, batch=None):
if compute_nabla_X:
self.nabla_X.alloc()
self.nabla_X.assign_all(0)
break_point()
def process(pool, bi, k, i, j):
for (x, h_in, w_in, h, w), c \
in zip(pool, self.comparisons[bi][k][i][j]):
hh = h * h_in
ww = w * w_in
res = h_in * w_in * c * self.nabla_Y[bi][i][j][k]
get_program().protect_memory(True)
self.nabla_X[bi][hh][ww][k] += res
get_program().protect_memory(False)
self.traverse(batch, process)
class Argmax(NoVariableLayer):
""" Fixed-point Argmax layer.
@@ -2058,6 +2061,12 @@ def easyMaxPool(input_shape, kernel_size, stride=None, padding=0):
or tuple/list of two int
"""
kernel_size, stride, padding = \
_standardize_pool_options(kernel_size, stride, padding)
return MaxPool(input_shape, [1] + list(stride) + [1],
[1] + list(kernel_size) + [1], padding)
def _standardize_pool_options(kernel_size, stride, padding):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
if isinstance(stride, int):
@@ -2066,8 +2075,7 @@ def easyMaxPool(input_shape, kernel_size, stride=None, padding=0):
stride = kernel_size
padding = padding.upper() if isinstance(padding, str) \
else padding
return MaxPool(input_shape, [1] + list(stride) + [1],
[1] + list(kernel_size) + [1], padding)
return kernel_size, stride, padding
class QuantAveragePool2d(QuantBase, AveragePool2d):
def input_params_from(self, player):
@@ -2075,14 +2083,47 @@ class QuantAveragePool2d(QuantBase, AveragePool2d):
for s in self.input_squant, self.output_squant:
s.get_params_from(player)
class FixAveragePool2d(FixBase, AveragePool2d):
class FixAveragePool2d(PoolBase, FixBase):
""" Fixed-point 2D AvgPool layer.
:param input_shape: input shape (tuple/list of four int)
:param output_shape: output shape (tuple/list of four int)
:param filter_size: filter size (tuple/list of two int)
:param strides: strides (tuple/list of two int)
"""
:param filter_size: filter size (int or tuple/list of two int)
:param strides: strides (int or tuple/list of two int)
:param padding: :py:obj:`'SAME'`, :py:obj:`'VALID'`, int,
or tuple/list of two int
"""
def __init__(self, input_shape, output_shape, filter_size, strides=(1, 1),
padding=0):
filter_size, strides, padding = \
_standardize_pool_options(filter_size, strides, padding)
PoolBase.__init__(self, input_shape, [1] + list(strides) + [1],
[1] + list(filter_size) + [1], padding)
self.pool_size = reduce(operator.mul, filter_size)
if output_shape:
assert self.Y.shape == list(output_shape)
def _forward(self, batch):
def process(pool, bi, k, i, j):
self.Y[bi][i][j][k] = sum(x[0] for x in pool) * (1 / self.pool_size)
self.traverse(batch, process)
def backward(self, compute_nabla_X=True, batch=None):
if compute_nabla_X:
self.nabla_X.alloc()
self.nabla_X.assign_all(0)
break_point()
def process(pool, bi, k, i, j):
part = self.nabla_Y[bi][i][j][k] * (1 / self.pool_size)
for x, h_in, w_in, h, w in pool:
hh = h * h_in
ww = w * w_in
res = h_in * w_in * part
get_program().protect_memory(True)
self.nabla_X[bi][hh][ww][k] += res
get_program().protect_memory(False)
self.traverse(batch, process)
class QuantReshape(QuantBase, BaseLayer):
def __init__(self, input_shape, _, output_shape):
@@ -2265,6 +2306,8 @@ class Optimizer:
:param data: sample data (:py:class:`Compiler.types.Matrix` with one row per sample)
:param top: return top prediction instead of probability distribution
:returns: sfix/sint Array (depening on :py:obj:`top`)
"""
if isinstance(self.layers[-1].Y, Array) or top:
if top:
@@ -2540,6 +2583,8 @@ class Optimizer:
@_no_mem_warnings
def run_by_args(self, program, n_runs, batch_size, test_X, test_Y,
acc_batch_size=None, reset=True):
MultiArray.disable_index_checks()
Array.check_indices = False
if acc_batch_size is None:
acc_batch_size = batch_size
depreciation = None
@@ -2943,6 +2988,10 @@ class keras:
return 'maxpool', {'pool_size': pool_size, 'strides': strides,
'padding': padding}
def AveragePooling2D(pool_size=2, strides=None, padding='valid'):
return 'avgpool', {'filter_size': pool_size, 'strides': strides,
'padding': padding}
def Dropout(rate):
l = math.log(rate, 2)
if int(l) != l:
@@ -3014,9 +3063,12 @@ class keras:
n_units = reduce(operator.mul,
layers[-1].Y.sizes[1:])
if i == len(self.layers) - 1:
if layer[2].get('activation', 'softmax') in \
('softmax', 'sigmoid'):
activation = layer[2].get('activation', None)
if activation in ('softmax', 'sigmoid'):
layer[2].pop('activation', None)
if activation == 'softmax' and layer[1][0] == 1:
raise CompilerError(
'softmax requires more than one output neuron')
layers.append(Dense(N, n_units, layer[1][0],
**layer[2]))
input_shape = layers[-1].Y.sizes
@@ -3041,6 +3093,9 @@ class keras:
layers.append(easyMaxPool(input_shape, pool_size,
strides, padding))
input_shape = layers[-1].Y.sizes
elif name == 'avgpool':
layers.append(FixAveragePool2d(input_shape, None, **layer[1]))
input_shape = layers[-1].Y.sizes
elif name == 'dropout':
layers.append(Dropout(batch_size, reduce(
operator.mul, layers[-1].Y.sizes[1:]),
@@ -3192,6 +3247,10 @@ def layers_from_torch(sequence, data_input_shape, batch_size, input_via=None):
layers.append(easyMaxPool(input_shape, item.kernel_size,
item.stride, item.padding))
input_shape = layers[-1].Y.shape
elif name == 'AvgPool2d':
layers.append(FixAveragePool2d(input_shape, None, item.kernel_size,
item.stride, item.padding))
input_shape = layers[-1].Y.shape
elif name == 'ReLU':
layers.append(Relu(input_shape))
elif name == 'Flatten':
@@ -3295,7 +3354,7 @@ class SGDLogistic(OneLayerSGD):
return super(SGDLogistic, self).predict(X)
class SGDLinear(OneLayerSGD):
""" Logistic regression using SGD.
""" Linear regression using SGD.
:param n_epochs: number of epochs
:param batch_size: batch size
@@ -3415,11 +3474,16 @@ def var(x):
return res.read()
def cholesky(A, reveal_diagonal=False):
""" Cholesky decomposition. """
""" Cholesky decomposition.
:returns: lower triangular matrix
"""
assert len(A.shape) == 2
assert A.shape[0] == A.shape[1]
L = A.same_shape()
L.assign_all(0)
diag_inv = A.value_type.Array(A.shape[0])
@for_range(A.shape[0])
def _(i):
@for_range(i + 1)
@@ -3429,10 +3493,47 @@ def cholesky(A, reveal_diagonal=False):
@if_e(i == j)
def _():
L[i][j] = mpc_math.sqrt(A[i][i] - sum)
diag_inv[i] = 1 / L[i][j]
if reveal_diagonal:
print_ln('L[%s][%s] = %s = sqrt(%s - %s)', i, j,
L[i][j].reveal(), A[i][j].reveal(), sum.reveal())
@else_
def _():
L[i][j] = (1.0 / L[j][j] * (A[i][j] - sum))
L[i][j] = (diag_inv[j] * (A[i][j] - sum))
return L
def solve_lower(A, b):
""" Linear solver where :py:obj:`A` is lower triangular quadratic. """
assert len(A.shape) == 2
assert A.shape[0] == A.shape[1]
assert len(b) == A.shape[0]
b = Array.create_from(b)
res = sfix.Array(len(b))
@for_range(len(b))
def _(i):
res[i] = b[i] / A[i][i]
b[:] -= res[i] * A.get_column(i)
return res
def solve_upper(A, b):
""" Linear solver where :py:obj:`A` is upper triangular quadratic. """
assert len(A.shape) == 2
assert A.shape[0] == A.shape[1]
assert len(b) == A.shape[0]
b = Array.create_from(b)
res = sfix.Array(len(b))
@for_range(len(b) - 1, -1, -1)
def _(i):
res[i] = b[i] / A[i][i]
b[:] -= res[i] * A.get_column(i)
return res
def solve_cholesky(A, b, debug=False):
""" Linear solver using Cholesky decomposition. """
L = cholesky(A, reveal_diagonal=debug)
if debug:
Optimizer.stat('L', L)
x = solve_lower(L, b)
if debug:
Optimizer.stat('intermediate', x)
return solve_upper(L.transpose(), x)

View File

@@ -661,7 +661,7 @@ def sqrt_simplified_fx(x):
h = h * r
H = 4 * (h * h)
if not x.round_nearest or (2 * f < k - 1):
if not x.round_nearest or (2 * x.f < x.k - 1):
H = (h < 2 ** (-x.f / 2) / 2).if_else(0, H)
H = H * x
@@ -806,9 +806,7 @@ def sqrt_fx(x_l, k, f):
@instructions_base.sfix_cisc
def sqrt(x, k=None, f=None):
"""
Returns the square root (sfix) of any given fractional
value as long as it can be rounded to a integral value
with :py:obj:`f` bits of decimal precision.
Square root.
:param x: fractional input (sfix).

View File

@@ -186,6 +186,8 @@ class Program(object):
self.input_files = {}
self.base_addresses = {}
self._protect_memory = False
self._always_active = True
self.active = True
if not self.options.cisc:
self.options.cisc = not self.options.optimize_hard
@@ -207,16 +209,14 @@ class Program(object):
return self.n_threads
def init_names(self, args):
# ignore path to file - source must be in Programs/Source
if "Programs" in os.listdir(os.getcwd()):
# compile prog in ./Programs/Source directory
self.programs_dir = "Programs"
else:
# assume source is in main SPDZ directory
self.programs_dir = sys.path[0] + "/Programs"
self.programs_dir = "Programs"
if self.verbose:
print("Compiling program in", self.programs_dir)
for dirname in (self.programs_dir, "Player-Data"):
if not os.path.exists(dirname):
os.mkdir(dirname)
# create extra directories if needed
for dirname in ["Public-Input", "Bytecode", "Schedules"]:
if not os.path.exists(self.programs_dir + "/" + dirname):
@@ -224,13 +224,29 @@ class Program(object):
if self.name is None:
self.name = args[0].split("/")[-1]
if self.name.endswith(".mpc"):
self.name = self.name[:-4]
exts = ".mpc", ".py"
for ext in exts:
if self.name.endswith(ext):
self.name = self.name[:-len(ext)]
if os.path.exists(args[0]):
self.infile = args[0]
else:
self.infile = self.programs_dir + "/Source/" + self.name + ".mpc"
infiles = []
for x in (self.programs_dir, sys.path[0] + "/Programs"):
for ext in exts:
filename = args[0]
if not filename.endswith(ext):
filename += ext
infiles += [x + "/Source/" + filename]
for f in infiles:
if os.path.exists(f):
self.infile = f
break
else:
raise CompilerError(
"found none of the potential input files: " +
", ".join("'%s'" % x for x in [args[0]] + infiles))
"""
self.name is input file name (minus extension) + any optional arguments.
Used to generate output filenames
@@ -479,6 +495,9 @@ class Program(object):
# finalize the memory
self.finalize_memory()
# communicate protocol compability
Compiler.instructions.active(self._always_active)
self.write_bytes()
if self.options.asmoutfile:
@@ -672,6 +691,19 @@ class Program(object):
logp = int(round(math.log(p, 2)))
return abs(p - 2 ** logp) / p < 2 ** -self.security
@property
def active(self):
""" Whether to use actively secure protocols. """
return self._active
@active.setter
def active(self, change):
self._always_active &= change
self._active = change
def semi_honest(self):
self._always_active = False
@staticmethod
def read_tapes(schedule):
m = re.search(r"([^/]*)\.mpc", schedule)
@@ -1454,6 +1486,9 @@ class Tape:
return Tape.Register(self.reg_type, Program.prog.curr_tape)
def link(self, other):
if Program.prog.options.noreallocate:
raise CompilerError("reallocation necessary for linking, "
"remove option -u")
self.duplicates |= other.duplicates
for dup in self.duplicates:
dup.duplicates = self.duplicates
@@ -1466,12 +1501,15 @@ class Tape:
:param other: any convertible type
"""
other = type(self)(other)
if isinstance(other, Tape.Register) and other.block != Program.prog.curr_block:
other = type(self)(other)
else:
other = self.conv(other)
if Program.prog.curr_block in [x.block for x in self.duplicates]:
self.program.start_new_basicblock()
if self.program != other.program:
raise CompilerError(
'cannot update register with one from another thread')
if other.block in [x.block for x in self.duplicates]:
self.program.start_new_basicblock()
self.link(other)
@property

View File

@@ -659,6 +659,7 @@ class _secret_structure(_structure):
traverse(x, level + 1)
traverse(content, 0)
f.write('\n')
f.flush()
if requested_shape is not None and \
list(shape) != list(requested_shape):
raise CompilerError('content contradicts shape')
@@ -2415,26 +2416,34 @@ class sint(_secret, _int):
def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType):
""" Securely obtain shares of values input by a client.
This uses the triple-based input protocol introduced by
`Damgård et al. <http://eprint.iacr.org/2015/1006>`_
`Damgård et al. <http://eprint.iacr.org/2015/1006>`_ unless
:py:obj:`program.active` is set to false, in which case
it uses random values to mask the clients' input.
:param n: number of inputs (int)
:param client_id: regint
:param size: vector size (default 1)
:returns: list of sint
"""
# send shares of a triple to client
triples = list(itertools.chain(*(sint.get_random_triple() for i in range(n))))
if program.active:
# send shares of a triple to client
triples = list(itertools.chain(*(sint.get_random_triple() for i in range(n))))
else:
triples = [sint.get_random() for i in range(n)]
sint.write_shares_to_socket(client_id, triples, message_type)
received = util.tuplify(cint.read_from_socket(client_id, n))
y = [0] * n
for i in range(n):
y[i] = received[i] - triples[i * 3]
y[i] = received[i] - triples[i * 3 if program.active else i]
return y
@classmethod
def reveal_to_clients(cls, clients, values):
""" Reveal securely to clients.
Uses :py:obj:`program.active` to determine whether to use
triples for active security.
:param clients: client ids (list or array)
:param values: list of sint to reveal
@@ -2445,8 +2454,11 @@ class sint(_secret, _int):
for value in values:
assert(value.size == values[0].size)
r = sint.get_random()
to_send += [value, r, value * r]
if program.active:
r = sint.get_random()
to_send += [value, r, value * r]
else:
to_send += [value]
if isinstance(clients, Array):
n_clients = clients.length
@@ -2844,7 +2856,7 @@ class sint(_secret, _int):
privateoutput(self.size, player, res._v, self)
return res
def private_division(self, divisor, active=True, dividend_length=None,
def private_division(self, divisor, active=None, dividend_length=None,
divisor_length=None):
""" Private integer division as per `Veugen and Abspoel
<https://doi.org/10.2478/popets-2021-0073>`_
@@ -2878,6 +2890,9 @@ class sint(_secret, _int):
z_shared = ((self << (l + sigma)) + h + r_pprime)
z = z_shared.reveal_to(0)
if active is None:
active = program.active
if active:
z_prime = [sint(x) for x in (z // d).bit_decompose(min_length)]
check = [(x * (1 - x)).reveal() == 0 for x in z_prime]
@@ -2893,6 +2908,7 @@ class sint(_secret, _int):
y_prime = sint.bit_compose(z_prime[:l + sigma])
y = sint.bit_compose(z_prime[l + sigma:])
else:
program.semi_honest()
y = sint(z // (d << (l + sigma)))
y_prime = sint((z // d) % (2 ** (l + sigma)))
@@ -3147,7 +3163,9 @@ class sgf2n(_secret, _gf2n):
for i in range(0, bit_length, step)]
one = cgf2n(1)
masked = sum([b * (one << (i * step)) for i,b in enumerate(random_bits)], self).reveal()
masked = sum([b * (one << (i * step))
for i,b in enumerate(random_bits)], self).reveal(
check=False)
masked_bits = masked.bit_decompose(bit_length,step=step)
return [m + r for m,r in zip(masked_bits, random_bits)]
@@ -3157,7 +3175,9 @@ class sgf2n(_secret, _gf2n):
for i in range(8)]
one = cgf2n(1)
wanted_positions = [0, 5, 10, 15, 20, 25, 30, 35]
masked = sum([b * (one << wanted_positions[i]) for i,b in enumerate(random_bits)], self).reveal()
masked = sum([b * (one << wanted_positions[i])
for i,b in enumerate(random_bits)], self).reveal(
check=False)
return [self.clear_type((masked >> wanted_positions[i]) & one) + r for i,r in enumerate(random_bits)]
for t in (sint, sgf2n):
@@ -4080,7 +4100,8 @@ class _single(_number, _secret_structure):
@vectorized_classmethod
def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType):
"""
Securely obtain shares of values input by a client. Assumes client
Securely obtain shares of values input by a client via
:py:func:`sint.receive_from_client`. Assumes client
has already converted values to integer representation.
:param n: number of inputs (int)
@@ -4095,7 +4116,7 @@ class _single(_number, _secret_structure):
@classmethod
def reveal_to_clients(cls, clients, values):
""" Reveal securely to clients.
""" Reveal securely to clients via :py:func:`sint.reveal_to_clients`.
:param clients: client ids (list or array)
:param values: list of values of this class
@@ -4556,7 +4577,7 @@ class sfix(_fix):
:py:class:`sfix`, and comparisons (``==, !=, <, <=, >, >=``),
returning :py:class:`sbitint`. The other operand can be any of
sfix/sint/cfix/regint/cint/int/float. It also supports ``abs()``
and ``**``, the latter for integer exponents.
and ``**``.
Note that the default precision (16 bits after the dot, 31 bits in
total) only allows numbers up to :math:`2^{31-16-1} \\approx
@@ -4669,6 +4690,8 @@ class sfix(_fix):
return self.v
def mul_no_reduce(self, other, res_params=None):
if not isinstance(other, type(self)):
return self * other
assert self.f == other.f
assert self.k == other.k
return self.unreduced(self.v * other.v)
@@ -4734,6 +4757,11 @@ class unreduced_sfix(_single):
nearest=sfix.round_nearest, signed=True)
return sfix._new(v, k=self.k - self.m, f=self.m)
def update(self, other):
assert self.k == other.k
assert self.m == other.m
self.v.update(other.v)
sfix.unreduced_type = unreduced_sfix
sfix.set_precision(16, 31)
@@ -4953,6 +4981,8 @@ class sfloat(_number, _secret_structure):
This uses integer operations internally, see :py:class:`sint` for security
considerations.
See `Aliasgari et al. <https://eprint.iacr.org/2012/405.pdf>`_ for
details.
The type supports basic arithmetic (``+, -, *, /``), returning
:py:class:`sfloat`, and comparisons (``==, !=, <, <=, >, >=``),
@@ -5459,6 +5489,9 @@ class Array(_vectorizable):
b.input_from(1)
a[:] += b[:]
Arrays aren't initialized on creation, you need to call
:py:func:`assign_all` to initialize them to a constant value.
"""
check_indices = True
@@ -5708,7 +5741,7 @@ class Array(_vectorizable):
mem_value = MemValue(value)
self.address = MemValue.if_necessary(self.address)
n_threads = 8 if use_threads and util.is_constant(self.length) and \
len(self) > 2**20 else None
len(self) > 2**20 and not program.options.garbled else None
@library.multithread(n_threads, self.length)
def _(base, size):
if use_vector:
@@ -5896,7 +5929,7 @@ class Array(_vectorizable):
self.assign_vector(self.get_vector().secure_shuffle())
def secure_permute(self, *args, **kwargs):
""" Secure permutate in place according to the security model.
""" Secure permute in place according to the security model.
See :py:func:`MultiArray.secure_shuffle` for references.
:param permutation: output of :py:func:`sint.get_secure_shuffle()`
@@ -6227,7 +6260,10 @@ class SubMultiArray(_vectorizable):
def same_shape(self):
""" :return: new multidimensional array with same shape and basic type """
return MultiArray(self.sizes, self.value_type)
if len(self.sizes) == 2:
return Matrix(*self.sizes, self.value_type)
else:
return MultiArray(self.sizes, self.value_type)
def get_part(self, start, size):
""" Part multi-array.
@@ -6400,7 +6436,7 @@ class SubMultiArray(_vectorizable):
pass
t.params = res_params
else:
if issubclass(self.value_type, _secret_structure):
if self.value_type == other.value_type:
t = self.value_type
else:
t = type(self.value_type(0) * other.value_type(0))
@@ -6435,10 +6471,12 @@ class SubMultiArray(_vectorizable):
# fallback for binary circuits
@library.for_range_opt(other.sizes[1])
def _(j):
res_matrix[i][j] = 0
@library.for_range_opt(self.sizes[1])
tmp = self[i][0].mul_no_reduce(other[0][j])
@library.for_range_opt(1, self.sizes[1])
def _(k):
res_matrix[i][j] += self[i][k] * other[k][j]
prod = self[i][k].mul_no_reduce(other[k][j])
tmp.iadd(prod)
res_matrix[i][j] = tmp.reduce_after_mul()
return res_matrix
elif isinstance(other, self.value_type):
return self * Array.create_from(other)
@@ -6780,6 +6818,9 @@ class MultiArray(SubMultiArray):
a[1].input_from(1)
a[2][:] = a[0][:] * a[1][:]
Arrays aren't initialized on creation, you need to call
:py:func:`assign_all` to initialize them to a constant value.
"""
@staticmethod
def disable_index_checks():
@@ -6817,6 +6858,9 @@ class Matrix(MultiArray):
:param columns: compile-time (int)
:param value_type: basic type of entries
Matrices aren't initialized on creation, you need to call
:py:func:`assign_all` to initialize them to a constant value.
"""
def __init__(self, rows, columns, value_type, debug=None, address=None):
MultiArray.__init__(self, [rows, columns], value_type, debug=debug, \

View File

@@ -47,23 +47,16 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
libboost-dev \
libboost-thread-dev \
libclang-dev \
libgmp-dev \
libntl-dev \
libsodium-dev \
libssl-dev \
libtool \
m4 \
texinfo \
yasm \
vim \
gdb \
valgrind \
&& rm -rf /var/lib/apt/lists/*
# mpir
COPY --from=initc3/mpir:55fe6a9 /usr/local/mpir/include/* /usr/local/include/
COPY --from=initc3/mpir:55fe6a9 /usr/local/mpir/lib/* /usr/local/lib/
COPY --from=initc3/mpir:55fe6a9 /usr/local/mpir/share/info/* /usr/local/share/info/
ENV MP_SPDZ_HOME /usr/src/MP-SPDZ
WORKDIR $MP_SPDZ_HOME

View File

@@ -46,6 +46,7 @@ void Client::send_private_inputs(const vector<T>& values)
octetStream os;
vector< vector<T> > triples(num_inputs, vector<T>(3));
vector<T> triple_shares(3);
bool active = true;
// Receive num_inputs triples from SPDZ
for (size_t j = 0; j < sockets.size(); j++)
@@ -61,9 +62,21 @@ void Client::send_private_inputs(const vector<T>& values)
cerr << "received " << os.get_length() << " from " << j << endl << flush;
#endif
if (j == 0)
{
if (os.get_length() == 3 * values.size() * T::size())
active = true;
else
active = false;
}
int n_expected = active ? 3 : 1;
if (os.get_length() != n_expected * T::size() * values.size())
throw runtime_error("unexpected data length in sending");
for (int j = 0; j < num_inputs; j++)
{
for (int k = 0; k < 3; k++)
for (int k = 0; k < n_expected; k++)
{
triple_shares[k].unpack(os);
triples[j][k] += triple_shares[k];
@@ -71,16 +84,18 @@ void Client::send_private_inputs(const vector<T>& values)
}
}
// Check triple relations (is a party cheating?)
for (int i = 0; i < num_inputs; i++)
{
if (T(triples[i][0] * triples[i][1]) != triples[i][2])
if (active)
// Check triple relations (is a party cheating?)
for (int i = 0; i < num_inputs; i++)
{
cerr << triples[i][2] << " != " << triples[i][0] << " * " << triples[i][1] << endl;
cerr << "Incorrect triple at " << i << ", aborting\n";
throw mac_fail();
if (T(triples[i][0] * triples[i][1]) != triples[i][2])
{
cerr << triples[i][2] << " != " << triples[i][0] << " * " << triples[i][1] << endl;
cerr << "Incorrect triple at " << i << ", aborting\n";
throw mac_fail();
}
}
}
// Send inputs + triple[0], so SPDZ can compute shares of each value
os.reset_write_head();
for (int i = 0; i < num_inputs; i++)
@@ -100,6 +115,7 @@ vector<U> Client::receive_outputs(int n)
{
vector<T> triples(3 * n);
octetStream os;
bool active = true;
for (auto& socket : sockets)
{
os.reset_write_head();
@@ -107,7 +123,20 @@ vector<U> Client::receive_outputs(int n)
#ifdef VERBOSE_COMM
cout << "received " << os.get_length() << endl << flush;
#endif
for (int j = 0; j < 3 * n; j++)
if (socket == sockets[0])
{
if (os.get_length() == (size_t) 3 * n * T::size())
active = true;
else
active = false;
}
int n_expected = n * (active ? 3 : 1);
if (os.get_length() != (size_t) n_expected * T::size())
throw runtime_error("unexpected data length in receiving");
for (int j = 0; j < n_expected; j++)
{
T value;
value.unpack(os);
@@ -115,16 +144,24 @@ vector<U> Client::receive_outputs(int n)
}
}
vector<U> output_values;
for (int i = 0; i < 3 * n; i += 3)
if (active)
{
if (T(triples[i] * triples[i + 1]) != triples[i + 2])
vector<U> output_values;
for (int i = 0; i < 3 * n; i += 3)
{
cerr << "Unable to authenticate output value as correct, aborting." << endl;
throw mac_fail();
if (T(triples[i] * triples[i + 1]) != triples[i + 2])
{
cerr << "Unable to authenticate output value as correct, aborting." << endl;
throw mac_fail();
}
output_values.push_back(triples[i]);
}
output_values.push_back(triples[i]);
}
return output_values;
return output_values;
}
else
{
triples.resize(n);
return triples;
}
}

View File

@@ -34,17 +34,25 @@ class Client:
os = octetStream()
for socket in self.sockets:
os.Receive(socket)
if socket == self.sockets[0]:
active = os.get_length() == 3 * n * T.size()
n_expected = 3 if active else 1
if os.get_length() != n_expected * T.size() * n:
import sys
print (os.get_length(), n_expected, T.size(), n, active, file=sys.stderr)
raise Exception('unexpected data length')
for triple in triples:
for i in range(3):
for i in range(n_expected):
t = T()
t.unpack(os)
triple[i] += t
res = []
for triple in triples:
prod = triple[0] * triple[1]
if prod != triple[2]:
raise Exception(
'invalid triple, diff %s' % hex(prod.v - triple[2].v))
if active:
for triple in triples:
prod = triple[0] * triple[1]
if prod != triple[2]:
raise Exception(
'invalid triple, diff %s' % hex(prod.v - triple[2].v))
return triples
def send_private_inputs(self, values):
@@ -68,6 +76,9 @@ class octetStream:
if value is not None:
self.buf += value
def get_length(self):
return len(self.buf)
def reset_write_head(self):
self.buf = b''
self.ptr = 0

View File

@@ -27,6 +27,10 @@ class Domain:
def __neq__(self, other):
return self.v != other.v
@classmethod
def size(cls):
return cls.n_bytes
def unpack(self, os):
self.v = 0
buf = os.consume(self.n_bytes)

View File

@@ -1,5 +1,4 @@
#include "Ciphertext.h"
#include "PPData.h"
#include "P2Data.h"
#include "Tools/Exceptions.h"
@@ -143,6 +142,5 @@ void Ciphertext::rerandomize(const FHE_PK& pk)
template void mul(Ciphertext& ans,const Plaintext<gfp,FFT_Data,bigint>& a,const Ciphertext& c);
template void mul(Ciphertext& ans,const Plaintext<gfp,PPData,bigint>& a,const Ciphertext& c);
template void mul(Ciphertext& ans, const Plaintext<gf2n_short, P2Data, int>& a,
const Ciphertext& c);

View File

@@ -259,6 +259,3 @@ void BFFT(vector<modp>& ans,const vector<modp>& a,const FFT_Data& FFTD,bool forw
else
{ throw crash_requested(); }
}

View File

@@ -83,6 +83,8 @@ void FFT_Data::init(const Ring& Rg,const Zp_Data& PrD)
for (int r=0; r<2; r++)
{ FFT_Iter(b[r],twop,two_root[0],PrD); }
}
else
throw bad_value();
}
}

View File

@@ -2,7 +2,6 @@
#include "FHE_Keys.h"
#include "Ciphertext.h"
#include "P2Data.h"
#include "PPData.h"
#include "FFT_Data.h"
#include "Math/modp.hpp"
@@ -406,29 +405,17 @@ bigint FHE_SK::get_noise(const Ciphertext& c)
}
#define X(FD) \
template void FHE_PK::encrypt(Ciphertext&, const Plaintext_<FD>& mess, \
const Random_Coins& rc) const; \
template Ciphertext FHE_PK::encrypt(const Plaintext_<FD>& mess) const; \
template Plaintext_<FD> FHE_SK::decrypt(const Ciphertext& c, \
const FD& FieldD); \
template void FHE_SK::decrypt(Plaintext_<FD>& res, \
const Ciphertext& c) const; \
template void FHE_SK::decrypt_any(Plaintext_<FD>& res, \
const Ciphertext& c); \
template void FHE_SK::check(const FHE_PK& pk, const FD&);
template void FHE_PK::encrypt(Ciphertext&, const Plaintext_<FFT_Data>& mess,
const Random_Coins& rc) const;
template void FHE_PK::encrypt(Ciphertext&, const Plaintext_<P2Data>& mess,
const Random_Coins& rc) const;
template Ciphertext FHE_PK::encrypt(const Plaintext_<FFT_Data>& mess,
const Random_Coins& rc) const;
template Ciphertext FHE_PK::encrypt(const Plaintext_<FFT_Data>& mess) const;
template Ciphertext FHE_PK::encrypt(const Plaintext_<P2Data>& mess) const;
template void FHE_SK::decrypt(Plaintext_<FFT_Data>&, const Ciphertext& c) const;
template void FHE_SK::decrypt(Plaintext_<P2Data>&, const Ciphertext& c) const;
template Plaintext_<FFT_Data> FHE_SK::decrypt(const Ciphertext& c,
const FFT_Data& FieldD);
template Plaintext_<P2Data> FHE_SK::decrypt(const Ciphertext& c,
const P2Data& FieldD);
template void FHE_SK::decrypt_any(Plaintext_<FFT_Data>& res,
const Ciphertext& c);
template void FHE_SK::decrypt_any(Plaintext_<P2Data>& res,
const Ciphertext& c);
template void FHE_SK::check(const FHE_PK& pk, const FFT_Data&);
template void FHE_SK::check(const FHE_PK& pk, const P2Data&);
X(FFT_Data)
X(P2Data)

View File

@@ -119,12 +119,6 @@ const P2Data& FHE_Params::get_plaintext_field_data() const
throw not_implemented();
}
template<>
const PPData& FHE_Params::get_plaintext_field_data() const
{
throw not_implemented();
}
bigint FHE_Params::get_plaintext_modulus() const
{
return fd.get_prime();

View File

@@ -248,49 +248,6 @@ matrix inv(const matrix& A)
}
vector<modp> solve(modp_matrix& A,const Zp_Data& PrD)
{
unsigned int n=A.size();
if ((n+1)!=A[0].size()) { throw invalid_params(); }
modp t,ti;
for (unsigned int r=0; r<n; r++)
{ // Find pivot
unsigned int p=r;
while (isZero(A[p][r],PrD)) { p++; }
// Do pivoting
if (p!=r)
{ for (unsigned int c=0; c<n+1; c++)
{ t=A[p][c]; A[p][c]=A[r][c]; A[r][c]=t; }
}
// Make Lcoeff=1
Inv(ti,A[r][r],PrD);
for (unsigned int c=0; c<n+1; c++)
{ Mul(A[r][c],A[r][c],ti,PrD); }
// Now kill off other entries in this column
for (unsigned int rr=0; rr<n; rr++)
{ if (r!=rr)
{ for (unsigned int c=0; c<n+1; c++)
{ Mul(t,A[rr][c],A[r][r],PrD);
Sub(A[rr][c],A[rr][c],t,PrD);
}
}
}
}
// Sanity check and extract answer
vector<modp> ans;
ans.resize(n);
for (unsigned int i=0; i<n; i++)
{ for (unsigned int j=0; j<n; j++)
{ if (i!=j && !isZero(A[i][j],PrD)) { throw bad_value(); }
else if (!isOne(A[i][j],PrD)) { throw bad_value(); }
}
ans[i]=A[i][n];
}
return ans;
}
/* Input matrix is assumed to have more rows than columns */
void pinv(imatrix& Ai,const imatrix& B)

View File

@@ -10,7 +10,6 @@ using namespace std;
#include "Tools/BitVector.h"
typedef vector< vector<bigint> > matrix;
typedef vector< vector<modp> > modp_matrix;
class imatrix : public vector< BitVector >
{
@@ -39,13 +38,6 @@ void print(const imatrix& S);
// requires column operations to create the inverse
matrix inv(const matrix& A);
// Another special routine for modp matrices.
// Solves
// Ax=b
// Assumes A is unimodular, square and only requires row operations to
// create the inverse. In put is C=(A||b) and the routines alters A
vector<modp> solve(modp_matrix& C,const Zp_Data& PrD);
// Finds a pseudo-inverse of a matrix A modulo 2
// - Input matrix is assumed to have more rows than columns
void pinv(imatrix& Ai,const imatrix& A);

View File

@@ -742,135 +742,3 @@ void load_or_generate(P2Data& P2D, const Ring& R)
P2D.store(R);
}
}
#ifdef USE_NTL
/*
* Create FHE parameters for a general plaintext modulus p
* Basically this is for general large primes only
*/
void SPDZ_Data_Setup_Char_p_General(Ring& R, PPData& PPD, bigint& pr0,
bigint& pr1, int n, int sec, bigint& p, FHE_Params& params)
{
cout << "Setting up parameters" << endl;
int lgp=numBits(p);
int mm,idx;
// mm is the minimum value of m we will accept
if (lgp<48)
{ mm=100; // Test case
idx=0;
}
else if (lgp <96)
{ mm=8192;
idx=1;
}
else if (lgp<192)
{ mm=16384;
idx=2;
}
else if (lgp<384)
{ mm=16384;
idx=3;
}
else if (lgp<768)
{ mm=32768;
idx=4;
}
else
{ throw invalid_params(); }
// Now find the small factors of p-1 and their exponents
bigint t=p-1;
vector<long> primes(100),exp(100);
PrimeSeq s;
long pr;
pr=s.next();
int len=0;
while (pr<2*mm)
{ int e=0;
while ((t%pr)==0)
{ e++;
t=t/pr;
}
if (e!=0)
{ primes[len]=pr;
exp[len]=e;
if (len!=0) { cout << " * "; }
cout << pr << "^" << e << flush;
len++;
}
pr=s.next();
}
cout << endl;
// We want to find the best m which divides pr-1, such that
// - 2*m > phi(m) > mm
// - m has the smallest number of factors
vector<int> ee;
ee.resize(len);
for (int i=0; i<len; i++) { ee[i]=0; }
int min_hwt=-1,m=-1,bphi_m=-1,bmx=-1;
bool flag=true;
while (flag)
{ int cand_m=1,hwt=0,mx=0;
for (int i=0; i<len; i++)
{ //cout << ee[i] << " ";
if (ee[i]!=0)
{ hwt++;
for (int j=0; j<ee[i]; j++)
{ cand_m*=primes[i]; }
if (ee[i]>mx) { mx=ee[i]; }
}
}
// Put "if" here to stop searching for things which will never work
if (cand_m>1 && cand_m<4*mm)
{ //cout << " : " << cand_m << " : " << hwt << flush;
int phim=phi_N(cand_m);
//cout << " : " << phim << " : " << mm << endl;
if (phim>mm && phim<3*mm)
{ if (m==-1 || hwt<min_hwt || (hwt==min_hwt && mx<bmx) || (hwt==min_hwt && mx==bmx && phim<bphi_m))
{ m=cand_m;
min_hwt=hwt;
bphi_m=phim;
bmx=mx;
//cout << "\t OK" << endl;
}
}
}
else
{ //cout << endl;
}
int i=0;
ee[i]=ee[i]+1;
while (ee[i]>exp[i] && flag)
{ ee[i]=0;
i++;
if (i==len) { flag=false; i=0; }
else { ee[i]=ee[i]+1; }
}
}
if (m==-1)
{ throw bad_value(); }
cout << "Chosen value of m=" << m << "\t\t phi(m)=" << bphi_m << " : " << min_hwt << " : " << bmx << endl;
Parameters parameters(n, lgp, sec);
parameters.SPDZ_Data_Setup_Char_p_Sub(idx,m,p,params);
int mx=0;
for (int i=0; i<R.phi_m(); i++)
{ if (mx<R.Phi()[i]) { mx=R.Phi()[i]; } }
cout << "Max Coeff = " << mx << endl;
init(R, m, true);
Zp_Data Zp(p);
PPD.init(R,Zp);
gfp::init_field(p);
pr0 = parameters.pr0;
pr1 = parameters.pr1;
}
#endif

View File

@@ -4,7 +4,6 @@
#include "FHE/Ring.h"
#include "FHE/FFT_Data.h"
#include "FHE/P2Data.h"
#include "FHE/PPData.h"
#include "FHE/FHE_Params.h"
/* Routines to set up key sizes given the number of players n
@@ -68,10 +67,6 @@ class GF2X;
NTL::GF2X get_F(const Ring& Rg);
// For use when we want p to be a specific value
void SPDZ_Data_Setup_Char_p_General(Ring& R, PPData& PPD, bigint& pr0,
bigint& pr1, int n, int sec, bigint& p, FHE_Params& params);
// generate moduli according to lengths and other parameters
void generate_moduli(bigint& pr0, bigint& pr1, const int m,
const bigint p, const int lg2p0, const int lg2p1);

View File

@@ -1,99 +0,0 @@
#include "FHE/Subroutines.h"
#include "FHE/PPData.h"
#include "FHE/FFT.h"
#include "FHE/Matrix.h"
#include "Math/modp.hpp"
void PPData::init(const Ring& Rg,const Zp_Data& PrD)
{
R=Rg;
prData=PrD;
root=Find_Primitive_Root_m(Rg.m(),Rg.Phi(),PrD);
}
void PPData::to_eval(vector<modp>& elem) const
{
if (elem.size()!= (unsigned) R.phi_m())
{ throw params_mismatch(); }
throw not_implemented();
/*
vector<modp> ans;
ans.resize(R.phi_m());
modp x=root;
for (int i=0; i<R.phi_m(); i++)
{ ans[i]=elem[R.phi_m()-1];
for (int j=1; j<R.phi_m(); j++)
{ Mul(ans[i],ans[i],x,prData);
Add(ans[i],ans[i],elem[R.phi_m()-j-1],prData);
}
Mul(x,x,root,prData);
}
elem=ans;
*/
}
void PPData::from_eval(vector<modp>&) const
{
// avoid warning
throw not_implemented();
/*
modp_matrix A;
int n=phi_m();
A.resize(n, vector<modp>(n+1) );
modp x=root;
for (int i=0; i<n; i++)
{ assignOne(A[0][i],prData);
for (int j=1; j<n; j++)
{ Mul(A[j][i],A[j-1][i],x,prData); }
Mul(x,x,root,prData);
A[i][n]=elem[i];
}
elem=solve(A,prData);
*/
}
void PPData::reset_iteration()
{
pow = 1;
theta = {root, prData};
thetaPow = theta;
}
void PPData::next_iteration()
{
do
{ thetaPow *= (theta);
pow++;
}
while (gcd(pow,m())!=1);
}
gfp PPData::get_evaluation(const vector<bigint>& mess) const
{
// Uses Horner's rule
gfp ans;
ans = mess[mess.size()-1];
gfp coeff;
for (int j=mess.size()-2; j>=0; j--)
{ ans *= (thetaPow);
coeff = mess[j];
ans += (coeff);
}
return ans;
}

View File

@@ -1,61 +0,0 @@
#ifndef _PPData
#define _PPData
#include "Math/modp.h"
#include "Math/Zp_Data.h"
#include "Math/gfpvar.h"
#include "Math/fixint.h"
#include "FHE/Ring.h"
#include "FHE/FFT_Data.h"
/* Class for holding modular arithmetic data wrt the ring
*
* It also holds the ring
*/
class PPData
{
public:
typedef gfp T;
typedef bigint S;
typedef typename FFT_Data::poly_type poly_type;
Ring R;
Zp_Data prData;
modp root; // m'th Root of Unity mod pr
void init(const Ring& Rg,const Zp_Data& PrD);
PPData() { ; }
PPData(const Ring& Rg,const Zp_Data& PrD)
{ init(Rg,PrD); }
const Zp_Data& get_prD() const { return prData; }
const bigint& get_prime() const { return prData.pr; }
int phi_m() const { return R.phi_m(); }
int m() const { return R.m(); }
int num_slots() const { return R.phi_m(); }
int p(int i) const { return R.p(i); }
int p_inv(int i) const { return R.p_inv(i); }
const vector<int>& Phi() const { return R.Phi(); }
// Convert input vector from poly to evaluation representation
// - Uses naive method and not FFT, we only use this rarely in any case
void to_eval(vector<modp>& elem) const;
void from_eval(vector<modp>& elem) const;
// Following are used to iteratively get slots, as we use PPData when
// we do not have an efficient FFT algorithm
gfp thetaPow,theta;
int pow;
void reset_iteration();
void next_iteration();
gfp get_evaluation(const vector<bigint>& mess) const;
};
#endif

View File

@@ -1,7 +1,6 @@
#include "FHE/Plaintext.h"
#include "FHE/Ring_Element.h"
#include "FHE/PPData.h"
#include "FHE/P2Data.h"
#include "FHE/Rq_Element.h"
#include "FHE_Keys.h"
@@ -85,39 +84,6 @@ void Plaintext<gfp,FFT_Data,bigint>::to_poly() const
}
template<>
void Plaintext<gfp,PPData,bigint>::from_poly() const
{
if (type!=Polynomial) { return; }
vector<modp> aa((*Field_Data).phi_m());
for (unsigned int i=0; i<aa.size(); i++)
{ to_modp(aa[i], bigint::tmp = b[i], (*Field_Data).prData); }
(*Field_Data).to_eval(aa);
a.resize(num_slots());
for (unsigned int i=0; i<aa.size(); i++)
a[i] = {aa[i], Field_Data->get_prD()};
type=Both;
}
template<>
void Plaintext<gfp,PPData,bigint>::to_poly() const
{
if (type!=Evaluation) { return; }
cout << "This is VERY inefficient to convert a plaintext to poly representation" << endl;
vector<modp> bb((*Field_Data).phi_m());
for (unsigned int i=0; i<bb.size(); i++)
{ bb[i]=a[i].get(); }
(*Field_Data).from_eval(bb);
for (unsigned int i=0; i<bb.size(); i++)
{
to_bigint(bigint::tmp,bb[i],(*Field_Data).prData);
b[i] = bigint::tmp;
}
type=Both;
}
template<>
void Plaintext<gf2n_short,P2Data,int>::from_poly() const
@@ -385,34 +351,6 @@ void add(Plaintext<gfp,FFT_Data,bigint>& z,const Plaintext<gfp,FFT_Data,bigint>&
}
template<>
void add(Plaintext<gfp,PPData,bigint>& z,const Plaintext<gfp,PPData,bigint>& x,
const Plaintext<gfp,PPData,bigint>& y)
{
if (z.Field_Data!=x.Field_Data) { throw field_mismatch(); }
if (z.Field_Data!=y.Field_Data) { throw field_mismatch(); }
if (x.type==Both && y.type!=Both) { z.type=y.type; }
else if (y.type==Both && x.type!=Both) { z.type=x.type; }
else if (x.type!=y.type) { throw rep_mismatch(); }
else { z.type=x.type; }
if (z.type!=Polynomial)
{
z.a.resize(z.num_slots());
for (unsigned int i=0; i<z.a.size(); i++)
{ z.a[i] = (x.a[i] + y.a[i]); }
}
if (z.type!=Evaluation)
{ for (unsigned int i=0; i<z.b.size(); i++)
{ z.b[i]=x.b[i]+y.b[i];
if (z.b[i]>(*z.Field_Data).get_prime())
{ z.b[i]-=(*z.Field_Data).get_prime(); }
}
}
}
template<>
@@ -475,36 +413,6 @@ void sub(Plaintext<gfp,FFT_Data,bigint>& z,const Plaintext<gfp,FFT_Data,bigint>&
template<>
void sub(Plaintext<gfp,PPData,bigint>& z,const Plaintext<gfp,PPData,bigint>& x,
const Plaintext<gfp,PPData,bigint>& y)
{
if (z.Field_Data!=x.Field_Data) { throw field_mismatch(); }
if (z.Field_Data!=y.Field_Data) { throw field_mismatch(); }
if (x.type==Both && y.type!=Both) { z.type=y.type; }
else if (y.type==Both && x.type!=Both) { z.type=x.type; }
else if (x.type!=y.type) { throw rep_mismatch(); }
else { z.type=x.type; }
z.allocate();
if (z.type!=Polynomial)
{
z.a.resize(z.num_slots());
for (unsigned int i=0; i<z.a.size(); i++)
{ z.a[i] = (x.a[i] - y.a[i]); }
}
if (z.type!=Evaluation)
{ for (unsigned int i=0; i<z.b.size(); i++)
{ z.b[i]=x.b[i]-y.b[i];
if (z.b[i]<0)
{ z.b[i]+=(*z.Field_Data).get_prime(); }
}
}
}
template<>
@@ -572,23 +480,6 @@ void Plaintext<gfp,FFT_Data,bigint>::negate()
}
}
template<>
void Plaintext<gfp,PPData,bigint>::negate()
{
if (type!=Polynomial)
{
a.resize(num_slots());
for (unsigned int i=0; i<a.size(); i++)
{ a[i].negate(); }
}
if (type!=Evaluation)
{ for (unsigned int i=0; i<b.size(); i++)
{ if (b[i]!=0)
{ b[i]=(*Field_Data).get_prime()-b[i]; }
}
}
}
template<>
@@ -731,12 +622,6 @@ template void mul(Plaintext<gfp,FFT_Data,bigint>& z,const Plaintext<gfp,FFT_Data
template class Plaintext<gfp,PPData,bigint>;
template void mul(Plaintext<gfp,PPData,bigint>& z,const Plaintext<gfp,PPData,bigint>& x,const Plaintext<gfp,PPData,bigint>& y);
template class Plaintext<gf2n_short,P2Data,int>;
template void mul(Plaintext<gf2n_short,P2Data,int>& z,const Plaintext<gf2n_short,P2Data,int>& x,const Plaintext<gf2n_short,P2Data,int>& y);

View File

@@ -274,7 +274,7 @@ void EncCommit<T,FD,S>::Create_More() const
(*P).Broadcast_Receive(ctx_Delta);
// Output the ctx_Delta to a file
sprintf(filename,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,i,thread);
snprintf(filename,1024,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,i,thread);
ofstream outf(filename);
for (int j=0; j<(*P).num_players(); j++)
{
@@ -308,7 +308,7 @@ void EncCommit<T,FD,S>::Create_More() const
octetStream occ,ctx_D;
for (int i=0; i<2*TT; i++)
{ if (open[i]==1)
{ sprintf(filename,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,i,thread);
{ snprintf(filename,1024,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,i,thread);
ifstream inpf(filename);
for (int j=0; j<(*P).num_players(); j++)
{
@@ -386,7 +386,7 @@ void EncCommit<T,FD,S>::Create_More() const
Ciphertext enc1(params),enc2(params),eDelta(params);
octetStream oe1,oe2;
sprintf(filename,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,index[b*i+j],thread);
snprintf(filename,1024,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,index[b*i+j],thread);
ifstream inpf(filename);
for (int k=0; k<(*P).num_players(); k++)
{

View File

@@ -26,6 +26,7 @@ public:
void set_protocol(DealerSecret::Protocol& protocol)
{
P = &protocol.P;
BufferPrep<DealerSecret>::P = P;
}
void buffer_triples()

View File

@@ -183,6 +183,8 @@ public:
NoShare operator-(const NoShare&) const { fail(); return {}; }
NoShare operator*(const NoValue&) const { fail(); return {}; }
NoShare operator^(const NoShare&) const { fail(); return {}; }
NoShare operator&(int) const { fail(); return {}; }
NoShare operator>>(int) const { fail(); return {}; }

View File

@@ -123,7 +123,9 @@ BreakType Program::execute(Processor<T>& Proc, U& dynamic_memory,
}
time++;
#ifdef DEBUG_COMPLEXITY
cout << "complexity at " << time << ": " << Proc.complexity << endl;
cout << T::part_type::name() << " complexity at " << time << ": " <<
Proc.complexity << " after " << hex <<
instruction.get_opcode() << dec << endl;
#endif
}
while (Proc.complexity < (size_t) OnlineOptions::singleton.batch_size);

View File

@@ -39,6 +39,7 @@ void RepPrep<T>::set_protocol(typename T::Protocol& protocol)
return;
this->protocol = new ReplicatedBase(protocol.P);
this->P = &protocol.P;
}
template<class T>

View File

@@ -89,7 +89,7 @@ void Secret<T>::random(int n_bits, int128 share)
{
(void)share;
if (n_bits > 128)
throw not_implemented();
throw runtime_error("too many bits");
resize_regs(n_bits);
for (int i = 0; i < n_bits; i++)
get_reg(i).random();

View File

@@ -37,6 +37,7 @@ void SemiPrep::set_protocol(SemiSecret::Protocol& protocol)
protocol.P.N, -1, OnlineOptions::singleton.batch_size,
1, params, {}, &protocol.P);
triple_generator->multi_threaded = false;
this->P = &protocol.P;
}
void SemiPrep::buffer_triples()

View File

@@ -103,9 +103,7 @@ void ThreadMaster<T>::run()
machine.print_timers();
cerr << "Data sent = " << stats.sent * 1e-6 << " MB" << endl;
machine.print_global_comm(*P, stats);
machine.print_comm(*P, stats);
delete P;
}

View File

@@ -105,6 +105,11 @@ public:
*this = a + b;
}
This operator^(const This& other) const
{
return *this + other;
}
This& operator^=(const This& other)
{
*this += other;

View File

@@ -146,6 +146,7 @@
X(THRESHOLD, I0 = T::threshold(Thread<T>::s().P->num_players())) \
X(PLAYERID, I0 = Thread<T>::s().P->my_num()) \
X(CRASH, if (I0.get()) throw crash_requested()) \
X(ACTIVE, ) \
#define INSTRUCTIONS BIT_INSTRUCTIONS GC_INSTRUCTIONS

View File

@@ -365,11 +365,11 @@ void OTMachine::run()
{
BitVector receiver_output, sender_output;
char filename[1024];
sprintf(filename, RECEIVER_INPUT, my_num);
snprintf(filename, 1024, RECEIVER_INPUT, my_num);
ofstream outf(filename);
receiverInput.output(outf, false);
outf.close();
sprintf(filename, RECEIVER_OUTPUT, my_num);
snprintf(filename, 1024, RECEIVER_OUTPUT, my_num);
outf.open(filename);
for (unsigned int i = 0; i < nOTs; i++)
{
@@ -380,7 +380,7 @@ void OTMachine::run()
for (int i = 0; i < 2; i++)
{
sprintf(filename, SENDER_OUTPUT, my_num, i);
snprintf(filename,1024, SENDER_OUTPUT, my_num, i);
outf.open(filename);
for (int j = 0; j < nOTs; j++)
{

View File

@@ -116,7 +116,7 @@ mascot: mascot-party.x spdz2k mama-party.x
ifeq ($(OS), Darwin)
setup: mac-setup
else
setup: boost mpir linux-machine-setup
setup: boost linux-machine-setup
endif
tldr: setup
@@ -297,27 +297,6 @@ deps/SimplestOT_C/ref10/Makefile:
Programs/Circuits:
git submodule update --init Programs/Circuits
.PHONY: mpir-setup mpir-global
mpir-setup: deps/mpir/Makefile
deps/mpir/Makefile:
git submodule update --init deps/mpir || git clone https://github.com/wbhart/mpir deps/mpir
cd deps/mpir; \
autoreconf -i; \
autoreconf -i
- $(MAKE) -C deps/mpir clean
mpir-global: mpir-setup
cd deps/mpir; \
./configure --enable-cxx;
$(MAKE) -C deps/mpir
sudo $(MAKE) -C deps/mpir install
mpir: local/lib/libmpirxx.so
local/lib/libmpirxx.so: deps/mpir/Makefile
cd deps/mpir; \
./configure --enable-cxx --prefix=$(CURDIR)/local
$(MAKE) -C deps/mpir install
deps/libOTe/libOTe:
git submodule update --init --recursive deps/libOTe || git clone --recurse-submodules https://github.com/mkskeller/softspoken-implementation deps/libOTe
boost: deps/libOTe/libOTe
@@ -369,26 +348,16 @@ cmake:
./bootstrap --parallel=8 --prefix=../local && make && make install
mac-setup: mac-machine-setup
brew install openssl boost libsodium mpir yasm ntl cmake
-echo MY_CFLAGS += -I/usr/local/opt/openssl/include -I`brew --prefix`/opt/openssl/include -I`brew --prefix`/include >> CONFIG.mine
-echo MY_LDLIBS += -L/usr/local/opt/openssl/lib -L`brew --prefix`/lib -L`brew --prefix`/opt/openssl/lib >> CONFIG.mine
# -echo USE_NTL = 1 >> CONFIG.mine
brew install openssl boost libsodium gmp yasm ntl cmake
ifeq ($(ARM), 1)
mac-machine-setup:
-echo ARCH = >> CONFIG.mine
linux-machine-setup:
-echo ARCH = -march=armv8.2-a+crypto >> CONFIG.mine
else
mac-machine-setup:
linux-machine-setup:
endif
deps/simde/simde:
git submodule update --init deps/simde || git clone https://github.com/simd-everywhere/simde deps/simde
clean-deps:
-rm -rf local deps/libOTe/out
-rm -rf local/lib/liblibOTe.* deps/libOTe/out
clean: clean-deps
-rm -f */*.o *.o */*.d *.d *.x core.* *.a gmon.out */*/*.o static/*.x *.so

View File

@@ -17,6 +17,10 @@ using namespace std;
#include "ValueInterface.h"
#include "gf2nlong.h"
// Fix false warning
#if __GNUC__ == 10
#pragma GCC diagnostic ignored "-Wstringop-overflow"
#endif
// Functionality shared between integers and bit vectors
template<class T>
@@ -39,6 +43,8 @@ public:
static bool allows(Dtype type) { return type <= DATA_BIT; }
static void check_setup(const string&) {}
IntBase() { a = 0; }
IntBase(T a) : a(a) {}

View File

@@ -160,13 +160,13 @@ void check_setup(string dir, bigint pr)
}
string get_prep_sub_dir(const string& prep_dir, int nparties, int log2mod,
const string& type_short)
const string& type_short, bool create)
{
string res = prep_dir + "/" + to_string(nparties) + "-" + type_short;
if (log2mod > 1)
res += "-" + to_string(log2mod);
res += "/";
if (mkdir_p(res.c_str()) < 0)
if (create and mkdir_p(res.c_str()) < 0)
throw file_error("cannot create " + res);
return res;
}

View File

@@ -38,26 +38,28 @@ bigint generate_prime(int lgp, int m);
int default_m(int& lgp, int& idx);
string get_prep_sub_dir(const string& prep_dir, int nparties, int log2mod,
const string& type_short);
const string& type_short, bool create = false);
template<class T>
string get_prep_sub_dir(const string& prep_dir, int nparties, int log2mod)
string get_prep_sub_dir(const string& prep_dir, int nparties, int log2mod,
bool create = false)
{
if (T::clear::length() > 1)
log2mod = T::clear::length();
return get_prep_sub_dir(prep_dir, nparties, log2mod, T::type_short());
return get_prep_sub_dir(prep_dir, nparties, log2mod, T::type_short(), create);
}
template<class T>
string get_prep_sub_dir(const string& prep_dir, int nparties)
string get_prep_sub_dir(const string& prep_dir, int nparties, bool create =
false)
{
return get_prep_sub_dir<T>(prep_dir, nparties, T::clear::length());
return get_prep_sub_dir<T>(prep_dir, nparties, T::clear::length(), create);
}
template<class T>
string get_prep_sub_dir(int nparties)
string get_prep_sub_dir(int nparties, bool create = false)
{
return get_prep_sub_dir<T>(PREP_DIR, nparties);
return get_prep_sub_dir<T>(PREP_DIR, nparties, create);
}
template<class T>

18
Math/ValueInterface.cpp Normal file
View File

@@ -0,0 +1,18 @@
/*
* ValueInterface.cpp
*
*/
#include "bigint.h"
#include "ValueInterface.h"
#include <sys/stat.h>
void ValueInterface::check_setup(const string& directory)
{
struct stat sb;
if (stat(directory.c_str(), &sb) != 0)
throw runtime_error(directory + " does not exist");
if (not (sb.st_mode & S_IFDIR))
throw runtime_error(directory + " is not a directory");
}

View File

@@ -7,6 +7,7 @@
#define MATH_VALUEINTERFACE_H_
#include "Tools/Exceptions.h"
#include "Math/Setup.h"
class OnlineOptions;
class bigint;
@@ -31,9 +32,10 @@ public:
template<class T>
static void generate_setup(string, int, int) {}
template<class T>
static void write_setup(int) {}
static void write_setup(int nplayers) { get_prep_sub_dir<T>(nplayers, true); }
static void write_setup(string) {}
static void check_setup(string) {}
static void check_setup(const string& directory);
static const char* fake_opts() { return ""; }
static bigint pr() { throw runtime_error("no prime modulus"); }

View File

@@ -6,7 +6,7 @@
#ifndef MATH_Z2K_H_
#define MATH_Z2K_H_
#include <mpirxx.h>
#include <gmpxx.h>
#include <string>
using namespace std;
@@ -74,6 +74,8 @@ public:
static Z2 power_of_two(bool bit, int exp) { return Z2(bit) << exp; }
static string fake_opts() { return " -lgp " + to_string(K); }
typedef Z2 next;
typedef Z2 Scalar;

View File

@@ -53,7 +53,7 @@ void Zp_Data::init(const bigint& p,bool mont)
mpn_copyi(R3,r3.get_mpz_t()->_mp_d,mpz_size(r3.get_mpz_t()));
if (sizeof(unsigned long)!=sizeof(mp_limb_t))
{ cout << "The underlying types of MPIR mean we cannot use our Montgomery code" << endl;
{ cout << "The underlying types of GMP mean we cannot use our Montgomery code" << endl;
throw not_implemented();
}
}
@@ -194,3 +194,37 @@ bool Zp_Data::operator==(const Zp_Data& other) const
{
return not (*this != other);
}
void Zp_Data::get_shanks_parameters(bigint& y, bigint& q_half, int& r) const
{
if (shanks_y == 0)
{
auto& p = pr;
bigint n, q, yy, xx, temp;
// Find n such that (n/p)=-1
int leg = 1;
gmp_randclass Gen(gmp_randinit_default);
Gen.seed(0);
while (leg != -1)
{
n = Gen.get_z_range(p);
leg = mpz_legendre(n.get_mpz_t(), p.get_mpz_t());
}
// Split p-1 = 2^e q
q = p - 1;
int e = 0;
while (mpz_even_p(q.get_mpz_t()))
{
e++;
q = q / 2;
}
// y=n^q mod p, x=a^((q-1)/2) mod p, r=e
shanks_r = e;
mpz_powm(shanks_y.get_mpz_t(), n.get_mpz_t(), q.get_mpz_t(), p.get_mpz_t());
shanks_q_half = (q - 1) / 2;
}
y = shanks_y;
q_half = shanks_q_half;
r = shanks_r;
}

View File

@@ -38,6 +38,8 @@ class Zp_Data
int t; // More Montgomery data
mp_limb_t overhang;
Lock lock;
mutable bigint shanks_y, shanks_q_half;
mutable int shanks_r;
template <int T>
void Mont_Mult_(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const;
@@ -89,6 +91,8 @@ class Zp_Data
bool operator!=(const Zp_Data& other) const;
bool operator==(const Zp_Data& other) const;
void get_shanks_parameters(bigint& y, bigint& q_half, int& r) const;
template<int L> friend void to_modp(modp_<L>& ans,int x,const Zp_Data& ZpD);
template<int L> friend void to_modp(modp_<L>& ans,const mpz_class& x,const Zp_Data& ZpD);

View File

@@ -10,76 +10,10 @@
#include "bigint.hpp"
class gmp_random
{
public:
gmp_randclass Gen;
gmp_random() : Gen(gmp_randinit_default)
{
Gen.seed(0);
}
};
thread_local bigint bigint::tmp = 0;
thread_local bigint bigint::tmp2 = 0;
thread_local gmp_random bigint::random;
bigint sqrRootMod(const bigint& a,const bigint& p)
{
bigint ans;
if (a==0) { ans=0; return ans; }
if (mpz_legendre(a.get_mpz_t(), p.get_mpz_t()) != 1)
throw runtime_error("cannot compute square root of non-square");
if (mpz_tstbit(p.get_mpz_t(),1)==1)
{ // First do case with p=3 mod 4
bigint exp=(p+1)/4;
mpz_powm(ans.get_mpz_t(),a.get_mpz_t(),exp.get_mpz_t(),p.get_mpz_t());
}
else
{ // Shanks algorithm
bigint x,y,n,q,t,b,temp;
// Find n such that (n/p)=-1
int leg=1;
while (leg!=-1)
{ n=bigint::random.Gen.get_z_range(p);
leg=mpz_legendre(n.get_mpz_t(),p.get_mpz_t());
}
// Split p-1 = 2^e q
q=p-1;
int e=0;
while (mpz_even_p(q.get_mpz_t()))
{ e++; q=q/2; }
// y=n^q mod p, x=a^((q-1)/2) mod p, r=e
int r=e;
mpz_powm(y.get_mpz_t(),n.get_mpz_t(),q.get_mpz_t(),p.get_mpz_t());
temp=(q-1)/2;
mpz_powm(x.get_mpz_t(),a.get_mpz_t(),temp.get_mpz_t(),p.get_mpz_t());
// b=a*x^2 mod p, x=a*x mod p
b=(a*x*x)%p;
x=(a*x)%p;
// While b!=1 do
while (b!=1)
{ // Find smallest m such that b^(2^m)=1 mod p
int m=1;
temp=(b*b)%p;
while (temp!=1)
{ temp=(temp*temp)%p; m++; }
// t=y^(2^(r-m-1)) mod p, y=t^2, r=m
t=y;
for (int i=0; i<r-m-1; i++)
{ t=(t*t)%p; }
y=(t*t)%p;
r=m;
// x=x*t mod p, b=b*y mod p
x=(x*t)%p;
b=(b*y)%p;
}
ans=x;
}
return ans;
}
bigint powerMod(const bigint& x,const bigint& e,const bigint& p)
{

View File

@@ -5,7 +5,7 @@
using namespace std;
#include <stddef.h>
#include <mpirxx.h>
#include <gmpxx.h>
#include "Tools/Exceptions.h"
#include "Tools/int.h"
@@ -39,7 +39,7 @@ namespace GC
/**
* Type for arbitrarily large integers.
* This is a sub-class of ``mpz_class`` from MPIR. As such, it implements
* This is a sub-class of ``mpz_class`` from GMP. As such, it implements
* all integers operations and input/output via C++ streams. In addition,
* the ``get_ui()`` member function allows retrieving the least significant
* 64 bits.
@@ -139,8 +139,6 @@ public:
void inline_mpn_zero(mp_limb_t* x, mp_size_t size);
void inline_mpn_copyi(mp_limb_t* dest, const mp_limb_t* src, mp_size_t size);
#include "Z2k.h"
inline bigint& bigint::operator=(int n)
{
@@ -281,11 +279,7 @@ inline int numBytes(const bigint& m)
inline int probPrime(const bigint& x)
{
gmp_randstate_t rand_state;
gmp_randinit_default(rand_state);
int ans = mpz_probable_prime_p(x.get_mpz_t(), rand_state,
max(40, DEFAULT_SECURITY), 0);
gmp_randclear(rand_state);
int ans = mpz_probab_prime_p(x.get_mpz_t(), max(40, DEFAULT_SECURITY) / 2);
return ans;
}
@@ -318,7 +312,8 @@ inline int isOdd(const bigint& x)
}
bigint sqrRootMod(const bigint& x,const bigint& p);
template<class T>
bigint sqrRootMod(const T& x);
bigint powerMod(const bigint& x,const bigint& e,const bigint& p);

View File

@@ -26,7 +26,7 @@ bigint& bigint::from_signed(const T& other)
template<class T>
mpf_class bigint::get_float(T v, T p, T z, T s)
{
// MPIR can't handle more precision in exponent
// GMP can't handle more precision in exponent
Integer exp = Integer(p, 31).get();
bigint tmp;
tmp.from_signed(v);
@@ -59,4 +59,76 @@ void bigint::output_float(U& o, const mpf_class& x, T nan)
o << "NaN";
}
class gmp_random
{
public:
gmp_randclass Gen;
gmp_random() : Gen(gmp_randinit_default)
{
Gen.seed(0);
}
};
template<class T>
bigint sqrRootMod(const T& aa)
{
bigint a = aa;
bigint p = T::pr();
bigint ans;
if (a == 0)
{
ans = 0;
return ans;
}
if (mpz_legendre(a.get_mpz_t(), p.get_mpz_t()) != 1)
throw runtime_error("cannot compute square root of non-square");
if (mpz_tstbit(p.get_mpz_t(), 1) == 1)
{
// First do case with p=3 mod 4
bigint exp = (p + 1) / 4;
mpz_powm(ans.get_mpz_t(), a.get_mpz_t(), exp.get_mpz_t(),
p.get_mpz_t());
}
else
{
// Shanks algorithm
bigint n, q, yy, xx, temp;
int r;
T::get_ZpD().get_shanks_parameters(yy, temp, r);
mpz_powm(xx.get_mpz_t(), a.get_mpz_t(), temp.get_mpz_t(), p.get_mpz_t());
// b=a*x^2 mod p, x=a*x mod p
T x = xx;
T b = (aa * x * x);
x = (aa * x);
T y = yy;
// While b!=1 do
while (b != 1)
{
// Find smallest m such that b^(2^m)=1 mod p
int m = 1;
T temp = (b * b);
while (temp != 1)
{
temp = (temp * temp);
m++;
}
// t=y^(2^(r-m-1)) mod p, y=t^2, r=m
T t = y;
for (int i = 0; i < r - m - 1; i++)
{
t = (t * t);
}
y = (t * t);
r = m;
// x=x*t mod p, b=b*y mod p
x = (x * t);
b = (b * y);
}
ans = x;
}
return ans;
}
#endif /* MATH_BIGINT_HPP_ */

View File

@@ -17,6 +17,7 @@ class gf2n_short;
class P2Data;
class Bit;
class int128;
template<class T> class IntBase;
template<class T> class Square;
typedef Square<gf2n_short> gf2n_short_square;
@@ -88,6 +89,8 @@ protected:
static string options();
static string fake_opts() { return " -lg2 " + to_string(length()); }
static const true_type invertible;
static const true_type characteristic_two;

View File

@@ -154,6 +154,8 @@ class gf2n_long : public gf2n_<int128>
gf2n_long(int g) : gf2n_long(int128(unsigned(g))) {}
template<class T>
gf2n_long(IntBase<T> g) : super(g.get()) {}
template<class T>
gf2n_long(const gf2n_<T>& a) : super(int128(a.get())) {}
};
#if defined(__aarch64__) && defined(__clang__)

View File

@@ -105,6 +105,7 @@ class gfp_ : public ValueInterface
static void write_setup(string dir)
{ write_online_setup(dir, pr()); }
static void check_setup(string dir);
static string fake_opts() { return " -lgp " + to_string(length()); }
/**
* Get the prime modulus
@@ -314,6 +315,8 @@ gfp_<X, L>::gfp_(long x)
{
if (x == 0)
assign_zero();
else if (x == 1)
assign_one();
else
*this = bigint::tmp = x;
}

View File

@@ -146,8 +146,7 @@ gfp_<X, L> gfp_<X, L>::sqrRoot()
{
// Temp move to bigint so as to call sqrRootMod
bigint ti;
to_bigint(ti, *this);
ti = sqrRootMod(ti, ZpD.pr);
ti = sqrRootMod(*this);
if (!isOdd(ti))
ti = ZpD.pr - ti;
gfp_<X, L> temp;

View File

@@ -312,8 +312,8 @@ gfpvar_<X, L> gfpvar_<X, L>::invert() const
template<int X, int L>
gfpvar_<X, L> gfpvar_<X, L>::sqrRoot() const
{
bigint ti = *this;
ti = sqrRootMod(ti, ZpD.pr);
bigint ti;
ti = sqrRootMod(*this);
if (!isOdd(ti))
ti = ZpD.pr - ti;
return ti;

View File

@@ -81,6 +81,7 @@ public:
{
write_setup(get_prep_sub_dir<T>(nplayers));
}
static string fake_opts() { return " -lgp " + to_string(length()); }
gfpvar_();
gfpvar_(int other);

View File

@@ -2,7 +2,7 @@
#define _Modp
/*
* Currently we only support an MPIR based implementation.
* Currently we only support an GMP based implementation.
*
* What ever is type-def'd to bigint is assumed to have
* operator overloading for all standard operators, has

View File

@@ -6,7 +6,7 @@
#ifndef MATH_MPN_FIXED_H_
#define MATH_MPN_FIXED_H_
#include <mpir.h>
#include <gmp.h>
#include <string.h>
#include <assert.h>

View File

@@ -3,7 +3,7 @@
*
*/
#include <mpirxx.h>
#include <gmpxx.h>
#include "OT/BitMatrix.h"
#include "Tools/random.h"

View File

@@ -78,7 +78,7 @@ void Names::init(int player, int pnb, const string& filename, int nplayers_wante
}
}
if (nplayers_wanted > 0 and nplayers_wanted != nplayers)
throw runtime_error("not enought hosts in HOSTS");
throw runtime_error("not enough hosts in " + filename);
#ifdef DEBUG_NETWORKING
cerr << "Got list of " << nplayers << " players from file: " << endl;
for (unsigned int i = 0; i < names.size(); i++)
@@ -324,7 +324,9 @@ void PlainPlayer::setup_sockets(const vector<string>& names,
template<class T>
void MultiPlayer<T>::send_long(int i, long a) const
{
TimeScope ts(comm_stats["Sending by number"].add(sizeof(long)));
send(sockets[i], (octet*)&a, sizeof(long));
sent += sizeof(long);
}
template<class T>
@@ -716,7 +718,7 @@ size_t VirtualTwoPartyPlayer::send(const PlayerBuffer& buffer, bool block) const
{
auto sent = P.send_no_stats(other_player, buffer, block);
lock.lock();
comm_stats["Sending one-to-one"].add(sent);
comm_stats.add_to_last_round("Sending one-to-one", sent);
comm_stats.sent += sent;
lock.unlock();
return sent;
@@ -726,7 +728,7 @@ size_t VirtualTwoPartyPlayer::recv(const PlayerBuffer& buffer, bool block) const
{
auto received = P.recv_no_stats(other_player, buffer, block);
lock.lock();
comm_stats["Receiving one-to-one"].add(received);
comm_stats.add_to_last_round("Receiving one-to-one", received);
lock.unlock();
return received;
}
@@ -805,6 +807,17 @@ void NamedCommStats::reset()
sent = 0;
}
Timer& NamedCommStats::add_to_last_round(const string& name, size_t length)
{
if (name == last)
return (*this)[name].add_length_only(length);
else
{
last = name;
return (*this)[name].add(length);
}
}
void PlayerBase::reset_stats()
{
comm_stats.reset();

View File

@@ -136,11 +136,15 @@ struct CommStats
CommStats() : data(0), rounds(0) {}
Timer& add(size_t length)
{
rounds++;
return add_length_only(length);
}
Timer& add_length_only(size_t length)
{
#ifdef VERBOSE_COMM
cout << "add " << length << endl;
#endif
data += length;
rounds++;
return timer;
}
Timer& add(const octetStream& os) { return add(os.get_length()); }
@@ -153,6 +157,7 @@ class NamedCommStats : public map<string, CommStats>
{
public:
size_t sent;
string last;
NamedCommStats();
@@ -161,6 +166,7 @@ public:
NamedCommStats operator-(const NamedCommStats& other) const;
void print(bool newline = false);
void reset();
Timer& add_to_last_round(const string& name, size_t length);
#ifdef VERBOSE_COMM
CommStats& operator[](const string& name)
{

View File

@@ -134,7 +134,7 @@ void close_client_socket(int socket)
if (close(socket))
{
char tmp[1000];
sprintf(tmp, "close(%d)", socket);
snprintf(tmp, 1000, "close(%d)", socket);
error(tmp);
}
}

View File

@@ -126,7 +126,7 @@ void BaseMachine::time()
void BaseMachine::start(int n)
{
cout << "Starting timer " << n << " at " << timer[n].elapsed()
<< " (" << timer[n].mb_sent() << " MB)"
<< " (" << timer[n] << ")"
<< " after " << timer[n].idle() << endl;
timer[n].start(total_comm());
}
@@ -135,7 +135,7 @@ void BaseMachine::stop(int n)
{
timer[n].stop(total_comm());
cout << "Stopped timer " << n << " at " << timer[n].elapsed() << " ("
<< timer[n].mb_sent() << " MB)" << endl;
<< timer[n] << ")" << endl;
}
void BaseMachine::print_timers()
@@ -150,7 +150,7 @@ void BaseMachine::print_timers()
timer.erase(0);
for (auto it = timer.begin(); it != timer.end(); it++)
cerr << "Time" << it->first << " = " << it->second.elapsed() << " seconds ("
<< it->second.mb_sent() << " MB)" << endl;
<< it->second << ")" << endl;
}
string BaseMachine::memory_filename(const string& type_short, int my_number)
@@ -227,3 +227,19 @@ void BaseMachine::print_global_comm(Player& P, const NamedCommStats& stats)
global += os.get_int(8);
cerr << "Global data sent = " << global / 1e6 << " MB (all parties)" << endl;
}
void BaseMachine::print_comm(Player& P, const NamedCommStats& comm_stats)
{
size_t rounds = 0;
for (auto& x : comm_stats)
rounds += x.second.rounds;
cerr << "Data sent = " << comm_stats.sent / 1e6 << " MB in ~" << rounds
<< " rounds (party " << P.my_num() << " only";
if (nthreads > 1)
cerr << "; rounds counted double due to multi-threading";
if (not OnlineOptions::singleton.verbose)
cerr << "; use '-v' for more details";
cerr << ")" << endl;
print_global_comm(P, comm_stats);
}

View File

@@ -67,6 +67,7 @@ public:
void print_timers();
virtual void reqbl(int) {}
virtual void active(int) {}
static OTTripleSetup fresh_ot_setup(Player& P);
@@ -74,6 +75,7 @@ public:
void set_thread_comm(const NamedCommStats& stats);
void print_global_comm(Player& P, const NamedCommStats& stats);
void print_comm(Player& P, const NamedCommStats& stats);
};
inline OTTripleSetup BaseMachine::fresh_ot_setup(Player& P)

41
Processor/Conv2dTuple.h Normal file
View File

@@ -0,0 +1,41 @@
/*
* Conv2dTuple.h
*
*/
#ifndef PROCESSOR_CONV2DTUPLE_H_
#define PROCESSOR_CONV2DTUPLE_H_
#include <vector>
using namespace std;
class Conv2dTuple
{
public:
int output_h, output_w;
int inputs_h, inputs_w;
int weights_h, weights_w;
int stride_h, stride_w;
int n_channels_in;
int padding_h;
int padding_w;
int batch_size;
size_t r0;
size_t r1;
int r2;
vector<vector<vector<int>>> lengths;
int filter_stride_h = 1;
int filter_stride_w = 1;
Conv2dTuple(const vector<int>& args, int start);
template<class T>
void pre(vector<T>& S, typename T::Protocol& protocol);
template<class T>
void post(vector<T>& S, typename T::Protocol& protocol);
template<class T>
void run_matrix(SubProcessor<T>& processor);
};
#endif /* PROCESSOR_CONV2DTUPLE_H_ */

View File

@@ -222,7 +222,8 @@ bool DataPositions::any_more(const DataPositions& other) const
for (auto it = edabits.begin(); it != edabits.end(); it++)
{
auto x = other.edabits.find(it->first);
if (x == other.edabits.end() or it->second > x->second)
if ((x == other.edabits.end() or it->second > x->second)
and it->second > 0)
return true;
}

View File

@@ -12,6 +12,8 @@
#include "Networking/Player.h"
#include "Protocols/edabit.h"
#include "PrepBase.h"
#include "EdabitBuffer.h"
#include "Tools/TimerWithComm.h"
#include <fstream>
#include <map>
@@ -102,9 +104,6 @@ protected:
DataPositions& usage;
map<pair<bool, int>, vector<edabitvec<T>>> edabits;
map<pair<bool, int>, edabitvec<T>> my_edabits;
bool do_count;
void count(Dtype dtype, int n = 1)
@@ -120,6 +119,8 @@ protected:
const vector<int>&, true_type)
{ throw not_implemented(); }
void fill(edabitvec<T>& res, bool strict, int n_bits);
T get_random_from_inputs(int nplayers);
public:
@@ -173,12 +174,11 @@ public:
virtual void get_edabits(bool strict, size_t size, T* a,
vector<typename T::bit_type>& Sb, const vector<int>& regs)
{ get_edabits<0>(strict, size, a, Sb, regs, T::clear::characteristic_two); }
template<int>
void get_edabit_no_count(bool, int n_bits, edabit<T>& eb);
template<int>
virtual void get_edabit_no_count(bool, int, edabit<T>&)
{ throw runtime_error("no edaBits"); }
/// Get fresh edaBit chunk
edabitvec<T> get_edabitvec(bool strict, int n_bits);
virtual void buffer_edabits_with_queues(bool, int) { throw runtime_error("no edaBits"); }
virtual edabitvec<T> get_edabitvec(bool, int)
{ throw runtime_error("no edabitvec"); }
virtual void push_triples(const vector<array<T, 3>>&)
{ throw runtime_error("no pushing"); }
@@ -204,7 +204,8 @@ class Sub_Data_Files : public Preprocessing<T>
BufferOwner<InputTuple<T>, RefInputTuple<T>> my_input_buffers;
map<DataTag, BufferOwner<T, T> > extended;
BufferOwner<dabit<T>, dabit<T>> dabit_buffer;
map<int, ifstream*> edabit_buffers;
map<int, EdabitBuffer<T>> edabit_buffers;
map<int, edabitvec<T>> my_edabits;
int my_num,num_players;
@@ -213,13 +214,11 @@ class Sub_Data_Files : public Preprocessing<T>
part_type* part;
void buffer_edabits_with_queues(bool strict, int n_bits)
{ buffer_edabits_with_queues<0>(strict, n_bits, T::clear::characteristic_two); }
template<int>
void buffer_edabits_with_queues(bool strict, int n_bits, false_type);
template<int>
void buffer_edabits_with_queues(bool, int, true_type)
{ throw not_implemented(); }
EdabitBuffer<T>& get_edabit_buffer(int n_bits);
/// Get fresh edaBit chunk
edabitvec<T> get_edabitvec(bool strict, int n_bits);
void get_edabit_no_count(bool strict, int n_bits, edabit<T>& eb);
public:
static string get_filename(const Names& N, Dtype type, int thread_num = -1);
@@ -317,6 +316,8 @@ class Data_Files
void reset_usage() { usage.reset(); skipped.reset(); }
void set_usage(const DataPositions& pos) { usage = pos; }
TimerWithComm total_time();
};
template<class T> inline

View File

@@ -108,7 +108,21 @@ Sub_Data_Files<T>::Sub_Data_Files(int my_num, int num_players,
#ifdef DEBUG_FILES
cerr << "Setting up Data_Files in: " << prep_data_dir << endl;
#endif
T::clear::check_setup(prep_data_dir);
try
{
T::clear::check_setup(prep_data_dir);
}
catch (...)
{
cerr << "Something is wrong with the preprocessing data on disk." << endl;
cerr
<< "Have you run the right program for generating it, such as './Fake-Offline.x "
<< num_players
<< T::clear::fake_opts() << "'?" << endl;
throw;
}
string type_short = T::type_short();
string type_string = T::type_string();
@@ -135,7 +149,7 @@ Sub_Data_Files<T>::Sub_Data_Files(int my_num, int num_players,
type_short, i, my_num, thread_num);
if (i == my_num)
my_input_buffers.setup(filename,
T::size() + T::clear::size(), type_string);
InputTuple<T>::size(), type_string);
else
input_buffers[i].setup(filename,
T::size(), type_string);
@@ -179,10 +193,6 @@ Data_Files<sint, sgf2n>::~Data_Files()
template<class T>
Sub_Data_Files<T>::~Sub_Data_Files()
{
for (auto& x: edabit_buffers)
{
delete x.second;
}
if (part != 0)
delete part;
}
@@ -229,6 +239,26 @@ void Sub_Data_Files<T>::seekg(DataPositions& pos)
extended[it->first].seekg(it->second);
}
dabit_buffer.seekg(pos.files[field_type][DATA_DABIT]);
if (field_type == DATA_INT)
{
for (auto& x : pos.edabits)
{
// open files
get_edabit_buffer(x.first.second);
}
int block_size = edabitvec<T>::MAX_SIZE;
for (auto& x : edabit_buffers)
{
int n = pos.edabits[{true, x.first}] + pos.edabits[{false, x.first}];
x.second.seekg(n / block_size);
edabit<T> eb;
for (int i = 0; i < n % block_size; i++)
get_edabit_no_count(false, x.first, eb);
}
}
}
template<class sint, class sgf2n>
@@ -262,6 +292,8 @@ void Sub_Data_Files<T>::prune()
dabit_buffer.prune();
if (part != 0)
part->prune();
for (auto& x : edabit_buffers)
x.second.prune();
}
template<class sint, class sgf2n>
@@ -285,6 +317,8 @@ void Sub_Data_Files<T>::purge()
dabit_buffer.purge();
if (part != 0)
part->purge();
for (auto& x : edabit_buffers)
x.second.prune();
}
template<class T>
@@ -322,34 +356,43 @@ void Sub_Data_Files<T>::get_dabit_no_count(T& a, typename T::bit_type& b)
}
template<class T>
template<int>
void Sub_Data_Files<T>::buffer_edabits_with_queues(bool strict, int n_bits,
false_type)
EdabitBuffer<T>& Sub_Data_Files<T>::get_edabit_buffer(int n_bits)
{
if (edabit_buffers.empty())
insecure("reading edaBits from files");
if (edabit_buffers.find(n_bits) == edabit_buffers.end())
{
string filename = PrepBase::get_edabit_filename(prep_data_dir,
n_bits, my_num, thread_num);
ifstream* f = new ifstream(filename);
if (f->fail())
throw runtime_error("cannot open " + filename);
check_file_signature<T>(*f, filename);
edabit_buffers[n_bits] = f;
edabit_buffers[n_bits] = n_bits;
edabit_buffers[n_bits].setup(filename,
T::size() * edabitvec<T>::MAX_SIZE
+ n_bits * T::bit_type::part_type::size());
}
auto& buffer = *edabit_buffers[n_bits];
if (buffer.peek() == EOF)
return edabit_buffers[n_bits];
}
template<class T>
edabitvec<T> Sub_Data_Files<T>::get_edabitvec(bool strict, int n_bits)
{
if (my_edabits[n_bits].empty())
return get_edabit_buffer(n_bits).read();
else
{
buffer.seekg(0);
check_file_signature<T>(buffer, "");
auto res = my_edabits[n_bits];
my_edabits[n_bits] = {};
this->fill(res, strict, n_bits);
return res;
}
}
template<class T>
void Preprocessing<T>::fill(edabitvec<T>& res, bool strict, int n_bits)
{
edabit<T> eb;
while (res.size() < res.MAX_SIZE)
{
get_edabit_no_count(strict, n_bits, eb);
res.push_back(eb);
}
edabitvec<T> eb;
eb.input(n_bits, buffer);
this->edabits[{strict, n_bits}].push_back(eb);
if (buffer.fail())
throw runtime_error("error reading edaBits");
}
template<class T>
@@ -362,4 +405,10 @@ typename Sub_Data_Files<T>::part_type& Sub_Data_Files<T>::get_part()
return *part;
}
template<class sint, class sgf2n>
TimerWithComm Data_Files<sint, sgf2n>::total_time()
{
return DataFp.prep_timer + DataF2.prep_timer + DataFb.prep_timer;
}
#endif

50
Processor/EdabitBuffer.h Normal file
View File

@@ -0,0 +1,50 @@
/*
* EdabitBuffer.h
*
*/
#ifndef PROCESSOR_EDABITBUFFER_H_
#define PROCESSOR_EDABITBUFFER_H_
#include "Tools/Buffer.h"
template<class T>
class EdabitBuffer : public BufferOwner<T, T>
{
int n_bits;
int element_length()
{
return -1;
}
public:
EdabitBuffer(int n_bits = 0) :
n_bits(n_bits)
{
}
edabitvec<T> read()
{
if (not BufferBase::file)
{
if (this->open()->fail())
throw runtime_error("error opening " + this->filename);
}
assert(BufferBase::file);
auto& buffer = *BufferBase::file;
if (buffer.peek() == EOF)
{
this->try_rewind();
}
edabitvec<T> eb;
eb.input(n_bits, buffer);
if (buffer.fail())
throw runtime_error("error reading edaBits");
return eb;
}
};
#endif /* PROCESSOR_EDABITBUFFER_H_ */

View File

@@ -70,6 +70,7 @@ enum
PLAYERID = 0xE4,
USE_EDABIT = 0xE5,
USE_MATMUL = 0x1F,
ACTIVE = 0xE9,
// Addition
ADDC = 0x20,
ADDS = 0x21,

View File

@@ -311,6 +311,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case PRIVATEOUTPUT:
case TRUNC_PR:
case RUN_TAPE:
case CONV2DS:
num_var_args = get_int(s);
get_vector(num_var_args, start, s);
break;
@@ -322,10 +323,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
get_ints(r, s, 3);
get_vector(9, start, s);
break;
case CONV2DS:
get_ints(r, s, 3);
get_vector(12, start, s);
break;
// read from file, input is opcode num_args,
// start_file_posn (read), end_file_posn(write) var1, var2, ...
@@ -425,6 +422,10 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
throw Processor_Error(ss.str());
}
break;
case ACTIVE:
n = get_int(s);
BaseMachine::s().active(n);
break;
case XORM:
case ANDM:
case XORCB:
@@ -720,7 +721,16 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
case MATMULSM:
return r[0] + start[0] * start[2];
case CONV2DS:
return r[0] + start[0] * start[1] * start[11];
{
unsigned res = 0;
for (size_t i = 0; i < start.size(); i += 15)
{
unsigned tmp = start[i]
+ start[i + 3] * start[i + 4] * start.at(i + 14);
res = max(res, tmp);
}
return res;
}
case OPEN:
skip = 2;
break;
@@ -1164,6 +1174,7 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
break;
case REQBL:
case GREQBL:
case ACTIVE:
case USE:
case USE_INP:
case USE_EDABIT:

View File

@@ -109,6 +109,7 @@ class Machine : public BaseMachine
string prep_dir_prefix();
void reqbl(int n);
void active(int n);
typename sint::bit_type::mac_key_type get_bit_mac_key() { return alphabi; }
typename sint::mac_key_type get_sint_mac_key() { return alphapi; }

View File

@@ -415,6 +415,9 @@ pair<DataPositions, NamedCommStats> Machine<sint, sgf2n>::stop_threads()
auto comm_stats = total_comm();
if (OnlineOptions::singleton.verbose)
queues.print_breakdown();
for (auto& queue : queues)
delete queue;
@@ -477,20 +480,7 @@ void Machine<sint, sgf2n>::run(const string& progname)
print_timers();
if (sint::is_real)
{
size_t rounds = 0;
for (auto& x : comm_stats)
rounds += x.second.rounds;
cerr << "Data sent = " << comm_stats.sent / 1e6 << " MB in ~" << rounds
<< " rounds (party " << my_number;
if (threads.size() > 1)
cerr << "; rounds counted double due to multi-threading";
cerr << "; use '-v' for more details";
cerr << ")" << endl;
auto& P = *this->P;
this->print_global_comm(P, comm_stats);
}
this->print_comm(*this->P, comm_stats);
#ifdef VERBOSE_OPTIONS
if (opening_sum < N.num_players() && !direct)
@@ -521,23 +511,6 @@ void Machine<sint, sgf2n>::run(const string& progname)
bit_memories.write_memory(N.my_num());
#ifdef OLD_USAGE
for (int dtype = 0; dtype < N_DTYPE; dtype++)
{
cerr << "Num " << DataPositions::dtype_names[dtype] << "\t=";
for (int field_type = 0; field_type < N_DATA_FIELD_TYPE; field_type++)
cerr << " " << pos.files[field_type][dtype];
cerr << endl;
}
for (int field_type = 0; field_type < N_DATA_FIELD_TYPE; field_type++)
{
cerr << "Num " << DataPositions::field_names[field_type] << " Inputs\t=";
for (int i = 0; i < N.num_players(); i++)
cerr << " " << pos.inputs[i][field_type];
cerr << endl;
}
#endif
if (opts.verbose)
{
cerr << "Actual cost of program:" << endl;
@@ -586,6 +559,17 @@ void Machine<sint, sgf2n>::reqbl(int n)
sint::clear::reqbl(n);
}
template<class sint, class sgf2n>
void Machine<sint, sgf2n>::active(int n)
{
if (sint::malicious and n == 0)
{
cerr << "Program requires a semi-honest protocol" << endl;
exit(1);
}
}
template<class sint, class sgf2n>
void Machine<sint, sgf2n>::suggest_optimizations()
{
@@ -599,8 +583,8 @@ void Machine<sint, sgf2n>::suggest_optimizations()
optimizations.append("\tprogram.use_edabit(True)\n");
if (not optimizations.empty())
cerr << "This program might benefit from some protocol options." << endl
<< "Consider adding the following at the beginning of '" << progname
<< ".mpc':" << endl << optimizations;
<< "Consider adding the following at the beginning of your code:"
<< endl << optimizations;
#ifndef __clang__
cerr << "This virtual machine was compiled with GCC. Recompile with "
"'CXX = clang++' in 'CONFIG.mine' for optimal performance." << endl;

View File

@@ -172,7 +172,7 @@ void OfflineMachine<W>::generate()
auto& opts = OnlineOptions::singleton;
opts.batch_size = DIV_CEIL(opts.batch_size, batch) * batch;
for (int i = 0; i < buffered_total(total, batch) / batch; i++)
preprocessing.template get_edabitvec<0>(true, n_bits).output(n_bits,
preprocessing.get_edabitvec(true, n_bits).output(n_bits,
out);
}
else

View File

@@ -44,6 +44,7 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
auto& queues = machine.queues[num];
queues->next();
ThreadQueue::thread_queue = queues;
#ifdef DEBUG_THREADS
fprintf(stderr, "\tI am in thread %d\n",num);
@@ -118,6 +119,8 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
DataPositions actual_usage(P.num_players());
Timer thread_timer(CLOCK_THREAD_CPUTIME_ID), wait_timer;
thread_timer.start();
TimerWithComm timer, online_timer, online_prep_timer;
timer.start();
while (flag)
{ // Wait until I have a program to run
@@ -262,6 +265,8 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
#ifdef DEBUG_THREADS
printf("\tClient %d about to run %d\n",num,program);
#endif
online_timer.start(P.total_comm());
online_prep_timer -= Proc.DataF.total_time();
Proc.reset(progs[program], job.arg);
// Bits, Triples, Squares, and Inverses skipping
@@ -290,6 +295,8 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
printf("\tSignalling I have finished with program %d"
"in thread %d\n", program, num);
#endif
online_timer.stop(P.total_comm());
online_prep_timer += Proc.DataF.total_time();
wait_timer.start();
queues->finished(job, P.total_comm());
wait_timer.stop();
@@ -297,7 +304,11 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
}
// final check
online_timer.start(P.total_comm());
online_prep_timer -= Proc.DataF.total_time();
Proc.check();
online_timer.stop(P.total_comm());
online_prep_timer += Proc.DataF.total_time();
if (machine.opts.file_prep_per_thread)
Proc.DataF.prune();
@@ -330,6 +341,11 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
// wind down thread by thread
machine.stats += Proc.stats;
queues->timers["wait"] = wait_timer + queues->wait_timer;
timer.stop(P.total_comm());
queues->timers["online"] = online_timer - online_prep_timer - queues->wait_timer;
queues->timers["prep"] = timer - queues->timers["wait"] - queues->timers["online"];
// prevent faulty usage message
Proc.DataF.set_usage(actual_usage);
delete processor;

View File

@@ -69,7 +69,7 @@ void PrepBase::print_left_edabits(size_t n, size_t n_batch, bool strict,
cerr << " edaBits of size " << n_bits << " left" << endl;
}
if (n > used / 10)
if (n * n_batch > used / 10)
cerr << "Significant amount of unused edaBits of size " << n_bits
<< ". For more accurate benchmarks, "
<< "consider reducing the batch size with --batch-size "

View File

@@ -10,6 +10,7 @@
using namespace std;
#include "Math/field_types.h"
#include "Tools/TimerWithComm.h"
class PrepBase
{
@@ -28,6 +29,8 @@ public:
const string& type_string, size_t used);
static void print_left_edabits(size_t n, size_t n_batch, bool strict,
int n_bits, size_t used);
TimerWithComm prep_timer;
};
#endif /* PROCESSOR_PREPBASE_H_ */

View File

@@ -5,6 +5,7 @@
#include "Processor/Program.h"
#include "GC/square64.h"
#include "SpecificPrivateOutput.h"
#include "Conv2dTuple.h"
#include "Processor/ProcessorBase.hpp"
#include "GC/Processor.hpp"
@@ -31,6 +32,7 @@ SubProcessor<T>::SubProcessor(typename T::MAC_Check& MC,
DataF.set_proc(this);
protocol.init(DataF, MC);
DataF.set_protocol(protocol);
MC.set_prep(DataF);
bit_usage.set_num_players(P.num_players());
personal_bit_preps.resize(P.num_players());
for (int i = 0; i < P.num_players(); i++)
@@ -40,6 +42,7 @@ SubProcessor<T>::SubProcessor(typename T::MAC_Check& MC,
template<class T>
SubProcessor<T>::~SubProcessor()
{
DataF.set_proc(0);
for (size_t i = 0; i < personal_bit_preps.size(); i++)
{
auto& x = personal_bit_preps[i];
@@ -391,7 +394,7 @@ void Processor<sint, sgf2n>::read_shares_from_file(int start_file_posn, int end_
return;
string filename;
filename = "Persistence/Transactions-P" + to_string(P.my_num()) + ".data";
filename = binary_file_io.filename(P.my_num());
unsigned int size = data_registers.size();
@@ -652,21 +655,35 @@ void SubProcessor<T>::conv2ds(const Instruction& instruction)
{
protocol.init_dotprod();
auto& args = instruction.get_start();
int output_h = args[0], output_w = args[1];
int inputs_h = args[2], inputs_w = args[3];
int weights_h = args[4], weights_w = args[5];
int stride_h = args[6], stride_w = args[7];
int n_channels_in = args[8];
int padding_h = args[9];
int padding_w = args[10];
int batch_size = args[11];
size_t r0 = instruction.get_r(0);
size_t r1 = instruction.get_r(1);
int r2 = instruction.get_r(2);
int lengths[batch_size][output_h][output_w];
memset(lengths, 0, sizeof(lengths));
int filter_stride_h = 1;
int filter_stride_w = 1;
vector<Conv2dTuple> tuples;
for (size_t i = 0; i < args.size(); i += 15)
tuples.push_back(Conv2dTuple(args, i));
for (auto& tuple : tuples)
tuple.pre(S, protocol);
protocol.exchange();
for (auto& tuple : tuples)
tuple.post(S, protocol);
}
inline
Conv2dTuple::Conv2dTuple(const vector<int>& arguments, int start)
{
assert(arguments.size() >= start + 15ul);
auto args = arguments.data() + start + 3;
output_h = args[0], output_w = args[1];
inputs_h = args[2], inputs_w = args[3];
weights_h = args[4], weights_w = args[5];
stride_h = args[6], stride_w = args[7];
n_channels_in = args[8];
padding_h = args[9];
padding_w = args[10];
batch_size = args[11];
r0 = arguments[start];
r1 = arguments[start + 1];
r2 = arguments[start + 2];
lengths.resize(batch_size, vector<vector<int>>(output_h, vector<int>(output_w)));
filter_stride_h = 1;
filter_stride_w = 1;
if (stride_h < 0)
{
filter_stride_h = -stride_h;
@@ -677,7 +694,11 @@ void SubProcessor<T>::conv2ds(const Instruction& instruction)
filter_stride_w = -stride_w;
stride_w = 1;
}
}
template<class T>
void Conv2dTuple::pre(vector<T>& S, typename T::Protocol& protocol)
{
for (int i_batch = 0; i_batch < batch_size; i_batch ++)
{
size_t base = r1 + i_batch * inputs_w * inputs_h * n_channels_in;
@@ -714,9 +735,11 @@ void SubProcessor<T>::conv2ds(const Instruction& instruction)
protocol.next_dotprod();
}
}
}
protocol.exchange();
template<class T>
void Conv2dTuple::post(vector<T>& S, typename T::Protocol& protocol)
{
for (int i_batch = 0; i_batch < batch_size; i_batch ++)
{
size_t base = r0 + i_batch * output_h * output_w;

View File

@@ -6,6 +6,8 @@
#include "ThreadQueue.h"
thread_local ThreadQueue* ThreadQueue::thread_queue = 0;
void ThreadQueue::schedule(const ThreadJob& job)
{
lock.lock();
@@ -14,7 +16,11 @@ void ThreadQueue::schedule(const ThreadJob& job)
cerr << this << ": " << left << " left" << endl;
#endif
lock.unlock();
if (thread_queue)
thread_queue->wait_timer.start();
in.push(job);
if (thread_queue)
thread_queue->wait_timer.stop();
}
ThreadJob ThreadQueue::next()
@@ -42,7 +48,11 @@ void ThreadQueue::set_comm_stats(const NamedCommStats& new_comm_stats)
ThreadJob ThreadQueue::result()
{
if (thread_queue)
thread_queue->wait_timer.start();
auto res = out.pop();
if (thread_queue)
thread_queue->wait_timer.stop();
lock.lock();
left--;
#ifdef DEBUG_THREAD_QUEUE

View File

@@ -16,6 +16,11 @@ class ThreadQueue
NamedCommStats comm_stats;
public:
static thread_local ThreadQueue* thread_queue;
map<string, TimerWithComm> timers;
Timer wait_timer;
ThreadQueue() :
left(0)
{

View File

@@ -85,3 +85,32 @@ void ThreadQueues::wrap_up(ThreadJob job)
}
available.clear();
}
TimerWithComm ThreadQueues::sum(const string& phase)
{
TimerWithComm res;
for (auto& x : *this)
res += x->timers[phase];
return res;
}
void ThreadQueues::print_breakdown()
{
if (size() > 0)
{
if (size() == 1)
{
cerr << "Spent " << (*this)[0]->timers["online"].full()
<< " on the online phase and "
<< (*this)[0]->timers["prep"].full()
<< " on the preprocessing/offline phase." << endl;
}
else
{
cerr << size() << " threads spent a total of " << sum("online").full()
<< " on the online phase, " << sum("prep").full()
<< " on the preprocessing/offline phase, and "
<< sum("wait").full() << " idling." << endl;
}
}
}

View File

@@ -24,6 +24,10 @@ public:
int distribute_no_setup(ThreadJob job, int n_items, int base = 0,
int granularity = 1, const vector<void*>* supplies = 0);
void wrap_up(ThreadJob job);
TimerWithComm sum(const string& phase);
void print_breakdown();
};
#endif /* PROCESSOR_THREADQUEUES_H_ */

View File

@@ -387,6 +387,7 @@
X(GENSECSHUFFLE, throw not_implemented(),) \
X(APPLYSHUFFLE, throw not_implemented(),) \
X(DELSHUFFLE, throw not_implemented(),) \
X(ACTIVE, throw not_implemented(),) \
#define ALL_INSTRUCTIONS ARITHMETIC_INSTRUCTIONS REGINT_INSTRUCTIONS \
CLEAR_GF2N_INSTRUCTIONS REMAINING_INSTRUCTIONS

114
Programs/Source/alex.mpc Normal file
View File

@@ -0,0 +1,114 @@
from Compiler.ml import keras
import Compiler.ml as tf
try:
n_epochs = int(program.args[1])
except (ValueError, IndexError):
n_epochs = 20
try:
batch_size = int(program.args[2])
except (ValueError, IndexError):
batch_size = 128
try:
n_threads = int(program.args[3])
except (ValueError, IndexError):
n_threads = 36
#Instantiation
AlexNet = []
padding = 1
batchnorm = 'batchnorm' in program.args
bn1 = 'bn1' in program.args
bn2 = 'bn2' in program.args
MultiArray.disable_index_checks()
#1st Convolutional Layer
AlexNet.append(keras.layers.Conv2D(filters=64, input_shape=(32,32,3), kernel_size=3, strides=1, padding=2))
AlexNet.append(keras.layers.Activation('relu'))
if batchnorm:
AlexNet.append(keras.layers.BatchNormalization())
AlexNet.append(keras.layers.MaxPooling2D(pool_size=(2,2), strides=(2,2), padding=0))
#2nd Convolutional Layer
AlexNet.append(keras.layers.Conv2D(filters=96, kernel_size=3, strides=1, padding=2))
AlexNet.append(keras.layers.Activation('relu'))
if batchnorm or bn2:
AlexNet.append(keras.layers.BatchNormalization())
AlexNet.append(keras.layers.MaxPooling2D(pool_size=(2,2), strides=(2,2), padding='same'))
#3rd Convolutional Layer
AlexNet.append(keras.layers.Conv2D(filters=96, kernel_size=(3,3), strides=(1,1), padding=padding))
AlexNet.append(keras.layers.Activation('relu'))
if batchnorm:
AlexNet.append(keras.layers.BatchNormalization())
#4th Convolutional Layer
AlexNet.append(keras.layers.Conv2D(filters=64, kernel_size=(3,3), strides=(1,1), padding=padding))
AlexNet.append(keras.layers.Activation('relu'))
if batchnorm or bn1:
AlexNet.append(keras.layers.BatchNormalization())
#5th Convolutional Layer
AlexNet.append(keras.layers.Conv2D(filters=64, kernel_size=(3,3), strides=(1,1), padding=padding))
AlexNet.append(keras.layers.Activation('relu'))
if batchnorm or bn2:
AlexNet.append(keras.layers.BatchNormalization())
AlexNet.append(keras.layers.MaxPooling2D(pool_size=(3,3), strides=(2,2), padding=0))
#Passing it to a Fully Connected layer
# 1st Fully Connected Layer
AlexNet.append(keras.layers.Dense(128))
AlexNet.append(keras.layers.Activation('relu'))
if 'dropout' in program.args:
AlexNet.append(keras.layers.Dropout(0.5))
#2nd Fully Connected Layer
AlexNet.append(keras.layers.Dense(256))
AlexNet.append(keras.layers.Activation('relu'))
if 'dropout' in program.args:
AlexNet.append(keras.layers.Dropout(0.5))
#Output Layer
AlexNet.append(keras.layers.Dense(10))
tf.set_n_threads(n_threads)
program.options_from_args()
sfix.set_precision_from_args(program, adapt_ring=True)
training_samples = MultiArray([50000, 32, 32, 3], sfix)
training_labels = MultiArray([50000, 10], sint)
test_samples = MultiArray([10000, 32, 32, 3], sfix)
test_labels = MultiArray([10000, 10], sint)
training_labels.input_from(0)
training_samples.input_from(0, binary='binary_samples' in program.args)
test_labels.input_from(0)
test_samples.input_from(0, binary='binary_samples' in program.args)
model = tf.keras.models.Sequential(AlexNet)
model.compile_by_args(program)
model.build(training_samples.sizes)
model.summary()
model.opt.output_diff = 'output_diff' in program.args
model.opt.output_grad = 'output_grad' in program.args
model.opt.output_stats = 100 if 'output_stats' in program.args else 0
model.opt.shuffle = not 'noshuffle' in program.args
opt = model.fit(
training_samples,
training_labels,
epochs=n_epochs,
batch_size=batch_size,
validation_data=(test_samples, test_labels)
)

View File

@@ -25,6 +25,9 @@ n_threads = 2
if len(program.args) > 1:
n_rounds = int(program.args[1])
if len(program.args) > 2:
program.active = bool(int(program.args[2]))
def accept_client():
client_socket_id = accept_client_connection(PORTNUM)
last = regint.read_from_socket(client_socket_id)

View File

@@ -4,7 +4,7 @@ import Compiler.ml as tf
try:
n_epochs = int(program.args[1])
except (ValueError, IndexError):
n_epochs = 10
n_epochs = 20
try:
batch_size = int(program.args[2])

View File

@@ -0,0 +1,72 @@
# this trains LeNet on MNIST with a dropout layer
# see https://github.com/csiro-mlai/mnist-mpc for data preparation
program.options_from_args()
if 'torch' in program.args:
import torchvision
data = []
for train in True, False:
ds = torchvision.datasets.MNIST(root='/tmp', train=train, download=True)
# normalize to [0,1] before input
samples = sfix.input_tensor_via(0, ds.data / 255., binary=True)
labels = sint.input_tensor_via(0, ds.targets, binary=True, one_hot=True)
data += [(labels, samples)]
(training_labels, training_samples), (test_labels, test_samples) = data
else:
training_samples = sfix.Tensor([60000, 28, 28])
training_labels = sint.Tensor([60000, 10])
test_samples = sfix.Tensor([10000, 28, 28])
test_labels = sint.Tensor([10000, 10])
training_labels.input_from(0)
training_samples.input_from(0)
test_labels.input_from(0)
test_samples.input_from(0)
from Compiler import ml
tf = ml
layers = [
tf.keras.layers.Conv2D(20, 5, 1, 'valid', activation='relu'),
]
if 'batchnorm' in program.args:
layers += [tf.keras.layers.BatchNormalization()]
layers += [
tf.keras.layers.AveragePooling2D(2),
tf.keras.layers.Conv2D(50, 5, 1, 'valid', activation='relu'),
]
if 'batchnorm' in program.args:
layers += [tf.keras.layers.BatchNormalization()]
layers += [
tf.keras.layers.AveragePooling2D(2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(500, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
]
model = tf.keras.models.Sequential(layers)
optim = tf.keras.optimizers.Adam(amsgrad=True)
model.compile(optimizer=optim)
opt = model.fit(
training_samples,
training_labels,
epochs=10,
batch_size=128,
validation_data=(test_samples, test_labels)
)
for var in model.trainable_variables:
var.write_to_file()

View File

@@ -0,0 +1,49 @@
# this trains a dense neural network on MNIST
program.options_from_args()
import torchvision
data = []
for train in True, False:
ds = torchvision.datasets.MNIST(root='/tmp', train=train, download=True)
# normalize to [0,1] before input
samples = sfix.input_tensor_via(0, ds.data / 255., binary=True)
labels = sint.input_tensor_via(0, ds.targets, binary=True, one_hot=True)
data += [(labels, samples)]
import torch
import torch.nn as nn
net = nn.Sequential(
nn.Conv2d(1, 20, 5),
nn.ReLU(),
nn.AvgPool2d(2),
nn.Conv2d(20, 50, 5),
nn.ReLU(),
nn.AvgPool2d(2),
nn.Flatten(),
nn.ReLU(),
nn.Linear(800, 500),
nn.ReLU(),
nn.Linear(500, 10)
)
# test network
ds = torchvision.datasets.MNIST(
root='/tmp', transform=torchvision.transforms.ToTensor())
inputs = next(iter(torch.utils.data.DataLoader(ds)))[0]
print(inputs.shape)
outputs = net(inputs)
from Compiler import ml
ml.set_n_threads(int(program.args[2]))
layers = ml.layers_from_torch(net, data[0][1].shape, 128)
layers[0].X = data[0][1]
layers[-1].Y = data[0][0]
optimizer = ml.SGD(layers)
optimizer.run_by_args(program, int(program.args[1]), 128,
data[1][1], data[1][0])

View File

@@ -11,6 +11,8 @@ DealerMatrixPrep<T>::DealerMatrixPrep(int n_rows, int n_inner, int n_cols,
super(usage), n_rows(n_rows), n_inner(n_inner), n_cols(n_cols),
prep(&prep)
{
assert(prep.proc);
this->P = &prep.proc->P;
}
template<class T>

View File

@@ -20,9 +20,6 @@ class Hemi : public T::BasicProtocol
MatrixMC<T> mc;
ShareMatrix<T> matrix_multiply(const ShareMatrix<T>& A, const ShareMatrix<T>& B,
SubProcessor<T>& processor);
public:
Hemi(Player& P) :
T::BasicProtocol(P)
@@ -33,6 +30,9 @@ public:
typename T::MatrixPrep& get_matrix_prep(const array<int, 3>& dimensions,
SubProcessor<T>& processor);
ShareMatrix<T> matrix_multiply(const ShareMatrix<T>& A, const ShareMatrix<T>& B,
SubProcessor<T>& processor);
void matmulsm(SubProcessor<T>& processor, CheckVector<T>& source,
const Instruction& instruction, int a, int b);
void conv2ds(SubProcessor<T>& processor, const Instruction& instruction);

View File

@@ -130,37 +130,23 @@ void Hemi<T>::conv2ds(SubProcessor<T>& processor,
}
auto& args = instruction.get_start();
int output_h = args[0], output_w = args[1];
int inputs_h = args[2], inputs_w = args[3];
int weights_h = args[4], weights_w = args[5];
int stride_h = args[6], stride_w = args[7];
int n_channels_in = args[8];
int padding_h = args[9];
int padding_w = args[10];
int batch_size = args[11];
size_t r0 = instruction.get_r(0);
size_t r1 = instruction.get_r(1);
int r2 = instruction.get_r(2);
int filter_stride_h = 1;
int filter_stride_w = 1;
if (stride_h < 0)
{
filter_stride_h = -stride_h;
stride_h = 1;
}
if (stride_w < 0)
{
filter_stride_w = -stride_w;
stride_w = 1;
}
vector<Conv2dTuple> tuples;
for (size_t i = 0; i < args.size(); i += 15)
tuples.push_back(Conv2dTuple(args, i));
for (auto& tuple : tuples)
tuple.run_matrix(processor);
}
template<class T>
void Conv2dTuple::run_matrix(SubProcessor<T>& processor)
{
auto& S = processor.get_S();
array<int, 3> dim({{1, weights_h * weights_w * n_channels_in, batch_size * output_h * output_w}});
ShareMatrix<T> A(dim[0], dim[1]), B(dim[1], dim[2]);
if (not T::real_shares(processor.P))
{
matrix_multiply(A, B, processor);
processor.protocol.matrix_multiply(A, B, processor);
return;
}
@@ -208,7 +194,7 @@ void Hemi<T>::conv2ds(SubProcessor<T>& processor,
}
}
auto C = matrix_multiply(A, B, processor);
auto C = processor.protocol.matrix_multiply(A, B, processor);
for (int i_batch = 0; i_batch < batch_size; i_batch ++)
{

View File

@@ -37,6 +37,8 @@ public:
if (swapped)
std::swap(this->n_rows, this->n_cols);
assert(this->n_cols >= this->n_rows);
assert(prep.proc);
this->P = &prep.proc->P;
}
void set_protocol(typename ShareMatrix<T>::Protocol&)

View File

@@ -21,11 +21,7 @@ void MaliciousShamirMC<T>::init_open(const Player& P, int n)
reconstructions.resize(2 * threshold + 2);
for (int i = threshold + 1; i <= 2 * threshold + 1; i++)
{
reconstructions[i].resize(i);
for (int j = 0; j < i; j++)
reconstructions[i][j] =
Shamir<T>::get_rec_factor(P.get_player(j),
P.num_players(), P.my_num(), i);
reconstructions[i] = ShamirMC<T>::get_reconstruction(P, i);
}
}

Some files were not shown because too many files have changed in this diff Show More