Maintenance.

This commit is contained in:
Marcel Keller
2023-05-09 14:49:52 +10:00
parent c62ab2ca1e
commit 6cc3fccef0
135 changed files with 1658 additions and 1062 deletions

3
.gitmodules vendored
View File

@@ -1,9 +1,6 @@
[submodule "SimpleOT"] [submodule "SimpleOT"]
path = deps/SimpleOT path = deps/SimpleOT
url = https://github.com/mkskeller/SimpleOT url = https://github.com/mkskeller/SimpleOT
[submodule "mpir"]
path = deps/mpir
url = https://github.com/wbhart/mpir
[submodule "Programs/Circuits"] [submodule "Programs/Circuits"]
path = Programs/Circuits path = Programs/Circuits
url = https://github.com/mkskeller/bristol-fashion url = https://github.com/mkskeller/bristol-fashion

View File

@@ -43,8 +43,6 @@ class RealProgramParty : public ProgramPartySpec<T>
bool one_shot; bool one_shot;
size_t data_sent;
public: public:
static RealProgramParty& s(); static RealProgramParty& s();

View File

@@ -154,7 +154,6 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
while (next != GC::DONE_BREAK); while (next != GC::DONE_BREAK);
MC->Check(*P); MC->Check(*P);
data_sent = P->total_comm().sent;
if (online_opts.verbose) if (online_opts.verbose)
P->total_comm().print(); P->total_comm().print();
@@ -216,7 +215,7 @@ RealProgramParty<T>::~RealProgramParty()
delete prep; delete prep;
delete garble_inputter; delete garble_inputter;
delete garble_protocol; delete garble_protocol;
cout << "Data sent = " << data_sent * 1e-6 << " MB" << endl; garble_machine.print_comm(*this->P, this->P->total_comm());
T::MAC_Check::teardown(); T::MAC_Check::teardown();
} }

View File

@@ -62,11 +62,13 @@ private:
#endif #endif
}; };
#else #else
class BaseKeyVector : public vector<Key> class BaseKeyVector : public CheckVector<Key>
{ {
typedef CheckVector<Key> super;
public: public:
BaseKeyVector(int size = 0) : vector<Key>(size, Key(0)) {} BaseKeyVector(int size = 0) : super(size, Key(0)) {}
void resize(int size) { vector<Key>::resize(size, Key(0)); } void resize(int size) { super::resize(size, Key(0)); }
}; };
#endif #endif
@@ -296,7 +298,8 @@ public:
static void andm(GC::Processor<U>&, const BaseInstruction&) static void andm(GC::Processor<U>&, const BaseInstruction&)
{ throw runtime_error("andm not implemented"); } { throw runtime_error("andm not implemented"); }
static void run_tapes(const vector<int>&) { throw not_implemented(); } static void run_tapes(const vector<int>&)
{ throw runtime_error("multi-threading not implemented"); }
// most BMR phases don't need actual input // most BMR phases don't need actual input
template<class T> template<class T>

View File

@@ -1,5 +1,20 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
## 0.3.6 (May 9, 2023)
- More extensive benchmarking outputs
- Replace MPIR by GMP
- Secure reading of edaBits from files
- Semi-honest client communication
- Back-propagation for average pooling
- Parallelized convolution
- Probabilistic truncation as in ABY3
- More balanced communication in Shamir secret sharing
- Avoid unnecessary communication in Dealer protocol
- Linear solver using Cholesky decomposition
- Accept .py files for compilation
- Fixed security bug: proper accounting for random elements
## 0.3.5 (Feb 16, 2023) ## 0.3.5 (Feb 16, 2023)
- Easier-to-use machine learning interface - Easier-to-use machine learning interface

21
CONFIG
View File

@@ -35,15 +35,32 @@ ARM := $(shell uname -m | grep x86; echo $$?)
OS := $(shell uname -s) OS := $(shell uname -s)
ifeq ($(MACHINE), x86_64) ifeq ($(MACHINE), x86_64)
ifeq ($(OS), Linux) ifeq ($(OS), Linux)
ifeq ($(shell cat /proc/cpuinfo | grep -q avx2; echo $$?), 0)
AVX_OT = 1 AVX_OT = 1
else else
AVX_OT = 0 AVX_OT = 0
endif endif
else else
AVX_OT = 0
endif
else
ARCH = ARCH =
AVX_OT = 0 AVX_OT = 0
endif endif
ifeq ($(OS), Darwin)
BREW_CFLAGS += -I/usr/local/opt/openssl/include -I`brew --prefix`/opt/openssl/include -I`brew --prefix`/include
BREW_LDLIBS += -L/usr/local/opt/openssl/lib -L`brew --prefix`/lib -L`brew --prefix`/opt/openssl/lib
endif
ifeq ($(OS), Linux)
ifeq ($(ARM), 1)
ifeq ($(shell cat /proc/cpuinfo | grep -q aes; echo $$?), 0)
ARCH = -march=armv8.2-a+crypto
endif
endif
endif
USE_KOS = 0 USE_KOS = 0
# allow to set compiler in CONFIG.mine # allow to set compiler in CONFIG.mine
@@ -66,7 +83,8 @@ endif
# Default for MAX_MOD_SZ is 10, which suffices for all Overdrive protocols # Default for MAX_MOD_SZ is 10, which suffices for all Overdrive protocols
# MOD = -DMAX_MOD_SZ=10 -DGFP_MOD_SZ=5 # MOD = -DMAX_MOD_SZ=10 -DGFP_MOD_SZ=5
LDLIBS = -lmpirxx -lmpir -lsodium $(MY_LDLIBS) LDLIBS = -lgmpxx -lgmp -lsodium $(MY_LDLIBS)
LDLIBS += $(BREW_LDLIBS)
LDLIBS += -Wl,-rpath -Wl,$(CURDIR)/local/lib -L$(CURDIR)/local/lib LDLIBS += -Wl,-rpath -Wl,$(CURDIR)/local/lib -L$(CURDIR)/local/lib
LDLIBS += -lboost_system -lssl -lcrypto LDLIBS += -lboost_system -lssl -lcrypto
@@ -88,6 +106,7 @@ BOOST = -lboost_thread $(MY_BOOST)
endif endif
CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -I$(ROOT)/deps -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SSL_DIR) $(SECURE) -std=c++11 -Werror CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -I$(ROOT)/deps -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SSL_DIR) $(SECURE) -std=c++11 -Werror
CFLAGS += $(BREW_CFLAGS)
CPPFLAGS = $(CFLAGS) CPPFLAGS = $(CFLAGS)
LD = $(CXX) LD = $(CXX)

View File

@@ -17,8 +17,10 @@ import math
class SecretBitsAF(base.RegisterArgFormat): class SecretBitsAF(base.RegisterArgFormat):
reg_type = 'sb' reg_type = 'sb'
name = 'sbit'
class ClearBitsAF(base.RegisterArgFormat): class ClearBitsAF(base.RegisterArgFormat):
reg_type = 'cb' reg_type = 'cb'
name = 'cbit'
base.ArgFormats['sb'] = SecretBitsAF base.ArgFormats['sb'] = SecretBitsAF
base.ArgFormats['sbw'] = SecretBitsAF base.ArgFormats['sbw'] = SecretBitsAF

View File

@@ -338,16 +338,19 @@ class Merger:
d[j] = d[i] d[j] = d[i]
def read(reg, n): def read(reg, n):
last_read[reg] = n
for dup in reg.duplicates: for dup in reg.duplicates:
if last_def[dup] != -1: if last_def[dup] not in (-1, n):
add_edge(last_def[dup], n) add_edge(last_def[dup], n)
last_read[reg] = n
def write(reg, n): def write(reg, n):
last_def[reg] = n
for dup in reg.duplicates: for dup in reg.duplicates:
if last_read[dup] not in (-1, n): if last_read[dup] not in (-1, n):
add_edge(last_read[dup], n) add_edge(last_read[dup], n)
if id(dup) in [id(x) for x in block.instructions[n].get_used()] and \
last_read[dup] not in (-1, n):
add_edge(last_read[dup], n)
last_def[reg] = n
def handle_mem_access(addr, reg_type, last_access_this_kind, def handle_mem_access(addr, reg_type, last_access_this_kind,
last_access_other_kind): last_access_other_kind):
@@ -434,13 +437,6 @@ class Merger:
# if options.debug: # if options.debug:
# col = colordict[instr.__class__.__name__] # col = colordict[instr.__class__.__name__]
# G.add_node(n, color=col, label=str(instr)) # G.add_node(n, color=col, label=str(instr))
for reg in inputs:
if reg.vector and instr.is_vec():
for i in reg.vector:
read(i, n)
else:
read(reg, n)
for reg in outputs: for reg in outputs:
if reg.vector and instr.is_vec(): if reg.vector and instr.is_vec():
for i in reg.vector: for i in reg.vector:
@@ -448,6 +444,13 @@ class Merger:
else: else:
write(reg, n) write(reg, n)
for reg in inputs:
if reg.vector and instr.is_vec():
for i in reg.vector:
read(i, n)
else:
read(reg, n)
# will be merged # will be merged
if isinstance(instr, TextInputInstruction): if isinstance(instr, TextInputInstruction):
keep_text_order(instr, n) keep_text_order(instr, n)
@@ -556,18 +559,6 @@ class Merger:
if unused_result: if unused_result:
eliminate(i) eliminate(i)
count += 1 count += 1
# remove unnecessary stack instructions
# left by optimization with budget
if isinstance(inst, popint_class) and \
(not G.degree(i) or (G.degree(i) == 1 and
isinstance(instructions[list(G[i])[0]], StackInstruction))) \
and \
inst.args[0].can_eliminate and \
len(G.pred[i]) == 1 and \
isinstance(instructions[list(G.pred[i])[0]], pushint_class):
eliminate(list(G.pred[i])[0])
eliminate(i)
count += 2
if count > 0 and self.block.parent.program.verbose: if count > 0 and self.block.parent.program.verbose:
print('Eliminated %d dead instructions, among which %d opens: %s' \ print('Eliminated %d dead instructions, among which %d opens: %s' \
% (count, open_count, dict(stats))) % (count, open_count, dict(stats)))

View File

@@ -50,6 +50,9 @@ def set_variant(options):
do_precomp = False do_precomp = False
elif variant is not None: elif variant is not None:
raise CompilerError('Unknown comparison variant: %s' % variant) raise CompilerError('Unknown comparison variant: %s' % variant)
if const_rounds and instructions_base.program.options.binary:
raise CompilerError(
'Comparison variant choice incompatible with binary circuits')
def ld2i(c, n): def ld2i(c, n):
""" Load immediate 2^n into clear GF(p) register c """ """ Load immediate 2^n into clear GF(p) register c """

View File

@@ -22,6 +22,7 @@ class Compiler:
self.custom_args = custom_args self.custom_args = custom_args
self.build_option_parser() self.build_option_parser()
self.VARS = {} self.VARS = {}
self.root = os.path.dirname(__file__) + '/..'
def build_option_parser(self): def build_option_parser(self):
parser = OptionParser(usage=self.usage) parser = OptionParser(usage=self.usage)
@@ -269,7 +270,7 @@ class Compiler:
self.prog = Program(self.args, self.options, name=name) self.prog = Program(self.args, self.options, name=name)
if self.execute: if self.execute:
if self.options.execute in \ if self.options.execute in \
("emulate", "ring", "rep-field", "semi2k"): ("emulate", "ring", "rep-field"):
self.prog.use_trunc_pr = True self.prog.use_trunc_pr = True
if self.options.execute in ("ring",): if self.options.execute in ("ring",):
self.prog.use_split(3) self.prog.use_split(3)
@@ -405,7 +406,7 @@ class Compiler:
infile = open(self.prog.infile) infile = open(self.prog.infile)
# make compiler modules directly accessible # make compiler modules directly accessible
sys.path.insert(0, "Compiler") sys.path.insert(0, "%s/Compiler" % self.root)
# create the tapes # create the tapes
exec(compile(infile.read(), infile.name, "exec"), self.VARS) exec(compile(infile.read(), infile.name, "exec"), self.VARS)
@@ -477,15 +478,15 @@ class Compiler:
def local_execution(self, args=[]): def local_execution(self, args=[]):
executable = self.executable_from_protocol(self.options.execute) executable = self.executable_from_protocol(self.options.execute)
if not os.path.exists(executable): if not os.path.exists("%s/%s" % (self.root, executable)):
print("Creating binary for virtual machine...") print("Creating binary for virtual machine...")
try: try:
subprocess.run(["make", executable], check=True) subprocess.run(["make", executable], check=True, cwd=self.root)
except: except:
raise CompilerError( raise CompilerError(
"Cannot produce %s. " % executable + \ "Cannot produce %s. " % executable + \
"Note that compilation requires a few GB of RAM.") "Note that compilation requires a few GB of RAM.")
vm = 'Scripts/%s.sh' % self.options.execute vm = "%s/Scripts/%s.sh" % (self.root, self.options.execute)
os.execl(vm, vm, self.prog.name, *args) os.execl(vm, vm, self.prog.name, *args)
def remote_execution(self, args=[]): def remote_execution(self, args=[]):
@@ -496,7 +497,7 @@ class Compiler:
from fabric import Connection from fabric import Connection
import subprocess import subprocess
print("Creating static binary for virtual machine...") print("Creating static binary for virtual machine...")
subprocess.run(["make", "static/%s" % vm], check=True) subprocess.run(["make", "static/%s" % vm], check=True, cwd=self.root)
# transfer files # transfer files
import glob import glob
@@ -519,7 +520,7 @@ class Compiler:
"mkdir -p %s/{Player-Data,Programs/{Bytecode,Schedules}} " % \ "mkdir -p %s/{Player-Data,Programs/{Bytecode,Schedules}} " % \
dest) dest)
# executable # executable
connection.put("static/%s" % vm, dest) connection.put("%s/static/%s" % (self.root, vm), dest)
# program # program
dest += "/" dest += "/"
connection.put("Programs/Schedules/%s.sch" % self.prog.name, connection.put("Programs/Schedules/%s.sch" % self.prog.name,

View File

@@ -289,7 +289,7 @@ def BitDecRingRaw(a, k, m):
def BitDecRing(a, k, m): def BitDecRing(a, k, m):
bits = BitDecRingRaw(a, k, m) bits = BitDecRingRaw(a, k, m)
# reversing to reduce number of rounds # reversing to reduce number of rounds
return [types.sint.conv(bit) for bit in reversed(bits)][::-1] return [types.sintbit.conv(bit) for bit in reversed(bits)][::-1]
def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None): def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None):
instructions_base.set_global_vector_size(a.size) instructions_base.set_global_vector_size(a.size)
@@ -306,7 +306,7 @@ def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None):
def BitDecField(a, k, m, kappa, bits_to_compute=None): def BitDecField(a, k, m, kappa, bits_to_compute=None):
res = BitDecFieldRaw(a, k, m, kappa, bits_to_compute) res = BitDecFieldRaw(a, k, m, kappa, bits_to_compute)
return [types.sint.conv(bit) for bit in res] return [types.sintbit.conv(bit) for bit in res]
@instructions_base.ret_cisc @instructions_base.ret_cisc

View File

@@ -356,7 +356,17 @@ class reqbl(base.Instruction):
code = base.opcodes['REQBL'] code = base.opcodes['REQBL']
arg_format = ['int'] arg_format = ['int']
class active(base.Instruction):
""" Indicate whether program is compatible with malicious-security
protocols.
:param: 0 for no, 1 for yes
"""
code = base.opcodes['ACTIVE']
arg_format = ['int']
class time(base.IOInstruction): class time(base.IOInstruction):
""" Output time since start of computation. """ """ Output time since start of computation. """
code = base.opcodes['TIME'] code = base.opcodes['TIME']
arg_format = [] arg_format = []
@@ -2418,9 +2428,10 @@ class matmulsm(matmul_base):
super(matmulsm, self).add_usage(req_node) super(matmulsm, self).add_usage(req_node)
req_node.increment(('matmul', tuple(self.args[3:6])), 1) req_node.increment(('matmul', tuple(self.args[3:6])), 1)
class conv2ds(base.DataInstruction): class conv2ds(base.DataInstruction, base.VarArgsInstruction, base.Mergeable):
""" Secret 2D convolution. """ Secret 2D convolution.
:param: number of arguments to follow (int)
:param: result (sint vector in row-first order) :param: result (sint vector in row-first order)
:param: inputs (sint vector in row-first order) :param: inputs (sint vector in row-first order)
:param: weights (sint vector in row-first order) :param: weights (sint vector in row-first order)
@@ -2436,10 +2447,12 @@ class conv2ds(base.DataInstruction):
:param: padding height (int) :param: padding height (int)
:param: padding width (int) :param: padding width (int)
:param: batch size (int) :param: batch size (int)
:param: repeat from result...
""" """
code = base.opcodes['CONV2DS'] code = base.opcodes['CONV2DS']
arg_format = ['sw','s','s','int','int','int','int','int','int','int','int', arg_format = itertools.cycle(['sw','s','s','int','int','int','int','int',
'int','int','int','int'] 'int','int','int','int','int','int','int'])
data_type = 'triple' data_type = 'triple'
is_vec = lambda self: True is_vec = lambda self: True
@@ -2450,12 +2463,14 @@ class conv2ds(base.DataInstruction):
assert args[2].size == args[7] * args[8] * args[11] assert args[2].size == args[7] * args[8] * args[11]
def get_repeat(self): def get_repeat(self):
return self.args[3] * self.args[4] * self.args[7] * self.args[8] * \ args = self.args
self.args[11] * self.args[14] return sum(args[i+3] * args[i+4] * args[i+7] * args[i+8] * \
args[i+11] * args[i+14] for i in range(0, len(args), 15))
def add_usage(self, req_node): def add_usage(self, req_node):
super(conv2ds, self).add_usage(req_node) super(conv2ds, self).add_usage(req_node)
args = self.args for i in range(0, len(self.args), 15):
args = self.args[i:i + 15]
req_node.increment(('matmul', (1, args[7] * args[8] * args[11], req_node.increment(('matmul', (1, args[7] * args[8] * args[11],
args[14] * args[3] * args[4])), 1) args[14] * args[3] * args[4])), 1)

View File

@@ -66,6 +66,7 @@ opcodes = dict(
PLAYERID = 0xE4, PLAYERID = 0xE4,
USE_EDABIT = 0xE5, USE_EDABIT = 0xE5,
USE_MATMUL = 0x1F, USE_MATMUL = 0x1F,
ACTIVE = 0xE9,
# Addition # Addition
ADDC = 0x20, ADDC = 0x20,
ADDS = 0x21, ADDS = 0x21,
@@ -700,18 +701,23 @@ class RegisterArgFormat(ArgFormat):
class ClearModpAF(RegisterArgFormat): class ClearModpAF(RegisterArgFormat):
reg_type = RegType.ClearModp reg_type = RegType.ClearModp
name = 'cint'
class SecretModpAF(RegisterArgFormat): class SecretModpAF(RegisterArgFormat):
reg_type = RegType.SecretModp reg_type = RegType.SecretModp
name = 'sint'
class ClearGF2NAF(RegisterArgFormat): class ClearGF2NAF(RegisterArgFormat):
reg_type = RegType.ClearGF2N reg_type = RegType.ClearGF2N
name = 'cgf2n'
class SecretGF2NAF(RegisterArgFormat): class SecretGF2NAF(RegisterArgFormat):
reg_type = RegType.SecretGF2N reg_type = RegType.SecretGF2N
name = 'sgf2n'
class ClearIntAF(RegisterArgFormat): class ClearIntAF(RegisterArgFormat):
reg_type = RegType.ClearInt reg_type = RegType.ClearInt
name = 'regint'
class IntArgFormat(ArgFormat): class IntArgFormat(ArgFormat):
n_bits = 32 n_bits = 32

View File

@@ -1226,7 +1226,7 @@ def while_loop(loop_body, condition, arg=None, g=None):
result = loop_body(arg) result = loop_body(arg)
if isinstance(result, MemValue): if isinstance(result, MemValue):
result = result.read() result = result.read()
result.link(arg) arg.update(result)
return condition(result) return condition(result)
if not isinstance(pre_condition, (bool,int)) or pre_condition: if not isinstance(pre_condition, (bool,int)) or pre_condition:
if_statement(pre_condition, lambda: do_while(loop_fn, g=g)) if_statement(pre_condition, lambda: do_while(loop_fn, g=g))

View File

@@ -372,6 +372,7 @@ class Output(NoVariableLayer):
n = self.X.sizes[0] n = self.X.sizes[0]
if Y is None: if Y is None:
Y = self.Y Y = self.Y
assert isinstance(Y, Array)
n_correct = MemValue(0) n_correct = MemValue(0)
n_printed = MemValue(0) n_printed = MemValue(0)
@for_range_opt(n) @for_range_opt(n)
@@ -1109,14 +1110,7 @@ class Square(ElementWiseLayer):
f_prime = staticmethod(lambda x: cfix(2, size=x.size) * x) f_prime = staticmethod(lambda x: cfix(2, size=x.size) * x)
prime_type = sfix prime_type = sfix
class MaxPool(NoVariableLayer): class PoolBase(NoVariableLayer):
""" Fixed-point MaxPool layer.
:param shape: input shape (tuple/list of four int)
:param strides: strides (tuple/list of four int, first and last must be 1)
:param ksize: kernel size (tuple/list of four int, first and last must be 1)
:param padding: :py:obj:`'VALID'` (default) or :py:obj:`'SAME'`
"""
def __init__(self, shape, strides=(1, 2, 2, 1), ksize=(1, 2, 2, 1), def __init__(self, shape, strides=(1, 2, 2, 1), ksize=(1, 2, 2, 1),
padding='VALID'): padding='VALID'):
assert len(shape) == 4 assert len(shape) == 4
@@ -1152,38 +1146,6 @@ class MaxPool(NoVariableLayer):
(type(self).__name__, self.X.sizes, self.strides, (type(self).__name__, self.X.sizes, self.strides,
self.ksize, self.padding) self.ksize, self.padding)
def forward(self, batch=None, training=False):
if batch is None:
batch = Array.create_from(regint(0))
def process(pool, bi, k, i, j):
def m(a, b):
c = a[0] > b[0]
l = [c * x for x in a[1]]
l += [(1 - c) * x for x in b[1]]
return c.if_else(a[0], b[0]), l
red = util.tree_reduce(m, [(x[0], [1] if training else [])
for x in pool])
self.Y[bi][i][j][k] = red[0]
for ii, x in enumerate(red[1]):
self.comparisons[bi][k][i][j][ii] = x
self.traverse(batch, process)
def backward(self, compute_nabla_X=True, batch=None):
if compute_nabla_X:
self.nabla_X.alloc()
self.nabla_X.assign_all(0)
break_point()
def process(pool, bi, k, i, j):
for (x, h_in, w_in, h, w), c \
in zip(pool, self.comparisons[bi][k][i][j]):
hh = h * h_in
ww = w * w_in
res = h_in * w_in * c * self.nabla_Y[bi][i][j][k]
get_program().protect_memory(True)
self.nabla_X[bi][hh][ww][k] += res
get_program().protect_memory(False)
self.traverse(batch, process)
def traverse(self, batch, process): def traverse(self, batch, process):
need_padding = [self.strides[i] * (self.Y.sizes[i] - 1) + self.ksize[i] > need_padding = [self.strides[i] * (self.Y.sizes[i] - 1) + self.ksize[i] >
self.X.sizes[i] for i in range(4)] self.X.sizes[i] for i in range(4)]
@@ -1221,6 +1183,47 @@ class MaxPool(NoVariableLayer):
h_in, w_in, h, w]) h_in, w_in, h, w])
process(pool, bi, k, i, j) process(pool, bi, k, i, j)
class MaxPool(PoolBase):
""" Fixed-point MaxPool layer.
:param shape: input shape (tuple/list of four int)
:param strides: strides (tuple/list of four int, first and last must be 1)
:param ksize: kernel size (tuple/list of four int, first and last must be 1)
:param padding: :py:obj:`'VALID'` (default), :py:obj:`'SAME'`, integer, or
list/tuple of integers
"""
def forward(self, batch=None, training=False):
if batch is None:
batch = Array.create_from(regint(0))
def process(pool, bi, k, i, j):
def m(a, b):
c = a[0] > b[0]
l = [c * x for x in a[1]]
l += [(1 - c) * x for x in b[1]]
return c.if_else(a[0], b[0]), l
red = util.tree_reduce(m, [(x[0], [1] if training else [])
for x in pool])
self.Y[bi][i][j][k] = red[0]
for ii, x in enumerate(red[1]):
self.comparisons[bi][k][i][j][ii] = x
self.traverse(batch, process)
def backward(self, compute_nabla_X=True, batch=None):
if compute_nabla_X:
self.nabla_X.alloc()
self.nabla_X.assign_all(0)
break_point()
def process(pool, bi, k, i, j):
for (x, h_in, w_in, h, w), c \
in zip(pool, self.comparisons[bi][k][i][j]):
hh = h * h_in
ww = w * w_in
res = h_in * w_in * c * self.nabla_Y[bi][i][j][k]
get_program().protect_memory(True)
self.nabla_X[bi][hh][ww][k] += res
get_program().protect_memory(False)
self.traverse(batch, process)
class Argmax(NoVariableLayer): class Argmax(NoVariableLayer):
""" Fixed-point Argmax layer. """ Fixed-point Argmax layer.
@@ -2058,6 +2061,12 @@ def easyMaxPool(input_shape, kernel_size, stride=None, padding=0):
or tuple/list of two int or tuple/list of two int
""" """
kernel_size, stride, padding = \
_standardize_pool_options(kernel_size, stride, padding)
return MaxPool(input_shape, [1] + list(stride) + [1],
[1] + list(kernel_size) + [1], padding)
def _standardize_pool_options(kernel_size, stride, padding):
if isinstance(kernel_size, int): if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size) kernel_size = (kernel_size, kernel_size)
if isinstance(stride, int): if isinstance(stride, int):
@@ -2066,8 +2075,7 @@ def easyMaxPool(input_shape, kernel_size, stride=None, padding=0):
stride = kernel_size stride = kernel_size
padding = padding.upper() if isinstance(padding, str) \ padding = padding.upper() if isinstance(padding, str) \
else padding else padding
return MaxPool(input_shape, [1] + list(stride) + [1], return kernel_size, stride, padding
[1] + list(kernel_size) + [1], padding)
class QuantAveragePool2d(QuantBase, AveragePool2d): class QuantAveragePool2d(QuantBase, AveragePool2d):
def input_params_from(self, player): def input_params_from(self, player):
@@ -2075,14 +2083,47 @@ class QuantAveragePool2d(QuantBase, AveragePool2d):
for s in self.input_squant, self.output_squant: for s in self.input_squant, self.output_squant:
s.get_params_from(player) s.get_params_from(player)
class FixAveragePool2d(FixBase, AveragePool2d): class FixAveragePool2d(PoolBase, FixBase):
""" Fixed-point 2D AvgPool layer. """ Fixed-point 2D AvgPool layer.
:param input_shape: input shape (tuple/list of four int) :param input_shape: input shape (tuple/list of four int)
:param output_shape: output shape (tuple/list of four int) :param output_shape: output shape (tuple/list of four int)
:param filter_size: filter size (tuple/list of two int) :param filter_size: filter size (int or tuple/list of two int)
:param strides: strides (tuple/list of two int) :param strides: strides (int or tuple/list of two int)
:param padding: :py:obj:`'SAME'`, :py:obj:`'VALID'`, int,
or tuple/list of two int
""" """
def __init__(self, input_shape, output_shape, filter_size, strides=(1, 1),
padding=0):
filter_size, strides, padding = \
_standardize_pool_options(filter_size, strides, padding)
PoolBase.__init__(self, input_shape, [1] + list(strides) + [1],
[1] + list(filter_size) + [1], padding)
self.pool_size = reduce(operator.mul, filter_size)
if output_shape:
assert self.Y.shape == list(output_shape)
def _forward(self, batch):
def process(pool, bi, k, i, j):
self.Y[bi][i][j][k] = sum(x[0] for x in pool) * (1 / self.pool_size)
self.traverse(batch, process)
def backward(self, compute_nabla_X=True, batch=None):
if compute_nabla_X:
self.nabla_X.alloc()
self.nabla_X.assign_all(0)
break_point()
def process(pool, bi, k, i, j):
part = self.nabla_Y[bi][i][j][k] * (1 / self.pool_size)
for x, h_in, w_in, h, w in pool:
hh = h * h_in
ww = w * w_in
res = h_in * w_in * part
get_program().protect_memory(True)
self.nabla_X[bi][hh][ww][k] += res
get_program().protect_memory(False)
self.traverse(batch, process)
class QuantReshape(QuantBase, BaseLayer): class QuantReshape(QuantBase, BaseLayer):
def __init__(self, input_shape, _, output_shape): def __init__(self, input_shape, _, output_shape):
@@ -2265,6 +2306,8 @@ class Optimizer:
:param data: sample data (:py:class:`Compiler.types.Matrix` with one row per sample) :param data: sample data (:py:class:`Compiler.types.Matrix` with one row per sample)
:param top: return top prediction instead of probability distribution :param top: return top prediction instead of probability distribution
:returns: sfix/sint Array (depening on :py:obj:`top`)
""" """
if isinstance(self.layers[-1].Y, Array) or top: if isinstance(self.layers[-1].Y, Array) or top:
if top: if top:
@@ -2540,6 +2583,8 @@ class Optimizer:
@_no_mem_warnings @_no_mem_warnings
def run_by_args(self, program, n_runs, batch_size, test_X, test_Y, def run_by_args(self, program, n_runs, batch_size, test_X, test_Y,
acc_batch_size=None, reset=True): acc_batch_size=None, reset=True):
MultiArray.disable_index_checks()
Array.check_indices = False
if acc_batch_size is None: if acc_batch_size is None:
acc_batch_size = batch_size acc_batch_size = batch_size
depreciation = None depreciation = None
@@ -2943,6 +2988,10 @@ class keras:
return 'maxpool', {'pool_size': pool_size, 'strides': strides, return 'maxpool', {'pool_size': pool_size, 'strides': strides,
'padding': padding} 'padding': padding}
def AveragePooling2D(pool_size=2, strides=None, padding='valid'):
return 'avgpool', {'filter_size': pool_size, 'strides': strides,
'padding': padding}
def Dropout(rate): def Dropout(rate):
l = math.log(rate, 2) l = math.log(rate, 2)
if int(l) != l: if int(l) != l:
@@ -3014,9 +3063,12 @@ class keras:
n_units = reduce(operator.mul, n_units = reduce(operator.mul,
layers[-1].Y.sizes[1:]) layers[-1].Y.sizes[1:])
if i == len(self.layers) - 1: if i == len(self.layers) - 1:
if layer[2].get('activation', 'softmax') in \ activation = layer[2].get('activation', None)
('softmax', 'sigmoid'): if activation in ('softmax', 'sigmoid'):
layer[2].pop('activation', None) layer[2].pop('activation', None)
if activation == 'softmax' and layer[1][0] == 1:
raise CompilerError(
'softmax requires more than one output neuron')
layers.append(Dense(N, n_units, layer[1][0], layers.append(Dense(N, n_units, layer[1][0],
**layer[2])) **layer[2]))
input_shape = layers[-1].Y.sizes input_shape = layers[-1].Y.sizes
@@ -3041,6 +3093,9 @@ class keras:
layers.append(easyMaxPool(input_shape, pool_size, layers.append(easyMaxPool(input_shape, pool_size,
strides, padding)) strides, padding))
input_shape = layers[-1].Y.sizes input_shape = layers[-1].Y.sizes
elif name == 'avgpool':
layers.append(FixAveragePool2d(input_shape, None, **layer[1]))
input_shape = layers[-1].Y.sizes
elif name == 'dropout': elif name == 'dropout':
layers.append(Dropout(batch_size, reduce( layers.append(Dropout(batch_size, reduce(
operator.mul, layers[-1].Y.sizes[1:]), operator.mul, layers[-1].Y.sizes[1:]),
@@ -3192,6 +3247,10 @@ def layers_from_torch(sequence, data_input_shape, batch_size, input_via=None):
layers.append(easyMaxPool(input_shape, item.kernel_size, layers.append(easyMaxPool(input_shape, item.kernel_size,
item.stride, item.padding)) item.stride, item.padding))
input_shape = layers[-1].Y.shape input_shape = layers[-1].Y.shape
elif name == 'AvgPool2d':
layers.append(FixAveragePool2d(input_shape, None, item.kernel_size,
item.stride, item.padding))
input_shape = layers[-1].Y.shape
elif name == 'ReLU': elif name == 'ReLU':
layers.append(Relu(input_shape)) layers.append(Relu(input_shape))
elif name == 'Flatten': elif name == 'Flatten':
@@ -3295,7 +3354,7 @@ class SGDLogistic(OneLayerSGD):
return super(SGDLogistic, self).predict(X) return super(SGDLogistic, self).predict(X)
class SGDLinear(OneLayerSGD): class SGDLinear(OneLayerSGD):
""" Logistic regression using SGD. """ Linear regression using SGD.
:param n_epochs: number of epochs :param n_epochs: number of epochs
:param batch_size: batch size :param batch_size: batch size
@@ -3415,11 +3474,16 @@ def var(x):
return res.read() return res.read()
def cholesky(A, reveal_diagonal=False): def cholesky(A, reveal_diagonal=False):
""" Cholesky decomposition. """ """ Cholesky decomposition.
:returns: lower triangular matrix
"""
assert len(A.shape) == 2 assert len(A.shape) == 2
assert A.shape[0] == A.shape[1] assert A.shape[0] == A.shape[1]
L = A.same_shape() L = A.same_shape()
L.assign_all(0) L.assign_all(0)
diag_inv = A.value_type.Array(A.shape[0])
@for_range(A.shape[0]) @for_range(A.shape[0])
def _(i): def _(i):
@for_range(i + 1) @for_range(i + 1)
@@ -3429,10 +3493,47 @@ def cholesky(A, reveal_diagonal=False):
@if_e(i == j) @if_e(i == j)
def _(): def _():
L[i][j] = mpc_math.sqrt(A[i][i] - sum) L[i][j] = mpc_math.sqrt(A[i][i] - sum)
diag_inv[i] = 1 / L[i][j]
if reveal_diagonal: if reveal_diagonal:
print_ln('L[%s][%s] = %s = sqrt(%s - %s)', i, j, print_ln('L[%s][%s] = %s = sqrt(%s - %s)', i, j,
L[i][j].reveal(), A[i][j].reveal(), sum.reveal()) L[i][j].reveal(), A[i][j].reveal(), sum.reveal())
@else_ @else_
def _(): def _():
L[i][j] = (1.0 / L[j][j] * (A[i][j] - sum)) L[i][j] = (diag_inv[j] * (A[i][j] - sum))
return L return L
def solve_lower(A, b):
""" Linear solver where :py:obj:`A` is lower triangular quadratic. """
assert len(A.shape) == 2
assert A.shape[0] == A.shape[1]
assert len(b) == A.shape[0]
b = Array.create_from(b)
res = sfix.Array(len(b))
@for_range(len(b))
def _(i):
res[i] = b[i] / A[i][i]
b[:] -= res[i] * A.get_column(i)
return res
def solve_upper(A, b):
""" Linear solver where :py:obj:`A` is upper triangular quadratic. """
assert len(A.shape) == 2
assert A.shape[0] == A.shape[1]
assert len(b) == A.shape[0]
b = Array.create_from(b)
res = sfix.Array(len(b))
@for_range(len(b) - 1, -1, -1)
def _(i):
res[i] = b[i] / A[i][i]
b[:] -= res[i] * A.get_column(i)
return res
def solve_cholesky(A, b, debug=False):
""" Linear solver using Cholesky decomposition. """
L = cholesky(A, reveal_diagonal=debug)
if debug:
Optimizer.stat('L', L)
x = solve_lower(L, b)
if debug:
Optimizer.stat('intermediate', x)
return solve_upper(L.transpose(), x)

View File

@@ -661,7 +661,7 @@ def sqrt_simplified_fx(x):
h = h * r h = h * r
H = 4 * (h * h) H = 4 * (h * h)
if not x.round_nearest or (2 * f < k - 1): if not x.round_nearest or (2 * x.f < x.k - 1):
H = (h < 2 ** (-x.f / 2) / 2).if_else(0, H) H = (h < 2 ** (-x.f / 2) / 2).if_else(0, H)
H = H * x H = H * x
@@ -806,9 +806,7 @@ def sqrt_fx(x_l, k, f):
@instructions_base.sfix_cisc @instructions_base.sfix_cisc
def sqrt(x, k=None, f=None): def sqrt(x, k=None, f=None):
""" """
Returns the square root (sfix) of any given fractional Square root.
value as long as it can be rounded to a integral value
with :py:obj:`f` bits of decimal precision.
:param x: fractional input (sfix). :param x: fractional input (sfix).

View File

@@ -186,6 +186,8 @@ class Program(object):
self.input_files = {} self.input_files = {}
self.base_addresses = {} self.base_addresses = {}
self._protect_memory = False self._protect_memory = False
self._always_active = True
self.active = True
if not self.options.cisc: if not self.options.cisc:
self.options.cisc = not self.options.optimize_hard self.options.cisc = not self.options.optimize_hard
@@ -207,16 +209,14 @@ class Program(object):
return self.n_threads return self.n_threads
def init_names(self, args): def init_names(self, args):
# ignore path to file - source must be in Programs/Source
if "Programs" in os.listdir(os.getcwd()):
# compile prog in ./Programs/Source directory
self.programs_dir = "Programs" self.programs_dir = "Programs"
else:
# assume source is in main SPDZ directory
self.programs_dir = sys.path[0] + "/Programs"
if self.verbose: if self.verbose:
print("Compiling program in", self.programs_dir) print("Compiling program in", self.programs_dir)
for dirname in (self.programs_dir, "Player-Data"):
if not os.path.exists(dirname):
os.mkdir(dirname)
# create extra directories if needed # create extra directories if needed
for dirname in ["Public-Input", "Bytecode", "Schedules"]: for dirname in ["Public-Input", "Bytecode", "Schedules"]:
if not os.path.exists(self.programs_dir + "/" + dirname): if not os.path.exists(self.programs_dir + "/" + dirname):
@@ -224,13 +224,29 @@ class Program(object):
if self.name is None: if self.name is None:
self.name = args[0].split("/")[-1] self.name = args[0].split("/")[-1]
if self.name.endswith(".mpc"): exts = ".mpc", ".py"
self.name = self.name[:-4] for ext in exts:
if self.name.endswith(ext):
self.name = self.name[:-len(ext)]
if os.path.exists(args[0]): if os.path.exists(args[0]):
self.infile = args[0] self.infile = args[0]
else: else:
self.infile = self.programs_dir + "/Source/" + self.name + ".mpc" infiles = []
for x in (self.programs_dir, sys.path[0] + "/Programs"):
for ext in exts:
filename = args[0]
if not filename.endswith(ext):
filename += ext
infiles += [x + "/Source/" + filename]
for f in infiles:
if os.path.exists(f):
self.infile = f
break
else:
raise CompilerError(
"found none of the potential input files: " +
", ".join("'%s'" % x for x in [args[0]] + infiles))
""" """
self.name is input file name (minus extension) + any optional arguments. self.name is input file name (minus extension) + any optional arguments.
Used to generate output filenames Used to generate output filenames
@@ -479,6 +495,9 @@ class Program(object):
# finalize the memory # finalize the memory
self.finalize_memory() self.finalize_memory()
# communicate protocol compability
Compiler.instructions.active(self._always_active)
self.write_bytes() self.write_bytes()
if self.options.asmoutfile: if self.options.asmoutfile:
@@ -672,6 +691,19 @@ class Program(object):
logp = int(round(math.log(p, 2))) logp = int(round(math.log(p, 2)))
return abs(p - 2 ** logp) / p < 2 ** -self.security return abs(p - 2 ** logp) / p < 2 ** -self.security
@property
def active(self):
""" Whether to use actively secure protocols. """
return self._active
@active.setter
def active(self, change):
self._always_active &= change
self._active = change
def semi_honest(self):
self._always_active = False
@staticmethod @staticmethod
def read_tapes(schedule): def read_tapes(schedule):
m = re.search(r"([^/]*)\.mpc", schedule) m = re.search(r"([^/]*)\.mpc", schedule)
@@ -1454,6 +1486,9 @@ class Tape:
return Tape.Register(self.reg_type, Program.prog.curr_tape) return Tape.Register(self.reg_type, Program.prog.curr_tape)
def link(self, other): def link(self, other):
if Program.prog.options.noreallocate:
raise CompilerError("reallocation necessary for linking, "
"remove option -u")
self.duplicates |= other.duplicates self.duplicates |= other.duplicates
for dup in self.duplicates: for dup in self.duplicates:
dup.duplicates = self.duplicates dup.duplicates = self.duplicates
@@ -1466,12 +1501,15 @@ class Tape:
:param other: any convertible type :param other: any convertible type
""" """
if isinstance(other, Tape.Register) and other.block != Program.prog.curr_block:
other = type(self)(other) other = type(self)(other)
else:
other = self.conv(other)
if Program.prog.curr_block in [x.block for x in self.duplicates]:
self.program.start_new_basicblock()
if self.program != other.program: if self.program != other.program:
raise CompilerError( raise CompilerError(
'cannot update register with one from another thread') 'cannot update register with one from another thread')
if other.block in [x.block for x in self.duplicates]:
self.program.start_new_basicblock()
self.link(other) self.link(other)
@property @property

View File

@@ -659,6 +659,7 @@ class _secret_structure(_structure):
traverse(x, level + 1) traverse(x, level + 1)
traverse(content, 0) traverse(content, 0)
f.write('\n') f.write('\n')
f.flush()
if requested_shape is not None and \ if requested_shape is not None and \
list(shape) != list(requested_shape): list(shape) != list(requested_shape):
raise CompilerError('content contradicts shape') raise CompilerError('content contradicts shape')
@@ -2415,26 +2416,34 @@ class sint(_secret, _int):
def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType): def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType):
""" Securely obtain shares of values input by a client. """ Securely obtain shares of values input by a client.
This uses the triple-based input protocol introduced by This uses the triple-based input protocol introduced by
`Damgård et al. <http://eprint.iacr.org/2015/1006>`_ `Damgård et al. <http://eprint.iacr.org/2015/1006>`_ unless
:py:obj:`program.active` is set to false, in which case
it uses random values to mask the clients' input.
:param n: number of inputs (int) :param n: number of inputs (int)
:param client_id: regint :param client_id: regint
:param size: vector size (default 1) :param size: vector size (default 1)
:returns: list of sint :returns: list of sint
""" """
if program.active:
# send shares of a triple to client # send shares of a triple to client
triples = list(itertools.chain(*(sint.get_random_triple() for i in range(n)))) triples = list(itertools.chain(*(sint.get_random_triple() for i in range(n))))
else:
triples = [sint.get_random() for i in range(n)]
sint.write_shares_to_socket(client_id, triples, message_type) sint.write_shares_to_socket(client_id, triples, message_type)
received = util.tuplify(cint.read_from_socket(client_id, n)) received = util.tuplify(cint.read_from_socket(client_id, n))
y = [0] * n y = [0] * n
for i in range(n): for i in range(n):
y[i] = received[i] - triples[i * 3] y[i] = received[i] - triples[i * 3 if program.active else i]
return y return y
@classmethod @classmethod
def reveal_to_clients(cls, clients, values): def reveal_to_clients(cls, clients, values):
""" Reveal securely to clients. """ Reveal securely to clients.
Uses :py:obj:`program.active` to determine whether to use
triples for active security.
:param clients: client ids (list or array) :param clients: client ids (list or array)
:param values: list of sint to reveal :param values: list of sint to reveal
@@ -2445,8 +2454,11 @@ class sint(_secret, _int):
for value in values: for value in values:
assert(value.size == values[0].size) assert(value.size == values[0].size)
if program.active:
r = sint.get_random() r = sint.get_random()
to_send += [value, r, value * r] to_send += [value, r, value * r]
else:
to_send += [value]
if isinstance(clients, Array): if isinstance(clients, Array):
n_clients = clients.length n_clients = clients.length
@@ -2844,7 +2856,7 @@ class sint(_secret, _int):
privateoutput(self.size, player, res._v, self) privateoutput(self.size, player, res._v, self)
return res return res
def private_division(self, divisor, active=True, dividend_length=None, def private_division(self, divisor, active=None, dividend_length=None,
divisor_length=None): divisor_length=None):
""" Private integer division as per `Veugen and Abspoel """ Private integer division as per `Veugen and Abspoel
<https://doi.org/10.2478/popets-2021-0073>`_ <https://doi.org/10.2478/popets-2021-0073>`_
@@ -2878,6 +2890,9 @@ class sint(_secret, _int):
z_shared = ((self << (l + sigma)) + h + r_pprime) z_shared = ((self << (l + sigma)) + h + r_pprime)
z = z_shared.reveal_to(0) z = z_shared.reveal_to(0)
if active is None:
active = program.active
if active: if active:
z_prime = [sint(x) for x in (z // d).bit_decompose(min_length)] z_prime = [sint(x) for x in (z // d).bit_decompose(min_length)]
check = [(x * (1 - x)).reveal() == 0 for x in z_prime] check = [(x * (1 - x)).reveal() == 0 for x in z_prime]
@@ -2893,6 +2908,7 @@ class sint(_secret, _int):
y_prime = sint.bit_compose(z_prime[:l + sigma]) y_prime = sint.bit_compose(z_prime[:l + sigma])
y = sint.bit_compose(z_prime[l + sigma:]) y = sint.bit_compose(z_prime[l + sigma:])
else: else:
program.semi_honest()
y = sint(z // (d << (l + sigma))) y = sint(z // (d << (l + sigma)))
y_prime = sint((z // d) % (2 ** (l + sigma))) y_prime = sint((z // d) % (2 ** (l + sigma)))
@@ -3147,7 +3163,9 @@ class sgf2n(_secret, _gf2n):
for i in range(0, bit_length, step)] for i in range(0, bit_length, step)]
one = cgf2n(1) one = cgf2n(1)
masked = sum([b * (one << (i * step)) for i,b in enumerate(random_bits)], self).reveal() masked = sum([b * (one << (i * step))
for i,b in enumerate(random_bits)], self).reveal(
check=False)
masked_bits = masked.bit_decompose(bit_length,step=step) masked_bits = masked.bit_decompose(bit_length,step=step)
return [m + r for m,r in zip(masked_bits, random_bits)] return [m + r for m,r in zip(masked_bits, random_bits)]
@@ -3157,7 +3175,9 @@ class sgf2n(_secret, _gf2n):
for i in range(8)] for i in range(8)]
one = cgf2n(1) one = cgf2n(1)
wanted_positions = [0, 5, 10, 15, 20, 25, 30, 35] wanted_positions = [0, 5, 10, 15, 20, 25, 30, 35]
masked = sum([b * (one << wanted_positions[i]) for i,b in enumerate(random_bits)], self).reveal() masked = sum([b * (one << wanted_positions[i])
for i,b in enumerate(random_bits)], self).reveal(
check=False)
return [self.clear_type((masked >> wanted_positions[i]) & one) + r for i,r in enumerate(random_bits)] return [self.clear_type((masked >> wanted_positions[i]) & one) + r for i,r in enumerate(random_bits)]
for t in (sint, sgf2n): for t in (sint, sgf2n):
@@ -4080,7 +4100,8 @@ class _single(_number, _secret_structure):
@vectorized_classmethod @vectorized_classmethod
def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType): def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType):
""" """
Securely obtain shares of values input by a client. Assumes client Securely obtain shares of values input by a client via
:py:func:`sint.receive_from_client`. Assumes client
has already converted values to integer representation. has already converted values to integer representation.
:param n: number of inputs (int) :param n: number of inputs (int)
@@ -4095,7 +4116,7 @@ class _single(_number, _secret_structure):
@classmethod @classmethod
def reveal_to_clients(cls, clients, values): def reveal_to_clients(cls, clients, values):
""" Reveal securely to clients. """ Reveal securely to clients via :py:func:`sint.reveal_to_clients`.
:param clients: client ids (list or array) :param clients: client ids (list or array)
:param values: list of values of this class :param values: list of values of this class
@@ -4556,7 +4577,7 @@ class sfix(_fix):
:py:class:`sfix`, and comparisons (``==, !=, <, <=, >, >=``), :py:class:`sfix`, and comparisons (``==, !=, <, <=, >, >=``),
returning :py:class:`sbitint`. The other operand can be any of returning :py:class:`sbitint`. The other operand can be any of
sfix/sint/cfix/regint/cint/int/float. It also supports ``abs()`` sfix/sint/cfix/regint/cint/int/float. It also supports ``abs()``
and ``**``, the latter for integer exponents. and ``**``.
Note that the default precision (16 bits after the dot, 31 bits in Note that the default precision (16 bits after the dot, 31 bits in
total) only allows numbers up to :math:`2^{31-16-1} \\approx total) only allows numbers up to :math:`2^{31-16-1} \\approx
@@ -4669,6 +4690,8 @@ class sfix(_fix):
return self.v return self.v
def mul_no_reduce(self, other, res_params=None): def mul_no_reduce(self, other, res_params=None):
if not isinstance(other, type(self)):
return self * other
assert self.f == other.f assert self.f == other.f
assert self.k == other.k assert self.k == other.k
return self.unreduced(self.v * other.v) return self.unreduced(self.v * other.v)
@@ -4734,6 +4757,11 @@ class unreduced_sfix(_single):
nearest=sfix.round_nearest, signed=True) nearest=sfix.round_nearest, signed=True)
return sfix._new(v, k=self.k - self.m, f=self.m) return sfix._new(v, k=self.k - self.m, f=self.m)
def update(self, other):
assert self.k == other.k
assert self.m == other.m
self.v.update(other.v)
sfix.unreduced_type = unreduced_sfix sfix.unreduced_type = unreduced_sfix
sfix.set_precision(16, 31) sfix.set_precision(16, 31)
@@ -4953,6 +4981,8 @@ class sfloat(_number, _secret_structure):
This uses integer operations internally, see :py:class:`sint` for security This uses integer operations internally, see :py:class:`sint` for security
considerations. considerations.
See `Aliasgari et al. <https://eprint.iacr.org/2012/405.pdf>`_ for
details.
The type supports basic arithmetic (``+, -, *, /``), returning The type supports basic arithmetic (``+, -, *, /``), returning
:py:class:`sfloat`, and comparisons (``==, !=, <, <=, >, >=``), :py:class:`sfloat`, and comparisons (``==, !=, <, <=, >, >=``),
@@ -5459,6 +5489,9 @@ class Array(_vectorizable):
b.input_from(1) b.input_from(1)
a[:] += b[:] a[:] += b[:]
Arrays aren't initialized on creation, you need to call
:py:func:`assign_all` to initialize them to a constant value.
""" """
check_indices = True check_indices = True
@@ -5708,7 +5741,7 @@ class Array(_vectorizable):
mem_value = MemValue(value) mem_value = MemValue(value)
self.address = MemValue.if_necessary(self.address) self.address = MemValue.if_necessary(self.address)
n_threads = 8 if use_threads and util.is_constant(self.length) and \ n_threads = 8 if use_threads and util.is_constant(self.length) and \
len(self) > 2**20 else None len(self) > 2**20 and not program.options.garbled else None
@library.multithread(n_threads, self.length) @library.multithread(n_threads, self.length)
def _(base, size): def _(base, size):
if use_vector: if use_vector:
@@ -5896,7 +5929,7 @@ class Array(_vectorizable):
self.assign_vector(self.get_vector().secure_shuffle()) self.assign_vector(self.get_vector().secure_shuffle())
def secure_permute(self, *args, **kwargs): def secure_permute(self, *args, **kwargs):
""" Secure permutate in place according to the security model. """ Secure permute in place according to the security model.
See :py:func:`MultiArray.secure_shuffle` for references. See :py:func:`MultiArray.secure_shuffle` for references.
:param permutation: output of :py:func:`sint.get_secure_shuffle()` :param permutation: output of :py:func:`sint.get_secure_shuffle()`
@@ -6227,6 +6260,9 @@ class SubMultiArray(_vectorizable):
def same_shape(self): def same_shape(self):
""" :return: new multidimensional array with same shape and basic type """ """ :return: new multidimensional array with same shape and basic type """
if len(self.sizes) == 2:
return Matrix(*self.sizes, self.value_type)
else:
return MultiArray(self.sizes, self.value_type) return MultiArray(self.sizes, self.value_type)
def get_part(self, start, size): def get_part(self, start, size):
@@ -6400,7 +6436,7 @@ class SubMultiArray(_vectorizable):
pass pass
t.params = res_params t.params = res_params
else: else:
if issubclass(self.value_type, _secret_structure): if self.value_type == other.value_type:
t = self.value_type t = self.value_type
else: else:
t = type(self.value_type(0) * other.value_type(0)) t = type(self.value_type(0) * other.value_type(0))
@@ -6435,10 +6471,12 @@ class SubMultiArray(_vectorizable):
# fallback for binary circuits # fallback for binary circuits
@library.for_range_opt(other.sizes[1]) @library.for_range_opt(other.sizes[1])
def _(j): def _(j):
res_matrix[i][j] = 0 tmp = self[i][0].mul_no_reduce(other[0][j])
@library.for_range_opt(self.sizes[1]) @library.for_range_opt(1, self.sizes[1])
def _(k): def _(k):
res_matrix[i][j] += self[i][k] * other[k][j] prod = self[i][k].mul_no_reduce(other[k][j])
tmp.iadd(prod)
res_matrix[i][j] = tmp.reduce_after_mul()
return res_matrix return res_matrix
elif isinstance(other, self.value_type): elif isinstance(other, self.value_type):
return self * Array.create_from(other) return self * Array.create_from(other)
@@ -6780,6 +6818,9 @@ class MultiArray(SubMultiArray):
a[1].input_from(1) a[1].input_from(1)
a[2][:] = a[0][:] * a[1][:] a[2][:] = a[0][:] * a[1][:]
Arrays aren't initialized on creation, you need to call
:py:func:`assign_all` to initialize them to a constant value.
""" """
@staticmethod @staticmethod
def disable_index_checks(): def disable_index_checks():
@@ -6817,6 +6858,9 @@ class Matrix(MultiArray):
:param columns: compile-time (int) :param columns: compile-time (int)
:param value_type: basic type of entries :param value_type: basic type of entries
Matrices aren't initialized on creation, you need to call
:py:func:`assign_all` to initialize them to a constant value.
""" """
def __init__(self, rows, columns, value_type, debug=None, address=None): def __init__(self, rows, columns, value_type, debug=None, address=None):
MultiArray.__init__(self, [rows, columns], value_type, debug=debug, \ MultiArray.__init__(self, [rows, columns], value_type, debug=debug, \

View File

@@ -47,23 +47,16 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
libboost-dev \ libboost-dev \
libboost-thread-dev \ libboost-thread-dev \
libclang-dev \ libclang-dev \
libgmp-dev \
libntl-dev \ libntl-dev \
libsodium-dev \ libsodium-dev \
libssl-dev \ libssl-dev \
libtool \ libtool \
m4 \
texinfo \
yasm \
vim \ vim \
gdb \ gdb \
valgrind \ valgrind \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# mpir
COPY --from=initc3/mpir:55fe6a9 /usr/local/mpir/include/* /usr/local/include/
COPY --from=initc3/mpir:55fe6a9 /usr/local/mpir/lib/* /usr/local/lib/
COPY --from=initc3/mpir:55fe6a9 /usr/local/mpir/share/info/* /usr/local/share/info/
ENV MP_SPDZ_HOME /usr/src/MP-SPDZ ENV MP_SPDZ_HOME /usr/src/MP-SPDZ
WORKDIR $MP_SPDZ_HOME WORKDIR $MP_SPDZ_HOME

View File

@@ -46,6 +46,7 @@ void Client::send_private_inputs(const vector<T>& values)
octetStream os; octetStream os;
vector< vector<T> > triples(num_inputs, vector<T>(3)); vector< vector<T> > triples(num_inputs, vector<T>(3));
vector<T> triple_shares(3); vector<T> triple_shares(3);
bool active = true;
// Receive num_inputs triples from SPDZ // Receive num_inputs triples from SPDZ
for (size_t j = 0; j < sockets.size(); j++) for (size_t j = 0; j < sockets.size(); j++)
@@ -61,9 +62,21 @@ void Client::send_private_inputs(const vector<T>& values)
cerr << "received " << os.get_length() << " from " << j << endl << flush; cerr << "received " << os.get_length() << " from " << j << endl << flush;
#endif #endif
if (j == 0)
{
if (os.get_length() == 3 * values.size() * T::size())
active = true;
else
active = false;
}
int n_expected = active ? 3 : 1;
if (os.get_length() != n_expected * T::size() * values.size())
throw runtime_error("unexpected data length in sending");
for (int j = 0; j < num_inputs; j++) for (int j = 0; j < num_inputs; j++)
{ {
for (int k = 0; k < 3; k++) for (int k = 0; k < n_expected; k++)
{ {
triple_shares[k].unpack(os); triple_shares[k].unpack(os);
triples[j][k] += triple_shares[k]; triples[j][k] += triple_shares[k];
@@ -71,6 +84,7 @@ void Client::send_private_inputs(const vector<T>& values)
} }
} }
if (active)
// Check triple relations (is a party cheating?) // Check triple relations (is a party cheating?)
for (int i = 0; i < num_inputs; i++) for (int i = 0; i < num_inputs; i++)
{ {
@@ -81,6 +95,7 @@ void Client::send_private_inputs(const vector<T>& values)
throw mac_fail(); throw mac_fail();
} }
} }
// Send inputs + triple[0], so SPDZ can compute shares of each value // Send inputs + triple[0], so SPDZ can compute shares of each value
os.reset_write_head(); os.reset_write_head();
for (int i = 0; i < num_inputs; i++) for (int i = 0; i < num_inputs; i++)
@@ -100,6 +115,7 @@ vector<U> Client::receive_outputs(int n)
{ {
vector<T> triples(3 * n); vector<T> triples(3 * n);
octetStream os; octetStream os;
bool active = true;
for (auto& socket : sockets) for (auto& socket : sockets)
{ {
os.reset_write_head(); os.reset_write_head();
@@ -107,7 +123,20 @@ vector<U> Client::receive_outputs(int n)
#ifdef VERBOSE_COMM #ifdef VERBOSE_COMM
cout << "received " << os.get_length() << endl << flush; cout << "received " << os.get_length() << endl << flush;
#endif #endif
for (int j = 0; j < 3 * n; j++)
if (socket == sockets[0])
{
if (os.get_length() == (size_t) 3 * n * T::size())
active = true;
else
active = false;
}
int n_expected = n * (active ? 3 : 1);
if (os.get_length() != (size_t) n_expected * T::size())
throw runtime_error("unexpected data length in receiving");
for (int j = 0; j < n_expected; j++)
{ {
T value; T value;
value.unpack(os); value.unpack(os);
@@ -115,6 +144,8 @@ vector<U> Client::receive_outputs(int n)
} }
} }
if (active)
{
vector<U> output_values; vector<U> output_values;
for (int i = 0; i < 3 * n; i += 3) for (int i = 0; i < 3 * n; i += 3)
{ {
@@ -128,3 +159,9 @@ vector<U> Client::receive_outputs(int n)
return output_values; return output_values;
} }
else
{
triples.resize(n);
return triples;
}
}

View File

@@ -34,12 +34,20 @@ class Client:
os = octetStream() os = octetStream()
for socket in self.sockets: for socket in self.sockets:
os.Receive(socket) os.Receive(socket)
if socket == self.sockets[0]:
active = os.get_length() == 3 * n * T.size()
n_expected = 3 if active else 1
if os.get_length() != n_expected * T.size() * n:
import sys
print (os.get_length(), n_expected, T.size(), n, active, file=sys.stderr)
raise Exception('unexpected data length')
for triple in triples: for triple in triples:
for i in range(3): for i in range(n_expected):
t = T() t = T()
t.unpack(os) t.unpack(os)
triple[i] += t triple[i] += t
res = [] res = []
if active:
for triple in triples: for triple in triples:
prod = triple[0] * triple[1] prod = triple[0] * triple[1]
if prod != triple[2]: if prod != triple[2]:
@@ -68,6 +76,9 @@ class octetStream:
if value is not None: if value is not None:
self.buf += value self.buf += value
def get_length(self):
return len(self.buf)
def reset_write_head(self): def reset_write_head(self):
self.buf = b'' self.buf = b''
self.ptr = 0 self.ptr = 0

View File

@@ -27,6 +27,10 @@ class Domain:
def __neq__(self, other): def __neq__(self, other):
return self.v != other.v return self.v != other.v
@classmethod
def size(cls):
return cls.n_bytes
def unpack(self, os): def unpack(self, os):
self.v = 0 self.v = 0
buf = os.consume(self.n_bytes) buf = os.consume(self.n_bytes)

View File

@@ -1,5 +1,4 @@
#include "Ciphertext.h" #include "Ciphertext.h"
#include "PPData.h"
#include "P2Data.h" #include "P2Data.h"
#include "Tools/Exceptions.h" #include "Tools/Exceptions.h"
@@ -143,6 +142,5 @@ void Ciphertext::rerandomize(const FHE_PK& pk)
template void mul(Ciphertext& ans,const Plaintext<gfp,FFT_Data,bigint>& a,const Ciphertext& c); template void mul(Ciphertext& ans,const Plaintext<gfp,FFT_Data,bigint>& a,const Ciphertext& c);
template void mul(Ciphertext& ans,const Plaintext<gfp,PPData,bigint>& a,const Ciphertext& c);
template void mul(Ciphertext& ans, const Plaintext<gf2n_short, P2Data, int>& a, template void mul(Ciphertext& ans, const Plaintext<gf2n_short, P2Data, int>& a,
const Ciphertext& c); const Ciphertext& c);

View File

@@ -259,6 +259,3 @@ void BFFT(vector<modp>& ans,const vector<modp>& a,const FFT_Data& FFTD,bool forw
else else
{ throw crash_requested(); } { throw crash_requested(); }
} }

View File

@@ -83,6 +83,8 @@ void FFT_Data::init(const Ring& Rg,const Zp_Data& PrD)
for (int r=0; r<2; r++) for (int r=0; r<2; r++)
{ FFT_Iter(b[r],twop,two_root[0],PrD); } { FFT_Iter(b[r],twop,two_root[0],PrD); }
} }
else
throw bad_value();
} }
} }

View File

@@ -2,7 +2,6 @@
#include "FHE_Keys.h" #include "FHE_Keys.h"
#include "Ciphertext.h" #include "Ciphertext.h"
#include "P2Data.h" #include "P2Data.h"
#include "PPData.h"
#include "FFT_Data.h" #include "FFT_Data.h"
#include "Math/modp.hpp" #include "Math/modp.hpp"
@@ -406,29 +405,17 @@ bigint FHE_SK::get_noise(const Ciphertext& c)
} }
#define X(FD) \
template void FHE_PK::encrypt(Ciphertext&, const Plaintext_<FD>& mess, \
const Random_Coins& rc) const; \
template Ciphertext FHE_PK::encrypt(const Plaintext_<FD>& mess) const; \
template Plaintext_<FD> FHE_SK::decrypt(const Ciphertext& c, \
const FD& FieldD); \
template void FHE_SK::decrypt(Plaintext_<FD>& res, \
const Ciphertext& c) const; \
template void FHE_SK::decrypt_any(Plaintext_<FD>& res, \
const Ciphertext& c); \
template void FHE_SK::check(const FHE_PK& pk, const FD&);
template void FHE_PK::encrypt(Ciphertext&, const Plaintext_<FFT_Data>& mess, X(FFT_Data)
const Random_Coins& rc) const; X(P2Data)
template void FHE_PK::encrypt(Ciphertext&, const Plaintext_<P2Data>& mess,
const Random_Coins& rc) const;
template Ciphertext FHE_PK::encrypt(const Plaintext_<FFT_Data>& mess,
const Random_Coins& rc) const;
template Ciphertext FHE_PK::encrypt(const Plaintext_<FFT_Data>& mess) const;
template Ciphertext FHE_PK::encrypt(const Plaintext_<P2Data>& mess) const;
template void FHE_SK::decrypt(Plaintext_<FFT_Data>&, const Ciphertext& c) const;
template void FHE_SK::decrypt(Plaintext_<P2Data>&, const Ciphertext& c) const;
template Plaintext_<FFT_Data> FHE_SK::decrypt(const Ciphertext& c,
const FFT_Data& FieldD);
template Plaintext_<P2Data> FHE_SK::decrypt(const Ciphertext& c,
const P2Data& FieldD);
template void FHE_SK::decrypt_any(Plaintext_<FFT_Data>& res,
const Ciphertext& c);
template void FHE_SK::decrypt_any(Plaintext_<P2Data>& res,
const Ciphertext& c);
template void FHE_SK::check(const FHE_PK& pk, const FFT_Data&);
template void FHE_SK::check(const FHE_PK& pk, const P2Data&);

View File

@@ -119,12 +119,6 @@ const P2Data& FHE_Params::get_plaintext_field_data() const
throw not_implemented(); throw not_implemented();
} }
template<>
const PPData& FHE_Params::get_plaintext_field_data() const
{
throw not_implemented();
}
bigint FHE_Params::get_plaintext_modulus() const bigint FHE_Params::get_plaintext_modulus() const
{ {
return fd.get_prime(); return fd.get_prime();

View File

@@ -248,49 +248,6 @@ matrix inv(const matrix& A)
} }
vector<modp> solve(modp_matrix& A,const Zp_Data& PrD)
{
unsigned int n=A.size();
if ((n+1)!=A[0].size()) { throw invalid_params(); }
modp t,ti;
for (unsigned int r=0; r<n; r++)
{ // Find pivot
unsigned int p=r;
while (isZero(A[p][r],PrD)) { p++; }
// Do pivoting
if (p!=r)
{ for (unsigned int c=0; c<n+1; c++)
{ t=A[p][c]; A[p][c]=A[r][c]; A[r][c]=t; }
}
// Make Lcoeff=1
Inv(ti,A[r][r],PrD);
for (unsigned int c=0; c<n+1; c++)
{ Mul(A[r][c],A[r][c],ti,PrD); }
// Now kill off other entries in this column
for (unsigned int rr=0; rr<n; rr++)
{ if (r!=rr)
{ for (unsigned int c=0; c<n+1; c++)
{ Mul(t,A[rr][c],A[r][r],PrD);
Sub(A[rr][c],A[rr][c],t,PrD);
}
}
}
}
// Sanity check and extract answer
vector<modp> ans;
ans.resize(n);
for (unsigned int i=0; i<n; i++)
{ for (unsigned int j=0; j<n; j++)
{ if (i!=j && !isZero(A[i][j],PrD)) { throw bad_value(); }
else if (!isOne(A[i][j],PrD)) { throw bad_value(); }
}
ans[i]=A[i][n];
}
return ans;
}
/* Input matrix is assumed to have more rows than columns */ /* Input matrix is assumed to have more rows than columns */
void pinv(imatrix& Ai,const imatrix& B) void pinv(imatrix& Ai,const imatrix& B)

View File

@@ -10,7 +10,6 @@ using namespace std;
#include "Tools/BitVector.h" #include "Tools/BitVector.h"
typedef vector< vector<bigint> > matrix; typedef vector< vector<bigint> > matrix;
typedef vector< vector<modp> > modp_matrix;
class imatrix : public vector< BitVector > class imatrix : public vector< BitVector >
{ {
@@ -39,13 +38,6 @@ void print(const imatrix& S);
// requires column operations to create the inverse // requires column operations to create the inverse
matrix inv(const matrix& A); matrix inv(const matrix& A);
// Another special routine for modp matrices.
// Solves
// Ax=b
// Assumes A is unimodular, square and only requires row operations to
// create the inverse. In put is C=(A||b) and the routines alters A
vector<modp> solve(modp_matrix& C,const Zp_Data& PrD);
// Finds a pseudo-inverse of a matrix A modulo 2 // Finds a pseudo-inverse of a matrix A modulo 2
// - Input matrix is assumed to have more rows than columns // - Input matrix is assumed to have more rows than columns
void pinv(imatrix& Ai,const imatrix& A); void pinv(imatrix& Ai,const imatrix& A);

View File

@@ -742,135 +742,3 @@ void load_or_generate(P2Data& P2D, const Ring& R)
P2D.store(R); P2D.store(R);
} }
} }
#ifdef USE_NTL
/*
* Create FHE parameters for a general plaintext modulus p
* Basically this is for general large primes only
*/
void SPDZ_Data_Setup_Char_p_General(Ring& R, PPData& PPD, bigint& pr0,
bigint& pr1, int n, int sec, bigint& p, FHE_Params& params)
{
cout << "Setting up parameters" << endl;
int lgp=numBits(p);
int mm,idx;
// mm is the minimum value of m we will accept
if (lgp<48)
{ mm=100; // Test case
idx=0;
}
else if (lgp <96)
{ mm=8192;
idx=1;
}
else if (lgp<192)
{ mm=16384;
idx=2;
}
else if (lgp<384)
{ mm=16384;
idx=3;
}
else if (lgp<768)
{ mm=32768;
idx=4;
}
else
{ throw invalid_params(); }
// Now find the small factors of p-1 and their exponents
bigint t=p-1;
vector<long> primes(100),exp(100);
PrimeSeq s;
long pr;
pr=s.next();
int len=0;
while (pr<2*mm)
{ int e=0;
while ((t%pr)==0)
{ e++;
t=t/pr;
}
if (e!=0)
{ primes[len]=pr;
exp[len]=e;
if (len!=0) { cout << " * "; }
cout << pr << "^" << e << flush;
len++;
}
pr=s.next();
}
cout << endl;
// We want to find the best m which divides pr-1, such that
// - 2*m > phi(m) > mm
// - m has the smallest number of factors
vector<int> ee;
ee.resize(len);
for (int i=0; i<len; i++) { ee[i]=0; }
int min_hwt=-1,m=-1,bphi_m=-1,bmx=-1;
bool flag=true;
while (flag)
{ int cand_m=1,hwt=0,mx=0;
for (int i=0; i<len; i++)
{ //cout << ee[i] << " ";
if (ee[i]!=0)
{ hwt++;
for (int j=0; j<ee[i]; j++)
{ cand_m*=primes[i]; }
if (ee[i]>mx) { mx=ee[i]; }
}
}
// Put "if" here to stop searching for things which will never work
if (cand_m>1 && cand_m<4*mm)
{ //cout << " : " << cand_m << " : " << hwt << flush;
int phim=phi_N(cand_m);
//cout << " : " << phim << " : " << mm << endl;
if (phim>mm && phim<3*mm)
{ if (m==-1 || hwt<min_hwt || (hwt==min_hwt && mx<bmx) || (hwt==min_hwt && mx==bmx && phim<bphi_m))
{ m=cand_m;
min_hwt=hwt;
bphi_m=phim;
bmx=mx;
//cout << "\t OK" << endl;
}
}
}
else
{ //cout << endl;
}
int i=0;
ee[i]=ee[i]+1;
while (ee[i]>exp[i] && flag)
{ ee[i]=0;
i++;
if (i==len) { flag=false; i=0; }
else { ee[i]=ee[i]+1; }
}
}
if (m==-1)
{ throw bad_value(); }
cout << "Chosen value of m=" << m << "\t\t phi(m)=" << bphi_m << " : " << min_hwt << " : " << bmx << endl;
Parameters parameters(n, lgp, sec);
parameters.SPDZ_Data_Setup_Char_p_Sub(idx,m,p,params);
int mx=0;
for (int i=0; i<R.phi_m(); i++)
{ if (mx<R.Phi()[i]) { mx=R.Phi()[i]; } }
cout << "Max Coeff = " << mx << endl;
init(R, m, true);
Zp_Data Zp(p);
PPD.init(R,Zp);
gfp::init_field(p);
pr0 = parameters.pr0;
pr1 = parameters.pr1;
}
#endif

View File

@@ -4,7 +4,6 @@
#include "FHE/Ring.h" #include "FHE/Ring.h"
#include "FHE/FFT_Data.h" #include "FHE/FFT_Data.h"
#include "FHE/P2Data.h" #include "FHE/P2Data.h"
#include "FHE/PPData.h"
#include "FHE/FHE_Params.h" #include "FHE/FHE_Params.h"
/* Routines to set up key sizes given the number of players n /* Routines to set up key sizes given the number of players n
@@ -68,10 +67,6 @@ class GF2X;
NTL::GF2X get_F(const Ring& Rg); NTL::GF2X get_F(const Ring& Rg);
// For use when we want p to be a specific value
void SPDZ_Data_Setup_Char_p_General(Ring& R, PPData& PPD, bigint& pr0,
bigint& pr1, int n, int sec, bigint& p, FHE_Params& params);
// generate moduli according to lengths and other parameters // generate moduli according to lengths and other parameters
void generate_moduli(bigint& pr0, bigint& pr1, const int m, void generate_moduli(bigint& pr0, bigint& pr1, const int m,
const bigint p, const int lg2p0, const int lg2p1); const bigint p, const int lg2p0, const int lg2p1);

View File

@@ -1,99 +0,0 @@
#include "FHE/Subroutines.h"
#include "FHE/PPData.h"
#include "FHE/FFT.h"
#include "FHE/Matrix.h"
#include "Math/modp.hpp"
void PPData::init(const Ring& Rg,const Zp_Data& PrD)
{
R=Rg;
prData=PrD;
root=Find_Primitive_Root_m(Rg.m(),Rg.Phi(),PrD);
}
void PPData::to_eval(vector<modp>& elem) const
{
if (elem.size()!= (unsigned) R.phi_m())
{ throw params_mismatch(); }
throw not_implemented();
/*
vector<modp> ans;
ans.resize(R.phi_m());
modp x=root;
for (int i=0; i<R.phi_m(); i++)
{ ans[i]=elem[R.phi_m()-1];
for (int j=1; j<R.phi_m(); j++)
{ Mul(ans[i],ans[i],x,prData);
Add(ans[i],ans[i],elem[R.phi_m()-j-1],prData);
}
Mul(x,x,root,prData);
}
elem=ans;
*/
}
void PPData::from_eval(vector<modp>&) const
{
// avoid warning
throw not_implemented();
/*
modp_matrix A;
int n=phi_m();
A.resize(n, vector<modp>(n+1) );
modp x=root;
for (int i=0; i<n; i++)
{ assignOne(A[0][i],prData);
for (int j=1; j<n; j++)
{ Mul(A[j][i],A[j-1][i],x,prData); }
Mul(x,x,root,prData);
A[i][n]=elem[i];
}
elem=solve(A,prData);
*/
}
void PPData::reset_iteration()
{
pow = 1;
theta = {root, prData};
thetaPow = theta;
}
void PPData::next_iteration()
{
do
{ thetaPow *= (theta);
pow++;
}
while (gcd(pow,m())!=1);
}
gfp PPData::get_evaluation(const vector<bigint>& mess) const
{
// Uses Horner's rule
gfp ans;
ans = mess[mess.size()-1];
gfp coeff;
for (int j=mess.size()-2; j>=0; j--)
{ ans *= (thetaPow);
coeff = mess[j];
ans += (coeff);
}
return ans;
}

View File

@@ -1,61 +0,0 @@
#ifndef _PPData
#define _PPData
#include "Math/modp.h"
#include "Math/Zp_Data.h"
#include "Math/gfpvar.h"
#include "Math/fixint.h"
#include "FHE/Ring.h"
#include "FHE/FFT_Data.h"
/* Class for holding modular arithmetic data wrt the ring
*
* It also holds the ring
*/
class PPData
{
public:
typedef gfp T;
typedef bigint S;
typedef typename FFT_Data::poly_type poly_type;
Ring R;
Zp_Data prData;
modp root; // m'th Root of Unity mod pr
void init(const Ring& Rg,const Zp_Data& PrD);
PPData() { ; }
PPData(const Ring& Rg,const Zp_Data& PrD)
{ init(Rg,PrD); }
const Zp_Data& get_prD() const { return prData; }
const bigint& get_prime() const { return prData.pr; }
int phi_m() const { return R.phi_m(); }
int m() const { return R.m(); }
int num_slots() const { return R.phi_m(); }
int p(int i) const { return R.p(i); }
int p_inv(int i) const { return R.p_inv(i); }
const vector<int>& Phi() const { return R.Phi(); }
// Convert input vector from poly to evaluation representation
// - Uses naive method and not FFT, we only use this rarely in any case
void to_eval(vector<modp>& elem) const;
void from_eval(vector<modp>& elem) const;
// Following are used to iteratively get slots, as we use PPData when
// we do not have an efficient FFT algorithm
gfp thetaPow,theta;
int pow;
void reset_iteration();
void next_iteration();
gfp get_evaluation(const vector<bigint>& mess) const;
};
#endif

View File

@@ -1,7 +1,6 @@
#include "FHE/Plaintext.h" #include "FHE/Plaintext.h"
#include "FHE/Ring_Element.h" #include "FHE/Ring_Element.h"
#include "FHE/PPData.h"
#include "FHE/P2Data.h" #include "FHE/P2Data.h"
#include "FHE/Rq_Element.h" #include "FHE/Rq_Element.h"
#include "FHE_Keys.h" #include "FHE_Keys.h"
@@ -85,39 +84,6 @@ void Plaintext<gfp,FFT_Data,bigint>::to_poly() const
} }
template<>
void Plaintext<gfp,PPData,bigint>::from_poly() const
{
if (type!=Polynomial) { return; }
vector<modp> aa((*Field_Data).phi_m());
for (unsigned int i=0; i<aa.size(); i++)
{ to_modp(aa[i], bigint::tmp = b[i], (*Field_Data).prData); }
(*Field_Data).to_eval(aa);
a.resize(num_slots());
for (unsigned int i=0; i<aa.size(); i++)
a[i] = {aa[i], Field_Data->get_prD()};
type=Both;
}
template<>
void Plaintext<gfp,PPData,bigint>::to_poly() const
{
if (type!=Evaluation) { return; }
cout << "This is VERY inefficient to convert a plaintext to poly representation" << endl;
vector<modp> bb((*Field_Data).phi_m());
for (unsigned int i=0; i<bb.size(); i++)
{ bb[i]=a[i].get(); }
(*Field_Data).from_eval(bb);
for (unsigned int i=0; i<bb.size(); i++)
{
to_bigint(bigint::tmp,bb[i],(*Field_Data).prData);
b[i] = bigint::tmp;
}
type=Both;
}
template<> template<>
void Plaintext<gf2n_short,P2Data,int>::from_poly() const void Plaintext<gf2n_short,P2Data,int>::from_poly() const
@@ -385,34 +351,6 @@ void add(Plaintext<gfp,FFT_Data,bigint>& z,const Plaintext<gfp,FFT_Data,bigint>&
} }
template<>
void add(Plaintext<gfp,PPData,bigint>& z,const Plaintext<gfp,PPData,bigint>& x,
const Plaintext<gfp,PPData,bigint>& y)
{
if (z.Field_Data!=x.Field_Data) { throw field_mismatch(); }
if (z.Field_Data!=y.Field_Data) { throw field_mismatch(); }
if (x.type==Both && y.type!=Both) { z.type=y.type; }
else if (y.type==Both && x.type!=Both) { z.type=x.type; }
else if (x.type!=y.type) { throw rep_mismatch(); }
else { z.type=x.type; }
if (z.type!=Polynomial)
{
z.a.resize(z.num_slots());
for (unsigned int i=0; i<z.a.size(); i++)
{ z.a[i] = (x.a[i] + y.a[i]); }
}
if (z.type!=Evaluation)
{ for (unsigned int i=0; i<z.b.size(); i++)
{ z.b[i]=x.b[i]+y.b[i];
if (z.b[i]>(*z.Field_Data).get_prime())
{ z.b[i]-=(*z.Field_Data).get_prime(); }
}
}
}
template<> template<>
@@ -475,36 +413,6 @@ void sub(Plaintext<gfp,FFT_Data,bigint>& z,const Plaintext<gfp,FFT_Data,bigint>&
template<>
void sub(Plaintext<gfp,PPData,bigint>& z,const Plaintext<gfp,PPData,bigint>& x,
const Plaintext<gfp,PPData,bigint>& y)
{
if (z.Field_Data!=x.Field_Data) { throw field_mismatch(); }
if (z.Field_Data!=y.Field_Data) { throw field_mismatch(); }
if (x.type==Both && y.type!=Both) { z.type=y.type; }
else if (y.type==Both && x.type!=Both) { z.type=x.type; }
else if (x.type!=y.type) { throw rep_mismatch(); }
else { z.type=x.type; }
z.allocate();
if (z.type!=Polynomial)
{
z.a.resize(z.num_slots());
for (unsigned int i=0; i<z.a.size(); i++)
{ z.a[i] = (x.a[i] - y.a[i]); }
}
if (z.type!=Evaluation)
{ for (unsigned int i=0; i<z.b.size(); i++)
{ z.b[i]=x.b[i]-y.b[i];
if (z.b[i]<0)
{ z.b[i]+=(*z.Field_Data).get_prime(); }
}
}
}
template<> template<>
@@ -572,23 +480,6 @@ void Plaintext<gfp,FFT_Data,bigint>::negate()
} }
} }
template<>
void Plaintext<gfp,PPData,bigint>::negate()
{
if (type!=Polynomial)
{
a.resize(num_slots());
for (unsigned int i=0; i<a.size(); i++)
{ a[i].negate(); }
}
if (type!=Evaluation)
{ for (unsigned int i=0; i<b.size(); i++)
{ if (b[i]!=0)
{ b[i]=(*Field_Data).get_prime()-b[i]; }
}
}
}
template<> template<>
@@ -731,12 +622,6 @@ template void mul(Plaintext<gfp,FFT_Data,bigint>& z,const Plaintext<gfp,FFT_Data
template class Plaintext<gfp,PPData,bigint>;
template void mul(Plaintext<gfp,PPData,bigint>& z,const Plaintext<gfp,PPData,bigint>& x,const Plaintext<gfp,PPData,bigint>& y);
template class Plaintext<gf2n_short,P2Data,int>; template class Plaintext<gf2n_short,P2Data,int>;
template void mul(Plaintext<gf2n_short,P2Data,int>& z,const Plaintext<gf2n_short,P2Data,int>& x,const Plaintext<gf2n_short,P2Data,int>& y); template void mul(Plaintext<gf2n_short,P2Data,int>& z,const Plaintext<gf2n_short,P2Data,int>& x,const Plaintext<gf2n_short,P2Data,int>& y);

View File

@@ -274,7 +274,7 @@ void EncCommit<T,FD,S>::Create_More() const
(*P).Broadcast_Receive(ctx_Delta); (*P).Broadcast_Receive(ctx_Delta);
// Output the ctx_Delta to a file // Output the ctx_Delta to a file
sprintf(filename,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,i,thread); snprintf(filename,1024,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,i,thread);
ofstream outf(filename); ofstream outf(filename);
for (int j=0; j<(*P).num_players(); j++) for (int j=0; j<(*P).num_players(); j++)
{ {
@@ -308,7 +308,7 @@ void EncCommit<T,FD,S>::Create_More() const
octetStream occ,ctx_D; octetStream occ,ctx_D;
for (int i=0; i<2*TT; i++) for (int i=0; i<2*TT; i++)
{ if (open[i]==1) { if (open[i]==1)
{ sprintf(filename,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,i,thread); { snprintf(filename,1024,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,i,thread);
ifstream inpf(filename); ifstream inpf(filename);
for (int j=0; j<(*P).num_players(); j++) for (int j=0; j<(*P).num_players(); j++)
{ {
@@ -386,7 +386,7 @@ void EncCommit<T,FD,S>::Create_More() const
Ciphertext enc1(params),enc2(params),eDelta(params); Ciphertext enc1(params),enc2(params),eDelta(params);
octetStream oe1,oe2; octetStream oe1,oe2;
sprintf(filename,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,index[b*i+j],thread); snprintf(filename,1024,"%sctx_Delta-%d-%d-%d",file_prefix,my_num,index[b*i+j],thread);
ifstream inpf(filename); ifstream inpf(filename);
for (int k=0; k<(*P).num_players(); k++) for (int k=0; k<(*P).num_players(); k++)
{ {

View File

@@ -26,6 +26,7 @@ public:
void set_protocol(DealerSecret::Protocol& protocol) void set_protocol(DealerSecret::Protocol& protocol)
{ {
P = &protocol.P; P = &protocol.P;
BufferPrep<DealerSecret>::P = P;
} }
void buffer_triples() void buffer_triples()

View File

@@ -183,6 +183,8 @@ public:
NoShare operator-(const NoShare&) const { fail(); return {}; } NoShare operator-(const NoShare&) const { fail(); return {}; }
NoShare operator*(const NoValue&) const { fail(); return {}; } NoShare operator*(const NoValue&) const { fail(); return {}; }
NoShare operator^(const NoShare&) const { fail(); return {}; }
NoShare operator&(int) const { fail(); return {}; } NoShare operator&(int) const { fail(); return {}; }
NoShare operator>>(int) const { fail(); return {}; } NoShare operator>>(int) const { fail(); return {}; }

View File

@@ -123,7 +123,9 @@ BreakType Program::execute(Processor<T>& Proc, U& dynamic_memory,
} }
time++; time++;
#ifdef DEBUG_COMPLEXITY #ifdef DEBUG_COMPLEXITY
cout << "complexity at " << time << ": " << Proc.complexity << endl; cout << T::part_type::name() << " complexity at " << time << ": " <<
Proc.complexity << " after " << hex <<
instruction.get_opcode() << dec << endl;
#endif #endif
} }
while (Proc.complexity < (size_t) OnlineOptions::singleton.batch_size); while (Proc.complexity < (size_t) OnlineOptions::singleton.batch_size);

View File

@@ -39,6 +39,7 @@ void RepPrep<T>::set_protocol(typename T::Protocol& protocol)
return; return;
this->protocol = new ReplicatedBase(protocol.P); this->protocol = new ReplicatedBase(protocol.P);
this->P = &protocol.P;
} }
template<class T> template<class T>

View File

@@ -89,7 +89,7 @@ void Secret<T>::random(int n_bits, int128 share)
{ {
(void)share; (void)share;
if (n_bits > 128) if (n_bits > 128)
throw not_implemented(); throw runtime_error("too many bits");
resize_regs(n_bits); resize_regs(n_bits);
for (int i = 0; i < n_bits; i++) for (int i = 0; i < n_bits; i++)
get_reg(i).random(); get_reg(i).random();

View File

@@ -37,6 +37,7 @@ void SemiPrep::set_protocol(SemiSecret::Protocol& protocol)
protocol.P.N, -1, OnlineOptions::singleton.batch_size, protocol.P.N, -1, OnlineOptions::singleton.batch_size,
1, params, {}, &protocol.P); 1, params, {}, &protocol.P);
triple_generator->multi_threaded = false; triple_generator->multi_threaded = false;
this->P = &protocol.P;
} }
void SemiPrep::buffer_triples() void SemiPrep::buffer_triples()

View File

@@ -103,9 +103,7 @@ void ThreadMaster<T>::run()
machine.print_timers(); machine.print_timers();
cerr << "Data sent = " << stats.sent * 1e-6 << " MB" << endl; machine.print_comm(*P, stats);
machine.print_global_comm(*P, stats);
delete P; delete P;
} }

View File

@@ -105,6 +105,11 @@ public:
*this = a + b; *this = a + b;
} }
This operator^(const This& other) const
{
return *this + other;
}
This& operator^=(const This& other) This& operator^=(const This& other)
{ {
*this += other; *this += other;

View File

@@ -146,6 +146,7 @@
X(THRESHOLD, I0 = T::threshold(Thread<T>::s().P->num_players())) \ X(THRESHOLD, I0 = T::threshold(Thread<T>::s().P->num_players())) \
X(PLAYERID, I0 = Thread<T>::s().P->my_num()) \ X(PLAYERID, I0 = Thread<T>::s().P->my_num()) \
X(CRASH, if (I0.get()) throw crash_requested()) \ X(CRASH, if (I0.get()) throw crash_requested()) \
X(ACTIVE, ) \
#define INSTRUCTIONS BIT_INSTRUCTIONS GC_INSTRUCTIONS #define INSTRUCTIONS BIT_INSTRUCTIONS GC_INSTRUCTIONS

View File

@@ -365,11 +365,11 @@ void OTMachine::run()
{ {
BitVector receiver_output, sender_output; BitVector receiver_output, sender_output;
char filename[1024]; char filename[1024];
sprintf(filename, RECEIVER_INPUT, my_num); snprintf(filename, 1024, RECEIVER_INPUT, my_num);
ofstream outf(filename); ofstream outf(filename);
receiverInput.output(outf, false); receiverInput.output(outf, false);
outf.close(); outf.close();
sprintf(filename, RECEIVER_OUTPUT, my_num); snprintf(filename, 1024, RECEIVER_OUTPUT, my_num);
outf.open(filename); outf.open(filename);
for (unsigned int i = 0; i < nOTs; i++) for (unsigned int i = 0; i < nOTs; i++)
{ {
@@ -380,7 +380,7 @@ void OTMachine::run()
for (int i = 0; i < 2; i++) for (int i = 0; i < 2; i++)
{ {
sprintf(filename, SENDER_OUTPUT, my_num, i); snprintf(filename,1024, SENDER_OUTPUT, my_num, i);
outf.open(filename); outf.open(filename);
for (int j = 0; j < nOTs; j++) for (int j = 0; j < nOTs; j++)
{ {

View File

@@ -116,7 +116,7 @@ mascot: mascot-party.x spdz2k mama-party.x
ifeq ($(OS), Darwin) ifeq ($(OS), Darwin)
setup: mac-setup setup: mac-setup
else else
setup: boost mpir linux-machine-setup setup: boost linux-machine-setup
endif endif
tldr: setup tldr: setup
@@ -297,27 +297,6 @@ deps/SimplestOT_C/ref10/Makefile:
Programs/Circuits: Programs/Circuits:
git submodule update --init Programs/Circuits git submodule update --init Programs/Circuits
.PHONY: mpir-setup mpir-global
mpir-setup: deps/mpir/Makefile
deps/mpir/Makefile:
git submodule update --init deps/mpir || git clone https://github.com/wbhart/mpir deps/mpir
cd deps/mpir; \
autoreconf -i; \
autoreconf -i
- $(MAKE) -C deps/mpir clean
mpir-global: mpir-setup
cd deps/mpir; \
./configure --enable-cxx;
$(MAKE) -C deps/mpir
sudo $(MAKE) -C deps/mpir install
mpir: local/lib/libmpirxx.so
local/lib/libmpirxx.so: deps/mpir/Makefile
cd deps/mpir; \
./configure --enable-cxx --prefix=$(CURDIR)/local
$(MAKE) -C deps/mpir install
deps/libOTe/libOTe: deps/libOTe/libOTe:
git submodule update --init --recursive deps/libOTe || git clone --recurse-submodules https://github.com/mkskeller/softspoken-implementation deps/libOTe git submodule update --init --recursive deps/libOTe || git clone --recurse-submodules https://github.com/mkskeller/softspoken-implementation deps/libOTe
boost: deps/libOTe/libOTe boost: deps/libOTe/libOTe
@@ -369,26 +348,16 @@ cmake:
./bootstrap --parallel=8 --prefix=../local && make && make install ./bootstrap --parallel=8 --prefix=../local && make && make install
mac-setup: mac-machine-setup mac-setup: mac-machine-setup
brew install openssl boost libsodium mpir yasm ntl cmake brew install openssl boost libsodium gmp yasm ntl cmake
-echo MY_CFLAGS += -I/usr/local/opt/openssl/include -I`brew --prefix`/opt/openssl/include -I`brew --prefix`/include >> CONFIG.mine
-echo MY_LDLIBS += -L/usr/local/opt/openssl/lib -L`brew --prefix`/lib -L`brew --prefix`/opt/openssl/lib >> CONFIG.mine
# -echo USE_NTL = 1 >> CONFIG.mine
ifeq ($(ARM), 1)
mac-machine-setup:
-echo ARCH = >> CONFIG.mine
linux-machine-setup: linux-machine-setup:
-echo ARCH = -march=armv8.2-a+crypto >> CONFIG.mine
else
mac-machine-setup: mac-machine-setup:
linux-machine-setup:
endif
deps/simde/simde: deps/simde/simde:
git submodule update --init deps/simde || git clone https://github.com/simd-everywhere/simde deps/simde git submodule update --init deps/simde || git clone https://github.com/simd-everywhere/simde deps/simde
clean-deps: clean-deps:
-rm -rf local deps/libOTe/out -rm -rf local/lib/liblibOTe.* deps/libOTe/out
clean: clean-deps clean: clean-deps
-rm -f */*.o *.o */*.d *.d *.x core.* *.a gmon.out */*/*.o static/*.x *.so -rm -f */*.o *.o */*.d *.d *.x core.* *.a gmon.out */*/*.o static/*.x *.so

View File

@@ -17,6 +17,10 @@ using namespace std;
#include "ValueInterface.h" #include "ValueInterface.h"
#include "gf2nlong.h" #include "gf2nlong.h"
// Fix false warning
#if __GNUC__ == 10
#pragma GCC diagnostic ignored "-Wstringop-overflow"
#endif
// Functionality shared between integers and bit vectors // Functionality shared between integers and bit vectors
template<class T> template<class T>
@@ -39,6 +43,8 @@ public:
static bool allows(Dtype type) { return type <= DATA_BIT; } static bool allows(Dtype type) { return type <= DATA_BIT; }
static void check_setup(const string&) {}
IntBase() { a = 0; } IntBase() { a = 0; }
IntBase(T a) : a(a) {} IntBase(T a) : a(a) {}

View File

@@ -160,13 +160,13 @@ void check_setup(string dir, bigint pr)
} }
string get_prep_sub_dir(const string& prep_dir, int nparties, int log2mod, string get_prep_sub_dir(const string& prep_dir, int nparties, int log2mod,
const string& type_short) const string& type_short, bool create)
{ {
string res = prep_dir + "/" + to_string(nparties) + "-" + type_short; string res = prep_dir + "/" + to_string(nparties) + "-" + type_short;
if (log2mod > 1) if (log2mod > 1)
res += "-" + to_string(log2mod); res += "-" + to_string(log2mod);
res += "/"; res += "/";
if (mkdir_p(res.c_str()) < 0) if (create and mkdir_p(res.c_str()) < 0)
throw file_error("cannot create " + res); throw file_error("cannot create " + res);
return res; return res;
} }

View File

@@ -38,26 +38,28 @@ bigint generate_prime(int lgp, int m);
int default_m(int& lgp, int& idx); int default_m(int& lgp, int& idx);
string get_prep_sub_dir(const string& prep_dir, int nparties, int log2mod, string get_prep_sub_dir(const string& prep_dir, int nparties, int log2mod,
const string& type_short); const string& type_short, bool create = false);
template<class T> template<class T>
string get_prep_sub_dir(const string& prep_dir, int nparties, int log2mod) string get_prep_sub_dir(const string& prep_dir, int nparties, int log2mod,
bool create = false)
{ {
if (T::clear::length() > 1) if (T::clear::length() > 1)
log2mod = T::clear::length(); log2mod = T::clear::length();
return get_prep_sub_dir(prep_dir, nparties, log2mod, T::type_short()); return get_prep_sub_dir(prep_dir, nparties, log2mod, T::type_short(), create);
} }
template<class T> template<class T>
string get_prep_sub_dir(const string& prep_dir, int nparties) string get_prep_sub_dir(const string& prep_dir, int nparties, bool create =
false)
{ {
return get_prep_sub_dir<T>(prep_dir, nparties, T::clear::length()); return get_prep_sub_dir<T>(prep_dir, nparties, T::clear::length(), create);
} }
template<class T> template<class T>
string get_prep_sub_dir(int nparties) string get_prep_sub_dir(int nparties, bool create = false)
{ {
return get_prep_sub_dir<T>(PREP_DIR, nparties); return get_prep_sub_dir<T>(PREP_DIR, nparties, create);
} }
template<class T> template<class T>

18
Math/ValueInterface.cpp Normal file
View File

@@ -0,0 +1,18 @@
/*
* ValueInterface.cpp
*
*/
#include "bigint.h"
#include "ValueInterface.h"
#include <sys/stat.h>
void ValueInterface::check_setup(const string& directory)
{
struct stat sb;
if (stat(directory.c_str(), &sb) != 0)
throw runtime_error(directory + " does not exist");
if (not (sb.st_mode & S_IFDIR))
throw runtime_error(directory + " is not a directory");
}

View File

@@ -7,6 +7,7 @@
#define MATH_VALUEINTERFACE_H_ #define MATH_VALUEINTERFACE_H_
#include "Tools/Exceptions.h" #include "Tools/Exceptions.h"
#include "Math/Setup.h"
class OnlineOptions; class OnlineOptions;
class bigint; class bigint;
@@ -31,9 +32,10 @@ public:
template<class T> template<class T>
static void generate_setup(string, int, int) {} static void generate_setup(string, int, int) {}
template<class T> template<class T>
static void write_setup(int) {} static void write_setup(int nplayers) { get_prep_sub_dir<T>(nplayers, true); }
static void write_setup(string) {} static void write_setup(string) {}
static void check_setup(string) {} static void check_setup(const string& directory);
static const char* fake_opts() { return ""; }
static bigint pr() { throw runtime_error("no prime modulus"); } static bigint pr() { throw runtime_error("no prime modulus"); }

View File

@@ -6,7 +6,7 @@
#ifndef MATH_Z2K_H_ #ifndef MATH_Z2K_H_
#define MATH_Z2K_H_ #define MATH_Z2K_H_
#include <mpirxx.h> #include <gmpxx.h>
#include <string> #include <string>
using namespace std; using namespace std;
@@ -74,6 +74,8 @@ public:
static Z2 power_of_two(bool bit, int exp) { return Z2(bit) << exp; } static Z2 power_of_two(bool bit, int exp) { return Z2(bit) << exp; }
static string fake_opts() { return " -lgp " + to_string(K); }
typedef Z2 next; typedef Z2 next;
typedef Z2 Scalar; typedef Z2 Scalar;

View File

@@ -53,7 +53,7 @@ void Zp_Data::init(const bigint& p,bool mont)
mpn_copyi(R3,r3.get_mpz_t()->_mp_d,mpz_size(r3.get_mpz_t())); mpn_copyi(R3,r3.get_mpz_t()->_mp_d,mpz_size(r3.get_mpz_t()));
if (sizeof(unsigned long)!=sizeof(mp_limb_t)) if (sizeof(unsigned long)!=sizeof(mp_limb_t))
{ cout << "The underlying types of MPIR mean we cannot use our Montgomery code" << endl; { cout << "The underlying types of GMP mean we cannot use our Montgomery code" << endl;
throw not_implemented(); throw not_implemented();
} }
} }
@@ -194,3 +194,37 @@ bool Zp_Data::operator==(const Zp_Data& other) const
{ {
return not (*this != other); return not (*this != other);
} }
void Zp_Data::get_shanks_parameters(bigint& y, bigint& q_half, int& r) const
{
if (shanks_y == 0)
{
auto& p = pr;
bigint n, q, yy, xx, temp;
// Find n such that (n/p)=-1
int leg = 1;
gmp_randclass Gen(gmp_randinit_default);
Gen.seed(0);
while (leg != -1)
{
n = Gen.get_z_range(p);
leg = mpz_legendre(n.get_mpz_t(), p.get_mpz_t());
}
// Split p-1 = 2^e q
q = p - 1;
int e = 0;
while (mpz_even_p(q.get_mpz_t()))
{
e++;
q = q / 2;
}
// y=n^q mod p, x=a^((q-1)/2) mod p, r=e
shanks_r = e;
mpz_powm(shanks_y.get_mpz_t(), n.get_mpz_t(), q.get_mpz_t(), p.get_mpz_t());
shanks_q_half = (q - 1) / 2;
}
y = shanks_y;
q_half = shanks_q_half;
r = shanks_r;
}

View File

@@ -38,6 +38,8 @@ class Zp_Data
int t; // More Montgomery data int t; // More Montgomery data
mp_limb_t overhang; mp_limb_t overhang;
Lock lock; Lock lock;
mutable bigint shanks_y, shanks_q_half;
mutable int shanks_r;
template <int T> template <int T>
void Mont_Mult_(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const; void Mont_Mult_(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const;
@@ -89,6 +91,8 @@ class Zp_Data
bool operator!=(const Zp_Data& other) const; bool operator!=(const Zp_Data& other) const;
bool operator==(const Zp_Data& other) const; bool operator==(const Zp_Data& other) const;
void get_shanks_parameters(bigint& y, bigint& q_half, int& r) const;
template<int L> friend void to_modp(modp_<L>& ans,int x,const Zp_Data& ZpD); template<int L> friend void to_modp(modp_<L>& ans,int x,const Zp_Data& ZpD);
template<int L> friend void to_modp(modp_<L>& ans,const mpz_class& x,const Zp_Data& ZpD); template<int L> friend void to_modp(modp_<L>& ans,const mpz_class& x,const Zp_Data& ZpD);

View File

@@ -10,76 +10,10 @@
#include "bigint.hpp" #include "bigint.hpp"
class gmp_random
{
public:
gmp_randclass Gen;
gmp_random() : Gen(gmp_randinit_default)
{
Gen.seed(0);
}
};
thread_local bigint bigint::tmp = 0; thread_local bigint bigint::tmp = 0;
thread_local bigint bigint::tmp2 = 0; thread_local bigint bigint::tmp2 = 0;
thread_local gmp_random bigint::random; thread_local gmp_random bigint::random;
bigint sqrRootMod(const bigint& a,const bigint& p)
{
bigint ans;
if (a==0) { ans=0; return ans; }
if (mpz_legendre(a.get_mpz_t(), p.get_mpz_t()) != 1)
throw runtime_error("cannot compute square root of non-square");
if (mpz_tstbit(p.get_mpz_t(),1)==1)
{ // First do case with p=3 mod 4
bigint exp=(p+1)/4;
mpz_powm(ans.get_mpz_t(),a.get_mpz_t(),exp.get_mpz_t(),p.get_mpz_t());
}
else
{ // Shanks algorithm
bigint x,y,n,q,t,b,temp;
// Find n such that (n/p)=-1
int leg=1;
while (leg!=-1)
{ n=bigint::random.Gen.get_z_range(p);
leg=mpz_legendre(n.get_mpz_t(),p.get_mpz_t());
}
// Split p-1 = 2^e q
q=p-1;
int e=0;
while (mpz_even_p(q.get_mpz_t()))
{ e++; q=q/2; }
// y=n^q mod p, x=a^((q-1)/2) mod p, r=e
int r=e;
mpz_powm(y.get_mpz_t(),n.get_mpz_t(),q.get_mpz_t(),p.get_mpz_t());
temp=(q-1)/2;
mpz_powm(x.get_mpz_t(),a.get_mpz_t(),temp.get_mpz_t(),p.get_mpz_t());
// b=a*x^2 mod p, x=a*x mod p
b=(a*x*x)%p;
x=(a*x)%p;
// While b!=1 do
while (b!=1)
{ // Find smallest m such that b^(2^m)=1 mod p
int m=1;
temp=(b*b)%p;
while (temp!=1)
{ temp=(temp*temp)%p; m++; }
// t=y^(2^(r-m-1)) mod p, y=t^2, r=m
t=y;
for (int i=0; i<r-m-1; i++)
{ t=(t*t)%p; }
y=(t*t)%p;
r=m;
// x=x*t mod p, b=b*y mod p
x=(x*t)%p;
b=(b*y)%p;
}
ans=x;
}
return ans;
}
bigint powerMod(const bigint& x,const bigint& e,const bigint& p) bigint powerMod(const bigint& x,const bigint& e,const bigint& p)
{ {

View File

@@ -5,7 +5,7 @@
using namespace std; using namespace std;
#include <stddef.h> #include <stddef.h>
#include <mpirxx.h> #include <gmpxx.h>
#include "Tools/Exceptions.h" #include "Tools/Exceptions.h"
#include "Tools/int.h" #include "Tools/int.h"
@@ -39,7 +39,7 @@ namespace GC
/** /**
* Type for arbitrarily large integers. * Type for arbitrarily large integers.
* This is a sub-class of ``mpz_class`` from MPIR. As such, it implements * This is a sub-class of ``mpz_class`` from GMP. As such, it implements
* all integers operations and input/output via C++ streams. In addition, * all integers operations and input/output via C++ streams. In addition,
* the ``get_ui()`` member function allows retrieving the least significant * the ``get_ui()`` member function allows retrieving the least significant
* 64 bits. * 64 bits.
@@ -139,8 +139,6 @@ public:
void inline_mpn_zero(mp_limb_t* x, mp_size_t size); void inline_mpn_zero(mp_limb_t* x, mp_size_t size);
void inline_mpn_copyi(mp_limb_t* dest, const mp_limb_t* src, mp_size_t size); void inline_mpn_copyi(mp_limb_t* dest, const mp_limb_t* src, mp_size_t size);
#include "Z2k.h"
inline bigint& bigint::operator=(int n) inline bigint& bigint::operator=(int n)
{ {
@@ -281,11 +279,7 @@ inline int numBytes(const bigint& m)
inline int probPrime(const bigint& x) inline int probPrime(const bigint& x)
{ {
gmp_randstate_t rand_state; int ans = mpz_probab_prime_p(x.get_mpz_t(), max(40, DEFAULT_SECURITY) / 2);
gmp_randinit_default(rand_state);
int ans = mpz_probable_prime_p(x.get_mpz_t(), rand_state,
max(40, DEFAULT_SECURITY), 0);
gmp_randclear(rand_state);
return ans; return ans;
} }
@@ -318,7 +312,8 @@ inline int isOdd(const bigint& x)
} }
bigint sqrRootMod(const bigint& x,const bigint& p); template<class T>
bigint sqrRootMod(const T& x);
bigint powerMod(const bigint& x,const bigint& e,const bigint& p); bigint powerMod(const bigint& x,const bigint& e,const bigint& p);

View File

@@ -26,7 +26,7 @@ bigint& bigint::from_signed(const T& other)
template<class T> template<class T>
mpf_class bigint::get_float(T v, T p, T z, T s) mpf_class bigint::get_float(T v, T p, T z, T s)
{ {
// MPIR can't handle more precision in exponent // GMP can't handle more precision in exponent
Integer exp = Integer(p, 31).get(); Integer exp = Integer(p, 31).get();
bigint tmp; bigint tmp;
tmp.from_signed(v); tmp.from_signed(v);
@@ -59,4 +59,76 @@ void bigint::output_float(U& o, const mpf_class& x, T nan)
o << "NaN"; o << "NaN";
} }
class gmp_random
{
public:
gmp_randclass Gen;
gmp_random() : Gen(gmp_randinit_default)
{
Gen.seed(0);
}
};
template<class T>
bigint sqrRootMod(const T& aa)
{
bigint a = aa;
bigint p = T::pr();
bigint ans;
if (a == 0)
{
ans = 0;
return ans;
}
if (mpz_legendre(a.get_mpz_t(), p.get_mpz_t()) != 1)
throw runtime_error("cannot compute square root of non-square");
if (mpz_tstbit(p.get_mpz_t(), 1) == 1)
{
// First do case with p=3 mod 4
bigint exp = (p + 1) / 4;
mpz_powm(ans.get_mpz_t(), a.get_mpz_t(), exp.get_mpz_t(),
p.get_mpz_t());
}
else
{
// Shanks algorithm
bigint n, q, yy, xx, temp;
int r;
T::get_ZpD().get_shanks_parameters(yy, temp, r);
mpz_powm(xx.get_mpz_t(), a.get_mpz_t(), temp.get_mpz_t(), p.get_mpz_t());
// b=a*x^2 mod p, x=a*x mod p
T x = xx;
T b = (aa * x * x);
x = (aa * x);
T y = yy;
// While b!=1 do
while (b != 1)
{
// Find smallest m such that b^(2^m)=1 mod p
int m = 1;
T temp = (b * b);
while (temp != 1)
{
temp = (temp * temp);
m++;
}
// t=y^(2^(r-m-1)) mod p, y=t^2, r=m
T t = y;
for (int i = 0; i < r - m - 1; i++)
{
t = (t * t);
}
y = (t * t);
r = m;
// x=x*t mod p, b=b*y mod p
x = (x * t);
b = (b * y);
}
ans = x;
}
return ans;
}
#endif /* MATH_BIGINT_HPP_ */ #endif /* MATH_BIGINT_HPP_ */

View File

@@ -17,6 +17,7 @@ class gf2n_short;
class P2Data; class P2Data;
class Bit; class Bit;
class int128; class int128;
template<class T> class IntBase;
template<class T> class Square; template<class T> class Square;
typedef Square<gf2n_short> gf2n_short_square; typedef Square<gf2n_short> gf2n_short_square;
@@ -88,6 +89,8 @@ protected:
static string options(); static string options();
static string fake_opts() { return " -lg2 " + to_string(length()); }
static const true_type invertible; static const true_type invertible;
static const true_type characteristic_two; static const true_type characteristic_two;

View File

@@ -154,6 +154,8 @@ class gf2n_long : public gf2n_<int128>
gf2n_long(int g) : gf2n_long(int128(unsigned(g))) {} gf2n_long(int g) : gf2n_long(int128(unsigned(g))) {}
template<class T> template<class T>
gf2n_long(IntBase<T> g) : super(g.get()) {} gf2n_long(IntBase<T> g) : super(g.get()) {}
template<class T>
gf2n_long(const gf2n_<T>& a) : super(int128(a.get())) {}
}; };
#if defined(__aarch64__) && defined(__clang__) #if defined(__aarch64__) && defined(__clang__)

View File

@@ -105,6 +105,7 @@ class gfp_ : public ValueInterface
static void write_setup(string dir) static void write_setup(string dir)
{ write_online_setup(dir, pr()); } { write_online_setup(dir, pr()); }
static void check_setup(string dir); static void check_setup(string dir);
static string fake_opts() { return " -lgp " + to_string(length()); }
/** /**
* Get the prime modulus * Get the prime modulus
@@ -314,6 +315,8 @@ gfp_<X, L>::gfp_(long x)
{ {
if (x == 0) if (x == 0)
assign_zero(); assign_zero();
else if (x == 1)
assign_one();
else else
*this = bigint::tmp = x; *this = bigint::tmp = x;
} }

View File

@@ -146,8 +146,7 @@ gfp_<X, L> gfp_<X, L>::sqrRoot()
{ {
// Temp move to bigint so as to call sqrRootMod // Temp move to bigint so as to call sqrRootMod
bigint ti; bigint ti;
to_bigint(ti, *this); ti = sqrRootMod(*this);
ti = sqrRootMod(ti, ZpD.pr);
if (!isOdd(ti)) if (!isOdd(ti))
ti = ZpD.pr - ti; ti = ZpD.pr - ti;
gfp_<X, L> temp; gfp_<X, L> temp;

View File

@@ -312,8 +312,8 @@ gfpvar_<X, L> gfpvar_<X, L>::invert() const
template<int X, int L> template<int X, int L>
gfpvar_<X, L> gfpvar_<X, L>::sqrRoot() const gfpvar_<X, L> gfpvar_<X, L>::sqrRoot() const
{ {
bigint ti = *this; bigint ti;
ti = sqrRootMod(ti, ZpD.pr); ti = sqrRootMod(*this);
if (!isOdd(ti)) if (!isOdd(ti))
ti = ZpD.pr - ti; ti = ZpD.pr - ti;
return ti; return ti;

View File

@@ -81,6 +81,7 @@ public:
{ {
write_setup(get_prep_sub_dir<T>(nplayers)); write_setup(get_prep_sub_dir<T>(nplayers));
} }
static string fake_opts() { return " -lgp " + to_string(length()); }
gfpvar_(); gfpvar_();
gfpvar_(int other); gfpvar_(int other);

View File

@@ -2,7 +2,7 @@
#define _Modp #define _Modp
/* /*
* Currently we only support an MPIR based implementation. * Currently we only support an GMP based implementation.
* *
* What ever is type-def'd to bigint is assumed to have * What ever is type-def'd to bigint is assumed to have
* operator overloading for all standard operators, has * operator overloading for all standard operators, has

View File

@@ -6,7 +6,7 @@
#ifndef MATH_MPN_FIXED_H_ #ifndef MATH_MPN_FIXED_H_
#define MATH_MPN_FIXED_H_ #define MATH_MPN_FIXED_H_
#include <mpir.h> #include <gmp.h>
#include <string.h> #include <string.h>
#include <assert.h> #include <assert.h>

View File

@@ -3,7 +3,7 @@
* *
*/ */
#include <mpirxx.h> #include <gmpxx.h>
#include "OT/BitMatrix.h" #include "OT/BitMatrix.h"
#include "Tools/random.h" #include "Tools/random.h"

View File

@@ -78,7 +78,7 @@ void Names::init(int player, int pnb, const string& filename, int nplayers_wante
} }
} }
if (nplayers_wanted > 0 and nplayers_wanted != nplayers) if (nplayers_wanted > 0 and nplayers_wanted != nplayers)
throw runtime_error("not enought hosts in HOSTS"); throw runtime_error("not enough hosts in " + filename);
#ifdef DEBUG_NETWORKING #ifdef DEBUG_NETWORKING
cerr << "Got list of " << nplayers << " players from file: " << endl; cerr << "Got list of " << nplayers << " players from file: " << endl;
for (unsigned int i = 0; i < names.size(); i++) for (unsigned int i = 0; i < names.size(); i++)
@@ -324,7 +324,9 @@ void PlainPlayer::setup_sockets(const vector<string>& names,
template<class T> template<class T>
void MultiPlayer<T>::send_long(int i, long a) const void MultiPlayer<T>::send_long(int i, long a) const
{ {
TimeScope ts(comm_stats["Sending by number"].add(sizeof(long)));
send(sockets[i], (octet*)&a, sizeof(long)); send(sockets[i], (octet*)&a, sizeof(long));
sent += sizeof(long);
} }
template<class T> template<class T>
@@ -716,7 +718,7 @@ size_t VirtualTwoPartyPlayer::send(const PlayerBuffer& buffer, bool block) const
{ {
auto sent = P.send_no_stats(other_player, buffer, block); auto sent = P.send_no_stats(other_player, buffer, block);
lock.lock(); lock.lock();
comm_stats["Sending one-to-one"].add(sent); comm_stats.add_to_last_round("Sending one-to-one", sent);
comm_stats.sent += sent; comm_stats.sent += sent;
lock.unlock(); lock.unlock();
return sent; return sent;
@@ -726,7 +728,7 @@ size_t VirtualTwoPartyPlayer::recv(const PlayerBuffer& buffer, bool block) const
{ {
auto received = P.recv_no_stats(other_player, buffer, block); auto received = P.recv_no_stats(other_player, buffer, block);
lock.lock(); lock.lock();
comm_stats["Receiving one-to-one"].add(received); comm_stats.add_to_last_round("Receiving one-to-one", received);
lock.unlock(); lock.unlock();
return received; return received;
} }
@@ -805,6 +807,17 @@ void NamedCommStats::reset()
sent = 0; sent = 0;
} }
Timer& NamedCommStats::add_to_last_round(const string& name, size_t length)
{
if (name == last)
return (*this)[name].add_length_only(length);
else
{
last = name;
return (*this)[name].add(length);
}
}
void PlayerBase::reset_stats() void PlayerBase::reset_stats()
{ {
comm_stats.reset(); comm_stats.reset();

View File

@@ -136,11 +136,15 @@ struct CommStats
CommStats() : data(0), rounds(0) {} CommStats() : data(0), rounds(0) {}
Timer& add(size_t length) Timer& add(size_t length)
{ {
rounds++;
return add_length_only(length);
}
Timer& add_length_only(size_t length)
{
#ifdef VERBOSE_COMM #ifdef VERBOSE_COMM
cout << "add " << length << endl; cout << "add " << length << endl;
#endif #endif
data += length; data += length;
rounds++;
return timer; return timer;
} }
Timer& add(const octetStream& os) { return add(os.get_length()); } Timer& add(const octetStream& os) { return add(os.get_length()); }
@@ -153,6 +157,7 @@ class NamedCommStats : public map<string, CommStats>
{ {
public: public:
size_t sent; size_t sent;
string last;
NamedCommStats(); NamedCommStats();
@@ -161,6 +166,7 @@ public:
NamedCommStats operator-(const NamedCommStats& other) const; NamedCommStats operator-(const NamedCommStats& other) const;
void print(bool newline = false); void print(bool newline = false);
void reset(); void reset();
Timer& add_to_last_round(const string& name, size_t length);
#ifdef VERBOSE_COMM #ifdef VERBOSE_COMM
CommStats& operator[](const string& name) CommStats& operator[](const string& name)
{ {

View File

@@ -134,7 +134,7 @@ void close_client_socket(int socket)
if (close(socket)) if (close(socket))
{ {
char tmp[1000]; char tmp[1000];
sprintf(tmp, "close(%d)", socket); snprintf(tmp, 1000, "close(%d)", socket);
error(tmp); error(tmp);
} }
} }

View File

@@ -126,7 +126,7 @@ void BaseMachine::time()
void BaseMachine::start(int n) void BaseMachine::start(int n)
{ {
cout << "Starting timer " << n << " at " << timer[n].elapsed() cout << "Starting timer " << n << " at " << timer[n].elapsed()
<< " (" << timer[n].mb_sent() << " MB)" << " (" << timer[n] << ")"
<< " after " << timer[n].idle() << endl; << " after " << timer[n].idle() << endl;
timer[n].start(total_comm()); timer[n].start(total_comm());
} }
@@ -135,7 +135,7 @@ void BaseMachine::stop(int n)
{ {
timer[n].stop(total_comm()); timer[n].stop(total_comm());
cout << "Stopped timer " << n << " at " << timer[n].elapsed() << " (" cout << "Stopped timer " << n << " at " << timer[n].elapsed() << " ("
<< timer[n].mb_sent() << " MB)" << endl; << timer[n] << ")" << endl;
} }
void BaseMachine::print_timers() void BaseMachine::print_timers()
@@ -150,7 +150,7 @@ void BaseMachine::print_timers()
timer.erase(0); timer.erase(0);
for (auto it = timer.begin(); it != timer.end(); it++) for (auto it = timer.begin(); it != timer.end(); it++)
cerr << "Time" << it->first << " = " << it->second.elapsed() << " seconds (" cerr << "Time" << it->first << " = " << it->second.elapsed() << " seconds ("
<< it->second.mb_sent() << " MB)" << endl; << it->second << ")" << endl;
} }
string BaseMachine::memory_filename(const string& type_short, int my_number) string BaseMachine::memory_filename(const string& type_short, int my_number)
@@ -227,3 +227,19 @@ void BaseMachine::print_global_comm(Player& P, const NamedCommStats& stats)
global += os.get_int(8); global += os.get_int(8);
cerr << "Global data sent = " << global / 1e6 << " MB (all parties)" << endl; cerr << "Global data sent = " << global / 1e6 << " MB (all parties)" << endl;
} }
void BaseMachine::print_comm(Player& P, const NamedCommStats& comm_stats)
{
size_t rounds = 0;
for (auto& x : comm_stats)
rounds += x.second.rounds;
cerr << "Data sent = " << comm_stats.sent / 1e6 << " MB in ~" << rounds
<< " rounds (party " << P.my_num() << " only";
if (nthreads > 1)
cerr << "; rounds counted double due to multi-threading";
if (not OnlineOptions::singleton.verbose)
cerr << "; use '-v' for more details";
cerr << ")" << endl;
print_global_comm(P, comm_stats);
}

View File

@@ -67,6 +67,7 @@ public:
void print_timers(); void print_timers();
virtual void reqbl(int) {} virtual void reqbl(int) {}
virtual void active(int) {}
static OTTripleSetup fresh_ot_setup(Player& P); static OTTripleSetup fresh_ot_setup(Player& P);
@@ -74,6 +75,7 @@ public:
void set_thread_comm(const NamedCommStats& stats); void set_thread_comm(const NamedCommStats& stats);
void print_global_comm(Player& P, const NamedCommStats& stats); void print_global_comm(Player& P, const NamedCommStats& stats);
void print_comm(Player& P, const NamedCommStats& stats);
}; };
inline OTTripleSetup BaseMachine::fresh_ot_setup(Player& P) inline OTTripleSetup BaseMachine::fresh_ot_setup(Player& P)

41
Processor/Conv2dTuple.h Normal file
View File

@@ -0,0 +1,41 @@
/*
* Conv2dTuple.h
*
*/
#ifndef PROCESSOR_CONV2DTUPLE_H_
#define PROCESSOR_CONV2DTUPLE_H_
#include <vector>
using namespace std;
class Conv2dTuple
{
public:
int output_h, output_w;
int inputs_h, inputs_w;
int weights_h, weights_w;
int stride_h, stride_w;
int n_channels_in;
int padding_h;
int padding_w;
int batch_size;
size_t r0;
size_t r1;
int r2;
vector<vector<vector<int>>> lengths;
int filter_stride_h = 1;
int filter_stride_w = 1;
Conv2dTuple(const vector<int>& args, int start);
template<class T>
void pre(vector<T>& S, typename T::Protocol& protocol);
template<class T>
void post(vector<T>& S, typename T::Protocol& protocol);
template<class T>
void run_matrix(SubProcessor<T>& processor);
};
#endif /* PROCESSOR_CONV2DTUPLE_H_ */

View File

@@ -222,7 +222,8 @@ bool DataPositions::any_more(const DataPositions& other) const
for (auto it = edabits.begin(); it != edabits.end(); it++) for (auto it = edabits.begin(); it != edabits.end(); it++)
{ {
auto x = other.edabits.find(it->first); auto x = other.edabits.find(it->first);
if (x == other.edabits.end() or it->second > x->second) if ((x == other.edabits.end() or it->second > x->second)
and it->second > 0)
return true; return true;
} }

View File

@@ -12,6 +12,8 @@
#include "Networking/Player.h" #include "Networking/Player.h"
#include "Protocols/edabit.h" #include "Protocols/edabit.h"
#include "PrepBase.h" #include "PrepBase.h"
#include "EdabitBuffer.h"
#include "Tools/TimerWithComm.h"
#include <fstream> #include <fstream>
#include <map> #include <map>
@@ -102,9 +104,6 @@ protected:
DataPositions& usage; DataPositions& usage;
map<pair<bool, int>, vector<edabitvec<T>>> edabits;
map<pair<bool, int>, edabitvec<T>> my_edabits;
bool do_count; bool do_count;
void count(Dtype dtype, int n = 1) void count(Dtype dtype, int n = 1)
@@ -120,6 +119,8 @@ protected:
const vector<int>&, true_type) const vector<int>&, true_type)
{ throw not_implemented(); } { throw not_implemented(); }
void fill(edabitvec<T>& res, bool strict, int n_bits);
T get_random_from_inputs(int nplayers); T get_random_from_inputs(int nplayers);
public: public:
@@ -173,12 +174,11 @@ public:
virtual void get_edabits(bool strict, size_t size, T* a, virtual void get_edabits(bool strict, size_t size, T* a,
vector<typename T::bit_type>& Sb, const vector<int>& regs) vector<typename T::bit_type>& Sb, const vector<int>& regs)
{ get_edabits<0>(strict, size, a, Sb, regs, T::clear::characteristic_two); } { get_edabits<0>(strict, size, a, Sb, regs, T::clear::characteristic_two); }
template<int> virtual void get_edabit_no_count(bool, int, edabit<T>&)
void get_edabit_no_count(bool, int n_bits, edabit<T>& eb); { throw runtime_error("no edaBits"); }
template<int>
/// Get fresh edaBit chunk /// Get fresh edaBit chunk
edabitvec<T> get_edabitvec(bool strict, int n_bits); virtual edabitvec<T> get_edabitvec(bool, int)
virtual void buffer_edabits_with_queues(bool, int) { throw runtime_error("no edaBits"); } { throw runtime_error("no edabitvec"); }
virtual void push_triples(const vector<array<T, 3>>&) virtual void push_triples(const vector<array<T, 3>>&)
{ throw runtime_error("no pushing"); } { throw runtime_error("no pushing"); }
@@ -204,7 +204,8 @@ class Sub_Data_Files : public Preprocessing<T>
BufferOwner<InputTuple<T>, RefInputTuple<T>> my_input_buffers; BufferOwner<InputTuple<T>, RefInputTuple<T>> my_input_buffers;
map<DataTag, BufferOwner<T, T> > extended; map<DataTag, BufferOwner<T, T> > extended;
BufferOwner<dabit<T>, dabit<T>> dabit_buffer; BufferOwner<dabit<T>, dabit<T>> dabit_buffer;
map<int, ifstream*> edabit_buffers; map<int, EdabitBuffer<T>> edabit_buffers;
map<int, edabitvec<T>> my_edabits;
int my_num,num_players; int my_num,num_players;
@@ -213,13 +214,11 @@ class Sub_Data_Files : public Preprocessing<T>
part_type* part; part_type* part;
void buffer_edabits_with_queues(bool strict, int n_bits) EdabitBuffer<T>& get_edabit_buffer(int n_bits);
{ buffer_edabits_with_queues<0>(strict, n_bits, T::clear::characteristic_two); }
template<int> /// Get fresh edaBit chunk
void buffer_edabits_with_queues(bool strict, int n_bits, false_type); edabitvec<T> get_edabitvec(bool strict, int n_bits);
template<int> void get_edabit_no_count(bool strict, int n_bits, edabit<T>& eb);
void buffer_edabits_with_queues(bool, int, true_type)
{ throw not_implemented(); }
public: public:
static string get_filename(const Names& N, Dtype type, int thread_num = -1); static string get_filename(const Names& N, Dtype type, int thread_num = -1);
@@ -317,6 +316,8 @@ class Data_Files
void reset_usage() { usage.reset(); skipped.reset(); } void reset_usage() { usage.reset(); skipped.reset(); }
void set_usage(const DataPositions& pos) { usage = pos; } void set_usage(const DataPositions& pos) { usage = pos; }
TimerWithComm total_time();
}; };
template<class T> inline template<class T> inline

View File

@@ -108,7 +108,21 @@ Sub_Data_Files<T>::Sub_Data_Files(int my_num, int num_players,
#ifdef DEBUG_FILES #ifdef DEBUG_FILES
cerr << "Setting up Data_Files in: " << prep_data_dir << endl; cerr << "Setting up Data_Files in: " << prep_data_dir << endl;
#endif #endif
try
{
T::clear::check_setup(prep_data_dir); T::clear::check_setup(prep_data_dir);
}
catch (...)
{
cerr << "Something is wrong with the preprocessing data on disk." << endl;
cerr
<< "Have you run the right program for generating it, such as './Fake-Offline.x "
<< num_players
<< T::clear::fake_opts() << "'?" << endl;
throw;
}
string type_short = T::type_short(); string type_short = T::type_short();
string type_string = T::type_string(); string type_string = T::type_string();
@@ -135,7 +149,7 @@ Sub_Data_Files<T>::Sub_Data_Files(int my_num, int num_players,
type_short, i, my_num, thread_num); type_short, i, my_num, thread_num);
if (i == my_num) if (i == my_num)
my_input_buffers.setup(filename, my_input_buffers.setup(filename,
T::size() + T::clear::size(), type_string); InputTuple<T>::size(), type_string);
else else
input_buffers[i].setup(filename, input_buffers[i].setup(filename,
T::size(), type_string); T::size(), type_string);
@@ -179,10 +193,6 @@ Data_Files<sint, sgf2n>::~Data_Files()
template<class T> template<class T>
Sub_Data_Files<T>::~Sub_Data_Files() Sub_Data_Files<T>::~Sub_Data_Files()
{ {
for (auto& x: edabit_buffers)
{
delete x.second;
}
if (part != 0) if (part != 0)
delete part; delete part;
} }
@@ -229,6 +239,26 @@ void Sub_Data_Files<T>::seekg(DataPositions& pos)
extended[it->first].seekg(it->second); extended[it->first].seekg(it->second);
} }
dabit_buffer.seekg(pos.files[field_type][DATA_DABIT]); dabit_buffer.seekg(pos.files[field_type][DATA_DABIT]);
if (field_type == DATA_INT)
{
for (auto& x : pos.edabits)
{
// open files
get_edabit_buffer(x.first.second);
}
int block_size = edabitvec<T>::MAX_SIZE;
for (auto& x : edabit_buffers)
{
int n = pos.edabits[{true, x.first}] + pos.edabits[{false, x.first}];
x.second.seekg(n / block_size);
edabit<T> eb;
for (int i = 0; i < n % block_size; i++)
get_edabit_no_count(false, x.first, eb);
}
}
} }
template<class sint, class sgf2n> template<class sint, class sgf2n>
@@ -262,6 +292,8 @@ void Sub_Data_Files<T>::prune()
dabit_buffer.prune(); dabit_buffer.prune();
if (part != 0) if (part != 0)
part->prune(); part->prune();
for (auto& x : edabit_buffers)
x.second.prune();
} }
template<class sint, class sgf2n> template<class sint, class sgf2n>
@@ -285,6 +317,8 @@ void Sub_Data_Files<T>::purge()
dabit_buffer.purge(); dabit_buffer.purge();
if (part != 0) if (part != 0)
part->purge(); part->purge();
for (auto& x : edabit_buffers)
x.second.prune();
} }
template<class T> template<class T>
@@ -322,34 +356,43 @@ void Sub_Data_Files<T>::get_dabit_no_count(T& a, typename T::bit_type& b)
} }
template<class T> template<class T>
template<int> EdabitBuffer<T>& Sub_Data_Files<T>::get_edabit_buffer(int n_bits)
void Sub_Data_Files<T>::buffer_edabits_with_queues(bool strict, int n_bits,
false_type)
{ {
if (edabit_buffers.empty())
insecure("reading edaBits from files");
if (edabit_buffers.find(n_bits) == edabit_buffers.end()) if (edabit_buffers.find(n_bits) == edabit_buffers.end())
{ {
string filename = PrepBase::get_edabit_filename(prep_data_dir, string filename = PrepBase::get_edabit_filename(prep_data_dir,
n_bits, my_num, thread_num); n_bits, my_num, thread_num);
ifstream* f = new ifstream(filename); edabit_buffers[n_bits] = n_bits;
if (f->fail()) edabit_buffers[n_bits].setup(filename,
throw runtime_error("cannot open " + filename); T::size() * edabitvec<T>::MAX_SIZE
check_file_signature<T>(*f, filename); + n_bits * T::bit_type::part_type::size());
edabit_buffers[n_bits] = f;
} }
auto& buffer = *edabit_buffers[n_bits]; return edabit_buffers[n_bits];
if (buffer.peek() == EOF) }
template<class T>
edabitvec<T> Sub_Data_Files<T>::get_edabitvec(bool strict, int n_bits)
{ {
buffer.seekg(0); if (my_edabits[n_bits].empty())
check_file_signature<T>(buffer, ""); return get_edabit_buffer(n_bits).read();
else
{
auto res = my_edabits[n_bits];
my_edabits[n_bits] = {};
this->fill(res, strict, n_bits);
return res;
}
}
template<class T>
void Preprocessing<T>::fill(edabitvec<T>& res, bool strict, int n_bits)
{
edabit<T> eb;
while (res.size() < res.MAX_SIZE)
{
get_edabit_no_count(strict, n_bits, eb);
res.push_back(eb);
} }
edabitvec<T> eb;
eb.input(n_bits, buffer);
this->edabits[{strict, n_bits}].push_back(eb);
if (buffer.fail())
throw runtime_error("error reading edaBits");
} }
template<class T> template<class T>
@@ -362,4 +405,10 @@ typename Sub_Data_Files<T>::part_type& Sub_Data_Files<T>::get_part()
return *part; return *part;
} }
template<class sint, class sgf2n>
TimerWithComm Data_Files<sint, sgf2n>::total_time()
{
return DataFp.prep_timer + DataF2.prep_timer + DataFb.prep_timer;
}
#endif #endif

50
Processor/EdabitBuffer.h Normal file
View File

@@ -0,0 +1,50 @@
/*
* EdabitBuffer.h
*
*/
#ifndef PROCESSOR_EDABITBUFFER_H_
#define PROCESSOR_EDABITBUFFER_H_
#include "Tools/Buffer.h"
template<class T>
class EdabitBuffer : public BufferOwner<T, T>
{
int n_bits;
int element_length()
{
return -1;
}
public:
EdabitBuffer(int n_bits = 0) :
n_bits(n_bits)
{
}
edabitvec<T> read()
{
if (not BufferBase::file)
{
if (this->open()->fail())
throw runtime_error("error opening " + this->filename);
}
assert(BufferBase::file);
auto& buffer = *BufferBase::file;
if (buffer.peek() == EOF)
{
this->try_rewind();
}
edabitvec<T> eb;
eb.input(n_bits, buffer);
if (buffer.fail())
throw runtime_error("error reading edaBits");
return eb;
}
};
#endif /* PROCESSOR_EDABITBUFFER_H_ */

View File

@@ -70,6 +70,7 @@ enum
PLAYERID = 0xE4, PLAYERID = 0xE4,
USE_EDABIT = 0xE5, USE_EDABIT = 0xE5,
USE_MATMUL = 0x1F, USE_MATMUL = 0x1F,
ACTIVE = 0xE9,
// Addition // Addition
ADDC = 0x20, ADDC = 0x20,
ADDS = 0x21, ADDS = 0x21,

View File

@@ -311,6 +311,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case PRIVATEOUTPUT: case PRIVATEOUTPUT:
case TRUNC_PR: case TRUNC_PR:
case RUN_TAPE: case RUN_TAPE:
case CONV2DS:
num_var_args = get_int(s); num_var_args = get_int(s);
get_vector(num_var_args, start, s); get_vector(num_var_args, start, s);
break; break;
@@ -322,10 +323,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
get_ints(r, s, 3); get_ints(r, s, 3);
get_vector(9, start, s); get_vector(9, start, s);
break; break;
case CONV2DS:
get_ints(r, s, 3);
get_vector(12, start, s);
break;
// read from file, input is opcode num_args, // read from file, input is opcode num_args,
// start_file_posn (read), end_file_posn(write) var1, var2, ... // start_file_posn (read), end_file_posn(write) var1, var2, ...
@@ -425,6 +422,10 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
throw Processor_Error(ss.str()); throw Processor_Error(ss.str());
} }
break; break;
case ACTIVE:
n = get_int(s);
BaseMachine::s().active(n);
break;
case XORM: case XORM:
case ANDM: case ANDM:
case XORCB: case XORCB:
@@ -720,7 +721,16 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
case MATMULSM: case MATMULSM:
return r[0] + start[0] * start[2]; return r[0] + start[0] * start[2];
case CONV2DS: case CONV2DS:
return r[0] + start[0] * start[1] * start[11]; {
unsigned res = 0;
for (size_t i = 0; i < start.size(); i += 15)
{
unsigned tmp = start[i]
+ start[i + 3] * start[i + 4] * start.at(i + 14);
res = max(res, tmp);
}
return res;
}
case OPEN: case OPEN:
skip = 2; skip = 2;
break; break;
@@ -1164,6 +1174,7 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
break; break;
case REQBL: case REQBL:
case GREQBL: case GREQBL:
case ACTIVE:
case USE: case USE:
case USE_INP: case USE_INP:
case USE_EDABIT: case USE_EDABIT:

View File

@@ -109,6 +109,7 @@ class Machine : public BaseMachine
string prep_dir_prefix(); string prep_dir_prefix();
void reqbl(int n); void reqbl(int n);
void active(int n);
typename sint::bit_type::mac_key_type get_bit_mac_key() { return alphabi; } typename sint::bit_type::mac_key_type get_bit_mac_key() { return alphabi; }
typename sint::mac_key_type get_sint_mac_key() { return alphapi; } typename sint::mac_key_type get_sint_mac_key() { return alphapi; }

View File

@@ -415,6 +415,9 @@ pair<DataPositions, NamedCommStats> Machine<sint, sgf2n>::stop_threads()
auto comm_stats = total_comm(); auto comm_stats = total_comm();
if (OnlineOptions::singleton.verbose)
queues.print_breakdown();
for (auto& queue : queues) for (auto& queue : queues)
delete queue; delete queue;
@@ -477,20 +480,7 @@ void Machine<sint, sgf2n>::run(const string& progname)
print_timers(); print_timers();
if (sint::is_real) if (sint::is_real)
{ this->print_comm(*this->P, comm_stats);
size_t rounds = 0;
for (auto& x : comm_stats)
rounds += x.second.rounds;
cerr << "Data sent = " << comm_stats.sent / 1e6 << " MB in ~" << rounds
<< " rounds (party " << my_number;
if (threads.size() > 1)
cerr << "; rounds counted double due to multi-threading";
cerr << "; use '-v' for more details";
cerr << ")" << endl;
auto& P = *this->P;
this->print_global_comm(P, comm_stats);
}
#ifdef VERBOSE_OPTIONS #ifdef VERBOSE_OPTIONS
if (opening_sum < N.num_players() && !direct) if (opening_sum < N.num_players() && !direct)
@@ -521,23 +511,6 @@ void Machine<sint, sgf2n>::run(const string& progname)
bit_memories.write_memory(N.my_num()); bit_memories.write_memory(N.my_num());
#ifdef OLD_USAGE
for (int dtype = 0; dtype < N_DTYPE; dtype++)
{
cerr << "Num " << DataPositions::dtype_names[dtype] << "\t=";
for (int field_type = 0; field_type < N_DATA_FIELD_TYPE; field_type++)
cerr << " " << pos.files[field_type][dtype];
cerr << endl;
}
for (int field_type = 0; field_type < N_DATA_FIELD_TYPE; field_type++)
{
cerr << "Num " << DataPositions::field_names[field_type] << " Inputs\t=";
for (int i = 0; i < N.num_players(); i++)
cerr << " " << pos.inputs[i][field_type];
cerr << endl;
}
#endif
if (opts.verbose) if (opts.verbose)
{ {
cerr << "Actual cost of program:" << endl; cerr << "Actual cost of program:" << endl;
@@ -586,6 +559,17 @@ void Machine<sint, sgf2n>::reqbl(int n)
sint::clear::reqbl(n); sint::clear::reqbl(n);
} }
template<class sint, class sgf2n>
void Machine<sint, sgf2n>::active(int n)
{
if (sint::malicious and n == 0)
{
cerr << "Program requires a semi-honest protocol" << endl;
exit(1);
}
}
template<class sint, class sgf2n> template<class sint, class sgf2n>
void Machine<sint, sgf2n>::suggest_optimizations() void Machine<sint, sgf2n>::suggest_optimizations()
{ {
@@ -599,8 +583,8 @@ void Machine<sint, sgf2n>::suggest_optimizations()
optimizations.append("\tprogram.use_edabit(True)\n"); optimizations.append("\tprogram.use_edabit(True)\n");
if (not optimizations.empty()) if (not optimizations.empty())
cerr << "This program might benefit from some protocol options." << endl cerr << "This program might benefit from some protocol options." << endl
<< "Consider adding the following at the beginning of '" << progname << "Consider adding the following at the beginning of your code:"
<< ".mpc':" << endl << optimizations; << endl << optimizations;
#ifndef __clang__ #ifndef __clang__
cerr << "This virtual machine was compiled with GCC. Recompile with " cerr << "This virtual machine was compiled with GCC. Recompile with "
"'CXX = clang++' in 'CONFIG.mine' for optimal performance." << endl; "'CXX = clang++' in 'CONFIG.mine' for optimal performance." << endl;

View File

@@ -172,7 +172,7 @@ void OfflineMachine<W>::generate()
auto& opts = OnlineOptions::singleton; auto& opts = OnlineOptions::singleton;
opts.batch_size = DIV_CEIL(opts.batch_size, batch) * batch; opts.batch_size = DIV_CEIL(opts.batch_size, batch) * batch;
for (int i = 0; i < buffered_total(total, batch) / batch; i++) for (int i = 0; i < buffered_total(total, batch) / batch; i++)
preprocessing.template get_edabitvec<0>(true, n_bits).output(n_bits, preprocessing.get_edabitvec(true, n_bits).output(n_bits,
out); out);
} }
else else

View File

@@ -44,6 +44,7 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
auto& queues = machine.queues[num]; auto& queues = machine.queues[num];
queues->next(); queues->next();
ThreadQueue::thread_queue = queues;
#ifdef DEBUG_THREADS #ifdef DEBUG_THREADS
fprintf(stderr, "\tI am in thread %d\n",num); fprintf(stderr, "\tI am in thread %d\n",num);
@@ -118,6 +119,8 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
DataPositions actual_usage(P.num_players()); DataPositions actual_usage(P.num_players());
Timer thread_timer(CLOCK_THREAD_CPUTIME_ID), wait_timer; Timer thread_timer(CLOCK_THREAD_CPUTIME_ID), wait_timer;
thread_timer.start(); thread_timer.start();
TimerWithComm timer, online_timer, online_prep_timer;
timer.start();
while (flag) while (flag)
{ // Wait until I have a program to run { // Wait until I have a program to run
@@ -262,6 +265,8 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
#ifdef DEBUG_THREADS #ifdef DEBUG_THREADS
printf("\tClient %d about to run %d\n",num,program); printf("\tClient %d about to run %d\n",num,program);
#endif #endif
online_timer.start(P.total_comm());
online_prep_timer -= Proc.DataF.total_time();
Proc.reset(progs[program], job.arg); Proc.reset(progs[program], job.arg);
// Bits, Triples, Squares, and Inverses skipping // Bits, Triples, Squares, and Inverses skipping
@@ -290,6 +295,8 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
printf("\tSignalling I have finished with program %d" printf("\tSignalling I have finished with program %d"
"in thread %d\n", program, num); "in thread %d\n", program, num);
#endif #endif
online_timer.stop(P.total_comm());
online_prep_timer += Proc.DataF.total_time();
wait_timer.start(); wait_timer.start();
queues->finished(job, P.total_comm()); queues->finished(job, P.total_comm());
wait_timer.stop(); wait_timer.stop();
@@ -297,7 +304,11 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
} }
// final check // final check
online_timer.start(P.total_comm());
online_prep_timer -= Proc.DataF.total_time();
Proc.check(); Proc.check();
online_timer.stop(P.total_comm());
online_prep_timer += Proc.DataF.total_time();
if (machine.opts.file_prep_per_thread) if (machine.opts.file_prep_per_thread)
Proc.DataF.prune(); Proc.DataF.prune();
@@ -330,6 +341,11 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
// wind down thread by thread // wind down thread by thread
machine.stats += Proc.stats; machine.stats += Proc.stats;
queues->timers["wait"] = wait_timer + queues->wait_timer;
timer.stop(P.total_comm());
queues->timers["online"] = online_timer - online_prep_timer - queues->wait_timer;
queues->timers["prep"] = timer - queues->timers["wait"] - queues->timers["online"];
// prevent faulty usage message // prevent faulty usage message
Proc.DataF.set_usage(actual_usage); Proc.DataF.set_usage(actual_usage);
delete processor; delete processor;

View File

@@ -69,7 +69,7 @@ void PrepBase::print_left_edabits(size_t n, size_t n_batch, bool strict,
cerr << " edaBits of size " << n_bits << " left" << endl; cerr << " edaBits of size " << n_bits << " left" << endl;
} }
if (n > used / 10) if (n * n_batch > used / 10)
cerr << "Significant amount of unused edaBits of size " << n_bits cerr << "Significant amount of unused edaBits of size " << n_bits
<< ". For more accurate benchmarks, " << ". For more accurate benchmarks, "
<< "consider reducing the batch size with --batch-size " << "consider reducing the batch size with --batch-size "

View File

@@ -10,6 +10,7 @@
using namespace std; using namespace std;
#include "Math/field_types.h" #include "Math/field_types.h"
#include "Tools/TimerWithComm.h"
class PrepBase class PrepBase
{ {
@@ -28,6 +29,8 @@ public:
const string& type_string, size_t used); const string& type_string, size_t used);
static void print_left_edabits(size_t n, size_t n_batch, bool strict, static void print_left_edabits(size_t n, size_t n_batch, bool strict,
int n_bits, size_t used); int n_bits, size_t used);
TimerWithComm prep_timer;
}; };
#endif /* PROCESSOR_PREPBASE_H_ */ #endif /* PROCESSOR_PREPBASE_H_ */

View File

@@ -5,6 +5,7 @@
#include "Processor/Program.h" #include "Processor/Program.h"
#include "GC/square64.h" #include "GC/square64.h"
#include "SpecificPrivateOutput.h" #include "SpecificPrivateOutput.h"
#include "Conv2dTuple.h"
#include "Processor/ProcessorBase.hpp" #include "Processor/ProcessorBase.hpp"
#include "GC/Processor.hpp" #include "GC/Processor.hpp"
@@ -31,6 +32,7 @@ SubProcessor<T>::SubProcessor(typename T::MAC_Check& MC,
DataF.set_proc(this); DataF.set_proc(this);
protocol.init(DataF, MC); protocol.init(DataF, MC);
DataF.set_protocol(protocol); DataF.set_protocol(protocol);
MC.set_prep(DataF);
bit_usage.set_num_players(P.num_players()); bit_usage.set_num_players(P.num_players());
personal_bit_preps.resize(P.num_players()); personal_bit_preps.resize(P.num_players());
for (int i = 0; i < P.num_players(); i++) for (int i = 0; i < P.num_players(); i++)
@@ -40,6 +42,7 @@ SubProcessor<T>::SubProcessor(typename T::MAC_Check& MC,
template<class T> template<class T>
SubProcessor<T>::~SubProcessor() SubProcessor<T>::~SubProcessor()
{ {
DataF.set_proc(0);
for (size_t i = 0; i < personal_bit_preps.size(); i++) for (size_t i = 0; i < personal_bit_preps.size(); i++)
{ {
auto& x = personal_bit_preps[i]; auto& x = personal_bit_preps[i];
@@ -391,7 +394,7 @@ void Processor<sint, sgf2n>::read_shares_from_file(int start_file_posn, int end_
return; return;
string filename; string filename;
filename = "Persistence/Transactions-P" + to_string(P.my_num()) + ".data"; filename = binary_file_io.filename(P.my_num());
unsigned int size = data_registers.size(); unsigned int size = data_registers.size();
@@ -652,21 +655,35 @@ void SubProcessor<T>::conv2ds(const Instruction& instruction)
{ {
protocol.init_dotprod(); protocol.init_dotprod();
auto& args = instruction.get_start(); auto& args = instruction.get_start();
int output_h = args[0], output_w = args[1]; vector<Conv2dTuple> tuples;
int inputs_h = args[2], inputs_w = args[3]; for (size_t i = 0; i < args.size(); i += 15)
int weights_h = args[4], weights_w = args[5]; tuples.push_back(Conv2dTuple(args, i));
int stride_h = args[6], stride_w = args[7]; for (auto& tuple : tuples)
int n_channels_in = args[8]; tuple.pre(S, protocol);
int padding_h = args[9]; protocol.exchange();
int padding_w = args[10]; for (auto& tuple : tuples)
int batch_size = args[11]; tuple.post(S, protocol);
size_t r0 = instruction.get_r(0); }
size_t r1 = instruction.get_r(1);
int r2 = instruction.get_r(2); inline
int lengths[batch_size][output_h][output_w]; Conv2dTuple::Conv2dTuple(const vector<int>& arguments, int start)
memset(lengths, 0, sizeof(lengths)); {
int filter_stride_h = 1; assert(arguments.size() >= start + 15ul);
int filter_stride_w = 1; auto args = arguments.data() + start + 3;
output_h = args[0], output_w = args[1];
inputs_h = args[2], inputs_w = args[3];
weights_h = args[4], weights_w = args[5];
stride_h = args[6], stride_w = args[7];
n_channels_in = args[8];
padding_h = args[9];
padding_w = args[10];
batch_size = args[11];
r0 = arguments[start];
r1 = arguments[start + 1];
r2 = arguments[start + 2];
lengths.resize(batch_size, vector<vector<int>>(output_h, vector<int>(output_w)));
filter_stride_h = 1;
filter_stride_w = 1;
if (stride_h < 0) if (stride_h < 0)
{ {
filter_stride_h = -stride_h; filter_stride_h = -stride_h;
@@ -677,7 +694,11 @@ void SubProcessor<T>::conv2ds(const Instruction& instruction)
filter_stride_w = -stride_w; filter_stride_w = -stride_w;
stride_w = 1; stride_w = 1;
} }
}
template<class T>
void Conv2dTuple::pre(vector<T>& S, typename T::Protocol& protocol)
{
for (int i_batch = 0; i_batch < batch_size; i_batch ++) for (int i_batch = 0; i_batch < batch_size; i_batch ++)
{ {
size_t base = r1 + i_batch * inputs_w * inputs_h * n_channels_in; size_t base = r1 + i_batch * inputs_w * inputs_h * n_channels_in;
@@ -714,9 +735,11 @@ void SubProcessor<T>::conv2ds(const Instruction& instruction)
protocol.next_dotprod(); protocol.next_dotprod();
} }
} }
}
protocol.exchange(); template<class T>
void Conv2dTuple::post(vector<T>& S, typename T::Protocol& protocol)
{
for (int i_batch = 0; i_batch < batch_size; i_batch ++) for (int i_batch = 0; i_batch < batch_size; i_batch ++)
{ {
size_t base = r0 + i_batch * output_h * output_w; size_t base = r0 + i_batch * output_h * output_w;

View File

@@ -6,6 +6,8 @@
#include "ThreadQueue.h" #include "ThreadQueue.h"
thread_local ThreadQueue* ThreadQueue::thread_queue = 0;
void ThreadQueue::schedule(const ThreadJob& job) void ThreadQueue::schedule(const ThreadJob& job)
{ {
lock.lock(); lock.lock();
@@ -14,7 +16,11 @@ void ThreadQueue::schedule(const ThreadJob& job)
cerr << this << ": " << left << " left" << endl; cerr << this << ": " << left << " left" << endl;
#endif #endif
lock.unlock(); lock.unlock();
if (thread_queue)
thread_queue->wait_timer.start();
in.push(job); in.push(job);
if (thread_queue)
thread_queue->wait_timer.stop();
} }
ThreadJob ThreadQueue::next() ThreadJob ThreadQueue::next()
@@ -42,7 +48,11 @@ void ThreadQueue::set_comm_stats(const NamedCommStats& new_comm_stats)
ThreadJob ThreadQueue::result() ThreadJob ThreadQueue::result()
{ {
if (thread_queue)
thread_queue->wait_timer.start();
auto res = out.pop(); auto res = out.pop();
if (thread_queue)
thread_queue->wait_timer.stop();
lock.lock(); lock.lock();
left--; left--;
#ifdef DEBUG_THREAD_QUEUE #ifdef DEBUG_THREAD_QUEUE

View File

@@ -16,6 +16,11 @@ class ThreadQueue
NamedCommStats comm_stats; NamedCommStats comm_stats;
public: public:
static thread_local ThreadQueue* thread_queue;
map<string, TimerWithComm> timers;
Timer wait_timer;
ThreadQueue() : ThreadQueue() :
left(0) left(0)
{ {

View File

@@ -85,3 +85,32 @@ void ThreadQueues::wrap_up(ThreadJob job)
} }
available.clear(); available.clear();
} }
TimerWithComm ThreadQueues::sum(const string& phase)
{
TimerWithComm res;
for (auto& x : *this)
res += x->timers[phase];
return res;
}
void ThreadQueues::print_breakdown()
{
if (size() > 0)
{
if (size() == 1)
{
cerr << "Spent " << (*this)[0]->timers["online"].full()
<< " on the online phase and "
<< (*this)[0]->timers["prep"].full()
<< " on the preprocessing/offline phase." << endl;
}
else
{
cerr << size() << " threads spent a total of " << sum("online").full()
<< " on the online phase, " << sum("prep").full()
<< " on the preprocessing/offline phase, and "
<< sum("wait").full() << " idling." << endl;
}
}
}

View File

@@ -24,6 +24,10 @@ public:
int distribute_no_setup(ThreadJob job, int n_items, int base = 0, int distribute_no_setup(ThreadJob job, int n_items, int base = 0,
int granularity = 1, const vector<void*>* supplies = 0); int granularity = 1, const vector<void*>* supplies = 0);
void wrap_up(ThreadJob job); void wrap_up(ThreadJob job);
TimerWithComm sum(const string& phase);
void print_breakdown();
}; };
#endif /* PROCESSOR_THREADQUEUES_H_ */ #endif /* PROCESSOR_THREADQUEUES_H_ */

View File

@@ -387,6 +387,7 @@
X(GENSECSHUFFLE, throw not_implemented(),) \ X(GENSECSHUFFLE, throw not_implemented(),) \
X(APPLYSHUFFLE, throw not_implemented(),) \ X(APPLYSHUFFLE, throw not_implemented(),) \
X(DELSHUFFLE, throw not_implemented(),) \ X(DELSHUFFLE, throw not_implemented(),) \
X(ACTIVE, throw not_implemented(),) \
#define ALL_INSTRUCTIONS ARITHMETIC_INSTRUCTIONS REGINT_INSTRUCTIONS \ #define ALL_INSTRUCTIONS ARITHMETIC_INSTRUCTIONS REGINT_INSTRUCTIONS \
CLEAR_GF2N_INSTRUCTIONS REMAINING_INSTRUCTIONS CLEAR_GF2N_INSTRUCTIONS REMAINING_INSTRUCTIONS

114
Programs/Source/alex.mpc Normal file
View File

@@ -0,0 +1,114 @@
from Compiler.ml import keras
import Compiler.ml as tf
try:
n_epochs = int(program.args[1])
except (ValueError, IndexError):
n_epochs = 20
try:
batch_size = int(program.args[2])
except (ValueError, IndexError):
batch_size = 128
try:
n_threads = int(program.args[3])
except (ValueError, IndexError):
n_threads = 36
#Instantiation
AlexNet = []
padding = 1
batchnorm = 'batchnorm' in program.args
bn1 = 'bn1' in program.args
bn2 = 'bn2' in program.args
MultiArray.disable_index_checks()
#1st Convolutional Layer
AlexNet.append(keras.layers.Conv2D(filters=64, input_shape=(32,32,3), kernel_size=3, strides=1, padding=2))
AlexNet.append(keras.layers.Activation('relu'))
if batchnorm:
AlexNet.append(keras.layers.BatchNormalization())
AlexNet.append(keras.layers.MaxPooling2D(pool_size=(2,2), strides=(2,2), padding=0))
#2nd Convolutional Layer
AlexNet.append(keras.layers.Conv2D(filters=96, kernel_size=3, strides=1, padding=2))
AlexNet.append(keras.layers.Activation('relu'))
if batchnorm or bn2:
AlexNet.append(keras.layers.BatchNormalization())
AlexNet.append(keras.layers.MaxPooling2D(pool_size=(2,2), strides=(2,2), padding='same'))
#3rd Convolutional Layer
AlexNet.append(keras.layers.Conv2D(filters=96, kernel_size=(3,3), strides=(1,1), padding=padding))
AlexNet.append(keras.layers.Activation('relu'))
if batchnorm:
AlexNet.append(keras.layers.BatchNormalization())
#4th Convolutional Layer
AlexNet.append(keras.layers.Conv2D(filters=64, kernel_size=(3,3), strides=(1,1), padding=padding))
AlexNet.append(keras.layers.Activation('relu'))
if batchnorm or bn1:
AlexNet.append(keras.layers.BatchNormalization())
#5th Convolutional Layer
AlexNet.append(keras.layers.Conv2D(filters=64, kernel_size=(3,3), strides=(1,1), padding=padding))
AlexNet.append(keras.layers.Activation('relu'))
if batchnorm or bn2:
AlexNet.append(keras.layers.BatchNormalization())
AlexNet.append(keras.layers.MaxPooling2D(pool_size=(3,3), strides=(2,2), padding=0))
#Passing it to a Fully Connected layer
# 1st Fully Connected Layer
AlexNet.append(keras.layers.Dense(128))
AlexNet.append(keras.layers.Activation('relu'))
if 'dropout' in program.args:
AlexNet.append(keras.layers.Dropout(0.5))
#2nd Fully Connected Layer
AlexNet.append(keras.layers.Dense(256))
AlexNet.append(keras.layers.Activation('relu'))
if 'dropout' in program.args:
AlexNet.append(keras.layers.Dropout(0.5))
#Output Layer
AlexNet.append(keras.layers.Dense(10))
tf.set_n_threads(n_threads)
program.options_from_args()
sfix.set_precision_from_args(program, adapt_ring=True)
training_samples = MultiArray([50000, 32, 32, 3], sfix)
training_labels = MultiArray([50000, 10], sint)
test_samples = MultiArray([10000, 32, 32, 3], sfix)
test_labels = MultiArray([10000, 10], sint)
training_labels.input_from(0)
training_samples.input_from(0, binary='binary_samples' in program.args)
test_labels.input_from(0)
test_samples.input_from(0, binary='binary_samples' in program.args)
model = tf.keras.models.Sequential(AlexNet)
model.compile_by_args(program)
model.build(training_samples.sizes)
model.summary()
model.opt.output_diff = 'output_diff' in program.args
model.opt.output_grad = 'output_grad' in program.args
model.opt.output_stats = 100 if 'output_stats' in program.args else 0
model.opt.shuffle = not 'noshuffle' in program.args
opt = model.fit(
training_samples,
training_labels,
epochs=n_epochs,
batch_size=batch_size,
validation_data=(test_samples, test_labels)
)

View File

@@ -25,6 +25,9 @@ n_threads = 2
if len(program.args) > 1: if len(program.args) > 1:
n_rounds = int(program.args[1]) n_rounds = int(program.args[1])
if len(program.args) > 2:
program.active = bool(int(program.args[2]))
def accept_client(): def accept_client():
client_socket_id = accept_client_connection(PORTNUM) client_socket_id = accept_client_connection(PORTNUM)
last = regint.read_from_socket(client_socket_id) last = regint.read_from_socket(client_socket_id)

View File

@@ -4,7 +4,7 @@ import Compiler.ml as tf
try: try:
n_epochs = int(program.args[1]) n_epochs = int(program.args[1])
except (ValueError, IndexError): except (ValueError, IndexError):
n_epochs = 10 n_epochs = 20
try: try:
batch_size = int(program.args[2]) batch_size = int(program.args[2])

View File

@@ -0,0 +1,72 @@
# this trains LeNet on MNIST with a dropout layer
# see https://github.com/csiro-mlai/mnist-mpc for data preparation
program.options_from_args()
if 'torch' in program.args:
import torchvision
data = []
for train in True, False:
ds = torchvision.datasets.MNIST(root='/tmp', train=train, download=True)
# normalize to [0,1] before input
samples = sfix.input_tensor_via(0, ds.data / 255., binary=True)
labels = sint.input_tensor_via(0, ds.targets, binary=True, one_hot=True)
data += [(labels, samples)]
(training_labels, training_samples), (test_labels, test_samples) = data
else:
training_samples = sfix.Tensor([60000, 28, 28])
training_labels = sint.Tensor([60000, 10])
test_samples = sfix.Tensor([10000, 28, 28])
test_labels = sint.Tensor([10000, 10])
training_labels.input_from(0)
training_samples.input_from(0)
test_labels.input_from(0)
test_samples.input_from(0)
from Compiler import ml
tf = ml
layers = [
tf.keras.layers.Conv2D(20, 5, 1, 'valid', activation='relu'),
]
if 'batchnorm' in program.args:
layers += [tf.keras.layers.BatchNormalization()]
layers += [
tf.keras.layers.AveragePooling2D(2),
tf.keras.layers.Conv2D(50, 5, 1, 'valid', activation='relu'),
]
if 'batchnorm' in program.args:
layers += [tf.keras.layers.BatchNormalization()]
layers += [
tf.keras.layers.AveragePooling2D(2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(500, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
]
model = tf.keras.models.Sequential(layers)
optim = tf.keras.optimizers.Adam(amsgrad=True)
model.compile(optimizer=optim)
opt = model.fit(
training_samples,
training_labels,
epochs=10,
batch_size=128,
validation_data=(test_samples, test_labels)
)
for var in model.trainable_variables:
var.write_to_file()

View File

@@ -0,0 +1,49 @@
# this trains a dense neural network on MNIST
program.options_from_args()
import torchvision
data = []
for train in True, False:
ds = torchvision.datasets.MNIST(root='/tmp', train=train, download=True)
# normalize to [0,1] before input
samples = sfix.input_tensor_via(0, ds.data / 255., binary=True)
labels = sint.input_tensor_via(0, ds.targets, binary=True, one_hot=True)
data += [(labels, samples)]
import torch
import torch.nn as nn
net = nn.Sequential(
nn.Conv2d(1, 20, 5),
nn.ReLU(),
nn.AvgPool2d(2),
nn.Conv2d(20, 50, 5),
nn.ReLU(),
nn.AvgPool2d(2),
nn.Flatten(),
nn.ReLU(),
nn.Linear(800, 500),
nn.ReLU(),
nn.Linear(500, 10)
)
# test network
ds = torchvision.datasets.MNIST(
root='/tmp', transform=torchvision.transforms.ToTensor())
inputs = next(iter(torch.utils.data.DataLoader(ds)))[0]
print(inputs.shape)
outputs = net(inputs)
from Compiler import ml
ml.set_n_threads(int(program.args[2]))
layers = ml.layers_from_torch(net, data[0][1].shape, 128)
layers[0].X = data[0][1]
layers[-1].Y = data[0][0]
optimizer = ml.SGD(layers)
optimizer.run_by_args(program, int(program.args[1]), 128,
data[1][1], data[1][0])

View File

@@ -11,6 +11,8 @@ DealerMatrixPrep<T>::DealerMatrixPrep(int n_rows, int n_inner, int n_cols,
super(usage), n_rows(n_rows), n_inner(n_inner), n_cols(n_cols), super(usage), n_rows(n_rows), n_inner(n_inner), n_cols(n_cols),
prep(&prep) prep(&prep)
{ {
assert(prep.proc);
this->P = &prep.proc->P;
} }
template<class T> template<class T>

View File

@@ -20,9 +20,6 @@ class Hemi : public T::BasicProtocol
MatrixMC<T> mc; MatrixMC<T> mc;
ShareMatrix<T> matrix_multiply(const ShareMatrix<T>& A, const ShareMatrix<T>& B,
SubProcessor<T>& processor);
public: public:
Hemi(Player& P) : Hemi(Player& P) :
T::BasicProtocol(P) T::BasicProtocol(P)
@@ -33,6 +30,9 @@ public:
typename T::MatrixPrep& get_matrix_prep(const array<int, 3>& dimensions, typename T::MatrixPrep& get_matrix_prep(const array<int, 3>& dimensions,
SubProcessor<T>& processor); SubProcessor<T>& processor);
ShareMatrix<T> matrix_multiply(const ShareMatrix<T>& A, const ShareMatrix<T>& B,
SubProcessor<T>& processor);
void matmulsm(SubProcessor<T>& processor, CheckVector<T>& source, void matmulsm(SubProcessor<T>& processor, CheckVector<T>& source,
const Instruction& instruction, int a, int b); const Instruction& instruction, int a, int b);
void conv2ds(SubProcessor<T>& processor, const Instruction& instruction); void conv2ds(SubProcessor<T>& processor, const Instruction& instruction);

View File

@@ -130,37 +130,23 @@ void Hemi<T>::conv2ds(SubProcessor<T>& processor,
} }
auto& args = instruction.get_start(); auto& args = instruction.get_start();
int output_h = args[0], output_w = args[1]; vector<Conv2dTuple> tuples;
int inputs_h = args[2], inputs_w = args[3]; for (size_t i = 0; i < args.size(); i += 15)
int weights_h = args[4], weights_w = args[5]; tuples.push_back(Conv2dTuple(args, i));
int stride_h = args[6], stride_w = args[7]; for (auto& tuple : tuples)
int n_channels_in = args[8]; tuple.run_matrix(processor);
int padding_h = args[9];
int padding_w = args[10];
int batch_size = args[11];
size_t r0 = instruction.get_r(0);
size_t r1 = instruction.get_r(1);
int r2 = instruction.get_r(2);
int filter_stride_h = 1;
int filter_stride_w = 1;
if (stride_h < 0)
{
filter_stride_h = -stride_h;
stride_h = 1;
}
if (stride_w < 0)
{
filter_stride_w = -stride_w;
stride_w = 1;
} }
template<class T>
void Conv2dTuple::run_matrix(SubProcessor<T>& processor)
{
auto& S = processor.get_S(); auto& S = processor.get_S();
array<int, 3> dim({{1, weights_h * weights_w * n_channels_in, batch_size * output_h * output_w}}); array<int, 3> dim({{1, weights_h * weights_w * n_channels_in, batch_size * output_h * output_w}});
ShareMatrix<T> A(dim[0], dim[1]), B(dim[1], dim[2]); ShareMatrix<T> A(dim[0], dim[1]), B(dim[1], dim[2]);
if (not T::real_shares(processor.P)) if (not T::real_shares(processor.P))
{ {
matrix_multiply(A, B, processor); processor.protocol.matrix_multiply(A, B, processor);
return; return;
} }
@@ -208,7 +194,7 @@ void Hemi<T>::conv2ds(SubProcessor<T>& processor,
} }
} }
auto C = matrix_multiply(A, B, processor); auto C = processor.protocol.matrix_multiply(A, B, processor);
for (int i_batch = 0; i_batch < batch_size; i_batch ++) for (int i_batch = 0; i_batch < batch_size; i_batch ++)
{ {

View File

@@ -37,6 +37,8 @@ public:
if (swapped) if (swapped)
std::swap(this->n_rows, this->n_cols); std::swap(this->n_rows, this->n_cols);
assert(this->n_cols >= this->n_rows); assert(this->n_cols >= this->n_rows);
assert(prep.proc);
this->P = &prep.proc->P;
} }
void set_protocol(typename ShareMatrix<T>::Protocol&) void set_protocol(typename ShareMatrix<T>::Protocol&)

View File

@@ -21,11 +21,7 @@ void MaliciousShamirMC<T>::init_open(const Player& P, int n)
reconstructions.resize(2 * threshold + 2); reconstructions.resize(2 * threshold + 2);
for (int i = threshold + 1; i <= 2 * threshold + 1; i++) for (int i = threshold + 1; i <= 2 * threshold + 1; i++)
{ {
reconstructions[i].resize(i); reconstructions[i] = ShamirMC<T>::get_reconstruction(P, i);
for (int j = 0; j < i; j++)
reconstructions[i][j] =
Shamir<T>::get_rec_factor(P.get_player(j),
P.num_players(), P.my_num(), i);
} }
} }

Some files were not shown because too many files have changed in this diff Show More