mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-07 20:53:55 -05:00
Maintenance.
This commit is contained in:
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -1,9 +1,6 @@
|
||||
[submodule "SimpleOT"]
|
||||
path = deps/SimpleOT
|
||||
url = https://github.com/mkskeller/SimpleOT
|
||||
[submodule "mpir"]
|
||||
path = deps/mpir
|
||||
url = https://github.com/wbhart/mpir
|
||||
[submodule "Programs/Circuits"]
|
||||
path = Programs/Circuits
|
||||
url = https://github.com/mkskeller/bristol-fashion
|
||||
|
||||
@@ -43,8 +43,6 @@ class RealProgramParty : public ProgramPartySpec<T>
|
||||
|
||||
bool one_shot;
|
||||
|
||||
size_t data_sent;
|
||||
|
||||
public:
|
||||
static RealProgramParty& s();
|
||||
|
||||
|
||||
@@ -154,7 +154,6 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
|
||||
while (next != GC::DONE_BREAK);
|
||||
|
||||
MC->Check(*P);
|
||||
data_sent = P->total_comm().sent;
|
||||
|
||||
if (online_opts.verbose)
|
||||
P->total_comm().print();
|
||||
@@ -216,7 +215,7 @@ RealProgramParty<T>::~RealProgramParty()
|
||||
delete prep;
|
||||
delete garble_inputter;
|
||||
delete garble_protocol;
|
||||
cout << "Data sent = " << data_sent * 1e-6 << " MB" << endl;
|
||||
garble_machine.print_comm(*this->P, this->P->total_comm());
|
||||
T::MAC_Check::teardown();
|
||||
}
|
||||
|
||||
|
||||
@@ -62,11 +62,13 @@ private:
|
||||
#endif
|
||||
};
|
||||
#else
|
||||
class BaseKeyVector : public vector<Key>
|
||||
class BaseKeyVector : public CheckVector<Key>
|
||||
{
|
||||
typedef CheckVector<Key> super;
|
||||
|
||||
public:
|
||||
BaseKeyVector(int size = 0) : vector<Key>(size, Key(0)) {}
|
||||
void resize(int size) { vector<Key>::resize(size, Key(0)); }
|
||||
BaseKeyVector(int size = 0) : super(size, Key(0)) {}
|
||||
void resize(int size) { super::resize(size, Key(0)); }
|
||||
};
|
||||
#endif
|
||||
|
||||
@@ -296,7 +298,8 @@ public:
|
||||
static void andm(GC::Processor<U>&, const BaseInstruction&)
|
||||
{ throw runtime_error("andm not implemented"); }
|
||||
|
||||
static void run_tapes(const vector<int>&) { throw not_implemented(); }
|
||||
static void run_tapes(const vector<int>&)
|
||||
{ throw runtime_error("multi-threading not implemented"); }
|
||||
|
||||
// most BMR phases don't need actual input
|
||||
template<class T>
|
||||
|
||||
15
CHANGELOG.md
15
CHANGELOG.md
@@ -1,5 +1,20 @@
|
||||
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
|
||||
|
||||
## 0.3.6 (May 9, 2023)
|
||||
|
||||
- More extensive benchmarking outputs
|
||||
- Replace MPIR by GMP
|
||||
- Secure reading of edaBits from files
|
||||
- Semi-honest client communication
|
||||
- Back-propagation for average pooling
|
||||
- Parallelized convolution
|
||||
- Probabilistic truncation as in ABY3
|
||||
- More balanced communication in Shamir secret sharing
|
||||
- Avoid unnecessary communication in Dealer protocol
|
||||
- Linear solver using Cholesky decomposition
|
||||
- Accept .py files for compilation
|
||||
- Fixed security bug: proper accounting for random elements
|
||||
|
||||
## 0.3.5 (Feb 16, 2023)
|
||||
|
||||
- Easier-to-use machine learning interface
|
||||
|
||||
21
CONFIG
21
CONFIG
@@ -35,15 +35,32 @@ ARM := $(shell uname -m | grep x86; echo $$?)
|
||||
OS := $(shell uname -s)
|
||||
ifeq ($(MACHINE), x86_64)
|
||||
ifeq ($(OS), Linux)
|
||||
ifeq ($(shell cat /proc/cpuinfo | grep -q avx2; echo $$?), 0)
|
||||
AVX_OT = 1
|
||||
else
|
||||
AVX_OT = 0
|
||||
endif
|
||||
else
|
||||
AVX_OT = 0
|
||||
endif
|
||||
else
|
||||
ARCH =
|
||||
AVX_OT = 0
|
||||
endif
|
||||
|
||||
ifeq ($(OS), Darwin)
|
||||
BREW_CFLAGS += -I/usr/local/opt/openssl/include -I`brew --prefix`/opt/openssl/include -I`brew --prefix`/include
|
||||
BREW_LDLIBS += -L/usr/local/opt/openssl/lib -L`brew --prefix`/lib -L`brew --prefix`/opt/openssl/lib
|
||||
endif
|
||||
|
||||
ifeq ($(OS), Linux)
|
||||
ifeq ($(ARM), 1)
|
||||
ifeq ($(shell cat /proc/cpuinfo | grep -q aes; echo $$?), 0)
|
||||
ARCH = -march=armv8.2-a+crypto
|
||||
endif
|
||||
endif
|
||||
endif
|
||||
|
||||
USE_KOS = 0
|
||||
|
||||
# allow to set compiler in CONFIG.mine
|
||||
@@ -66,7 +83,8 @@ endif
|
||||
# Default for MAX_MOD_SZ is 10, which suffices for all Overdrive protocols
|
||||
# MOD = -DMAX_MOD_SZ=10 -DGFP_MOD_SZ=5
|
||||
|
||||
LDLIBS = -lmpirxx -lmpir -lsodium $(MY_LDLIBS)
|
||||
LDLIBS = -lgmpxx -lgmp -lsodium $(MY_LDLIBS)
|
||||
LDLIBS += $(BREW_LDLIBS)
|
||||
LDLIBS += -Wl,-rpath -Wl,$(CURDIR)/local/lib -L$(CURDIR)/local/lib
|
||||
LDLIBS += -lboost_system -lssl -lcrypto
|
||||
|
||||
@@ -88,6 +106,7 @@ BOOST = -lboost_thread $(MY_BOOST)
|
||||
endif
|
||||
|
||||
CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -I$(ROOT)/deps -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SSL_DIR) $(SECURE) -std=c++11 -Werror
|
||||
CFLAGS += $(BREW_CFLAGS)
|
||||
CPPFLAGS = $(CFLAGS)
|
||||
LD = $(CXX)
|
||||
|
||||
|
||||
@@ -17,8 +17,10 @@ import math
|
||||
|
||||
class SecretBitsAF(base.RegisterArgFormat):
|
||||
reg_type = 'sb'
|
||||
name = 'sbit'
|
||||
class ClearBitsAF(base.RegisterArgFormat):
|
||||
reg_type = 'cb'
|
||||
name = 'cbit'
|
||||
|
||||
base.ArgFormats['sb'] = SecretBitsAF
|
||||
base.ArgFormats['sbw'] = SecretBitsAF
|
||||
|
||||
@@ -338,16 +338,19 @@ class Merger:
|
||||
d[j] = d[i]
|
||||
|
||||
def read(reg, n):
|
||||
last_read[reg] = n
|
||||
for dup in reg.duplicates:
|
||||
if last_def[dup] != -1:
|
||||
if last_def[dup] not in (-1, n):
|
||||
add_edge(last_def[dup], n)
|
||||
last_read[reg] = n
|
||||
|
||||
def write(reg, n):
|
||||
last_def[reg] = n
|
||||
for dup in reg.duplicates:
|
||||
if last_read[dup] not in (-1, n):
|
||||
add_edge(last_read[dup], n)
|
||||
if id(dup) in [id(x) for x in block.instructions[n].get_used()] and \
|
||||
last_read[dup] not in (-1, n):
|
||||
add_edge(last_read[dup], n)
|
||||
last_def[reg] = n
|
||||
|
||||
def handle_mem_access(addr, reg_type, last_access_this_kind,
|
||||
last_access_other_kind):
|
||||
@@ -434,13 +437,6 @@ class Merger:
|
||||
# if options.debug:
|
||||
# col = colordict[instr.__class__.__name__]
|
||||
# G.add_node(n, color=col, label=str(instr))
|
||||
for reg in inputs:
|
||||
if reg.vector and instr.is_vec():
|
||||
for i in reg.vector:
|
||||
read(i, n)
|
||||
else:
|
||||
read(reg, n)
|
||||
|
||||
for reg in outputs:
|
||||
if reg.vector and instr.is_vec():
|
||||
for i in reg.vector:
|
||||
@@ -448,6 +444,13 @@ class Merger:
|
||||
else:
|
||||
write(reg, n)
|
||||
|
||||
for reg in inputs:
|
||||
if reg.vector and instr.is_vec():
|
||||
for i in reg.vector:
|
||||
read(i, n)
|
||||
else:
|
||||
read(reg, n)
|
||||
|
||||
# will be merged
|
||||
if isinstance(instr, TextInputInstruction):
|
||||
keep_text_order(instr, n)
|
||||
@@ -556,18 +559,6 @@ class Merger:
|
||||
if unused_result:
|
||||
eliminate(i)
|
||||
count += 1
|
||||
# remove unnecessary stack instructions
|
||||
# left by optimization with budget
|
||||
if isinstance(inst, popint_class) and \
|
||||
(not G.degree(i) or (G.degree(i) == 1 and
|
||||
isinstance(instructions[list(G[i])[0]], StackInstruction))) \
|
||||
and \
|
||||
inst.args[0].can_eliminate and \
|
||||
len(G.pred[i]) == 1 and \
|
||||
isinstance(instructions[list(G.pred[i])[0]], pushint_class):
|
||||
eliminate(list(G.pred[i])[0])
|
||||
eliminate(i)
|
||||
count += 2
|
||||
if count > 0 and self.block.parent.program.verbose:
|
||||
print('Eliminated %d dead instructions, among which %d opens: %s' \
|
||||
% (count, open_count, dict(stats)))
|
||||
|
||||
@@ -50,6 +50,9 @@ def set_variant(options):
|
||||
do_precomp = False
|
||||
elif variant is not None:
|
||||
raise CompilerError('Unknown comparison variant: %s' % variant)
|
||||
if const_rounds and instructions_base.program.options.binary:
|
||||
raise CompilerError(
|
||||
'Comparison variant choice incompatible with binary circuits')
|
||||
|
||||
def ld2i(c, n):
|
||||
""" Load immediate 2^n into clear GF(p) register c """
|
||||
|
||||
@@ -22,6 +22,7 @@ class Compiler:
|
||||
self.custom_args = custom_args
|
||||
self.build_option_parser()
|
||||
self.VARS = {}
|
||||
self.root = os.path.dirname(__file__) + '/..'
|
||||
|
||||
def build_option_parser(self):
|
||||
parser = OptionParser(usage=self.usage)
|
||||
@@ -269,7 +270,7 @@ class Compiler:
|
||||
self.prog = Program(self.args, self.options, name=name)
|
||||
if self.execute:
|
||||
if self.options.execute in \
|
||||
("emulate", "ring", "rep-field", "semi2k"):
|
||||
("emulate", "ring", "rep-field"):
|
||||
self.prog.use_trunc_pr = True
|
||||
if self.options.execute in ("ring",):
|
||||
self.prog.use_split(3)
|
||||
@@ -405,7 +406,7 @@ class Compiler:
|
||||
infile = open(self.prog.infile)
|
||||
|
||||
# make compiler modules directly accessible
|
||||
sys.path.insert(0, "Compiler")
|
||||
sys.path.insert(0, "%s/Compiler" % self.root)
|
||||
# create the tapes
|
||||
exec(compile(infile.read(), infile.name, "exec"), self.VARS)
|
||||
|
||||
@@ -477,15 +478,15 @@ class Compiler:
|
||||
|
||||
def local_execution(self, args=[]):
|
||||
executable = self.executable_from_protocol(self.options.execute)
|
||||
if not os.path.exists(executable):
|
||||
if not os.path.exists("%s/%s" % (self.root, executable)):
|
||||
print("Creating binary for virtual machine...")
|
||||
try:
|
||||
subprocess.run(["make", executable], check=True)
|
||||
subprocess.run(["make", executable], check=True, cwd=self.root)
|
||||
except:
|
||||
raise CompilerError(
|
||||
"Cannot produce %s. " % executable + \
|
||||
"Note that compilation requires a few GB of RAM.")
|
||||
vm = 'Scripts/%s.sh' % self.options.execute
|
||||
vm = "%s/Scripts/%s.sh" % (self.root, self.options.execute)
|
||||
os.execl(vm, vm, self.prog.name, *args)
|
||||
|
||||
def remote_execution(self, args=[]):
|
||||
@@ -496,7 +497,7 @@ class Compiler:
|
||||
from fabric import Connection
|
||||
import subprocess
|
||||
print("Creating static binary for virtual machine...")
|
||||
subprocess.run(["make", "static/%s" % vm], check=True)
|
||||
subprocess.run(["make", "static/%s" % vm], check=True, cwd=self.root)
|
||||
|
||||
# transfer files
|
||||
import glob
|
||||
@@ -519,7 +520,7 @@ class Compiler:
|
||||
"mkdir -p %s/{Player-Data,Programs/{Bytecode,Schedules}} " % \
|
||||
dest)
|
||||
# executable
|
||||
connection.put("static/%s" % vm, dest)
|
||||
connection.put("%s/static/%s" % (self.root, vm), dest)
|
||||
# program
|
||||
dest += "/"
|
||||
connection.put("Programs/Schedules/%s.sch" % self.prog.name,
|
||||
|
||||
@@ -289,7 +289,7 @@ def BitDecRingRaw(a, k, m):
|
||||
def BitDecRing(a, k, m):
|
||||
bits = BitDecRingRaw(a, k, m)
|
||||
# reversing to reduce number of rounds
|
||||
return [types.sint.conv(bit) for bit in reversed(bits)][::-1]
|
||||
return [types.sintbit.conv(bit) for bit in reversed(bits)][::-1]
|
||||
|
||||
def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None):
|
||||
instructions_base.set_global_vector_size(a.size)
|
||||
@@ -306,7 +306,7 @@ def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None):
|
||||
|
||||
def BitDecField(a, k, m, kappa, bits_to_compute=None):
|
||||
res = BitDecFieldRaw(a, k, m, kappa, bits_to_compute)
|
||||
return [types.sint.conv(bit) for bit in res]
|
||||
return [types.sintbit.conv(bit) for bit in res]
|
||||
|
||||
|
||||
@instructions_base.ret_cisc
|
||||
|
||||
@@ -356,7 +356,17 @@ class reqbl(base.Instruction):
|
||||
code = base.opcodes['REQBL']
|
||||
arg_format = ['int']
|
||||
|
||||
class active(base.Instruction):
|
||||
""" Indicate whether program is compatible with malicious-security
|
||||
protocols.
|
||||
|
||||
:param: 0 for no, 1 for yes
|
||||
"""
|
||||
code = base.opcodes['ACTIVE']
|
||||
arg_format = ['int']
|
||||
|
||||
class time(base.IOInstruction):
|
||||
|
||||
""" Output time since start of computation. """
|
||||
code = base.opcodes['TIME']
|
||||
arg_format = []
|
||||
@@ -2418,9 +2428,10 @@ class matmulsm(matmul_base):
|
||||
super(matmulsm, self).add_usage(req_node)
|
||||
req_node.increment(('matmul', tuple(self.args[3:6])), 1)
|
||||
|
||||
class conv2ds(base.DataInstruction):
|
||||
class conv2ds(base.DataInstruction, base.VarArgsInstruction, base.Mergeable):
|
||||
""" Secret 2D convolution.
|
||||
|
||||
:param: number of arguments to follow (int)
|
||||
:param: result (sint vector in row-first order)
|
||||
:param: inputs (sint vector in row-first order)
|
||||
:param: weights (sint vector in row-first order)
|
||||
@@ -2436,10 +2447,12 @@ class conv2ds(base.DataInstruction):
|
||||
:param: padding height (int)
|
||||
:param: padding width (int)
|
||||
:param: batch size (int)
|
||||
:param: repeat from result...
|
||||
|
||||
"""
|
||||
code = base.opcodes['CONV2DS']
|
||||
arg_format = ['sw','s','s','int','int','int','int','int','int','int','int',
|
||||
'int','int','int','int']
|
||||
arg_format = itertools.cycle(['sw','s','s','int','int','int','int','int',
|
||||
'int','int','int','int','int','int','int'])
|
||||
data_type = 'triple'
|
||||
is_vec = lambda self: True
|
||||
|
||||
@@ -2450,14 +2463,16 @@ class conv2ds(base.DataInstruction):
|
||||
assert args[2].size == args[7] * args[8] * args[11]
|
||||
|
||||
def get_repeat(self):
|
||||
return self.args[3] * self.args[4] * self.args[7] * self.args[8] * \
|
||||
self.args[11] * self.args[14]
|
||||
args = self.args
|
||||
return sum(args[i+3] * args[i+4] * args[i+7] * args[i+8] * \
|
||||
args[i+11] * args[i+14] for i in range(0, len(args), 15))
|
||||
|
||||
def add_usage(self, req_node):
|
||||
super(conv2ds, self).add_usage(req_node)
|
||||
args = self.args
|
||||
req_node.increment(('matmul', (1, args[7] * args[8] * args[11],
|
||||
args[14] * args[3] * args[4])), 1)
|
||||
for i in range(0, len(self.args), 15):
|
||||
args = self.args[i:i + 15]
|
||||
req_node.increment(('matmul', (1, args[7] * args[8] * args[11],
|
||||
args[14] * args[3] * args[4])), 1)
|
||||
|
||||
@base.vectorize
|
||||
class trunc_pr(base.VarArgsInstruction):
|
||||
|
||||
@@ -66,6 +66,7 @@ opcodes = dict(
|
||||
PLAYERID = 0xE4,
|
||||
USE_EDABIT = 0xE5,
|
||||
USE_MATMUL = 0x1F,
|
||||
ACTIVE = 0xE9,
|
||||
# Addition
|
||||
ADDC = 0x20,
|
||||
ADDS = 0x21,
|
||||
@@ -700,18 +701,23 @@ class RegisterArgFormat(ArgFormat):
|
||||
|
||||
class ClearModpAF(RegisterArgFormat):
|
||||
reg_type = RegType.ClearModp
|
||||
name = 'cint'
|
||||
|
||||
class SecretModpAF(RegisterArgFormat):
|
||||
reg_type = RegType.SecretModp
|
||||
name = 'sint'
|
||||
|
||||
class ClearGF2NAF(RegisterArgFormat):
|
||||
reg_type = RegType.ClearGF2N
|
||||
name = 'cgf2n'
|
||||
|
||||
class SecretGF2NAF(RegisterArgFormat):
|
||||
reg_type = RegType.SecretGF2N
|
||||
name = 'sgf2n'
|
||||
|
||||
class ClearIntAF(RegisterArgFormat):
|
||||
reg_type = RegType.ClearInt
|
||||
name = 'regint'
|
||||
|
||||
class IntArgFormat(ArgFormat):
|
||||
n_bits = 32
|
||||
|
||||
@@ -1226,7 +1226,7 @@ def while_loop(loop_body, condition, arg=None, g=None):
|
||||
result = loop_body(arg)
|
||||
if isinstance(result, MemValue):
|
||||
result = result.read()
|
||||
result.link(arg)
|
||||
arg.update(result)
|
||||
return condition(result)
|
||||
if not isinstance(pre_condition, (bool,int)) or pre_condition:
|
||||
if_statement(pre_condition, lambda: do_while(loop_fn, g=g))
|
||||
|
||||
203
Compiler/ml.py
203
Compiler/ml.py
@@ -372,6 +372,7 @@ class Output(NoVariableLayer):
|
||||
n = self.X.sizes[0]
|
||||
if Y is None:
|
||||
Y = self.Y
|
||||
assert isinstance(Y, Array)
|
||||
n_correct = MemValue(0)
|
||||
n_printed = MemValue(0)
|
||||
@for_range_opt(n)
|
||||
@@ -1109,14 +1110,7 @@ class Square(ElementWiseLayer):
|
||||
f_prime = staticmethod(lambda x: cfix(2, size=x.size) * x)
|
||||
prime_type = sfix
|
||||
|
||||
class MaxPool(NoVariableLayer):
|
||||
""" Fixed-point MaxPool layer.
|
||||
|
||||
:param shape: input shape (tuple/list of four int)
|
||||
:param strides: strides (tuple/list of four int, first and last must be 1)
|
||||
:param ksize: kernel size (tuple/list of four int, first and last must be 1)
|
||||
:param padding: :py:obj:`'VALID'` (default) or :py:obj:`'SAME'`
|
||||
"""
|
||||
class PoolBase(NoVariableLayer):
|
||||
def __init__(self, shape, strides=(1, 2, 2, 1), ksize=(1, 2, 2, 1),
|
||||
padding='VALID'):
|
||||
assert len(shape) == 4
|
||||
@@ -1152,38 +1146,6 @@ class MaxPool(NoVariableLayer):
|
||||
(type(self).__name__, self.X.sizes, self.strides,
|
||||
self.ksize, self.padding)
|
||||
|
||||
def forward(self, batch=None, training=False):
|
||||
if batch is None:
|
||||
batch = Array.create_from(regint(0))
|
||||
def process(pool, bi, k, i, j):
|
||||
def m(a, b):
|
||||
c = a[0] > b[0]
|
||||
l = [c * x for x in a[1]]
|
||||
l += [(1 - c) * x for x in b[1]]
|
||||
return c.if_else(a[0], b[0]), l
|
||||
red = util.tree_reduce(m, [(x[0], [1] if training else [])
|
||||
for x in pool])
|
||||
self.Y[bi][i][j][k] = red[0]
|
||||
for ii, x in enumerate(red[1]):
|
||||
self.comparisons[bi][k][i][j][ii] = x
|
||||
self.traverse(batch, process)
|
||||
|
||||
def backward(self, compute_nabla_X=True, batch=None):
|
||||
if compute_nabla_X:
|
||||
self.nabla_X.alloc()
|
||||
self.nabla_X.assign_all(0)
|
||||
break_point()
|
||||
def process(pool, bi, k, i, j):
|
||||
for (x, h_in, w_in, h, w), c \
|
||||
in zip(pool, self.comparisons[bi][k][i][j]):
|
||||
hh = h * h_in
|
||||
ww = w * w_in
|
||||
res = h_in * w_in * c * self.nabla_Y[bi][i][j][k]
|
||||
get_program().protect_memory(True)
|
||||
self.nabla_X[bi][hh][ww][k] += res
|
||||
get_program().protect_memory(False)
|
||||
self.traverse(batch, process)
|
||||
|
||||
def traverse(self, batch, process):
|
||||
need_padding = [self.strides[i] * (self.Y.sizes[i] - 1) + self.ksize[i] >
|
||||
self.X.sizes[i] for i in range(4)]
|
||||
@@ -1221,6 +1183,47 @@ class MaxPool(NoVariableLayer):
|
||||
h_in, w_in, h, w])
|
||||
process(pool, bi, k, i, j)
|
||||
|
||||
class MaxPool(PoolBase):
|
||||
""" Fixed-point MaxPool layer.
|
||||
|
||||
:param shape: input shape (tuple/list of four int)
|
||||
:param strides: strides (tuple/list of four int, first and last must be 1)
|
||||
:param ksize: kernel size (tuple/list of four int, first and last must be 1)
|
||||
:param padding: :py:obj:`'VALID'` (default), :py:obj:`'SAME'`, integer, or
|
||||
list/tuple of integers
|
||||
|
||||
"""
|
||||
def forward(self, batch=None, training=False):
|
||||
if batch is None:
|
||||
batch = Array.create_from(regint(0))
|
||||
def process(pool, bi, k, i, j):
|
||||
def m(a, b):
|
||||
c = a[0] > b[0]
|
||||
l = [c * x for x in a[1]]
|
||||
l += [(1 - c) * x for x in b[1]]
|
||||
return c.if_else(a[0], b[0]), l
|
||||
red = util.tree_reduce(m, [(x[0], [1] if training else [])
|
||||
for x in pool])
|
||||
self.Y[bi][i][j][k] = red[0]
|
||||
for ii, x in enumerate(red[1]):
|
||||
self.comparisons[bi][k][i][j][ii] = x
|
||||
self.traverse(batch, process)
|
||||
|
||||
def backward(self, compute_nabla_X=True, batch=None):
|
||||
if compute_nabla_X:
|
||||
self.nabla_X.alloc()
|
||||
self.nabla_X.assign_all(0)
|
||||
break_point()
|
||||
def process(pool, bi, k, i, j):
|
||||
for (x, h_in, w_in, h, w), c \
|
||||
in zip(pool, self.comparisons[bi][k][i][j]):
|
||||
hh = h * h_in
|
||||
ww = w * w_in
|
||||
res = h_in * w_in * c * self.nabla_Y[bi][i][j][k]
|
||||
get_program().protect_memory(True)
|
||||
self.nabla_X[bi][hh][ww][k] += res
|
||||
get_program().protect_memory(False)
|
||||
self.traverse(batch, process)
|
||||
|
||||
class Argmax(NoVariableLayer):
|
||||
""" Fixed-point Argmax layer.
|
||||
@@ -2058,6 +2061,12 @@ def easyMaxPool(input_shape, kernel_size, stride=None, padding=0):
|
||||
or tuple/list of two int
|
||||
|
||||
"""
|
||||
kernel_size, stride, padding = \
|
||||
_standardize_pool_options(kernel_size, stride, padding)
|
||||
return MaxPool(input_shape, [1] + list(stride) + [1],
|
||||
[1] + list(kernel_size) + [1], padding)
|
||||
|
||||
def _standardize_pool_options(kernel_size, stride, padding):
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size)
|
||||
if isinstance(stride, int):
|
||||
@@ -2066,8 +2075,7 @@ def easyMaxPool(input_shape, kernel_size, stride=None, padding=0):
|
||||
stride = kernel_size
|
||||
padding = padding.upper() if isinstance(padding, str) \
|
||||
else padding
|
||||
return MaxPool(input_shape, [1] + list(stride) + [1],
|
||||
[1] + list(kernel_size) + [1], padding)
|
||||
return kernel_size, stride, padding
|
||||
|
||||
class QuantAveragePool2d(QuantBase, AveragePool2d):
|
||||
def input_params_from(self, player):
|
||||
@@ -2075,14 +2083,47 @@ class QuantAveragePool2d(QuantBase, AveragePool2d):
|
||||
for s in self.input_squant, self.output_squant:
|
||||
s.get_params_from(player)
|
||||
|
||||
class FixAveragePool2d(FixBase, AveragePool2d):
|
||||
class FixAveragePool2d(PoolBase, FixBase):
|
||||
""" Fixed-point 2D AvgPool layer.
|
||||
|
||||
:param input_shape: input shape (tuple/list of four int)
|
||||
:param output_shape: output shape (tuple/list of four int)
|
||||
:param filter_size: filter size (tuple/list of two int)
|
||||
:param strides: strides (tuple/list of two int)
|
||||
"""
|
||||
:param filter_size: filter size (int or tuple/list of two int)
|
||||
:param strides: strides (int or tuple/list of two int)
|
||||
:param padding: :py:obj:`'SAME'`, :py:obj:`'VALID'`, int,
|
||||
or tuple/list of two int
|
||||
|
||||
"""
|
||||
def __init__(self, input_shape, output_shape, filter_size, strides=(1, 1),
|
||||
padding=0):
|
||||
filter_size, strides, padding = \
|
||||
_standardize_pool_options(filter_size, strides, padding)
|
||||
PoolBase.__init__(self, input_shape, [1] + list(strides) + [1],
|
||||
[1] + list(filter_size) + [1], padding)
|
||||
self.pool_size = reduce(operator.mul, filter_size)
|
||||
if output_shape:
|
||||
assert self.Y.shape == list(output_shape)
|
||||
|
||||
def _forward(self, batch):
|
||||
def process(pool, bi, k, i, j):
|
||||
self.Y[bi][i][j][k] = sum(x[0] for x in pool) * (1 / self.pool_size)
|
||||
self.traverse(batch, process)
|
||||
|
||||
def backward(self, compute_nabla_X=True, batch=None):
|
||||
if compute_nabla_X:
|
||||
self.nabla_X.alloc()
|
||||
self.nabla_X.assign_all(0)
|
||||
break_point()
|
||||
def process(pool, bi, k, i, j):
|
||||
part = self.nabla_Y[bi][i][j][k] * (1 / self.pool_size)
|
||||
for x, h_in, w_in, h, w in pool:
|
||||
hh = h * h_in
|
||||
ww = w * w_in
|
||||
res = h_in * w_in * part
|
||||
get_program().protect_memory(True)
|
||||
self.nabla_X[bi][hh][ww][k] += res
|
||||
get_program().protect_memory(False)
|
||||
self.traverse(batch, process)
|
||||
|
||||
class QuantReshape(QuantBase, BaseLayer):
|
||||
def __init__(self, input_shape, _, output_shape):
|
||||
@@ -2265,6 +2306,8 @@ class Optimizer:
|
||||
|
||||
:param data: sample data (:py:class:`Compiler.types.Matrix` with one row per sample)
|
||||
:param top: return top prediction instead of probability distribution
|
||||
:returns: sfix/sint Array (depening on :py:obj:`top`)
|
||||
|
||||
"""
|
||||
if isinstance(self.layers[-1].Y, Array) or top:
|
||||
if top:
|
||||
@@ -2540,6 +2583,8 @@ class Optimizer:
|
||||
@_no_mem_warnings
|
||||
def run_by_args(self, program, n_runs, batch_size, test_X, test_Y,
|
||||
acc_batch_size=None, reset=True):
|
||||
MultiArray.disable_index_checks()
|
||||
Array.check_indices = False
|
||||
if acc_batch_size is None:
|
||||
acc_batch_size = batch_size
|
||||
depreciation = None
|
||||
@@ -2943,6 +2988,10 @@ class keras:
|
||||
return 'maxpool', {'pool_size': pool_size, 'strides': strides,
|
||||
'padding': padding}
|
||||
|
||||
def AveragePooling2D(pool_size=2, strides=None, padding='valid'):
|
||||
return 'avgpool', {'filter_size': pool_size, 'strides': strides,
|
||||
'padding': padding}
|
||||
|
||||
def Dropout(rate):
|
||||
l = math.log(rate, 2)
|
||||
if int(l) != l:
|
||||
@@ -3014,9 +3063,12 @@ class keras:
|
||||
n_units = reduce(operator.mul,
|
||||
layers[-1].Y.sizes[1:])
|
||||
if i == len(self.layers) - 1:
|
||||
if layer[2].get('activation', 'softmax') in \
|
||||
('softmax', 'sigmoid'):
|
||||
activation = layer[2].get('activation', None)
|
||||
if activation in ('softmax', 'sigmoid'):
|
||||
layer[2].pop('activation', None)
|
||||
if activation == 'softmax' and layer[1][0] == 1:
|
||||
raise CompilerError(
|
||||
'softmax requires more than one output neuron')
|
||||
layers.append(Dense(N, n_units, layer[1][0],
|
||||
**layer[2]))
|
||||
input_shape = layers[-1].Y.sizes
|
||||
@@ -3041,6 +3093,9 @@ class keras:
|
||||
layers.append(easyMaxPool(input_shape, pool_size,
|
||||
strides, padding))
|
||||
input_shape = layers[-1].Y.sizes
|
||||
elif name == 'avgpool':
|
||||
layers.append(FixAveragePool2d(input_shape, None, **layer[1]))
|
||||
input_shape = layers[-1].Y.sizes
|
||||
elif name == 'dropout':
|
||||
layers.append(Dropout(batch_size, reduce(
|
||||
operator.mul, layers[-1].Y.sizes[1:]),
|
||||
@@ -3192,6 +3247,10 @@ def layers_from_torch(sequence, data_input_shape, batch_size, input_via=None):
|
||||
layers.append(easyMaxPool(input_shape, item.kernel_size,
|
||||
item.stride, item.padding))
|
||||
input_shape = layers[-1].Y.shape
|
||||
elif name == 'AvgPool2d':
|
||||
layers.append(FixAveragePool2d(input_shape, None, item.kernel_size,
|
||||
item.stride, item.padding))
|
||||
input_shape = layers[-1].Y.shape
|
||||
elif name == 'ReLU':
|
||||
layers.append(Relu(input_shape))
|
||||
elif name == 'Flatten':
|
||||
@@ -3295,7 +3354,7 @@ class SGDLogistic(OneLayerSGD):
|
||||
return super(SGDLogistic, self).predict(X)
|
||||
|
||||
class SGDLinear(OneLayerSGD):
|
||||
""" Logistic regression using SGD.
|
||||
""" Linear regression using SGD.
|
||||
|
||||
:param n_epochs: number of epochs
|
||||
:param batch_size: batch size
|
||||
@@ -3415,11 +3474,16 @@ def var(x):
|
||||
return res.read()
|
||||
|
||||
def cholesky(A, reveal_diagonal=False):
|
||||
""" Cholesky decomposition. """
|
||||
""" Cholesky decomposition.
|
||||
|
||||
:returns: lower triangular matrix
|
||||
|
||||
"""
|
||||
assert len(A.shape) == 2
|
||||
assert A.shape[0] == A.shape[1]
|
||||
L = A.same_shape()
|
||||
L.assign_all(0)
|
||||
diag_inv = A.value_type.Array(A.shape[0])
|
||||
@for_range(A.shape[0])
|
||||
def _(i):
|
||||
@for_range(i + 1)
|
||||
@@ -3429,10 +3493,47 @@ def cholesky(A, reveal_diagonal=False):
|
||||
@if_e(i == j)
|
||||
def _():
|
||||
L[i][j] = mpc_math.sqrt(A[i][i] - sum)
|
||||
diag_inv[i] = 1 / L[i][j]
|
||||
if reveal_diagonal:
|
||||
print_ln('L[%s][%s] = %s = sqrt(%s - %s)', i, j,
|
||||
L[i][j].reveal(), A[i][j].reveal(), sum.reveal())
|
||||
@else_
|
||||
def _():
|
||||
L[i][j] = (1.0 / L[j][j] * (A[i][j] - sum))
|
||||
L[i][j] = (diag_inv[j] * (A[i][j] - sum))
|
||||
return L
|
||||
|
||||
def solve_lower(A, b):
|
||||
""" Linear solver where :py:obj:`A` is lower triangular quadratic. """
|
||||
assert len(A.shape) == 2
|
||||
assert A.shape[0] == A.shape[1]
|
||||
assert len(b) == A.shape[0]
|
||||
b = Array.create_from(b)
|
||||
res = sfix.Array(len(b))
|
||||
@for_range(len(b))
|
||||
def _(i):
|
||||
res[i] = b[i] / A[i][i]
|
||||
b[:] -= res[i] * A.get_column(i)
|
||||
return res
|
||||
|
||||
def solve_upper(A, b):
|
||||
""" Linear solver where :py:obj:`A` is upper triangular quadratic. """
|
||||
assert len(A.shape) == 2
|
||||
assert A.shape[0] == A.shape[1]
|
||||
assert len(b) == A.shape[0]
|
||||
b = Array.create_from(b)
|
||||
res = sfix.Array(len(b))
|
||||
@for_range(len(b) - 1, -1, -1)
|
||||
def _(i):
|
||||
res[i] = b[i] / A[i][i]
|
||||
b[:] -= res[i] * A.get_column(i)
|
||||
return res
|
||||
|
||||
def solve_cholesky(A, b, debug=False):
|
||||
""" Linear solver using Cholesky decomposition. """
|
||||
L = cholesky(A, reveal_diagonal=debug)
|
||||
if debug:
|
||||
Optimizer.stat('L', L)
|
||||
x = solve_lower(L, b)
|
||||
if debug:
|
||||
Optimizer.stat('intermediate', x)
|
||||
return solve_upper(L.transpose(), x)
|
||||
|
||||
@@ -661,7 +661,7 @@ def sqrt_simplified_fx(x):
|
||||
h = h * r
|
||||
H = 4 * (h * h)
|
||||
|
||||
if not x.round_nearest or (2 * f < k - 1):
|
||||
if not x.round_nearest or (2 * x.f < x.k - 1):
|
||||
H = (h < 2 ** (-x.f / 2) / 2).if_else(0, H)
|
||||
|
||||
H = H * x
|
||||
@@ -806,9 +806,7 @@ def sqrt_fx(x_l, k, f):
|
||||
@instructions_base.sfix_cisc
|
||||
def sqrt(x, k=None, f=None):
|
||||
"""
|
||||
Returns the square root (sfix) of any given fractional
|
||||
value as long as it can be rounded to a integral value
|
||||
with :py:obj:`f` bits of decimal precision.
|
||||
Square root.
|
||||
|
||||
:param x: fractional input (sfix).
|
||||
|
||||
|
||||
@@ -186,6 +186,8 @@ class Program(object):
|
||||
self.input_files = {}
|
||||
self.base_addresses = {}
|
||||
self._protect_memory = False
|
||||
self._always_active = True
|
||||
self.active = True
|
||||
if not self.options.cisc:
|
||||
self.options.cisc = not self.options.optimize_hard
|
||||
|
||||
@@ -207,16 +209,14 @@ class Program(object):
|
||||
return self.n_threads
|
||||
|
||||
def init_names(self, args):
|
||||
# ignore path to file - source must be in Programs/Source
|
||||
if "Programs" in os.listdir(os.getcwd()):
|
||||
# compile prog in ./Programs/Source directory
|
||||
self.programs_dir = "Programs"
|
||||
else:
|
||||
# assume source is in main SPDZ directory
|
||||
self.programs_dir = sys.path[0] + "/Programs"
|
||||
self.programs_dir = "Programs"
|
||||
if self.verbose:
|
||||
print("Compiling program in", self.programs_dir)
|
||||
|
||||
for dirname in (self.programs_dir, "Player-Data"):
|
||||
if not os.path.exists(dirname):
|
||||
os.mkdir(dirname)
|
||||
|
||||
# create extra directories if needed
|
||||
for dirname in ["Public-Input", "Bytecode", "Schedules"]:
|
||||
if not os.path.exists(self.programs_dir + "/" + dirname):
|
||||
@@ -224,13 +224,29 @@ class Program(object):
|
||||
|
||||
if self.name is None:
|
||||
self.name = args[0].split("/")[-1]
|
||||
if self.name.endswith(".mpc"):
|
||||
self.name = self.name[:-4]
|
||||
exts = ".mpc", ".py"
|
||||
for ext in exts:
|
||||
if self.name.endswith(ext):
|
||||
self.name = self.name[:-len(ext)]
|
||||
|
||||
if os.path.exists(args[0]):
|
||||
self.infile = args[0]
|
||||
else:
|
||||
self.infile = self.programs_dir + "/Source/" + self.name + ".mpc"
|
||||
infiles = []
|
||||
for x in (self.programs_dir, sys.path[0] + "/Programs"):
|
||||
for ext in exts:
|
||||
filename = args[0]
|
||||
if not filename.endswith(ext):
|
||||
filename += ext
|
||||
infiles += [x + "/Source/" + filename]
|
||||
for f in infiles:
|
||||
if os.path.exists(f):
|
||||
self.infile = f
|
||||
break
|
||||
else:
|
||||
raise CompilerError(
|
||||
"found none of the potential input files: " +
|
||||
", ".join("'%s'" % x for x in [args[0]] + infiles))
|
||||
"""
|
||||
self.name is input file name (minus extension) + any optional arguments.
|
||||
Used to generate output filenames
|
||||
@@ -479,6 +495,9 @@ class Program(object):
|
||||
# finalize the memory
|
||||
self.finalize_memory()
|
||||
|
||||
# communicate protocol compability
|
||||
Compiler.instructions.active(self._always_active)
|
||||
|
||||
self.write_bytes()
|
||||
|
||||
if self.options.asmoutfile:
|
||||
@@ -672,6 +691,19 @@ class Program(object):
|
||||
logp = int(round(math.log(p, 2)))
|
||||
return abs(p - 2 ** logp) / p < 2 ** -self.security
|
||||
|
||||
@property
|
||||
def active(self):
|
||||
""" Whether to use actively secure protocols. """
|
||||
return self._active
|
||||
|
||||
@active.setter
|
||||
def active(self, change):
|
||||
self._always_active &= change
|
||||
self._active = change
|
||||
|
||||
def semi_honest(self):
|
||||
self._always_active = False
|
||||
|
||||
@staticmethod
|
||||
def read_tapes(schedule):
|
||||
m = re.search(r"([^/]*)\.mpc", schedule)
|
||||
@@ -1454,6 +1486,9 @@ class Tape:
|
||||
return Tape.Register(self.reg_type, Program.prog.curr_tape)
|
||||
|
||||
def link(self, other):
|
||||
if Program.prog.options.noreallocate:
|
||||
raise CompilerError("reallocation necessary for linking, "
|
||||
"remove option -u")
|
||||
self.duplicates |= other.duplicates
|
||||
for dup in self.duplicates:
|
||||
dup.duplicates = self.duplicates
|
||||
@@ -1466,12 +1501,15 @@ class Tape:
|
||||
:param other: any convertible type
|
||||
|
||||
"""
|
||||
other = type(self)(other)
|
||||
if isinstance(other, Tape.Register) and other.block != Program.prog.curr_block:
|
||||
other = type(self)(other)
|
||||
else:
|
||||
other = self.conv(other)
|
||||
if Program.prog.curr_block in [x.block for x in self.duplicates]:
|
||||
self.program.start_new_basicblock()
|
||||
if self.program != other.program:
|
||||
raise CompilerError(
|
||||
'cannot update register with one from another thread')
|
||||
if other.block in [x.block for x in self.duplicates]:
|
||||
self.program.start_new_basicblock()
|
||||
self.link(other)
|
||||
|
||||
@property
|
||||
|
||||
@@ -659,6 +659,7 @@ class _secret_structure(_structure):
|
||||
traverse(x, level + 1)
|
||||
traverse(content, 0)
|
||||
f.write('\n')
|
||||
f.flush()
|
||||
if requested_shape is not None and \
|
||||
list(shape) != list(requested_shape):
|
||||
raise CompilerError('content contradicts shape')
|
||||
@@ -2415,26 +2416,34 @@ class sint(_secret, _int):
|
||||
def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType):
|
||||
""" Securely obtain shares of values input by a client.
|
||||
This uses the triple-based input protocol introduced by
|
||||
`Damgård et al. <http://eprint.iacr.org/2015/1006>`_
|
||||
`Damgård et al. <http://eprint.iacr.org/2015/1006>`_ unless
|
||||
:py:obj:`program.active` is set to false, in which case
|
||||
it uses random values to mask the clients' input.
|
||||
|
||||
:param n: number of inputs (int)
|
||||
:param client_id: regint
|
||||
:param size: vector size (default 1)
|
||||
:returns: list of sint
|
||||
"""
|
||||
# send shares of a triple to client
|
||||
triples = list(itertools.chain(*(sint.get_random_triple() for i in range(n))))
|
||||
if program.active:
|
||||
# send shares of a triple to client
|
||||
triples = list(itertools.chain(*(sint.get_random_triple() for i in range(n))))
|
||||
else:
|
||||
triples = [sint.get_random() for i in range(n)]
|
||||
|
||||
sint.write_shares_to_socket(client_id, triples, message_type)
|
||||
|
||||
received = util.tuplify(cint.read_from_socket(client_id, n))
|
||||
y = [0] * n
|
||||
for i in range(n):
|
||||
y[i] = received[i] - triples[i * 3]
|
||||
y[i] = received[i] - triples[i * 3 if program.active else i]
|
||||
return y
|
||||
|
||||
@classmethod
|
||||
def reveal_to_clients(cls, clients, values):
|
||||
""" Reveal securely to clients.
|
||||
Uses :py:obj:`program.active` to determine whether to use
|
||||
triples for active security.
|
||||
|
||||
:param clients: client ids (list or array)
|
||||
:param values: list of sint to reveal
|
||||
@@ -2445,8 +2454,11 @@ class sint(_secret, _int):
|
||||
|
||||
for value in values:
|
||||
assert(value.size == values[0].size)
|
||||
r = sint.get_random()
|
||||
to_send += [value, r, value * r]
|
||||
if program.active:
|
||||
r = sint.get_random()
|
||||
to_send += [value, r, value * r]
|
||||
else:
|
||||
to_send += [value]
|
||||
|
||||
if isinstance(clients, Array):
|
||||
n_clients = clients.length
|
||||
@@ -2844,7 +2856,7 @@ class sint(_secret, _int):
|
||||
privateoutput(self.size, player, res._v, self)
|
||||
return res
|
||||
|
||||
def private_division(self, divisor, active=True, dividend_length=None,
|
||||
def private_division(self, divisor, active=None, dividend_length=None,
|
||||
divisor_length=None):
|
||||
""" Private integer division as per `Veugen and Abspoel
|
||||
<https://doi.org/10.2478/popets-2021-0073>`_
|
||||
@@ -2878,6 +2890,9 @@ class sint(_secret, _int):
|
||||
z_shared = ((self << (l + sigma)) + h + r_pprime)
|
||||
z = z_shared.reveal_to(0)
|
||||
|
||||
if active is None:
|
||||
active = program.active
|
||||
|
||||
if active:
|
||||
z_prime = [sint(x) for x in (z // d).bit_decompose(min_length)]
|
||||
check = [(x * (1 - x)).reveal() == 0 for x in z_prime]
|
||||
@@ -2893,6 +2908,7 @@ class sint(_secret, _int):
|
||||
y_prime = sint.bit_compose(z_prime[:l + sigma])
|
||||
y = sint.bit_compose(z_prime[l + sigma:])
|
||||
else:
|
||||
program.semi_honest()
|
||||
y = sint(z // (d << (l + sigma)))
|
||||
y_prime = sint((z // d) % (2 ** (l + sigma)))
|
||||
|
||||
@@ -3147,7 +3163,9 @@ class sgf2n(_secret, _gf2n):
|
||||
for i in range(0, bit_length, step)]
|
||||
|
||||
one = cgf2n(1)
|
||||
masked = sum([b * (one << (i * step)) for i,b in enumerate(random_bits)], self).reveal()
|
||||
masked = sum([b * (one << (i * step))
|
||||
for i,b in enumerate(random_bits)], self).reveal(
|
||||
check=False)
|
||||
masked_bits = masked.bit_decompose(bit_length,step=step)
|
||||
return [m + r for m,r in zip(masked_bits, random_bits)]
|
||||
|
||||
@@ -3157,7 +3175,9 @@ class sgf2n(_secret, _gf2n):
|
||||
for i in range(8)]
|
||||
one = cgf2n(1)
|
||||
wanted_positions = [0, 5, 10, 15, 20, 25, 30, 35]
|
||||
masked = sum([b * (one << wanted_positions[i]) for i,b in enumerate(random_bits)], self).reveal()
|
||||
masked = sum([b * (one << wanted_positions[i])
|
||||
for i,b in enumerate(random_bits)], self).reveal(
|
||||
check=False)
|
||||
return [self.clear_type((masked >> wanted_positions[i]) & one) + r for i,r in enumerate(random_bits)]
|
||||
|
||||
for t in (sint, sgf2n):
|
||||
@@ -4080,7 +4100,8 @@ class _single(_number, _secret_structure):
|
||||
@vectorized_classmethod
|
||||
def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType):
|
||||
"""
|
||||
Securely obtain shares of values input by a client. Assumes client
|
||||
Securely obtain shares of values input by a client via
|
||||
:py:func:`sint.receive_from_client`. Assumes client
|
||||
has already converted values to integer representation.
|
||||
|
||||
:param n: number of inputs (int)
|
||||
@@ -4095,7 +4116,7 @@ class _single(_number, _secret_structure):
|
||||
|
||||
@classmethod
|
||||
def reveal_to_clients(cls, clients, values):
|
||||
""" Reveal securely to clients.
|
||||
""" Reveal securely to clients via :py:func:`sint.reveal_to_clients`.
|
||||
|
||||
:param clients: client ids (list or array)
|
||||
:param values: list of values of this class
|
||||
@@ -4556,7 +4577,7 @@ class sfix(_fix):
|
||||
:py:class:`sfix`, and comparisons (``==, !=, <, <=, >, >=``),
|
||||
returning :py:class:`sbitint`. The other operand can be any of
|
||||
sfix/sint/cfix/regint/cint/int/float. It also supports ``abs()``
|
||||
and ``**``, the latter for integer exponents.
|
||||
and ``**``.
|
||||
|
||||
Note that the default precision (16 bits after the dot, 31 bits in
|
||||
total) only allows numbers up to :math:`2^{31-16-1} \\approx
|
||||
@@ -4669,6 +4690,8 @@ class sfix(_fix):
|
||||
return self.v
|
||||
|
||||
def mul_no_reduce(self, other, res_params=None):
|
||||
if not isinstance(other, type(self)):
|
||||
return self * other
|
||||
assert self.f == other.f
|
||||
assert self.k == other.k
|
||||
return self.unreduced(self.v * other.v)
|
||||
@@ -4734,6 +4757,11 @@ class unreduced_sfix(_single):
|
||||
nearest=sfix.round_nearest, signed=True)
|
||||
return sfix._new(v, k=self.k - self.m, f=self.m)
|
||||
|
||||
def update(self, other):
|
||||
assert self.k == other.k
|
||||
assert self.m == other.m
|
||||
self.v.update(other.v)
|
||||
|
||||
sfix.unreduced_type = unreduced_sfix
|
||||
|
||||
sfix.set_precision(16, 31)
|
||||
@@ -4953,6 +4981,8 @@ class sfloat(_number, _secret_structure):
|
||||
|
||||
This uses integer operations internally, see :py:class:`sint` for security
|
||||
considerations.
|
||||
See `Aliasgari et al. <https://eprint.iacr.org/2012/405.pdf>`_ for
|
||||
details.
|
||||
|
||||
The type supports basic arithmetic (``+, -, *, /``), returning
|
||||
:py:class:`sfloat`, and comparisons (``==, !=, <, <=, >, >=``),
|
||||
@@ -5459,6 +5489,9 @@ class Array(_vectorizable):
|
||||
b.input_from(1)
|
||||
a[:] += b[:]
|
||||
|
||||
Arrays aren't initialized on creation, you need to call
|
||||
:py:func:`assign_all` to initialize them to a constant value.
|
||||
|
||||
"""
|
||||
check_indices = True
|
||||
|
||||
@@ -5708,7 +5741,7 @@ class Array(_vectorizable):
|
||||
mem_value = MemValue(value)
|
||||
self.address = MemValue.if_necessary(self.address)
|
||||
n_threads = 8 if use_threads and util.is_constant(self.length) and \
|
||||
len(self) > 2**20 else None
|
||||
len(self) > 2**20 and not program.options.garbled else None
|
||||
@library.multithread(n_threads, self.length)
|
||||
def _(base, size):
|
||||
if use_vector:
|
||||
@@ -5896,7 +5929,7 @@ class Array(_vectorizable):
|
||||
self.assign_vector(self.get_vector().secure_shuffle())
|
||||
|
||||
def secure_permute(self, *args, **kwargs):
|
||||
""" Secure permutate in place according to the security model.
|
||||
""" Secure permute in place according to the security model.
|
||||
See :py:func:`MultiArray.secure_shuffle` for references.
|
||||
|
||||
:param permutation: output of :py:func:`sint.get_secure_shuffle()`
|
||||
@@ -6227,7 +6260,10 @@ class SubMultiArray(_vectorizable):
|
||||
|
||||
def same_shape(self):
|
||||
""" :return: new multidimensional array with same shape and basic type """
|
||||
return MultiArray(self.sizes, self.value_type)
|
||||
if len(self.sizes) == 2:
|
||||
return Matrix(*self.sizes, self.value_type)
|
||||
else:
|
||||
return MultiArray(self.sizes, self.value_type)
|
||||
|
||||
def get_part(self, start, size):
|
||||
""" Part multi-array.
|
||||
@@ -6400,7 +6436,7 @@ class SubMultiArray(_vectorizable):
|
||||
pass
|
||||
t.params = res_params
|
||||
else:
|
||||
if issubclass(self.value_type, _secret_structure):
|
||||
if self.value_type == other.value_type:
|
||||
t = self.value_type
|
||||
else:
|
||||
t = type(self.value_type(0) * other.value_type(0))
|
||||
@@ -6435,10 +6471,12 @@ class SubMultiArray(_vectorizable):
|
||||
# fallback for binary circuits
|
||||
@library.for_range_opt(other.sizes[1])
|
||||
def _(j):
|
||||
res_matrix[i][j] = 0
|
||||
@library.for_range_opt(self.sizes[1])
|
||||
tmp = self[i][0].mul_no_reduce(other[0][j])
|
||||
@library.for_range_opt(1, self.sizes[1])
|
||||
def _(k):
|
||||
res_matrix[i][j] += self[i][k] * other[k][j]
|
||||
prod = self[i][k].mul_no_reduce(other[k][j])
|
||||
tmp.iadd(prod)
|
||||
res_matrix[i][j] = tmp.reduce_after_mul()
|
||||
return res_matrix
|
||||
elif isinstance(other, self.value_type):
|
||||
return self * Array.create_from(other)
|
||||
@@ -6780,6 +6818,9 @@ class MultiArray(SubMultiArray):
|
||||
a[1].input_from(1)
|
||||
a[2][:] = a[0][:] * a[1][:]
|
||||
|
||||
Arrays aren't initialized on creation, you need to call
|
||||
:py:func:`assign_all` to initialize them to a constant value.
|
||||
|
||||
"""
|
||||
@staticmethod
|
||||
def disable_index_checks():
|
||||
@@ -6817,6 +6858,9 @@ class Matrix(MultiArray):
|
||||
:param columns: compile-time (int)
|
||||
:param value_type: basic type of entries
|
||||
|
||||
Matrices aren't initialized on creation, you need to call
|
||||
:py:func:`assign_all` to initialize them to a constant value.
|
||||
|
||||
"""
|
||||
def __init__(self, rows, columns, value_type, debug=None, address=None):
|
||||
MultiArray.__init__(self, [rows, columns], value_type, debug=debug, \
|
||||
|
||||
@@ -47,23 +47,16 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libboost-dev \
|
||||
libboost-thread-dev \
|
||||
libclang-dev \
|
||||
libgmp-dev \
|
||||
libntl-dev \
|
||||
libsodium-dev \
|
||||
libssl-dev \
|
||||
libtool \
|
||||
m4 \
|
||||
texinfo \
|
||||
yasm \
|
||||
vim \
|
||||
gdb \
|
||||
valgrind \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# mpir
|
||||
COPY --from=initc3/mpir:55fe6a9 /usr/local/mpir/include/* /usr/local/include/
|
||||
COPY --from=initc3/mpir:55fe6a9 /usr/local/mpir/lib/* /usr/local/lib/
|
||||
COPY --from=initc3/mpir:55fe6a9 /usr/local/mpir/share/info/* /usr/local/share/info/
|
||||
|
||||
ENV MP_SPDZ_HOME /usr/src/MP-SPDZ
|
||||
WORKDIR $MP_SPDZ_HOME
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@ void Client::send_private_inputs(const vector<T>& values)
|
||||
octetStream os;
|
||||
vector< vector<T> > triples(num_inputs, vector<T>(3));
|
||||
vector<T> triple_shares(3);
|
||||
bool active = true;
|
||||
|
||||
// Receive num_inputs triples from SPDZ
|
||||
for (size_t j = 0; j < sockets.size(); j++)
|
||||
@@ -61,9 +62,21 @@ void Client::send_private_inputs(const vector<T>& values)
|
||||
cerr << "received " << os.get_length() << " from " << j << endl << flush;
|
||||
#endif
|
||||
|
||||
if (j == 0)
|
||||
{
|
||||
if (os.get_length() == 3 * values.size() * T::size())
|
||||
active = true;
|
||||
else
|
||||
active = false;
|
||||
}
|
||||
|
||||
int n_expected = active ? 3 : 1;
|
||||
if (os.get_length() != n_expected * T::size() * values.size())
|
||||
throw runtime_error("unexpected data length in sending");
|
||||
|
||||
for (int j = 0; j < num_inputs; j++)
|
||||
{
|
||||
for (int k = 0; k < 3; k++)
|
||||
for (int k = 0; k < n_expected; k++)
|
||||
{
|
||||
triple_shares[k].unpack(os);
|
||||
triples[j][k] += triple_shares[k];
|
||||
@@ -71,16 +84,18 @@ void Client::send_private_inputs(const vector<T>& values)
|
||||
}
|
||||
}
|
||||
|
||||
// Check triple relations (is a party cheating?)
|
||||
for (int i = 0; i < num_inputs; i++)
|
||||
{
|
||||
if (T(triples[i][0] * triples[i][1]) != triples[i][2])
|
||||
if (active)
|
||||
// Check triple relations (is a party cheating?)
|
||||
for (int i = 0; i < num_inputs; i++)
|
||||
{
|
||||
cerr << triples[i][2] << " != " << triples[i][0] << " * " << triples[i][1] << endl;
|
||||
cerr << "Incorrect triple at " << i << ", aborting\n";
|
||||
throw mac_fail();
|
||||
if (T(triples[i][0] * triples[i][1]) != triples[i][2])
|
||||
{
|
||||
cerr << triples[i][2] << " != " << triples[i][0] << " * " << triples[i][1] << endl;
|
||||
cerr << "Incorrect triple at " << i << ", aborting\n";
|
||||
throw mac_fail();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send inputs + triple[0], so SPDZ can compute shares of each value
|
||||
os.reset_write_head();
|
||||
for (int i = 0; i < num_inputs; i++)
|
||||
@@ -100,6 +115,7 @@ vector<U> Client::receive_outputs(int n)
|
||||
{
|
||||
vector<T> triples(3 * n);
|
||||
octetStream os;
|
||||
bool active = true;
|
||||
for (auto& socket : sockets)
|
||||
{
|
||||
os.reset_write_head();
|
||||
@@ -107,7 +123,20 @@ vector<U> Client::receive_outputs(int n)
|
||||
#ifdef VERBOSE_COMM
|
||||
cout << "received " << os.get_length() << endl << flush;
|
||||
#endif
|
||||
for (int j = 0; j < 3 * n; j++)
|
||||
|
||||
if (socket == sockets[0])
|
||||
{
|
||||
if (os.get_length() == (size_t) 3 * n * T::size())
|
||||
active = true;
|
||||
else
|
||||
active = false;
|
||||
}
|
||||
|
||||
int n_expected = n * (active ? 3 : 1);
|
||||
if (os.get_length() != (size_t) n_expected * T::size())
|
||||
throw runtime_error("unexpected data length in receiving");
|
||||
|
||||
for (int j = 0; j < n_expected; j++)
|
||||
{
|
||||
T value;
|
||||
value.unpack(os);
|
||||
@@ -115,16 +144,24 @@ vector<U> Client::receive_outputs(int n)
|
||||
}
|
||||
}
|
||||
|
||||
vector<U> output_values;
|
||||
for (int i = 0; i < 3 * n; i += 3)
|
||||
if (active)
|
||||
{
|
||||
if (T(triples[i] * triples[i + 1]) != triples[i + 2])
|
||||
vector<U> output_values;
|
||||
for (int i = 0; i < 3 * n; i += 3)
|
||||
{
|
||||
cerr << "Unable to authenticate output value as correct, aborting." << endl;
|
||||
throw mac_fail();
|
||||
if (T(triples[i] * triples[i + 1]) != triples[i + 2])
|
||||
{
|
||||
cerr << "Unable to authenticate output value as correct, aborting." << endl;
|
||||
throw mac_fail();
|
||||
}
|
||||
output_values.push_back(triples[i]);
|
||||
}
|
||||
output_values.push_back(triples[i]);
|
||||
}
|
||||
|
||||
return output_values;
|
||||
return output_values;
|
||||
}
|
||||
else
|
||||
{
|
||||
triples.resize(n);
|
||||
return triples;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,17 +34,25 @@ class Client:
|
||||
os = octetStream()
|
||||
for socket in self.sockets:
|
||||
os.Receive(socket)
|
||||
if socket == self.sockets[0]:
|
||||
active = os.get_length() == 3 * n * T.size()
|
||||
n_expected = 3 if active else 1
|
||||
if os.get_length() != n_expected * T.size() * n:
|
||||
import sys
|
||||
print (os.get_length(), n_expected, T.size(), n, active, file=sys.stderr)
|
||||
raise Exception('unexpected data length')
|
||||
for triple in triples:
|
||||
for i in range(3):
|
||||
for i in range(n_expected):
|
||||
t = T()
|
||||
t.unpack(os)
|
||||
triple[i] += t
|
||||
res = []
|
||||
for triple in triples:
|
||||
prod = triple[0] * triple[1]
|
||||
if prod != triple[2]:
|
||||
raise Exception(
|
||||
'invalid triple, diff %s' % hex(prod.v - triple[2].v))
|
||||
if active:
|
||||
for triple in triples:
|
||||
prod = triple[0] * triple[1]
|
||||
if prod != triple[2]:
|
||||
raise Exception(
|
||||
'invalid triple, diff %s' % hex(prod.v - triple[2].v))
|
||||
return triples
|
||||
|
||||
def send_private_inputs(self, values):
|
||||
@@ -68,6 +76,9 @@ class octetStream:
|
||||
if value is not None:
|
||||
self.buf += value
|
||||
|
||||
def get_length(self):
|
||||
return len(self.buf)
|
||||
|
||||
def reset_write_head(self):
|
||||
self.buf = b''
|
||||
self.ptr = 0
|
||||
|
||||
@@ -27,6 +27,10 @@ class Domain:
|
||||
def __neq__(self, other):
|
||||
return self.v != other.v
|
||||
|
||||
@classmethod
|
||||
def size(cls):
|
||||
return cls.n_bytes
|
||||
|
||||
def unpack(self, os):
|
||||
self.v = 0
|
||||
buf = os.consume(self.n_bytes)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#include "Ciphertext.h"
|
||||
#include "PPData.h"
|
||||
#include "P2Data.h"
|
||||
#include "Tools/Exceptions.h"
|
||||
|
||||
@@ -143,6 +142,5 @@ void Ciphertext::rerandomize(const FHE_PK& pk)
|
||||
|
||||
|
||||
template void mul(Ciphertext& ans,const Plaintext<gfp,FFT_Data,bigint>& a,const Ciphertext& c);
|
||||
template void mul(Ciphertext& ans,const Plaintext<gfp,PPData,bigint>& a,const Ciphertext& c);
|
||||
template void mul(Ciphertext& ans, const Plaintext<gf2n_short, P2Data, int>& a,
|
||||
const Ciphertext& c);
|
||||
|
||||
@@ -259,6 +259,3 @@ void BFFT(vector<modp>& ans,const vector<modp>& a,const FFT_Data& FFTD,bool forw
|
||||
else
|
||||
{ throw crash_requested(); }
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -83,6 +83,8 @@ void FFT_Data::init(const Ring& Rg,const Zp_Data& PrD)
|
||||
for (int r=0; r<2; r++)
|
||||
{ FFT_Iter(b[r],twop,two_root[0],PrD); }
|
||||
}
|
||||
else
|
||||
throw bad_value();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
#include "FHE_Keys.h"
|
||||
#include "Ciphertext.h"
|
||||
#include "P2Data.h"
|
||||
#include "PPData.h"
|
||||
#include "FFT_Data.h"
|
||||
|
||||
#include "Math/modp.hpp"
|
||||
@@ -406,29 +405,17 @@ bigint FHE_SK::get_noise(const Ciphertext& c)
|
||||
}
|
||||
|
||||
|
||||
#define X(FD) \
|
||||
template void FHE_PK::encrypt(Ciphertext&, const Plaintext_<FD>& mess, \
|
||||
const Random_Coins& rc) const; \
|
||||
template Ciphertext FHE_PK::encrypt(const Plaintext_<FD>& mess) const; \
|
||||
template Plaintext_<FD> FHE_SK::decrypt(const Ciphertext& c, \
|
||||
const FD& FieldD); \
|
||||
template void FHE_SK::decrypt(Plaintext_<FD>& res, \
|
||||
const Ciphertext& c) const; \
|
||||
template void FHE_SK::decrypt_any(Plaintext_<FD>& res, \
|
||||
const Ciphertext& c); \
|
||||
template void FHE_SK::check(const FHE_PK& pk, const FD&);
|
||||
|
||||
template void FHE_PK::encrypt(Ciphertext&, const Plaintext_<FFT_Data>& mess,
|
||||
const Random_Coins& rc) const;
|
||||
template void FHE_PK::encrypt(Ciphertext&, const Plaintext_<P2Data>& mess,
|
||||
const Random_Coins& rc) const;
|
||||
|
||||
template Ciphertext FHE_PK::encrypt(const Plaintext_<FFT_Data>& mess,
|
||||
const Random_Coins& rc) const;
|
||||
template Ciphertext FHE_PK::encrypt(const Plaintext_<FFT_Data>& mess) const;
|
||||
template Ciphertext FHE_PK::encrypt(const Plaintext_<P2Data>& mess) const;
|
||||
|
||||
template void FHE_SK::decrypt(Plaintext_<FFT_Data>&, const Ciphertext& c) const;
|
||||
template void FHE_SK::decrypt(Plaintext_<P2Data>&, const Ciphertext& c) const;
|
||||
|
||||
template Plaintext_<FFT_Data> FHE_SK::decrypt(const Ciphertext& c,
|
||||
const FFT_Data& FieldD);
|
||||
template Plaintext_<P2Data> FHE_SK::decrypt(const Ciphertext& c,
|
||||
const P2Data& FieldD);
|
||||
|
||||
template void FHE_SK::decrypt_any(Plaintext_<FFT_Data>& res,
|
||||
const Ciphertext& c);
|
||||
template void FHE_SK::decrypt_any(Plaintext_<P2Data>& res,
|
||||
const Ciphertext& c);
|
||||
|
||||
template void FHE_SK::check(const FHE_PK& pk, const FFT_Data&);
|
||||
template void FHE_SK::check(const FHE_PK& pk, const P2Data&);
|
||||
X(FFT_Data)
|
||||
X(P2Data)
|
||||
|
||||
@@ -119,12 +119,6 @@ const P2Data& FHE_Params::get_plaintext_field_data() const
|
||||
throw not_implemented();
|
||||
}
|
||||
|
||||
template<>
|
||||
const PPData& FHE_Params::get_plaintext_field_data() const
|
||||
{
|
||||
throw not_implemented();
|
||||
}
|
||||
|
||||
bigint FHE_Params::get_plaintext_modulus() const
|
||||
{
|
||||
return fd.get_prime();
|
||||
|
||||
@@ -248,49 +248,6 @@ matrix inv(const matrix& A)
|
||||
}
|
||||
|
||||
|
||||
vector<modp> solve(modp_matrix& A,const Zp_Data& PrD)
|
||||
{
|
||||
unsigned int n=A.size();
|
||||
if ((n+1)!=A[0].size()) { throw invalid_params(); }
|
||||
|
||||
modp t,ti;
|
||||
for (unsigned int r=0; r<n; r++)
|
||||
{ // Find pivot
|
||||
unsigned int p=r;
|
||||
while (isZero(A[p][r],PrD)) { p++; }
|
||||
// Do pivoting
|
||||
if (p!=r)
|
||||
{ for (unsigned int c=0; c<n+1; c++)
|
||||
{ t=A[p][c]; A[p][c]=A[r][c]; A[r][c]=t; }
|
||||
}
|
||||
// Make Lcoeff=1
|
||||
Inv(ti,A[r][r],PrD);
|
||||
for (unsigned int c=0; c<n+1; c++)
|
||||
{ Mul(A[r][c],A[r][c],ti,PrD); }
|
||||
// Now kill off other entries in this column
|
||||
for (unsigned int rr=0; rr<n; rr++)
|
||||
{ if (r!=rr)
|
||||
{ for (unsigned int c=0; c<n+1; c++)
|
||||
{ Mul(t,A[rr][c],A[r][r],PrD);
|
||||
Sub(A[rr][c],A[rr][c],t,PrD);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Sanity check and extract answer
|
||||
vector<modp> ans;
|
||||
ans.resize(n);
|
||||
for (unsigned int i=0; i<n; i++)
|
||||
{ for (unsigned int j=0; j<n; j++)
|
||||
{ if (i!=j && !isZero(A[i][j],PrD)) { throw bad_value(); }
|
||||
else if (!isOne(A[i][j],PrD)) { throw bad_value(); }
|
||||
}
|
||||
ans[i]=A[i][n];
|
||||
}
|
||||
return ans;
|
||||
}
|
||||
|
||||
|
||||
|
||||
/* Input matrix is assumed to have more rows than columns */
|
||||
void pinv(imatrix& Ai,const imatrix& B)
|
||||
|
||||
@@ -10,7 +10,6 @@ using namespace std;
|
||||
#include "Tools/BitVector.h"
|
||||
|
||||
typedef vector< vector<bigint> > matrix;
|
||||
typedef vector< vector<modp> > modp_matrix;
|
||||
|
||||
class imatrix : public vector< BitVector >
|
||||
{
|
||||
@@ -39,13 +38,6 @@ void print(const imatrix& S);
|
||||
// requires column operations to create the inverse
|
||||
matrix inv(const matrix& A);
|
||||
|
||||
// Another special routine for modp matrices.
|
||||
// Solves
|
||||
// Ax=b
|
||||
// Assumes A is unimodular, square and only requires row operations to
|
||||
// create the inverse. In put is C=(A||b) and the routines alters A
|
||||
vector<modp> solve(modp_matrix& C,const Zp_Data& PrD);
|
||||
|
||||
// Finds a pseudo-inverse of a matrix A modulo 2
|
||||
// - Input matrix is assumed to have more rows than columns
|
||||
void pinv(imatrix& Ai,const imatrix& A);
|
||||
|
||||
132
FHE/NTL-Subs.cpp
132
FHE/NTL-Subs.cpp
@@ -742,135 +742,3 @@ void load_or_generate(P2Data& P2D, const Ring& R)
|
||||
P2D.store(R);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#ifdef USE_NTL
|
||||
/*
|
||||
* Create FHE parameters for a general plaintext modulus p
|
||||
* Basically this is for general large primes only
|
||||
*/
|
||||
void SPDZ_Data_Setup_Char_p_General(Ring& R, PPData& PPD, bigint& pr0,
|
||||
bigint& pr1, int n, int sec, bigint& p, FHE_Params& params)
|
||||
{
|
||||
cout << "Setting up parameters" << endl;
|
||||
|
||||
int lgp=numBits(p);
|
||||
int mm,idx;
|
||||
|
||||
// mm is the minimum value of m we will accept
|
||||
if (lgp<48)
|
||||
{ mm=100; // Test case
|
||||
idx=0;
|
||||
}
|
||||
else if (lgp <96)
|
||||
{ mm=8192;
|
||||
idx=1;
|
||||
}
|
||||
else if (lgp<192)
|
||||
{ mm=16384;
|
||||
idx=2;
|
||||
}
|
||||
else if (lgp<384)
|
||||
{ mm=16384;
|
||||
idx=3;
|
||||
}
|
||||
else if (lgp<768)
|
||||
{ mm=32768;
|
||||
idx=4;
|
||||
}
|
||||
else
|
||||
{ throw invalid_params(); }
|
||||
|
||||
// Now find the small factors of p-1 and their exponents
|
||||
bigint t=p-1;
|
||||
vector<long> primes(100),exp(100);
|
||||
|
||||
PrimeSeq s;
|
||||
long pr;
|
||||
pr=s.next();
|
||||
int len=0;
|
||||
while (pr<2*mm)
|
||||
{ int e=0;
|
||||
while ((t%pr)==0)
|
||||
{ e++;
|
||||
t=t/pr;
|
||||
}
|
||||
if (e!=0)
|
||||
{ primes[len]=pr;
|
||||
exp[len]=e;
|
||||
if (len!=0) { cout << " * "; }
|
||||
cout << pr << "^" << e << flush;
|
||||
len++;
|
||||
}
|
||||
pr=s.next();
|
||||
}
|
||||
cout << endl;
|
||||
|
||||
// We want to find the best m which divides pr-1, such that
|
||||
// - 2*m > phi(m) > mm
|
||||
// - m has the smallest number of factors
|
||||
vector<int> ee;
|
||||
ee.resize(len);
|
||||
for (int i=0; i<len; i++) { ee[i]=0; }
|
||||
int min_hwt=-1,m=-1,bphi_m=-1,bmx=-1;
|
||||
bool flag=true;
|
||||
while (flag)
|
||||
{ int cand_m=1,hwt=0,mx=0;
|
||||
for (int i=0; i<len; i++)
|
||||
{ //cout << ee[i] << " ";
|
||||
if (ee[i]!=0)
|
||||
{ hwt++;
|
||||
for (int j=0; j<ee[i]; j++)
|
||||
{ cand_m*=primes[i]; }
|
||||
if (ee[i]>mx) { mx=ee[i]; }
|
||||
}
|
||||
}
|
||||
// Put "if" here to stop searching for things which will never work
|
||||
if (cand_m>1 && cand_m<4*mm)
|
||||
{ //cout << " : " << cand_m << " : " << hwt << flush;
|
||||
int phim=phi_N(cand_m);
|
||||
//cout << " : " << phim << " : " << mm << endl;
|
||||
if (phim>mm && phim<3*mm)
|
||||
{ if (m==-1 || hwt<min_hwt || (hwt==min_hwt && mx<bmx) || (hwt==min_hwt && mx==bmx && phim<bphi_m))
|
||||
{ m=cand_m;
|
||||
min_hwt=hwt;
|
||||
bphi_m=phim;
|
||||
bmx=mx;
|
||||
//cout << "\t OK" << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{ //cout << endl;
|
||||
}
|
||||
int i=0;
|
||||
ee[i]=ee[i]+1;
|
||||
while (ee[i]>exp[i] && flag)
|
||||
{ ee[i]=0;
|
||||
i++;
|
||||
if (i==len) { flag=false; i=0; }
|
||||
else { ee[i]=ee[i]+1; }
|
||||
}
|
||||
}
|
||||
if (m==-1)
|
||||
{ throw bad_value(); }
|
||||
cout << "Chosen value of m=" << m << "\t\t phi(m)=" << bphi_m << " : " << min_hwt << " : " << bmx << endl;
|
||||
|
||||
Parameters parameters(n, lgp, sec);
|
||||
parameters.SPDZ_Data_Setup_Char_p_Sub(idx,m,p,params);
|
||||
int mx=0;
|
||||
for (int i=0; i<R.phi_m(); i++)
|
||||
{ if (mx<R.Phi()[i]) { mx=R.Phi()[i]; } }
|
||||
cout << "Max Coeff = " << mx << endl;
|
||||
|
||||
init(R, m, true);
|
||||
|
||||
Zp_Data Zp(p);
|
||||
PPD.init(R,Zp);
|
||||
gfp::init_field(p);
|
||||
|
||||
pr0 = parameters.pr0;
|
||||
pr1 = parameters.pr1;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
#include "FHE/Ring.h"
|
||||
#include "FHE/FFT_Data.h"
|
||||
#include "FHE/P2Data.h"
|
||||
#include "FHE/PPData.h"
|
||||
#include "FHE/FHE_Params.h"
|
||||
|
||||
/* Routines to set up key sizes given the number of players n
|
||||
@@ -68,10 +67,6 @@ class GF2X;
|
||||
|
||||
NTL::GF2X get_F(const Ring& Rg);
|
||||
|
||||
// For use when we want p to be a specific value
|
||||
void SPDZ_Data_Setup_Char_p_General(Ring& R, PPData& PPD, bigint& pr0,
|
||||
bigint& pr1, int n, int sec, bigint& p, FHE_Params& params);
|
||||
|
||||
// generate moduli according to lengths and other parameters
|
||||
void generate_moduli(bigint& pr0, bigint& pr1, const int m,
|
||||
const bigint p, const int lg2p0, const int lg2p1);
|
||||
|
||||
@@ -1,99 +0,0 @@
|
||||
#include "FHE/Subroutines.h"
|
||||
#include "FHE/PPData.h"
|
||||
#include "FHE/FFT.h"
|
||||
#include "FHE/Matrix.h"
|
||||
|
||||
#include "Math/modp.hpp"
|
||||
|
||||
|
||||
|
||||
void PPData::init(const Ring& Rg,const Zp_Data& PrD)
|
||||
{
|
||||
R=Rg;
|
||||
prData=PrD;
|
||||
|
||||
root=Find_Primitive_Root_m(Rg.m(),Rg.Phi(),PrD);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void PPData::to_eval(vector<modp>& elem) const
|
||||
{
|
||||
if (elem.size()!= (unsigned) R.phi_m())
|
||||
{ throw params_mismatch(); }
|
||||
|
||||
throw not_implemented();
|
||||
|
||||
/*
|
||||
vector<modp> ans;
|
||||
ans.resize(R.phi_m());
|
||||
modp x=root;
|
||||
for (int i=0; i<R.phi_m(); i++)
|
||||
{ ans[i]=elem[R.phi_m()-1];
|
||||
for (int j=1; j<R.phi_m(); j++)
|
||||
{ Mul(ans[i],ans[i],x,prData);
|
||||
Add(ans[i],ans[i],elem[R.phi_m()-j-1],prData);
|
||||
}
|
||||
Mul(x,x,root,prData);
|
||||
}
|
||||
elem=ans;
|
||||
*/
|
||||
}
|
||||
|
||||
void PPData::from_eval(vector<modp>&) const
|
||||
{
|
||||
// avoid warning
|
||||
throw not_implemented();
|
||||
|
||||
/*
|
||||
modp_matrix A;
|
||||
int n=phi_m();
|
||||
A.resize(n, vector<modp>(n+1) );
|
||||
modp x=root;
|
||||
for (int i=0; i<n; i++)
|
||||
{ assignOne(A[0][i],prData);
|
||||
for (int j=1; j<n; j++)
|
||||
{ Mul(A[j][i],A[j-1][i],x,prData); }
|
||||
Mul(x,x,root,prData);
|
||||
A[i][n]=elem[i];
|
||||
}
|
||||
elem=solve(A,prData);
|
||||
*/
|
||||
|
||||
}
|
||||
|
||||
|
||||
void PPData::reset_iteration()
|
||||
{
|
||||
pow = 1;
|
||||
theta = {root, prData};
|
||||
thetaPow = theta;
|
||||
}
|
||||
|
||||
void PPData::next_iteration()
|
||||
{
|
||||
do
|
||||
{ thetaPow *= (theta);
|
||||
pow++;
|
||||
}
|
||||
while (gcd(pow,m())!=1);
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
gfp PPData::get_evaluation(const vector<bigint>& mess) const
|
||||
{
|
||||
// Uses Horner's rule
|
||||
gfp ans;
|
||||
ans = mess[mess.size()-1];
|
||||
gfp coeff;
|
||||
for (int j=mess.size()-2; j>=0; j--)
|
||||
{ ans *= (thetaPow);
|
||||
coeff = mess[j];
|
||||
ans += (coeff);
|
||||
}
|
||||
return ans;
|
||||
}
|
||||
|
||||
|
||||
61
FHE/PPData.h
61
FHE/PPData.h
@@ -1,61 +0,0 @@
|
||||
#ifndef _PPData
|
||||
#define _PPData
|
||||
|
||||
#include "Math/modp.h"
|
||||
#include "Math/Zp_Data.h"
|
||||
#include "Math/gfpvar.h"
|
||||
#include "Math/fixint.h"
|
||||
#include "FHE/Ring.h"
|
||||
#include "FHE/FFT_Data.h"
|
||||
|
||||
/* Class for holding modular arithmetic data wrt the ring
|
||||
*
|
||||
* It also holds the ring
|
||||
*/
|
||||
|
||||
class PPData
|
||||
{
|
||||
public:
|
||||
typedef gfp T;
|
||||
typedef bigint S;
|
||||
typedef typename FFT_Data::poly_type poly_type;
|
||||
|
||||
Ring R;
|
||||
Zp_Data prData;
|
||||
|
||||
modp root; // m'th Root of Unity mod pr
|
||||
|
||||
void init(const Ring& Rg,const Zp_Data& PrD);
|
||||
|
||||
PPData() { ; }
|
||||
PPData(const Ring& Rg,const Zp_Data& PrD)
|
||||
{ init(Rg,PrD); }
|
||||
|
||||
const Zp_Data& get_prD() const { return prData; }
|
||||
const bigint& get_prime() const { return prData.pr; }
|
||||
int phi_m() const { return R.phi_m(); }
|
||||
int m() const { return R.m(); }
|
||||
int num_slots() const { return R.phi_m(); }
|
||||
|
||||
|
||||
int p(int i) const { return R.p(i); }
|
||||
int p_inv(int i) const { return R.p_inv(i); }
|
||||
const vector<int>& Phi() const { return R.Phi(); }
|
||||
|
||||
// Convert input vector from poly to evaluation representation
|
||||
// - Uses naive method and not FFT, we only use this rarely in any case
|
||||
void to_eval(vector<modp>& elem) const;
|
||||
void from_eval(vector<modp>& elem) const;
|
||||
|
||||
// Following are used to iteratively get slots, as we use PPData when
|
||||
// we do not have an efficient FFT algorithm
|
||||
gfp thetaPow,theta;
|
||||
int pow;
|
||||
void reset_iteration();
|
||||
void next_iteration();
|
||||
gfp get_evaluation(const vector<bigint>& mess) const;
|
||||
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
|
||||
#include "FHE/Plaintext.h"
|
||||
#include "FHE/Ring_Element.h"
|
||||
#include "FHE/PPData.h"
|
||||
#include "FHE/P2Data.h"
|
||||
#include "FHE/Rq_Element.h"
|
||||
#include "FHE_Keys.h"
|
||||
@@ -85,39 +84,6 @@ void Plaintext<gfp,FFT_Data,bigint>::to_poly() const
|
||||
}
|
||||
|
||||
|
||||
template<>
|
||||
void Plaintext<gfp,PPData,bigint>::from_poly() const
|
||||
{
|
||||
if (type!=Polynomial) { return; }
|
||||
vector<modp> aa((*Field_Data).phi_m());
|
||||
for (unsigned int i=0; i<aa.size(); i++)
|
||||
{ to_modp(aa[i], bigint::tmp = b[i], (*Field_Data).prData); }
|
||||
(*Field_Data).to_eval(aa);
|
||||
a.resize(num_slots());
|
||||
for (unsigned int i=0; i<aa.size(); i++)
|
||||
a[i] = {aa[i], Field_Data->get_prD()};
|
||||
type=Both;
|
||||
}
|
||||
|
||||
|
||||
template<>
|
||||
void Plaintext<gfp,PPData,bigint>::to_poly() const
|
||||
{
|
||||
if (type!=Evaluation) { return; }
|
||||
cout << "This is VERY inefficient to convert a plaintext to poly representation" << endl;
|
||||
vector<modp> bb((*Field_Data).phi_m());
|
||||
for (unsigned int i=0; i<bb.size(); i++)
|
||||
{ bb[i]=a[i].get(); }
|
||||
(*Field_Data).from_eval(bb);
|
||||
for (unsigned int i=0; i<bb.size(); i++)
|
||||
{
|
||||
to_bigint(bigint::tmp,bb[i],(*Field_Data).prData);
|
||||
b[i] = bigint::tmp;
|
||||
}
|
||||
type=Both;
|
||||
}
|
||||
|
||||
|
||||
|
||||
template<>
|
||||
void Plaintext<gf2n_short,P2Data,int>::from_poly() const
|
||||
@@ -385,34 +351,6 @@ void add(Plaintext<gfp,FFT_Data,bigint>& z,const Plaintext<gfp,FFT_Data,bigint>&
|
||||
}
|
||||
|
||||
|
||||
template<>
|
||||
void add(Plaintext<gfp,PPData,bigint>& z,const Plaintext<gfp,PPData,bigint>& x,
|
||||
const Plaintext<gfp,PPData,bigint>& y)
|
||||
{
|
||||
if (z.Field_Data!=x.Field_Data) { throw field_mismatch(); }
|
||||
if (z.Field_Data!=y.Field_Data) { throw field_mismatch(); }
|
||||
|
||||
if (x.type==Both && y.type!=Both) { z.type=y.type; }
|
||||
else if (y.type==Both && x.type!=Both) { z.type=x.type; }
|
||||
else if (x.type!=y.type) { throw rep_mismatch(); }
|
||||
else { z.type=x.type; }
|
||||
|
||||
if (z.type!=Polynomial)
|
||||
{
|
||||
z.a.resize(z.num_slots());
|
||||
for (unsigned int i=0; i<z.a.size(); i++)
|
||||
{ z.a[i] = (x.a[i] + y.a[i]); }
|
||||
}
|
||||
if (z.type!=Evaluation)
|
||||
{ for (unsigned int i=0; i<z.b.size(); i++)
|
||||
{ z.b[i]=x.b[i]+y.b[i];
|
||||
if (z.b[i]>(*z.Field_Data).get_prime())
|
||||
{ z.b[i]-=(*z.Field_Data).get_prime(); }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
template<>
|
||||
@@ -475,36 +413,6 @@ void sub(Plaintext<gfp,FFT_Data,bigint>& z,const Plaintext<gfp,FFT_Data,bigint>&
|
||||
|
||||
|
||||
|
||||
template<>
|
||||
void sub(Plaintext<gfp,PPData,bigint>& z,const Plaintext<gfp,PPData,bigint>& x,
|
||||
const Plaintext<gfp,PPData,bigint>& y)
|
||||
{
|
||||
if (z.Field_Data!=x.Field_Data) { throw field_mismatch(); }
|
||||
if (z.Field_Data!=y.Field_Data) { throw field_mismatch(); }
|
||||
|
||||
if (x.type==Both && y.type!=Both) { z.type=y.type; }
|
||||
else if (y.type==Both && x.type!=Both) { z.type=x.type; }
|
||||
else if (x.type!=y.type) { throw rep_mismatch(); }
|
||||
else { z.type=x.type; }
|
||||
|
||||
z.allocate();
|
||||
if (z.type!=Polynomial)
|
||||
{
|
||||
z.a.resize(z.num_slots());
|
||||
for (unsigned int i=0; i<z.a.size(); i++)
|
||||
{ z.a[i] = (x.a[i] - y.a[i]); }
|
||||
}
|
||||
if (z.type!=Evaluation)
|
||||
{ for (unsigned int i=0; i<z.b.size(); i++)
|
||||
{ z.b[i]=x.b[i]-y.b[i];
|
||||
if (z.b[i]<0)
|
||||
{ z.b[i]+=(*z.Field_Data).get_prime(); }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
template<>
|
||||
@@ -572,23 +480,6 @@ void Plaintext<gfp,FFT_Data,bigint>::negate()
|
||||
}
|
||||
}
|
||||
|
||||
template<>
|
||||
void Plaintext<gfp,PPData,bigint>::negate()
|
||||
{
|
||||
if (type!=Polynomial)
|
||||
{
|
||||
a.resize(num_slots());
|
||||
for (unsigned int i=0; i<a.size(); i++)
|
||||
{ a[i].negate(); }
|
||||
}
|
||||
if (type!=Evaluation)
|
||||
{ for (unsigned int i=0; i<b.size(); i++)
|
||||
{ if (b[i]!=0)
|
||||
{ b[i]=(*Field_Data).get_prime()-b[i]; }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
template<>
|
||||
@@ -731,12 +622,6 @@ template void mul(Plaintext<gfp,FFT_Data,bigint>& z,const Plaintext<gfp,FFT_Data
|
||||
|
||||
|
||||
|
||||
template class Plaintext<gfp,PPData,bigint>;
|
||||
|
||||
template void mul(Plaintext<gfp,PPData,bigint>& z,const Plaintext<gfp,PPData,bigint>& x,const Plaintext<gfp,PPData,bigint>& y);
|
||||
|
||||
|
||||
|
||||
template class Plaintext<gf2n_short,P2Data,int>;
|
||||
|
||||
template void mul(Plaintext<gf2n_short,P2Data,int>& z,const Plaintext<gf2n_short,P2Data,int>& x,const Plaintext<gf2n_short,P2Data,int>& y);
|
||||
|
||||
@@ -274,7 +274,7 @@ void EncCommit<T,FD,S>::Create_More() const
|
||||
(*P).Broadcast_Receive(ctx_Delta);
|
||||
|
||||
// Output the ctx_Delta to a file
|
||||
sprintf(filename,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,i,thread);
|
||||
snprintf(filename,1024,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,i,thread);
|
||||
ofstream outf(filename);
|
||||
for (int j=0; j<(*P).num_players(); j++)
|
||||
{
|
||||
@@ -308,7 +308,7 @@ void EncCommit<T,FD,S>::Create_More() const
|
||||
octetStream occ,ctx_D;
|
||||
for (int i=0; i<2*TT; i++)
|
||||
{ if (open[i]==1)
|
||||
{ sprintf(filename,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,i,thread);
|
||||
{ snprintf(filename,1024,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,i,thread);
|
||||
ifstream inpf(filename);
|
||||
for (int j=0; j<(*P).num_players(); j++)
|
||||
{
|
||||
@@ -386,7 +386,7 @@ void EncCommit<T,FD,S>::Create_More() const
|
||||
Ciphertext enc1(params),enc2(params),eDelta(params);
|
||||
octetStream oe1,oe2;
|
||||
|
||||
sprintf(filename,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,index[b*i+j],thread);
|
||||
snprintf(filename,1024,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,index[b*i+j],thread);
|
||||
ifstream inpf(filename);
|
||||
for (int k=0; k<(*P).num_players(); k++)
|
||||
{
|
||||
|
||||
@@ -26,6 +26,7 @@ public:
|
||||
void set_protocol(DealerSecret::Protocol& protocol)
|
||||
{
|
||||
P = &protocol.P;
|
||||
BufferPrep<DealerSecret>::P = P;
|
||||
}
|
||||
|
||||
void buffer_triples()
|
||||
|
||||
@@ -183,6 +183,8 @@ public:
|
||||
NoShare operator-(const NoShare&) const { fail(); return {}; }
|
||||
NoShare operator*(const NoValue&) const { fail(); return {}; }
|
||||
|
||||
NoShare operator^(const NoShare&) const { fail(); return {}; }
|
||||
|
||||
NoShare operator&(int) const { fail(); return {}; }
|
||||
NoShare operator>>(int) const { fail(); return {}; }
|
||||
|
||||
|
||||
@@ -123,7 +123,9 @@ BreakType Program::execute(Processor<T>& Proc, U& dynamic_memory,
|
||||
}
|
||||
time++;
|
||||
#ifdef DEBUG_COMPLEXITY
|
||||
cout << "complexity at " << time << ": " << Proc.complexity << endl;
|
||||
cout << T::part_type::name() << " complexity at " << time << ": " <<
|
||||
Proc.complexity << " after " << hex <<
|
||||
instruction.get_opcode() << dec << endl;
|
||||
#endif
|
||||
}
|
||||
while (Proc.complexity < (size_t) OnlineOptions::singleton.batch_size);
|
||||
|
||||
@@ -39,6 +39,7 @@ void RepPrep<T>::set_protocol(typename T::Protocol& protocol)
|
||||
return;
|
||||
|
||||
this->protocol = new ReplicatedBase(protocol.P);
|
||||
this->P = &protocol.P;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
|
||||
@@ -89,7 +89,7 @@ void Secret<T>::random(int n_bits, int128 share)
|
||||
{
|
||||
(void)share;
|
||||
if (n_bits > 128)
|
||||
throw not_implemented();
|
||||
throw runtime_error("too many bits");
|
||||
resize_regs(n_bits);
|
||||
for (int i = 0; i < n_bits; i++)
|
||||
get_reg(i).random();
|
||||
|
||||
@@ -37,6 +37,7 @@ void SemiPrep::set_protocol(SemiSecret::Protocol& protocol)
|
||||
protocol.P.N, -1, OnlineOptions::singleton.batch_size,
|
||||
1, params, {}, &protocol.P);
|
||||
triple_generator->multi_threaded = false;
|
||||
this->P = &protocol.P;
|
||||
}
|
||||
|
||||
void SemiPrep::buffer_triples()
|
||||
|
||||
@@ -103,9 +103,7 @@ void ThreadMaster<T>::run()
|
||||
|
||||
machine.print_timers();
|
||||
|
||||
cerr << "Data sent = " << stats.sent * 1e-6 << " MB" << endl;
|
||||
|
||||
machine.print_global_comm(*P, stats);
|
||||
machine.print_comm(*P, stats);
|
||||
|
||||
delete P;
|
||||
}
|
||||
|
||||
@@ -105,6 +105,11 @@ public:
|
||||
*this = a + b;
|
||||
}
|
||||
|
||||
This operator^(const This& other) const
|
||||
{
|
||||
return *this + other;
|
||||
}
|
||||
|
||||
This& operator^=(const This& other)
|
||||
{
|
||||
*this += other;
|
||||
|
||||
@@ -146,6 +146,7 @@
|
||||
X(THRESHOLD, I0 = T::threshold(Thread<T>::s().P->num_players())) \
|
||||
X(PLAYERID, I0 = Thread<T>::s().P->my_num()) \
|
||||
X(CRASH, if (I0.get()) throw crash_requested()) \
|
||||
X(ACTIVE, ) \
|
||||
|
||||
#define INSTRUCTIONS BIT_INSTRUCTIONS GC_INSTRUCTIONS
|
||||
|
||||
|
||||
@@ -365,11 +365,11 @@ void OTMachine::run()
|
||||
{
|
||||
BitVector receiver_output, sender_output;
|
||||
char filename[1024];
|
||||
sprintf(filename, RECEIVER_INPUT, my_num);
|
||||
snprintf(filename, 1024, RECEIVER_INPUT, my_num);
|
||||
ofstream outf(filename);
|
||||
receiverInput.output(outf, false);
|
||||
outf.close();
|
||||
sprintf(filename, RECEIVER_OUTPUT, my_num);
|
||||
snprintf(filename, 1024, RECEIVER_OUTPUT, my_num);
|
||||
outf.open(filename);
|
||||
for (unsigned int i = 0; i < nOTs; i++)
|
||||
{
|
||||
@@ -380,7 +380,7 @@ void OTMachine::run()
|
||||
|
||||
for (int i = 0; i < 2; i++)
|
||||
{
|
||||
sprintf(filename, SENDER_OUTPUT, my_num, i);
|
||||
snprintf(filename,1024, SENDER_OUTPUT, my_num, i);
|
||||
outf.open(filename);
|
||||
for (int j = 0; j < nOTs; j++)
|
||||
{
|
||||
|
||||
37
Makefile
37
Makefile
@@ -116,7 +116,7 @@ mascot: mascot-party.x spdz2k mama-party.x
|
||||
ifeq ($(OS), Darwin)
|
||||
setup: mac-setup
|
||||
else
|
||||
setup: boost mpir linux-machine-setup
|
||||
setup: boost linux-machine-setup
|
||||
endif
|
||||
|
||||
tldr: setup
|
||||
@@ -297,27 +297,6 @@ deps/SimplestOT_C/ref10/Makefile:
|
||||
Programs/Circuits:
|
||||
git submodule update --init Programs/Circuits
|
||||
|
||||
.PHONY: mpir-setup mpir-global
|
||||
mpir-setup: deps/mpir/Makefile
|
||||
deps/mpir/Makefile:
|
||||
git submodule update --init deps/mpir || git clone https://github.com/wbhart/mpir deps/mpir
|
||||
cd deps/mpir; \
|
||||
autoreconf -i; \
|
||||
autoreconf -i
|
||||
- $(MAKE) -C deps/mpir clean
|
||||
|
||||
mpir-global: mpir-setup
|
||||
cd deps/mpir; \
|
||||
./configure --enable-cxx;
|
||||
$(MAKE) -C deps/mpir
|
||||
sudo $(MAKE) -C deps/mpir install
|
||||
|
||||
mpir: local/lib/libmpirxx.so
|
||||
local/lib/libmpirxx.so: deps/mpir/Makefile
|
||||
cd deps/mpir; \
|
||||
./configure --enable-cxx --prefix=$(CURDIR)/local
|
||||
$(MAKE) -C deps/mpir install
|
||||
|
||||
deps/libOTe/libOTe:
|
||||
git submodule update --init --recursive deps/libOTe || git clone --recurse-submodules https://github.com/mkskeller/softspoken-implementation deps/libOTe
|
||||
boost: deps/libOTe/libOTe
|
||||
@@ -369,26 +348,16 @@ cmake:
|
||||
./bootstrap --parallel=8 --prefix=../local && make && make install
|
||||
|
||||
mac-setup: mac-machine-setup
|
||||
brew install openssl boost libsodium mpir yasm ntl cmake
|
||||
-echo MY_CFLAGS += -I/usr/local/opt/openssl/include -I`brew --prefix`/opt/openssl/include -I`brew --prefix`/include >> CONFIG.mine
|
||||
-echo MY_LDLIBS += -L/usr/local/opt/openssl/lib -L`brew --prefix`/lib -L`brew --prefix`/opt/openssl/lib >> CONFIG.mine
|
||||
# -echo USE_NTL = 1 >> CONFIG.mine
|
||||
brew install openssl boost libsodium gmp yasm ntl cmake
|
||||
|
||||
ifeq ($(ARM), 1)
|
||||
mac-machine-setup:
|
||||
-echo ARCH = >> CONFIG.mine
|
||||
linux-machine-setup:
|
||||
-echo ARCH = -march=armv8.2-a+crypto >> CONFIG.mine
|
||||
else
|
||||
mac-machine-setup:
|
||||
linux-machine-setup:
|
||||
endif
|
||||
|
||||
deps/simde/simde:
|
||||
git submodule update --init deps/simde || git clone https://github.com/simd-everywhere/simde deps/simde
|
||||
|
||||
clean-deps:
|
||||
-rm -rf local deps/libOTe/out
|
||||
-rm -rf local/lib/liblibOTe.* deps/libOTe/out
|
||||
|
||||
clean: clean-deps
|
||||
-rm -f */*.o *.o */*.d *.d *.x core.* *.a gmon.out */*/*.o static/*.x *.so
|
||||
|
||||
@@ -17,6 +17,10 @@ using namespace std;
|
||||
#include "ValueInterface.h"
|
||||
#include "gf2nlong.h"
|
||||
|
||||
// Fix false warning
|
||||
#if __GNUC__ == 10
|
||||
#pragma GCC diagnostic ignored "-Wstringop-overflow"
|
||||
#endif
|
||||
|
||||
// Functionality shared between integers and bit vectors
|
||||
template<class T>
|
||||
@@ -39,6 +43,8 @@ public:
|
||||
|
||||
static bool allows(Dtype type) { return type <= DATA_BIT; }
|
||||
|
||||
static void check_setup(const string&) {}
|
||||
|
||||
IntBase() { a = 0; }
|
||||
IntBase(T a) : a(a) {}
|
||||
|
||||
|
||||
@@ -160,13 +160,13 @@ void check_setup(string dir, bigint pr)
|
||||
}
|
||||
|
||||
string get_prep_sub_dir(const string& prep_dir, int nparties, int log2mod,
|
||||
const string& type_short)
|
||||
const string& type_short, bool create)
|
||||
{
|
||||
string res = prep_dir + "/" + to_string(nparties) + "-" + type_short;
|
||||
if (log2mod > 1)
|
||||
res += "-" + to_string(log2mod);
|
||||
res += "/";
|
||||
if (mkdir_p(res.c_str()) < 0)
|
||||
if (create and mkdir_p(res.c_str()) < 0)
|
||||
throw file_error("cannot create " + res);
|
||||
return res;
|
||||
}
|
||||
|
||||
16
Math/Setup.h
16
Math/Setup.h
@@ -38,26 +38,28 @@ bigint generate_prime(int lgp, int m);
|
||||
int default_m(int& lgp, int& idx);
|
||||
|
||||
string get_prep_sub_dir(const string& prep_dir, int nparties, int log2mod,
|
||||
const string& type_short);
|
||||
const string& type_short, bool create = false);
|
||||
|
||||
template<class T>
|
||||
string get_prep_sub_dir(const string& prep_dir, int nparties, int log2mod)
|
||||
string get_prep_sub_dir(const string& prep_dir, int nparties, int log2mod,
|
||||
bool create = false)
|
||||
{
|
||||
if (T::clear::length() > 1)
|
||||
log2mod = T::clear::length();
|
||||
return get_prep_sub_dir(prep_dir, nparties, log2mod, T::type_short());
|
||||
return get_prep_sub_dir(prep_dir, nparties, log2mod, T::type_short(), create);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
string get_prep_sub_dir(const string& prep_dir, int nparties)
|
||||
string get_prep_sub_dir(const string& prep_dir, int nparties, bool create =
|
||||
false)
|
||||
{
|
||||
return get_prep_sub_dir<T>(prep_dir, nparties, T::clear::length());
|
||||
return get_prep_sub_dir<T>(prep_dir, nparties, T::clear::length(), create);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
string get_prep_sub_dir(int nparties)
|
||||
string get_prep_sub_dir(int nparties, bool create = false)
|
||||
{
|
||||
return get_prep_sub_dir<T>(PREP_DIR, nparties);
|
||||
return get_prep_sub_dir<T>(PREP_DIR, nparties, create);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
|
||||
18
Math/ValueInterface.cpp
Normal file
18
Math/ValueInterface.cpp
Normal file
@@ -0,0 +1,18 @@
|
||||
/*
|
||||
* ValueInterface.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "bigint.h"
|
||||
#include "ValueInterface.h"
|
||||
|
||||
#include <sys/stat.h>
|
||||
|
||||
void ValueInterface::check_setup(const string& directory)
|
||||
{
|
||||
struct stat sb;
|
||||
if (stat(directory.c_str(), &sb) != 0)
|
||||
throw runtime_error(directory + " does not exist");
|
||||
if (not (sb.st_mode & S_IFDIR))
|
||||
throw runtime_error(directory + " is not a directory");
|
||||
}
|
||||
@@ -7,6 +7,7 @@
|
||||
#define MATH_VALUEINTERFACE_H_
|
||||
|
||||
#include "Tools/Exceptions.h"
|
||||
#include "Math/Setup.h"
|
||||
|
||||
class OnlineOptions;
|
||||
class bigint;
|
||||
@@ -31,9 +32,10 @@ public:
|
||||
template<class T>
|
||||
static void generate_setup(string, int, int) {}
|
||||
template<class T>
|
||||
static void write_setup(int) {}
|
||||
static void write_setup(int nplayers) { get_prep_sub_dir<T>(nplayers, true); }
|
||||
static void write_setup(string) {}
|
||||
static void check_setup(string) {}
|
||||
static void check_setup(const string& directory);
|
||||
static const char* fake_opts() { return ""; }
|
||||
|
||||
static bigint pr() { throw runtime_error("no prime modulus"); }
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#ifndef MATH_Z2K_H_
|
||||
#define MATH_Z2K_H_
|
||||
|
||||
#include <mpirxx.h>
|
||||
#include <gmpxx.h>
|
||||
#include <string>
|
||||
using namespace std;
|
||||
|
||||
@@ -74,6 +74,8 @@ public:
|
||||
|
||||
static Z2 power_of_two(bool bit, int exp) { return Z2(bit) << exp; }
|
||||
|
||||
static string fake_opts() { return " -lgp " + to_string(K); }
|
||||
|
||||
typedef Z2 next;
|
||||
typedef Z2 Scalar;
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ void Zp_Data::init(const bigint& p,bool mont)
|
||||
mpn_copyi(R3,r3.get_mpz_t()->_mp_d,mpz_size(r3.get_mpz_t()));
|
||||
|
||||
if (sizeof(unsigned long)!=sizeof(mp_limb_t))
|
||||
{ cout << "The underlying types of MPIR mean we cannot use our Montgomery code" << endl;
|
||||
{ cout << "The underlying types of GMP mean we cannot use our Montgomery code" << endl;
|
||||
throw not_implemented();
|
||||
}
|
||||
}
|
||||
@@ -194,3 +194,37 @@ bool Zp_Data::operator==(const Zp_Data& other) const
|
||||
{
|
||||
return not (*this != other);
|
||||
}
|
||||
|
||||
void Zp_Data::get_shanks_parameters(bigint& y, bigint& q_half, int& r) const
|
||||
{
|
||||
if (shanks_y == 0)
|
||||
{
|
||||
auto& p = pr;
|
||||
bigint n, q, yy, xx, temp;
|
||||
// Find n such that (n/p)=-1
|
||||
int leg = 1;
|
||||
gmp_randclass Gen(gmp_randinit_default);
|
||||
Gen.seed(0);
|
||||
while (leg != -1)
|
||||
{
|
||||
n = Gen.get_z_range(p);
|
||||
leg = mpz_legendre(n.get_mpz_t(), p.get_mpz_t());
|
||||
}
|
||||
// Split p-1 = 2^e q
|
||||
q = p - 1;
|
||||
int e = 0;
|
||||
while (mpz_even_p(q.get_mpz_t()))
|
||||
{
|
||||
e++;
|
||||
q = q / 2;
|
||||
}
|
||||
// y=n^q mod p, x=a^((q-1)/2) mod p, r=e
|
||||
shanks_r = e;
|
||||
mpz_powm(shanks_y.get_mpz_t(), n.get_mpz_t(), q.get_mpz_t(), p.get_mpz_t());
|
||||
shanks_q_half = (q - 1) / 2;
|
||||
}
|
||||
|
||||
y = shanks_y;
|
||||
q_half = shanks_q_half;
|
||||
r = shanks_r;
|
||||
}
|
||||
|
||||
@@ -38,6 +38,8 @@ class Zp_Data
|
||||
int t; // More Montgomery data
|
||||
mp_limb_t overhang;
|
||||
Lock lock;
|
||||
mutable bigint shanks_y, shanks_q_half;
|
||||
mutable int shanks_r;
|
||||
|
||||
template <int T>
|
||||
void Mont_Mult_(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const;
|
||||
@@ -89,6 +91,8 @@ class Zp_Data
|
||||
bool operator!=(const Zp_Data& other) const;
|
||||
bool operator==(const Zp_Data& other) const;
|
||||
|
||||
void get_shanks_parameters(bigint& y, bigint& q_half, int& r) const;
|
||||
|
||||
template<int L> friend void to_modp(modp_<L>& ans,int x,const Zp_Data& ZpD);
|
||||
template<int L> friend void to_modp(modp_<L>& ans,const mpz_class& x,const Zp_Data& ZpD);
|
||||
|
||||
|
||||
@@ -10,76 +10,10 @@
|
||||
|
||||
#include "bigint.hpp"
|
||||
|
||||
class gmp_random
|
||||
{
|
||||
public:
|
||||
gmp_randclass Gen;
|
||||
gmp_random() : Gen(gmp_randinit_default)
|
||||
{
|
||||
Gen.seed(0);
|
||||
}
|
||||
};
|
||||
|
||||
thread_local bigint bigint::tmp = 0;
|
||||
thread_local bigint bigint::tmp2 = 0;
|
||||
thread_local gmp_random bigint::random;
|
||||
|
||||
bigint sqrRootMod(const bigint& a,const bigint& p)
|
||||
{
|
||||
bigint ans;
|
||||
if (a==0) { ans=0; return ans; }
|
||||
if (mpz_legendre(a.get_mpz_t(), p.get_mpz_t()) != 1)
|
||||
throw runtime_error("cannot compute square root of non-square");
|
||||
if (mpz_tstbit(p.get_mpz_t(),1)==1)
|
||||
{ // First do case with p=3 mod 4
|
||||
bigint exp=(p+1)/4;
|
||||
mpz_powm(ans.get_mpz_t(),a.get_mpz_t(),exp.get_mpz_t(),p.get_mpz_t());
|
||||
}
|
||||
else
|
||||
{ // Shanks algorithm
|
||||
bigint x,y,n,q,t,b,temp;
|
||||
// Find n such that (n/p)=-1
|
||||
int leg=1;
|
||||
while (leg!=-1)
|
||||
{ n=bigint::random.Gen.get_z_range(p);
|
||||
leg=mpz_legendre(n.get_mpz_t(),p.get_mpz_t());
|
||||
}
|
||||
// Split p-1 = 2^e q
|
||||
q=p-1;
|
||||
int e=0;
|
||||
while (mpz_even_p(q.get_mpz_t()))
|
||||
{ e++; q=q/2; }
|
||||
// y=n^q mod p, x=a^((q-1)/2) mod p, r=e
|
||||
int r=e;
|
||||
mpz_powm(y.get_mpz_t(),n.get_mpz_t(),q.get_mpz_t(),p.get_mpz_t());
|
||||
temp=(q-1)/2;
|
||||
mpz_powm(x.get_mpz_t(),a.get_mpz_t(),temp.get_mpz_t(),p.get_mpz_t());
|
||||
// b=a*x^2 mod p, x=a*x mod p
|
||||
b=(a*x*x)%p;
|
||||
x=(a*x)%p;
|
||||
// While b!=1 do
|
||||
while (b!=1)
|
||||
{ // Find smallest m such that b^(2^m)=1 mod p
|
||||
int m=1;
|
||||
temp=(b*b)%p;
|
||||
while (temp!=1)
|
||||
{ temp=(temp*temp)%p; m++; }
|
||||
// t=y^(2^(r-m-1)) mod p, y=t^2, r=m
|
||||
t=y;
|
||||
for (int i=0; i<r-m-1; i++)
|
||||
{ t=(t*t)%p; }
|
||||
y=(t*t)%p;
|
||||
r=m;
|
||||
// x=x*t mod p, b=b*y mod p
|
||||
x=(x*t)%p;
|
||||
b=(b*y)%p;
|
||||
}
|
||||
ans=x;
|
||||
}
|
||||
return ans;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bigint powerMod(const bigint& x,const bigint& e,const bigint& p)
|
||||
{
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
using namespace std;
|
||||
|
||||
#include <stddef.h>
|
||||
#include <mpirxx.h>
|
||||
#include <gmpxx.h>
|
||||
|
||||
#include "Tools/Exceptions.h"
|
||||
#include "Tools/int.h"
|
||||
@@ -39,7 +39,7 @@ namespace GC
|
||||
|
||||
/**
|
||||
* Type for arbitrarily large integers.
|
||||
* This is a sub-class of ``mpz_class`` from MPIR. As such, it implements
|
||||
* This is a sub-class of ``mpz_class`` from GMP. As such, it implements
|
||||
* all integers operations and input/output via C++ streams. In addition,
|
||||
* the ``get_ui()`` member function allows retrieving the least significant
|
||||
* 64 bits.
|
||||
@@ -139,8 +139,6 @@ public:
|
||||
void inline_mpn_zero(mp_limb_t* x, mp_size_t size);
|
||||
void inline_mpn_copyi(mp_limb_t* dest, const mp_limb_t* src, mp_size_t size);
|
||||
|
||||
#include "Z2k.h"
|
||||
|
||||
|
||||
inline bigint& bigint::operator=(int n)
|
||||
{
|
||||
@@ -281,11 +279,7 @@ inline int numBytes(const bigint& m)
|
||||
|
||||
inline int probPrime(const bigint& x)
|
||||
{
|
||||
gmp_randstate_t rand_state;
|
||||
gmp_randinit_default(rand_state);
|
||||
int ans = mpz_probable_prime_p(x.get_mpz_t(), rand_state,
|
||||
max(40, DEFAULT_SECURITY), 0);
|
||||
gmp_randclear(rand_state);
|
||||
int ans = mpz_probab_prime_p(x.get_mpz_t(), max(40, DEFAULT_SECURITY) / 2);
|
||||
return ans;
|
||||
}
|
||||
|
||||
@@ -318,7 +312,8 @@ inline int isOdd(const bigint& x)
|
||||
}
|
||||
|
||||
|
||||
bigint sqrRootMod(const bigint& x,const bigint& p);
|
||||
template<class T>
|
||||
bigint sqrRootMod(const T& x);
|
||||
|
||||
bigint powerMod(const bigint& x,const bigint& e,const bigint& p);
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ bigint& bigint::from_signed(const T& other)
|
||||
template<class T>
|
||||
mpf_class bigint::get_float(T v, T p, T z, T s)
|
||||
{
|
||||
// MPIR can't handle more precision in exponent
|
||||
// GMP can't handle more precision in exponent
|
||||
Integer exp = Integer(p, 31).get();
|
||||
bigint tmp;
|
||||
tmp.from_signed(v);
|
||||
@@ -59,4 +59,76 @@ void bigint::output_float(U& o, const mpf_class& x, T nan)
|
||||
o << "NaN";
|
||||
}
|
||||
|
||||
|
||||
class gmp_random
|
||||
{
|
||||
public:
|
||||
gmp_randclass Gen;
|
||||
gmp_random() : Gen(gmp_randinit_default)
|
||||
{
|
||||
Gen.seed(0);
|
||||
}
|
||||
};
|
||||
|
||||
template<class T>
|
||||
bigint sqrRootMod(const T& aa)
|
||||
{
|
||||
bigint a = aa;
|
||||
bigint p = T::pr();
|
||||
|
||||
bigint ans;
|
||||
if (a == 0)
|
||||
{
|
||||
ans = 0;
|
||||
return ans;
|
||||
}
|
||||
if (mpz_legendre(a.get_mpz_t(), p.get_mpz_t()) != 1)
|
||||
throw runtime_error("cannot compute square root of non-square");
|
||||
if (mpz_tstbit(p.get_mpz_t(), 1) == 1)
|
||||
{
|
||||
// First do case with p=3 mod 4
|
||||
bigint exp = (p + 1) / 4;
|
||||
mpz_powm(ans.get_mpz_t(), a.get_mpz_t(), exp.get_mpz_t(),
|
||||
p.get_mpz_t());
|
||||
}
|
||||
else
|
||||
{
|
||||
// Shanks algorithm
|
||||
bigint n, q, yy, xx, temp;
|
||||
int r;
|
||||
T::get_ZpD().get_shanks_parameters(yy, temp, r);
|
||||
mpz_powm(xx.get_mpz_t(), a.get_mpz_t(), temp.get_mpz_t(), p.get_mpz_t());
|
||||
// b=a*x^2 mod p, x=a*x mod p
|
||||
T x = xx;
|
||||
T b = (aa * x * x);
|
||||
x = (aa * x);
|
||||
T y = yy;
|
||||
// While b!=1 do
|
||||
while (b != 1)
|
||||
{
|
||||
// Find smallest m such that b^(2^m)=1 mod p
|
||||
int m = 1;
|
||||
T temp = (b * b);
|
||||
while (temp != 1)
|
||||
{
|
||||
temp = (temp * temp);
|
||||
m++;
|
||||
}
|
||||
// t=y^(2^(r-m-1)) mod p, y=t^2, r=m
|
||||
T t = y;
|
||||
for (int i = 0; i < r - m - 1; i++)
|
||||
{
|
||||
t = (t * t);
|
||||
}
|
||||
y = (t * t);
|
||||
r = m;
|
||||
// x=x*t mod p, b=b*y mod p
|
||||
x = (x * t);
|
||||
b = (b * y);
|
||||
}
|
||||
ans = x;
|
||||
}
|
||||
return ans;
|
||||
}
|
||||
|
||||
#endif /* MATH_BIGINT_HPP_ */
|
||||
|
||||
@@ -17,6 +17,7 @@ class gf2n_short;
|
||||
class P2Data;
|
||||
class Bit;
|
||||
class int128;
|
||||
template<class T> class IntBase;
|
||||
template<class T> class Square;
|
||||
typedef Square<gf2n_short> gf2n_short_square;
|
||||
|
||||
@@ -88,6 +89,8 @@ protected:
|
||||
|
||||
static string options();
|
||||
|
||||
static string fake_opts() { return " -lg2 " + to_string(length()); }
|
||||
|
||||
static const true_type invertible;
|
||||
static const true_type characteristic_two;
|
||||
|
||||
|
||||
@@ -154,6 +154,8 @@ class gf2n_long : public gf2n_<int128>
|
||||
gf2n_long(int g) : gf2n_long(int128(unsigned(g))) {}
|
||||
template<class T>
|
||||
gf2n_long(IntBase<T> g) : super(g.get()) {}
|
||||
template<class T>
|
||||
gf2n_long(const gf2n_<T>& a) : super(int128(a.get())) {}
|
||||
};
|
||||
|
||||
#if defined(__aarch64__) && defined(__clang__)
|
||||
|
||||
@@ -105,6 +105,7 @@ class gfp_ : public ValueInterface
|
||||
static void write_setup(string dir)
|
||||
{ write_online_setup(dir, pr()); }
|
||||
static void check_setup(string dir);
|
||||
static string fake_opts() { return " -lgp " + to_string(length()); }
|
||||
|
||||
/**
|
||||
* Get the prime modulus
|
||||
@@ -314,6 +315,8 @@ gfp_<X, L>::gfp_(long x)
|
||||
{
|
||||
if (x == 0)
|
||||
assign_zero();
|
||||
else if (x == 1)
|
||||
assign_one();
|
||||
else
|
||||
*this = bigint::tmp = x;
|
||||
}
|
||||
|
||||
@@ -146,8 +146,7 @@ gfp_<X, L> gfp_<X, L>::sqrRoot()
|
||||
{
|
||||
// Temp move to bigint so as to call sqrRootMod
|
||||
bigint ti;
|
||||
to_bigint(ti, *this);
|
||||
ti = sqrRootMod(ti, ZpD.pr);
|
||||
ti = sqrRootMod(*this);
|
||||
if (!isOdd(ti))
|
||||
ti = ZpD.pr - ti;
|
||||
gfp_<X, L> temp;
|
||||
|
||||
@@ -312,8 +312,8 @@ gfpvar_<X, L> gfpvar_<X, L>::invert() const
|
||||
template<int X, int L>
|
||||
gfpvar_<X, L> gfpvar_<X, L>::sqrRoot() const
|
||||
{
|
||||
bigint ti = *this;
|
||||
ti = sqrRootMod(ti, ZpD.pr);
|
||||
bigint ti;
|
||||
ti = sqrRootMod(*this);
|
||||
if (!isOdd(ti))
|
||||
ti = ZpD.pr - ti;
|
||||
return ti;
|
||||
|
||||
@@ -81,6 +81,7 @@ public:
|
||||
{
|
||||
write_setup(get_prep_sub_dir<T>(nplayers));
|
||||
}
|
||||
static string fake_opts() { return " -lgp " + to_string(length()); }
|
||||
|
||||
gfpvar_();
|
||||
gfpvar_(int other);
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
#define _Modp
|
||||
|
||||
/*
|
||||
* Currently we only support an MPIR based implementation.
|
||||
* Currently we only support an GMP based implementation.
|
||||
*
|
||||
* What ever is type-def'd to bigint is assumed to have
|
||||
* operator overloading for all standard operators, has
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#ifndef MATH_MPN_FIXED_H_
|
||||
#define MATH_MPN_FIXED_H_
|
||||
|
||||
#include <mpir.h>
|
||||
#include <gmp.h>
|
||||
#include <string.h>
|
||||
#include <assert.h>
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
*
|
||||
*/
|
||||
|
||||
#include <mpirxx.h>
|
||||
#include <gmpxx.h>
|
||||
|
||||
#include "OT/BitMatrix.h"
|
||||
#include "Tools/random.h"
|
||||
|
||||
@@ -78,7 +78,7 @@ void Names::init(int player, int pnb, const string& filename, int nplayers_wante
|
||||
}
|
||||
}
|
||||
if (nplayers_wanted > 0 and nplayers_wanted != nplayers)
|
||||
throw runtime_error("not enought hosts in HOSTS");
|
||||
throw runtime_error("not enough hosts in " + filename);
|
||||
#ifdef DEBUG_NETWORKING
|
||||
cerr << "Got list of " << nplayers << " players from file: " << endl;
|
||||
for (unsigned int i = 0; i < names.size(); i++)
|
||||
@@ -324,7 +324,9 @@ void PlainPlayer::setup_sockets(const vector<string>& names,
|
||||
template<class T>
|
||||
void MultiPlayer<T>::send_long(int i, long a) const
|
||||
{
|
||||
TimeScope ts(comm_stats["Sending by number"].add(sizeof(long)));
|
||||
send(sockets[i], (octet*)&a, sizeof(long));
|
||||
sent += sizeof(long);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -716,7 +718,7 @@ size_t VirtualTwoPartyPlayer::send(const PlayerBuffer& buffer, bool block) const
|
||||
{
|
||||
auto sent = P.send_no_stats(other_player, buffer, block);
|
||||
lock.lock();
|
||||
comm_stats["Sending one-to-one"].add(sent);
|
||||
comm_stats.add_to_last_round("Sending one-to-one", sent);
|
||||
comm_stats.sent += sent;
|
||||
lock.unlock();
|
||||
return sent;
|
||||
@@ -726,7 +728,7 @@ size_t VirtualTwoPartyPlayer::recv(const PlayerBuffer& buffer, bool block) const
|
||||
{
|
||||
auto received = P.recv_no_stats(other_player, buffer, block);
|
||||
lock.lock();
|
||||
comm_stats["Receiving one-to-one"].add(received);
|
||||
comm_stats.add_to_last_round("Receiving one-to-one", received);
|
||||
lock.unlock();
|
||||
return received;
|
||||
}
|
||||
@@ -805,6 +807,17 @@ void NamedCommStats::reset()
|
||||
sent = 0;
|
||||
}
|
||||
|
||||
Timer& NamedCommStats::add_to_last_round(const string& name, size_t length)
|
||||
{
|
||||
if (name == last)
|
||||
return (*this)[name].add_length_only(length);
|
||||
else
|
||||
{
|
||||
last = name;
|
||||
return (*this)[name].add(length);
|
||||
}
|
||||
}
|
||||
|
||||
void PlayerBase::reset_stats()
|
||||
{
|
||||
comm_stats.reset();
|
||||
|
||||
@@ -136,11 +136,15 @@ struct CommStats
|
||||
CommStats() : data(0), rounds(0) {}
|
||||
Timer& add(size_t length)
|
||||
{
|
||||
rounds++;
|
||||
return add_length_only(length);
|
||||
}
|
||||
Timer& add_length_only(size_t length)
|
||||
{
|
||||
#ifdef VERBOSE_COMM
|
||||
cout << "add " << length << endl;
|
||||
#endif
|
||||
data += length;
|
||||
rounds++;
|
||||
return timer;
|
||||
}
|
||||
Timer& add(const octetStream& os) { return add(os.get_length()); }
|
||||
@@ -153,6 +157,7 @@ class NamedCommStats : public map<string, CommStats>
|
||||
{
|
||||
public:
|
||||
size_t sent;
|
||||
string last;
|
||||
|
||||
NamedCommStats();
|
||||
|
||||
@@ -161,6 +166,7 @@ public:
|
||||
NamedCommStats operator-(const NamedCommStats& other) const;
|
||||
void print(bool newline = false);
|
||||
void reset();
|
||||
Timer& add_to_last_round(const string& name, size_t length);
|
||||
#ifdef VERBOSE_COMM
|
||||
CommStats& operator[](const string& name)
|
||||
{
|
||||
|
||||
@@ -134,7 +134,7 @@ void close_client_socket(int socket)
|
||||
if (close(socket))
|
||||
{
|
||||
char tmp[1000];
|
||||
sprintf(tmp, "close(%d)", socket);
|
||||
snprintf(tmp, 1000, "close(%d)", socket);
|
||||
error(tmp);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -126,7 +126,7 @@ void BaseMachine::time()
|
||||
void BaseMachine::start(int n)
|
||||
{
|
||||
cout << "Starting timer " << n << " at " << timer[n].elapsed()
|
||||
<< " (" << timer[n].mb_sent() << " MB)"
|
||||
<< " (" << timer[n] << ")"
|
||||
<< " after " << timer[n].idle() << endl;
|
||||
timer[n].start(total_comm());
|
||||
}
|
||||
@@ -135,7 +135,7 @@ void BaseMachine::stop(int n)
|
||||
{
|
||||
timer[n].stop(total_comm());
|
||||
cout << "Stopped timer " << n << " at " << timer[n].elapsed() << " ("
|
||||
<< timer[n].mb_sent() << " MB)" << endl;
|
||||
<< timer[n] << ")" << endl;
|
||||
}
|
||||
|
||||
void BaseMachine::print_timers()
|
||||
@@ -150,7 +150,7 @@ void BaseMachine::print_timers()
|
||||
timer.erase(0);
|
||||
for (auto it = timer.begin(); it != timer.end(); it++)
|
||||
cerr << "Time" << it->first << " = " << it->second.elapsed() << " seconds ("
|
||||
<< it->second.mb_sent() << " MB)" << endl;
|
||||
<< it->second << ")" << endl;
|
||||
}
|
||||
|
||||
string BaseMachine::memory_filename(const string& type_short, int my_number)
|
||||
@@ -227,3 +227,19 @@ void BaseMachine::print_global_comm(Player& P, const NamedCommStats& stats)
|
||||
global += os.get_int(8);
|
||||
cerr << "Global data sent = " << global / 1e6 << " MB (all parties)" << endl;
|
||||
}
|
||||
|
||||
void BaseMachine::print_comm(Player& P, const NamedCommStats& comm_stats)
|
||||
{
|
||||
size_t rounds = 0;
|
||||
for (auto& x : comm_stats)
|
||||
rounds += x.second.rounds;
|
||||
cerr << "Data sent = " << comm_stats.sent / 1e6 << " MB in ~" << rounds
|
||||
<< " rounds (party " << P.my_num() << " only";
|
||||
if (nthreads > 1)
|
||||
cerr << "; rounds counted double due to multi-threading";
|
||||
if (not OnlineOptions::singleton.verbose)
|
||||
cerr << "; use '-v' for more details";
|
||||
cerr << ")" << endl;
|
||||
|
||||
print_global_comm(P, comm_stats);
|
||||
}
|
||||
|
||||
@@ -67,6 +67,7 @@ public:
|
||||
void print_timers();
|
||||
|
||||
virtual void reqbl(int) {}
|
||||
virtual void active(int) {}
|
||||
|
||||
static OTTripleSetup fresh_ot_setup(Player& P);
|
||||
|
||||
@@ -74,6 +75,7 @@ public:
|
||||
void set_thread_comm(const NamedCommStats& stats);
|
||||
|
||||
void print_global_comm(Player& P, const NamedCommStats& stats);
|
||||
void print_comm(Player& P, const NamedCommStats& stats);
|
||||
};
|
||||
|
||||
inline OTTripleSetup BaseMachine::fresh_ot_setup(Player& P)
|
||||
|
||||
41
Processor/Conv2dTuple.h
Normal file
41
Processor/Conv2dTuple.h
Normal file
@@ -0,0 +1,41 @@
|
||||
/*
|
||||
* Conv2dTuple.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef PROCESSOR_CONV2DTUPLE_H_
|
||||
#define PROCESSOR_CONV2DTUPLE_H_
|
||||
|
||||
#include <vector>
|
||||
using namespace std;
|
||||
|
||||
class Conv2dTuple
|
||||
{
|
||||
public:
|
||||
int output_h, output_w;
|
||||
int inputs_h, inputs_w;
|
||||
int weights_h, weights_w;
|
||||
int stride_h, stride_w;
|
||||
int n_channels_in;
|
||||
int padding_h;
|
||||
int padding_w;
|
||||
int batch_size;
|
||||
size_t r0;
|
||||
size_t r1;
|
||||
int r2;
|
||||
vector<vector<vector<int>>> lengths;
|
||||
int filter_stride_h = 1;
|
||||
int filter_stride_w = 1;
|
||||
|
||||
Conv2dTuple(const vector<int>& args, int start);
|
||||
|
||||
template<class T>
|
||||
void pre(vector<T>& S, typename T::Protocol& protocol);
|
||||
template<class T>
|
||||
void post(vector<T>& S, typename T::Protocol& protocol);
|
||||
|
||||
template<class T>
|
||||
void run_matrix(SubProcessor<T>& processor);
|
||||
};
|
||||
|
||||
#endif /* PROCESSOR_CONV2DTUPLE_H_ */
|
||||
@@ -222,7 +222,8 @@ bool DataPositions::any_more(const DataPositions& other) const
|
||||
for (auto it = edabits.begin(); it != edabits.end(); it++)
|
||||
{
|
||||
auto x = other.edabits.find(it->first);
|
||||
if (x == other.edabits.end() or it->second > x->second)
|
||||
if ((x == other.edabits.end() or it->second > x->second)
|
||||
and it->second > 0)
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
#include "Networking/Player.h"
|
||||
#include "Protocols/edabit.h"
|
||||
#include "PrepBase.h"
|
||||
#include "EdabitBuffer.h"
|
||||
#include "Tools/TimerWithComm.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
@@ -102,9 +104,6 @@ protected:
|
||||
|
||||
DataPositions& usage;
|
||||
|
||||
map<pair<bool, int>, vector<edabitvec<T>>> edabits;
|
||||
map<pair<bool, int>, edabitvec<T>> my_edabits;
|
||||
|
||||
bool do_count;
|
||||
|
||||
void count(Dtype dtype, int n = 1)
|
||||
@@ -120,6 +119,8 @@ protected:
|
||||
const vector<int>&, true_type)
|
||||
{ throw not_implemented(); }
|
||||
|
||||
void fill(edabitvec<T>& res, bool strict, int n_bits);
|
||||
|
||||
T get_random_from_inputs(int nplayers);
|
||||
|
||||
public:
|
||||
@@ -173,12 +174,11 @@ public:
|
||||
virtual void get_edabits(bool strict, size_t size, T* a,
|
||||
vector<typename T::bit_type>& Sb, const vector<int>& regs)
|
||||
{ get_edabits<0>(strict, size, a, Sb, regs, T::clear::characteristic_two); }
|
||||
template<int>
|
||||
void get_edabit_no_count(bool, int n_bits, edabit<T>& eb);
|
||||
template<int>
|
||||
virtual void get_edabit_no_count(bool, int, edabit<T>&)
|
||||
{ throw runtime_error("no edaBits"); }
|
||||
/// Get fresh edaBit chunk
|
||||
edabitvec<T> get_edabitvec(bool strict, int n_bits);
|
||||
virtual void buffer_edabits_with_queues(bool, int) { throw runtime_error("no edaBits"); }
|
||||
virtual edabitvec<T> get_edabitvec(bool, int)
|
||||
{ throw runtime_error("no edabitvec"); }
|
||||
|
||||
virtual void push_triples(const vector<array<T, 3>>&)
|
||||
{ throw runtime_error("no pushing"); }
|
||||
@@ -204,7 +204,8 @@ class Sub_Data_Files : public Preprocessing<T>
|
||||
BufferOwner<InputTuple<T>, RefInputTuple<T>> my_input_buffers;
|
||||
map<DataTag, BufferOwner<T, T> > extended;
|
||||
BufferOwner<dabit<T>, dabit<T>> dabit_buffer;
|
||||
map<int, ifstream*> edabit_buffers;
|
||||
map<int, EdabitBuffer<T>> edabit_buffers;
|
||||
map<int, edabitvec<T>> my_edabits;
|
||||
|
||||
int my_num,num_players;
|
||||
|
||||
@@ -213,13 +214,11 @@ class Sub_Data_Files : public Preprocessing<T>
|
||||
|
||||
part_type* part;
|
||||
|
||||
void buffer_edabits_with_queues(bool strict, int n_bits)
|
||||
{ buffer_edabits_with_queues<0>(strict, n_bits, T::clear::characteristic_two); }
|
||||
template<int>
|
||||
void buffer_edabits_with_queues(bool strict, int n_bits, false_type);
|
||||
template<int>
|
||||
void buffer_edabits_with_queues(bool, int, true_type)
|
||||
{ throw not_implemented(); }
|
||||
EdabitBuffer<T>& get_edabit_buffer(int n_bits);
|
||||
|
||||
/// Get fresh edaBit chunk
|
||||
edabitvec<T> get_edabitvec(bool strict, int n_bits);
|
||||
void get_edabit_no_count(bool strict, int n_bits, edabit<T>& eb);
|
||||
|
||||
public:
|
||||
static string get_filename(const Names& N, Dtype type, int thread_num = -1);
|
||||
@@ -317,6 +316,8 @@ class Data_Files
|
||||
void reset_usage() { usage.reset(); skipped.reset(); }
|
||||
|
||||
void set_usage(const DataPositions& pos) { usage = pos; }
|
||||
|
||||
TimerWithComm total_time();
|
||||
};
|
||||
|
||||
template<class T> inline
|
||||
|
||||
@@ -108,7 +108,21 @@ Sub_Data_Files<T>::Sub_Data_Files(int my_num, int num_players,
|
||||
#ifdef DEBUG_FILES
|
||||
cerr << "Setting up Data_Files in: " << prep_data_dir << endl;
|
||||
#endif
|
||||
T::clear::check_setup(prep_data_dir);
|
||||
|
||||
try
|
||||
{
|
||||
T::clear::check_setup(prep_data_dir);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
cerr << "Something is wrong with the preprocessing data on disk." << endl;
|
||||
cerr
|
||||
<< "Have you run the right program for generating it, such as './Fake-Offline.x "
|
||||
<< num_players
|
||||
<< T::clear::fake_opts() << "'?" << endl;
|
||||
throw;
|
||||
}
|
||||
|
||||
string type_short = T::type_short();
|
||||
string type_string = T::type_string();
|
||||
|
||||
@@ -135,7 +149,7 @@ Sub_Data_Files<T>::Sub_Data_Files(int my_num, int num_players,
|
||||
type_short, i, my_num, thread_num);
|
||||
if (i == my_num)
|
||||
my_input_buffers.setup(filename,
|
||||
T::size() + T::clear::size(), type_string);
|
||||
InputTuple<T>::size(), type_string);
|
||||
else
|
||||
input_buffers[i].setup(filename,
|
||||
T::size(), type_string);
|
||||
@@ -179,10 +193,6 @@ Data_Files<sint, sgf2n>::~Data_Files()
|
||||
template<class T>
|
||||
Sub_Data_Files<T>::~Sub_Data_Files()
|
||||
{
|
||||
for (auto& x: edabit_buffers)
|
||||
{
|
||||
delete x.second;
|
||||
}
|
||||
if (part != 0)
|
||||
delete part;
|
||||
}
|
||||
@@ -229,6 +239,26 @@ void Sub_Data_Files<T>::seekg(DataPositions& pos)
|
||||
extended[it->first].seekg(it->second);
|
||||
}
|
||||
dabit_buffer.seekg(pos.files[field_type][DATA_DABIT]);
|
||||
|
||||
if (field_type == DATA_INT)
|
||||
{
|
||||
for (auto& x : pos.edabits)
|
||||
{
|
||||
// open files
|
||||
get_edabit_buffer(x.first.second);
|
||||
}
|
||||
|
||||
|
||||
int block_size = edabitvec<T>::MAX_SIZE;
|
||||
for (auto& x : edabit_buffers)
|
||||
{
|
||||
int n = pos.edabits[{true, x.first}] + pos.edabits[{false, x.first}];
|
||||
x.second.seekg(n / block_size);
|
||||
edabit<T> eb;
|
||||
for (int i = 0; i < n % block_size; i++)
|
||||
get_edabit_no_count(false, x.first, eb);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
@@ -262,6 +292,8 @@ void Sub_Data_Files<T>::prune()
|
||||
dabit_buffer.prune();
|
||||
if (part != 0)
|
||||
part->prune();
|
||||
for (auto& x : edabit_buffers)
|
||||
x.second.prune();
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
@@ -285,6 +317,8 @@ void Sub_Data_Files<T>::purge()
|
||||
dabit_buffer.purge();
|
||||
if (part != 0)
|
||||
part->purge();
|
||||
for (auto& x : edabit_buffers)
|
||||
x.second.prune();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -322,34 +356,43 @@ void Sub_Data_Files<T>::get_dabit_no_count(T& a, typename T::bit_type& b)
|
||||
}
|
||||
|
||||
template<class T>
|
||||
template<int>
|
||||
void Sub_Data_Files<T>::buffer_edabits_with_queues(bool strict, int n_bits,
|
||||
false_type)
|
||||
EdabitBuffer<T>& Sub_Data_Files<T>::get_edabit_buffer(int n_bits)
|
||||
{
|
||||
if (edabit_buffers.empty())
|
||||
insecure("reading edaBits from files");
|
||||
|
||||
if (edabit_buffers.find(n_bits) == edabit_buffers.end())
|
||||
{
|
||||
string filename = PrepBase::get_edabit_filename(prep_data_dir,
|
||||
n_bits, my_num, thread_num);
|
||||
ifstream* f = new ifstream(filename);
|
||||
if (f->fail())
|
||||
throw runtime_error("cannot open " + filename);
|
||||
check_file_signature<T>(*f, filename);
|
||||
edabit_buffers[n_bits] = f;
|
||||
edabit_buffers[n_bits] = n_bits;
|
||||
edabit_buffers[n_bits].setup(filename,
|
||||
T::size() * edabitvec<T>::MAX_SIZE
|
||||
+ n_bits * T::bit_type::part_type::size());
|
||||
}
|
||||
auto& buffer = *edabit_buffers[n_bits];
|
||||
if (buffer.peek() == EOF)
|
||||
return edabit_buffers[n_bits];
|
||||
}
|
||||
|
||||
template<class T>
|
||||
edabitvec<T> Sub_Data_Files<T>::get_edabitvec(bool strict, int n_bits)
|
||||
{
|
||||
if (my_edabits[n_bits].empty())
|
||||
return get_edabit_buffer(n_bits).read();
|
||||
else
|
||||
{
|
||||
buffer.seekg(0);
|
||||
check_file_signature<T>(buffer, "");
|
||||
auto res = my_edabits[n_bits];
|
||||
my_edabits[n_bits] = {};
|
||||
this->fill(res, strict, n_bits);
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Preprocessing<T>::fill(edabitvec<T>& res, bool strict, int n_bits)
|
||||
{
|
||||
edabit<T> eb;
|
||||
while (res.size() < res.MAX_SIZE)
|
||||
{
|
||||
get_edabit_no_count(strict, n_bits, eb);
|
||||
res.push_back(eb);
|
||||
}
|
||||
edabitvec<T> eb;
|
||||
eb.input(n_bits, buffer);
|
||||
this->edabits[{strict, n_bits}].push_back(eb);
|
||||
if (buffer.fail())
|
||||
throw runtime_error("error reading edaBits");
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -362,4 +405,10 @@ typename Sub_Data_Files<T>::part_type& Sub_Data_Files<T>::get_part()
|
||||
return *part;
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
TimerWithComm Data_Files<sint, sgf2n>::total_time()
|
||||
{
|
||||
return DataFp.prep_timer + DataF2.prep_timer + DataFb.prep_timer;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
50
Processor/EdabitBuffer.h
Normal file
50
Processor/EdabitBuffer.h
Normal file
@@ -0,0 +1,50 @@
|
||||
/*
|
||||
* EdabitBuffer.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef PROCESSOR_EDABITBUFFER_H_
|
||||
#define PROCESSOR_EDABITBUFFER_H_
|
||||
|
||||
#include "Tools/Buffer.h"
|
||||
|
||||
template<class T>
|
||||
class EdabitBuffer : public BufferOwner<T, T>
|
||||
{
|
||||
int n_bits;
|
||||
|
||||
int element_length()
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
public:
|
||||
EdabitBuffer(int n_bits = 0) :
|
||||
n_bits(n_bits)
|
||||
{
|
||||
}
|
||||
|
||||
edabitvec<T> read()
|
||||
{
|
||||
if (not BufferBase::file)
|
||||
{
|
||||
if (this->open()->fail())
|
||||
throw runtime_error("error opening " + this->filename);
|
||||
}
|
||||
|
||||
assert(BufferBase::file);
|
||||
auto& buffer = *BufferBase::file;
|
||||
if (buffer.peek() == EOF)
|
||||
{
|
||||
this->try_rewind();
|
||||
}
|
||||
|
||||
edabitvec<T> eb;
|
||||
eb.input(n_bits, buffer);
|
||||
if (buffer.fail())
|
||||
throw runtime_error("error reading edaBits");
|
||||
return eb;
|
||||
}
|
||||
};
|
||||
|
||||
#endif /* PROCESSOR_EDABITBUFFER_H_ */
|
||||
@@ -70,6 +70,7 @@ enum
|
||||
PLAYERID = 0xE4,
|
||||
USE_EDABIT = 0xE5,
|
||||
USE_MATMUL = 0x1F,
|
||||
ACTIVE = 0xE9,
|
||||
// Addition
|
||||
ADDC = 0x20,
|
||||
ADDS = 0x21,
|
||||
|
||||
@@ -311,6 +311,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case PRIVATEOUTPUT:
|
||||
case TRUNC_PR:
|
||||
case RUN_TAPE:
|
||||
case CONV2DS:
|
||||
num_var_args = get_int(s);
|
||||
get_vector(num_var_args, start, s);
|
||||
break;
|
||||
@@ -322,10 +323,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
get_ints(r, s, 3);
|
||||
get_vector(9, start, s);
|
||||
break;
|
||||
case CONV2DS:
|
||||
get_ints(r, s, 3);
|
||||
get_vector(12, start, s);
|
||||
break;
|
||||
|
||||
// read from file, input is opcode num_args,
|
||||
// start_file_posn (read), end_file_posn(write) var1, var2, ...
|
||||
@@ -425,6 +422,10 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
throw Processor_Error(ss.str());
|
||||
}
|
||||
break;
|
||||
case ACTIVE:
|
||||
n = get_int(s);
|
||||
BaseMachine::s().active(n);
|
||||
break;
|
||||
case XORM:
|
||||
case ANDM:
|
||||
case XORCB:
|
||||
@@ -720,7 +721,16 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
|
||||
case MATMULSM:
|
||||
return r[0] + start[0] * start[2];
|
||||
case CONV2DS:
|
||||
return r[0] + start[0] * start[1] * start[11];
|
||||
{
|
||||
unsigned res = 0;
|
||||
for (size_t i = 0; i < start.size(); i += 15)
|
||||
{
|
||||
unsigned tmp = start[i]
|
||||
+ start[i + 3] * start[i + 4] * start.at(i + 14);
|
||||
res = max(res, tmp);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
case OPEN:
|
||||
skip = 2;
|
||||
break;
|
||||
@@ -1164,6 +1174,7 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
break;
|
||||
case REQBL:
|
||||
case GREQBL:
|
||||
case ACTIVE:
|
||||
case USE:
|
||||
case USE_INP:
|
||||
case USE_EDABIT:
|
||||
|
||||
@@ -109,6 +109,7 @@ class Machine : public BaseMachine
|
||||
string prep_dir_prefix();
|
||||
|
||||
void reqbl(int n);
|
||||
void active(int n);
|
||||
|
||||
typename sint::bit_type::mac_key_type get_bit_mac_key() { return alphabi; }
|
||||
typename sint::mac_key_type get_sint_mac_key() { return alphapi; }
|
||||
|
||||
@@ -415,6 +415,9 @@ pair<DataPositions, NamedCommStats> Machine<sint, sgf2n>::stop_threads()
|
||||
|
||||
auto comm_stats = total_comm();
|
||||
|
||||
if (OnlineOptions::singleton.verbose)
|
||||
queues.print_breakdown();
|
||||
|
||||
for (auto& queue : queues)
|
||||
delete queue;
|
||||
|
||||
@@ -477,20 +480,7 @@ void Machine<sint, sgf2n>::run(const string& progname)
|
||||
print_timers();
|
||||
|
||||
if (sint::is_real)
|
||||
{
|
||||
size_t rounds = 0;
|
||||
for (auto& x : comm_stats)
|
||||
rounds += x.second.rounds;
|
||||
cerr << "Data sent = " << comm_stats.sent / 1e6 << " MB in ~" << rounds
|
||||
<< " rounds (party " << my_number;
|
||||
if (threads.size() > 1)
|
||||
cerr << "; rounds counted double due to multi-threading";
|
||||
cerr << "; use '-v' for more details";
|
||||
cerr << ")" << endl;
|
||||
|
||||
auto& P = *this->P;
|
||||
this->print_global_comm(P, comm_stats);
|
||||
}
|
||||
this->print_comm(*this->P, comm_stats);
|
||||
|
||||
#ifdef VERBOSE_OPTIONS
|
||||
if (opening_sum < N.num_players() && !direct)
|
||||
@@ -521,23 +511,6 @@ void Machine<sint, sgf2n>::run(const string& progname)
|
||||
|
||||
bit_memories.write_memory(N.my_num());
|
||||
|
||||
#ifdef OLD_USAGE
|
||||
for (int dtype = 0; dtype < N_DTYPE; dtype++)
|
||||
{
|
||||
cerr << "Num " << DataPositions::dtype_names[dtype] << "\t=";
|
||||
for (int field_type = 0; field_type < N_DATA_FIELD_TYPE; field_type++)
|
||||
cerr << " " << pos.files[field_type][dtype];
|
||||
cerr << endl;
|
||||
}
|
||||
for (int field_type = 0; field_type < N_DATA_FIELD_TYPE; field_type++)
|
||||
{
|
||||
cerr << "Num " << DataPositions::field_names[field_type] << " Inputs\t=";
|
||||
for (int i = 0; i < N.num_players(); i++)
|
||||
cerr << " " << pos.inputs[i][field_type];
|
||||
cerr << endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (opts.verbose)
|
||||
{
|
||||
cerr << "Actual cost of program:" << endl;
|
||||
@@ -586,6 +559,17 @@ void Machine<sint, sgf2n>::reqbl(int n)
|
||||
sint::clear::reqbl(n);
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
void Machine<sint, sgf2n>::active(int n)
|
||||
{
|
||||
|
||||
if (sint::malicious and n == 0)
|
||||
{
|
||||
cerr << "Program requires a semi-honest protocol" << endl;
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
void Machine<sint, sgf2n>::suggest_optimizations()
|
||||
{
|
||||
@@ -599,8 +583,8 @@ void Machine<sint, sgf2n>::suggest_optimizations()
|
||||
optimizations.append("\tprogram.use_edabit(True)\n");
|
||||
if (not optimizations.empty())
|
||||
cerr << "This program might benefit from some protocol options." << endl
|
||||
<< "Consider adding the following at the beginning of '" << progname
|
||||
<< ".mpc':" << endl << optimizations;
|
||||
<< "Consider adding the following at the beginning of your code:"
|
||||
<< endl << optimizations;
|
||||
#ifndef __clang__
|
||||
cerr << "This virtual machine was compiled with GCC. Recompile with "
|
||||
"'CXX = clang++' in 'CONFIG.mine' for optimal performance." << endl;
|
||||
|
||||
@@ -172,7 +172,7 @@ void OfflineMachine<W>::generate()
|
||||
auto& opts = OnlineOptions::singleton;
|
||||
opts.batch_size = DIV_CEIL(opts.batch_size, batch) * batch;
|
||||
for (int i = 0; i < buffered_total(total, batch) / batch; i++)
|
||||
preprocessing.template get_edabitvec<0>(true, n_bits).output(n_bits,
|
||||
preprocessing.get_edabitvec(true, n_bits).output(n_bits,
|
||||
out);
|
||||
}
|
||||
else
|
||||
|
||||
@@ -44,6 +44,7 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
|
||||
|
||||
auto& queues = machine.queues[num];
|
||||
queues->next();
|
||||
ThreadQueue::thread_queue = queues;
|
||||
|
||||
#ifdef DEBUG_THREADS
|
||||
fprintf(stderr, "\tI am in thread %d\n",num);
|
||||
@@ -118,6 +119,8 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
|
||||
DataPositions actual_usage(P.num_players());
|
||||
Timer thread_timer(CLOCK_THREAD_CPUTIME_ID), wait_timer;
|
||||
thread_timer.start();
|
||||
TimerWithComm timer, online_timer, online_prep_timer;
|
||||
timer.start();
|
||||
|
||||
while (flag)
|
||||
{ // Wait until I have a program to run
|
||||
@@ -262,6 +265,8 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
|
||||
#ifdef DEBUG_THREADS
|
||||
printf("\tClient %d about to run %d\n",num,program);
|
||||
#endif
|
||||
online_timer.start(P.total_comm());
|
||||
online_prep_timer -= Proc.DataF.total_time();
|
||||
Proc.reset(progs[program], job.arg);
|
||||
|
||||
// Bits, Triples, Squares, and Inverses skipping
|
||||
@@ -290,6 +295,8 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
|
||||
printf("\tSignalling I have finished with program %d"
|
||||
"in thread %d\n", program, num);
|
||||
#endif
|
||||
online_timer.stop(P.total_comm());
|
||||
online_prep_timer += Proc.DataF.total_time();
|
||||
wait_timer.start();
|
||||
queues->finished(job, P.total_comm());
|
||||
wait_timer.stop();
|
||||
@@ -297,7 +304,11 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
|
||||
}
|
||||
|
||||
// final check
|
||||
online_timer.start(P.total_comm());
|
||||
online_prep_timer -= Proc.DataF.total_time();
|
||||
Proc.check();
|
||||
online_timer.stop(P.total_comm());
|
||||
online_prep_timer += Proc.DataF.total_time();
|
||||
|
||||
if (machine.opts.file_prep_per_thread)
|
||||
Proc.DataF.prune();
|
||||
@@ -330,6 +341,11 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
|
||||
|
||||
// wind down thread by thread
|
||||
machine.stats += Proc.stats;
|
||||
queues->timers["wait"] = wait_timer + queues->wait_timer;
|
||||
timer.stop(P.total_comm());
|
||||
queues->timers["online"] = online_timer - online_prep_timer - queues->wait_timer;
|
||||
queues->timers["prep"] = timer - queues->timers["wait"] - queues->timers["online"];
|
||||
|
||||
// prevent faulty usage message
|
||||
Proc.DataF.set_usage(actual_usage);
|
||||
delete processor;
|
||||
|
||||
@@ -69,7 +69,7 @@ void PrepBase::print_left_edabits(size_t n, size_t n_batch, bool strict,
|
||||
cerr << " edaBits of size " << n_bits << " left" << endl;
|
||||
}
|
||||
|
||||
if (n > used / 10)
|
||||
if (n * n_batch > used / 10)
|
||||
cerr << "Significant amount of unused edaBits of size " << n_bits
|
||||
<< ". For more accurate benchmarks, "
|
||||
<< "consider reducing the batch size with --batch-size "
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
using namespace std;
|
||||
|
||||
#include "Math/field_types.h"
|
||||
#include "Tools/TimerWithComm.h"
|
||||
|
||||
class PrepBase
|
||||
{
|
||||
@@ -28,6 +29,8 @@ public:
|
||||
const string& type_string, size_t used);
|
||||
static void print_left_edabits(size_t n, size_t n_batch, bool strict,
|
||||
int n_bits, size_t used);
|
||||
|
||||
TimerWithComm prep_timer;
|
||||
};
|
||||
|
||||
#endif /* PROCESSOR_PREPBASE_H_ */
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "Processor/Program.h"
|
||||
#include "GC/square64.h"
|
||||
#include "SpecificPrivateOutput.h"
|
||||
#include "Conv2dTuple.h"
|
||||
|
||||
#include "Processor/ProcessorBase.hpp"
|
||||
#include "GC/Processor.hpp"
|
||||
@@ -31,6 +32,7 @@ SubProcessor<T>::SubProcessor(typename T::MAC_Check& MC,
|
||||
DataF.set_proc(this);
|
||||
protocol.init(DataF, MC);
|
||||
DataF.set_protocol(protocol);
|
||||
MC.set_prep(DataF);
|
||||
bit_usage.set_num_players(P.num_players());
|
||||
personal_bit_preps.resize(P.num_players());
|
||||
for (int i = 0; i < P.num_players(); i++)
|
||||
@@ -40,6 +42,7 @@ SubProcessor<T>::SubProcessor(typename T::MAC_Check& MC,
|
||||
template<class T>
|
||||
SubProcessor<T>::~SubProcessor()
|
||||
{
|
||||
DataF.set_proc(0);
|
||||
for (size_t i = 0; i < personal_bit_preps.size(); i++)
|
||||
{
|
||||
auto& x = personal_bit_preps[i];
|
||||
@@ -391,7 +394,7 @@ void Processor<sint, sgf2n>::read_shares_from_file(int start_file_posn, int end_
|
||||
return;
|
||||
|
||||
string filename;
|
||||
filename = "Persistence/Transactions-P" + to_string(P.my_num()) + ".data";
|
||||
filename = binary_file_io.filename(P.my_num());
|
||||
|
||||
unsigned int size = data_registers.size();
|
||||
|
||||
@@ -652,21 +655,35 @@ void SubProcessor<T>::conv2ds(const Instruction& instruction)
|
||||
{
|
||||
protocol.init_dotprod();
|
||||
auto& args = instruction.get_start();
|
||||
int output_h = args[0], output_w = args[1];
|
||||
int inputs_h = args[2], inputs_w = args[3];
|
||||
int weights_h = args[4], weights_w = args[5];
|
||||
int stride_h = args[6], stride_w = args[7];
|
||||
int n_channels_in = args[8];
|
||||
int padding_h = args[9];
|
||||
int padding_w = args[10];
|
||||
int batch_size = args[11];
|
||||
size_t r0 = instruction.get_r(0);
|
||||
size_t r1 = instruction.get_r(1);
|
||||
int r2 = instruction.get_r(2);
|
||||
int lengths[batch_size][output_h][output_w];
|
||||
memset(lengths, 0, sizeof(lengths));
|
||||
int filter_stride_h = 1;
|
||||
int filter_stride_w = 1;
|
||||
vector<Conv2dTuple> tuples;
|
||||
for (size_t i = 0; i < args.size(); i += 15)
|
||||
tuples.push_back(Conv2dTuple(args, i));
|
||||
for (auto& tuple : tuples)
|
||||
tuple.pre(S, protocol);
|
||||
protocol.exchange();
|
||||
for (auto& tuple : tuples)
|
||||
tuple.post(S, protocol);
|
||||
}
|
||||
|
||||
inline
|
||||
Conv2dTuple::Conv2dTuple(const vector<int>& arguments, int start)
|
||||
{
|
||||
assert(arguments.size() >= start + 15ul);
|
||||
auto args = arguments.data() + start + 3;
|
||||
output_h = args[0], output_w = args[1];
|
||||
inputs_h = args[2], inputs_w = args[3];
|
||||
weights_h = args[4], weights_w = args[5];
|
||||
stride_h = args[6], stride_w = args[7];
|
||||
n_channels_in = args[8];
|
||||
padding_h = args[9];
|
||||
padding_w = args[10];
|
||||
batch_size = args[11];
|
||||
r0 = arguments[start];
|
||||
r1 = arguments[start + 1];
|
||||
r2 = arguments[start + 2];
|
||||
lengths.resize(batch_size, vector<vector<int>>(output_h, vector<int>(output_w)));
|
||||
filter_stride_h = 1;
|
||||
filter_stride_w = 1;
|
||||
if (stride_h < 0)
|
||||
{
|
||||
filter_stride_h = -stride_h;
|
||||
@@ -677,7 +694,11 @@ void SubProcessor<T>::conv2ds(const Instruction& instruction)
|
||||
filter_stride_w = -stride_w;
|
||||
stride_w = 1;
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Conv2dTuple::pre(vector<T>& S, typename T::Protocol& protocol)
|
||||
{
|
||||
for (int i_batch = 0; i_batch < batch_size; i_batch ++)
|
||||
{
|
||||
size_t base = r1 + i_batch * inputs_w * inputs_h * n_channels_in;
|
||||
@@ -714,9 +735,11 @@ void SubProcessor<T>::conv2ds(const Instruction& instruction)
|
||||
protocol.next_dotprod();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protocol.exchange();
|
||||
|
||||
template<class T>
|
||||
void Conv2dTuple::post(vector<T>& S, typename T::Protocol& protocol)
|
||||
{
|
||||
for (int i_batch = 0; i_batch < batch_size; i_batch ++)
|
||||
{
|
||||
size_t base = r0 + i_batch * output_h * output_w;
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
|
||||
#include "ThreadQueue.h"
|
||||
|
||||
thread_local ThreadQueue* ThreadQueue::thread_queue = 0;
|
||||
|
||||
void ThreadQueue::schedule(const ThreadJob& job)
|
||||
{
|
||||
lock.lock();
|
||||
@@ -14,7 +16,11 @@ void ThreadQueue::schedule(const ThreadJob& job)
|
||||
cerr << this << ": " << left << " left" << endl;
|
||||
#endif
|
||||
lock.unlock();
|
||||
if (thread_queue)
|
||||
thread_queue->wait_timer.start();
|
||||
in.push(job);
|
||||
if (thread_queue)
|
||||
thread_queue->wait_timer.stop();
|
||||
}
|
||||
|
||||
ThreadJob ThreadQueue::next()
|
||||
@@ -42,7 +48,11 @@ void ThreadQueue::set_comm_stats(const NamedCommStats& new_comm_stats)
|
||||
|
||||
ThreadJob ThreadQueue::result()
|
||||
{
|
||||
if (thread_queue)
|
||||
thread_queue->wait_timer.start();
|
||||
auto res = out.pop();
|
||||
if (thread_queue)
|
||||
thread_queue->wait_timer.stop();
|
||||
lock.lock();
|
||||
left--;
|
||||
#ifdef DEBUG_THREAD_QUEUE
|
||||
|
||||
@@ -16,6 +16,11 @@ class ThreadQueue
|
||||
NamedCommStats comm_stats;
|
||||
|
||||
public:
|
||||
static thread_local ThreadQueue* thread_queue;
|
||||
|
||||
map<string, TimerWithComm> timers;
|
||||
Timer wait_timer;
|
||||
|
||||
ThreadQueue() :
|
||||
left(0)
|
||||
{
|
||||
|
||||
@@ -85,3 +85,32 @@ void ThreadQueues::wrap_up(ThreadJob job)
|
||||
}
|
||||
available.clear();
|
||||
}
|
||||
|
||||
TimerWithComm ThreadQueues::sum(const string& phase)
|
||||
{
|
||||
TimerWithComm res;
|
||||
for (auto& x : *this)
|
||||
res += x->timers[phase];
|
||||
return res;
|
||||
}
|
||||
|
||||
void ThreadQueues::print_breakdown()
|
||||
{
|
||||
if (size() > 0)
|
||||
{
|
||||
if (size() == 1)
|
||||
{
|
||||
cerr << "Spent " << (*this)[0]->timers["online"].full()
|
||||
<< " on the online phase and "
|
||||
<< (*this)[0]->timers["prep"].full()
|
||||
<< " on the preprocessing/offline phase." << endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
cerr << size() << " threads spent a total of " << sum("online").full()
|
||||
<< " on the online phase, " << sum("prep").full()
|
||||
<< " on the preprocessing/offline phase, and "
|
||||
<< sum("wait").full() << " idling." << endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,6 +24,10 @@ public:
|
||||
int distribute_no_setup(ThreadJob job, int n_items, int base = 0,
|
||||
int granularity = 1, const vector<void*>* supplies = 0);
|
||||
void wrap_up(ThreadJob job);
|
||||
|
||||
TimerWithComm sum(const string& phase);
|
||||
|
||||
void print_breakdown();
|
||||
};
|
||||
|
||||
#endif /* PROCESSOR_THREADQUEUES_H_ */
|
||||
|
||||
@@ -387,6 +387,7 @@
|
||||
X(GENSECSHUFFLE, throw not_implemented(),) \
|
||||
X(APPLYSHUFFLE, throw not_implemented(),) \
|
||||
X(DELSHUFFLE, throw not_implemented(),) \
|
||||
X(ACTIVE, throw not_implemented(),) \
|
||||
|
||||
#define ALL_INSTRUCTIONS ARITHMETIC_INSTRUCTIONS REGINT_INSTRUCTIONS \
|
||||
CLEAR_GF2N_INSTRUCTIONS REMAINING_INSTRUCTIONS
|
||||
|
||||
114
Programs/Source/alex.mpc
Normal file
114
Programs/Source/alex.mpc
Normal file
@@ -0,0 +1,114 @@
|
||||
from Compiler.ml import keras
|
||||
import Compiler.ml as tf
|
||||
|
||||
try:
|
||||
n_epochs = int(program.args[1])
|
||||
except (ValueError, IndexError):
|
||||
n_epochs = 20
|
||||
|
||||
try:
|
||||
batch_size = int(program.args[2])
|
||||
except (ValueError, IndexError):
|
||||
batch_size = 128
|
||||
|
||||
try:
|
||||
n_threads = int(program.args[3])
|
||||
except (ValueError, IndexError):
|
||||
n_threads = 36
|
||||
|
||||
#Instantiation
|
||||
AlexNet = []
|
||||
|
||||
padding = 1
|
||||
batchnorm = 'batchnorm' in program.args
|
||||
bn1 = 'bn1' in program.args
|
||||
bn2 = 'bn2' in program.args
|
||||
|
||||
MultiArray.disable_index_checks()
|
||||
|
||||
#1st Convolutional Layer
|
||||
AlexNet.append(keras.layers.Conv2D(filters=64, input_shape=(32,32,3), kernel_size=3, strides=1, padding=2))
|
||||
AlexNet.append(keras.layers.Activation('relu'))
|
||||
if batchnorm:
|
||||
AlexNet.append(keras.layers.BatchNormalization())
|
||||
AlexNet.append(keras.layers.MaxPooling2D(pool_size=(2,2), strides=(2,2), padding=0))
|
||||
|
||||
#2nd Convolutional Layer
|
||||
AlexNet.append(keras.layers.Conv2D(filters=96, kernel_size=3, strides=1, padding=2))
|
||||
AlexNet.append(keras.layers.Activation('relu'))
|
||||
if batchnorm or bn2:
|
||||
AlexNet.append(keras.layers.BatchNormalization())
|
||||
AlexNet.append(keras.layers.MaxPooling2D(pool_size=(2,2), strides=(2,2), padding='same'))
|
||||
|
||||
#3rd Convolutional Layer
|
||||
AlexNet.append(keras.layers.Conv2D(filters=96, kernel_size=(3,3), strides=(1,1), padding=padding))
|
||||
AlexNet.append(keras.layers.Activation('relu'))
|
||||
if batchnorm:
|
||||
AlexNet.append(keras.layers.BatchNormalization())
|
||||
|
||||
#4th Convolutional Layer
|
||||
AlexNet.append(keras.layers.Conv2D(filters=64, kernel_size=(3,3), strides=(1,1), padding=padding))
|
||||
AlexNet.append(keras.layers.Activation('relu'))
|
||||
if batchnorm or bn1:
|
||||
AlexNet.append(keras.layers.BatchNormalization())
|
||||
|
||||
#5th Convolutional Layer
|
||||
AlexNet.append(keras.layers.Conv2D(filters=64, kernel_size=(3,3), strides=(1,1), padding=padding))
|
||||
AlexNet.append(keras.layers.Activation('relu'))
|
||||
if batchnorm or bn2:
|
||||
AlexNet.append(keras.layers.BatchNormalization())
|
||||
AlexNet.append(keras.layers.MaxPooling2D(pool_size=(3,3), strides=(2,2), padding=0))
|
||||
|
||||
#Passing it to a Fully Connected layer
|
||||
# 1st Fully Connected Layer
|
||||
AlexNet.append(keras.layers.Dense(128))
|
||||
AlexNet.append(keras.layers.Activation('relu'))
|
||||
|
||||
if 'dropout' in program.args:
|
||||
AlexNet.append(keras.layers.Dropout(0.5))
|
||||
|
||||
#2nd Fully Connected Layer
|
||||
AlexNet.append(keras.layers.Dense(256))
|
||||
AlexNet.append(keras.layers.Activation('relu'))
|
||||
|
||||
if 'dropout' in program.args:
|
||||
AlexNet.append(keras.layers.Dropout(0.5))
|
||||
|
||||
#Output Layer
|
||||
AlexNet.append(keras.layers.Dense(10))
|
||||
|
||||
tf.set_n_threads(n_threads)
|
||||
program.options_from_args()
|
||||
sfix.set_precision_from_args(program, adapt_ring=True)
|
||||
|
||||
training_samples = MultiArray([50000, 32, 32, 3], sfix)
|
||||
training_labels = MultiArray([50000, 10], sint)
|
||||
|
||||
test_samples = MultiArray([10000, 32, 32, 3], sfix)
|
||||
test_labels = MultiArray([10000, 10], sint)
|
||||
|
||||
training_labels.input_from(0)
|
||||
training_samples.input_from(0, binary='binary_samples' in program.args)
|
||||
|
||||
test_labels.input_from(0)
|
||||
test_samples.input_from(0, binary='binary_samples' in program.args)
|
||||
|
||||
model = tf.keras.models.Sequential(AlexNet)
|
||||
|
||||
model.compile_by_args(program)
|
||||
|
||||
model.build(training_samples.sizes)
|
||||
model.summary()
|
||||
|
||||
model.opt.output_diff = 'output_diff' in program.args
|
||||
model.opt.output_grad = 'output_grad' in program.args
|
||||
model.opt.output_stats = 100 if 'output_stats' in program.args else 0
|
||||
model.opt.shuffle = not 'noshuffle' in program.args
|
||||
|
||||
opt = model.fit(
|
||||
training_samples,
|
||||
training_labels,
|
||||
epochs=n_epochs,
|
||||
batch_size=batch_size,
|
||||
validation_data=(test_samples, test_labels)
|
||||
)
|
||||
@@ -25,6 +25,9 @@ n_threads = 2
|
||||
if len(program.args) > 1:
|
||||
n_rounds = int(program.args[1])
|
||||
|
||||
if len(program.args) > 2:
|
||||
program.active = bool(int(program.args[2]))
|
||||
|
||||
def accept_client():
|
||||
client_socket_id = accept_client_connection(PORTNUM)
|
||||
last = regint.read_from_socket(client_socket_id)
|
||||
|
||||
@@ -4,7 +4,7 @@ import Compiler.ml as tf
|
||||
try:
|
||||
n_epochs = int(program.args[1])
|
||||
except (ValueError, IndexError):
|
||||
n_epochs = 10
|
||||
n_epochs = 20
|
||||
|
||||
try:
|
||||
batch_size = int(program.args[2])
|
||||
|
||||
72
Programs/Source/keras_mnist_lenet_avgpool.mpc
Normal file
72
Programs/Source/keras_mnist_lenet_avgpool.mpc
Normal file
@@ -0,0 +1,72 @@
|
||||
# this trains LeNet on MNIST with a dropout layer
|
||||
# see https://github.com/csiro-mlai/mnist-mpc for data preparation
|
||||
|
||||
program.options_from_args()
|
||||
|
||||
if 'torch' in program.args:
|
||||
import torchvision
|
||||
data = []
|
||||
for train in True, False:
|
||||
ds = torchvision.datasets.MNIST(root='/tmp', train=train, download=True)
|
||||
# normalize to [0,1] before input
|
||||
samples = sfix.input_tensor_via(0, ds.data / 255., binary=True)
|
||||
labels = sint.input_tensor_via(0, ds.targets, binary=True, one_hot=True)
|
||||
data += [(labels, samples)]
|
||||
|
||||
(training_labels, training_samples), (test_labels, test_samples) = data
|
||||
else:
|
||||
training_samples = sfix.Tensor([60000, 28, 28])
|
||||
training_labels = sint.Tensor([60000, 10])
|
||||
|
||||
test_samples = sfix.Tensor([10000, 28, 28])
|
||||
test_labels = sint.Tensor([10000, 10])
|
||||
|
||||
training_labels.input_from(0)
|
||||
training_samples.input_from(0)
|
||||
|
||||
test_labels.input_from(0)
|
||||
test_samples.input_from(0)
|
||||
|
||||
from Compiler import ml
|
||||
tf = ml
|
||||
|
||||
layers = [
|
||||
tf.keras.layers.Conv2D(20, 5, 1, 'valid', activation='relu'),
|
||||
]
|
||||
|
||||
if 'batchnorm' in program.args:
|
||||
layers += [tf.keras.layers.BatchNormalization()]
|
||||
|
||||
layers += [
|
||||
tf.keras.layers.AveragePooling2D(2),
|
||||
tf.keras.layers.Conv2D(50, 5, 1, 'valid', activation='relu'),
|
||||
]
|
||||
|
||||
|
||||
if 'batchnorm' in program.args:
|
||||
layers += [tf.keras.layers.BatchNormalization()]
|
||||
|
||||
layers += [
|
||||
tf.keras.layers.AveragePooling2D(2),
|
||||
tf.keras.layers.Flatten(),
|
||||
tf.keras.layers.Dropout(0.5),
|
||||
tf.keras.layers.Dense(500, activation='relu'),
|
||||
tf.keras.layers.Dense(10, activation='softmax')
|
||||
]
|
||||
|
||||
model = tf.keras.models.Sequential(layers)
|
||||
|
||||
optim = tf.keras.optimizers.Adam(amsgrad=True)
|
||||
|
||||
model.compile(optimizer=optim)
|
||||
|
||||
opt = model.fit(
|
||||
training_samples,
|
||||
training_labels,
|
||||
epochs=10,
|
||||
batch_size=128,
|
||||
validation_data=(test_samples, test_labels)
|
||||
)
|
||||
|
||||
for var in model.trainable_variables:
|
||||
var.write_to_file()
|
||||
49
Programs/Source/torch_mnist_lenet_avgpool.mpc
Normal file
49
Programs/Source/torch_mnist_lenet_avgpool.mpc
Normal file
@@ -0,0 +1,49 @@
|
||||
# this trains a dense neural network on MNIST
|
||||
|
||||
program.options_from_args()
|
||||
|
||||
import torchvision
|
||||
|
||||
data = []
|
||||
for train in True, False:
|
||||
ds = torchvision.datasets.MNIST(root='/tmp', train=train, download=True)
|
||||
# normalize to [0,1] before input
|
||||
samples = sfix.input_tensor_via(0, ds.data / 255., binary=True)
|
||||
labels = sint.input_tensor_via(0, ds.targets, binary=True, one_hot=True)
|
||||
data += [(labels, samples)]
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
net = nn.Sequential(
|
||||
nn.Conv2d(1, 20, 5),
|
||||
nn.ReLU(),
|
||||
nn.AvgPool2d(2),
|
||||
nn.Conv2d(20, 50, 5),
|
||||
nn.ReLU(),
|
||||
nn.AvgPool2d(2),
|
||||
nn.Flatten(),
|
||||
nn.ReLU(),
|
||||
nn.Linear(800, 500),
|
||||
nn.ReLU(),
|
||||
nn.Linear(500, 10)
|
||||
)
|
||||
|
||||
# test network
|
||||
ds = torchvision.datasets.MNIST(
|
||||
root='/tmp', transform=torchvision.transforms.ToTensor())
|
||||
inputs = next(iter(torch.utils.data.DataLoader(ds)))[0]
|
||||
print(inputs.shape)
|
||||
outputs = net(inputs)
|
||||
|
||||
from Compiler import ml
|
||||
|
||||
ml.set_n_threads(int(program.args[2]))
|
||||
|
||||
layers = ml.layers_from_torch(net, data[0][1].shape, 128)
|
||||
layers[0].X = data[0][1]
|
||||
layers[-1].Y = data[0][0]
|
||||
|
||||
optimizer = ml.SGD(layers)
|
||||
optimizer.run_by_args(program, int(program.args[1]), 128,
|
||||
data[1][1], data[1][0])
|
||||
@@ -11,6 +11,8 @@ DealerMatrixPrep<T>::DealerMatrixPrep(int n_rows, int n_inner, int n_cols,
|
||||
super(usage), n_rows(n_rows), n_inner(n_inner), n_cols(n_cols),
|
||||
prep(&prep)
|
||||
{
|
||||
assert(prep.proc);
|
||||
this->P = &prep.proc->P;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
|
||||
@@ -20,9 +20,6 @@ class Hemi : public T::BasicProtocol
|
||||
|
||||
MatrixMC<T> mc;
|
||||
|
||||
ShareMatrix<T> matrix_multiply(const ShareMatrix<T>& A, const ShareMatrix<T>& B,
|
||||
SubProcessor<T>& processor);
|
||||
|
||||
public:
|
||||
Hemi(Player& P) :
|
||||
T::BasicProtocol(P)
|
||||
@@ -33,6 +30,9 @@ public:
|
||||
typename T::MatrixPrep& get_matrix_prep(const array<int, 3>& dimensions,
|
||||
SubProcessor<T>& processor);
|
||||
|
||||
ShareMatrix<T> matrix_multiply(const ShareMatrix<T>& A, const ShareMatrix<T>& B,
|
||||
SubProcessor<T>& processor);
|
||||
|
||||
void matmulsm(SubProcessor<T>& processor, CheckVector<T>& source,
|
||||
const Instruction& instruction, int a, int b);
|
||||
void conv2ds(SubProcessor<T>& processor, const Instruction& instruction);
|
||||
|
||||
@@ -130,37 +130,23 @@ void Hemi<T>::conv2ds(SubProcessor<T>& processor,
|
||||
}
|
||||
|
||||
auto& args = instruction.get_start();
|
||||
int output_h = args[0], output_w = args[1];
|
||||
int inputs_h = args[2], inputs_w = args[3];
|
||||
int weights_h = args[4], weights_w = args[5];
|
||||
int stride_h = args[6], stride_w = args[7];
|
||||
int n_channels_in = args[8];
|
||||
int padding_h = args[9];
|
||||
int padding_w = args[10];
|
||||
int batch_size = args[11];
|
||||
size_t r0 = instruction.get_r(0);
|
||||
size_t r1 = instruction.get_r(1);
|
||||
int r2 = instruction.get_r(2);
|
||||
int filter_stride_h = 1;
|
||||
int filter_stride_w = 1;
|
||||
if (stride_h < 0)
|
||||
{
|
||||
filter_stride_h = -stride_h;
|
||||
stride_h = 1;
|
||||
}
|
||||
if (stride_w < 0)
|
||||
{
|
||||
filter_stride_w = -stride_w;
|
||||
stride_w = 1;
|
||||
}
|
||||
vector<Conv2dTuple> tuples;
|
||||
for (size_t i = 0; i < args.size(); i += 15)
|
||||
tuples.push_back(Conv2dTuple(args, i));
|
||||
for (auto& tuple : tuples)
|
||||
tuple.run_matrix(processor);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Conv2dTuple::run_matrix(SubProcessor<T>& processor)
|
||||
{
|
||||
auto& S = processor.get_S();
|
||||
array<int, 3> dim({{1, weights_h * weights_w * n_channels_in, batch_size * output_h * output_w}});
|
||||
ShareMatrix<T> A(dim[0], dim[1]), B(dim[1], dim[2]);
|
||||
|
||||
if (not T::real_shares(processor.P))
|
||||
{
|
||||
matrix_multiply(A, B, processor);
|
||||
processor.protocol.matrix_multiply(A, B, processor);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -208,7 +194,7 @@ void Hemi<T>::conv2ds(SubProcessor<T>& processor,
|
||||
}
|
||||
}
|
||||
|
||||
auto C = matrix_multiply(A, B, processor);
|
||||
auto C = processor.protocol.matrix_multiply(A, B, processor);
|
||||
|
||||
for (int i_batch = 0; i_batch < batch_size; i_batch ++)
|
||||
{
|
||||
|
||||
@@ -37,6 +37,8 @@ public:
|
||||
if (swapped)
|
||||
std::swap(this->n_rows, this->n_cols);
|
||||
assert(this->n_cols >= this->n_rows);
|
||||
assert(prep.proc);
|
||||
this->P = &prep.proc->P;
|
||||
}
|
||||
|
||||
void set_protocol(typename ShareMatrix<T>::Protocol&)
|
||||
|
||||
@@ -21,11 +21,7 @@ void MaliciousShamirMC<T>::init_open(const Player& P, int n)
|
||||
reconstructions.resize(2 * threshold + 2);
|
||||
for (int i = threshold + 1; i <= 2 * threshold + 1; i++)
|
||||
{
|
||||
reconstructions[i].resize(i);
|
||||
for (int j = 0; j < i; j++)
|
||||
reconstructions[i][j] =
|
||||
Shamir<T>::get_rec_factor(P.get_player(j),
|
||||
P.num_players(), P.my_num(), i);
|
||||
reconstructions[i] = ShamirMC<T>::get_reconstruction(P, i);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user