Semi-honest computation based on threshold semi-homomorphic encryption.

This commit is contained in:
Marcel Keller
2022-02-17 13:21:19 +11:00
parent 61d40b7d83
commit 0f7020d791
129 changed files with 1973 additions and 539 deletions

View File

@@ -1,6 +1,17 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
## 0.2.9 (Jan 11, 2021)
## 0.3.0 (Feb 17, 2022)
- Semi-honest computation based on threshold semi-homomorphic encryption
- Batch normalization backward propagation
- AlexNet for CIFAR-10
- Specific private output protocols
- Semi-honest additive secret sharing without communication
- Sending of personal values
- Allow overwriting of persistence files
- Protocol signature in persistence files
## 0.2.9 (Jan 11, 2022)
- Disassembler
- Run-time parameter for probabilistic truncation error

1
CONFIG
View File

@@ -42,6 +42,7 @@ else
AVX_OT = 1
endif
else
ARCH =
AVX_OT = 0
endif

View File

@@ -497,7 +497,7 @@ class movsb(NonVectorInstruction):
code = opcodes['MOVSB']
arg_format = ['sbw','sb']
class trans(base.VarArgsInstruction):
class trans(base.VarArgsInstruction, base.DynFormatInstruction):
""" Secret bit register vector transpose. The first destination vector
will contain the least significant bits of all source vectors etc.
@@ -511,10 +511,22 @@ class trans(base.VarArgsInstruction):
code = opcodes['TRANS']
is_vec = lambda self: True
def __init__(self, *args):
self.arg_format = ['int'] + ['sbw'] * args[0] + \
['sb'] * (len(args) - 1 - args[0])
super(trans, self).__init__(*args)
@classmethod
def dynamic_arg_format(cls, args):
yield 'int'
n = next(args)
for i in range(n):
yield 'sbw'
next(args)
while True:
try:
yield 'sb'
next(args)
except StopIteration:
break
class bitb(NonVectorInstruction):
""" Copy fresh secret random bit to secret bit register.
@@ -560,7 +572,7 @@ class inputb(base.DoNotEliminateInstruction, base.VarArgsInstruction):
req_node.increment(('bit', 'input', self.args[i]), self.args[i + 1])
class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction,
base.Mergeable):
base.Mergeable, base.DynFormatInstruction):
""" Copy private input to secret bit registers bit by bit. The input is
read as floating-point number, multiplied by a power of two, rounded to an
integer, and then decomposed into bits.
@@ -577,11 +589,18 @@ class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction,
code = opcodes['INPUTBVEC']
def __init__(self, *args, **kwargs):
self.arg_format = []
for x in self.get_arg_tuples(args):
self.arg_format += ['int', 'int', 'p'] + ['sbw'] * (x[0] - 3)
super(inputbvec, self).__init__(*args, **kwargs)
@classmethod
def dynamic_arg_format(cls, args):
yield 'int'
for i, n in cls.bases(args):
yield 'int'
yield 'p'
for j in range(n - 3):
yield 'sbw'
yield 'int'
@staticmethod
def get_arg_tuples(args):
i = 0
@@ -590,10 +609,6 @@ class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction,
i += args[i]
assert i == len(args)
def merge(self, other):
self.args += other.args
self.arg_format += other.arg_format
def add_usage(self, req_node):
for x in self.get_arg_tuples(self.args):
req_node.increment(('bit', 'input', x[2]), x[0] - 3)

View File

@@ -41,7 +41,7 @@ class bits(Tape.Register, _structure, _bit):
return cls.types[length]
@classmethod
def conv(cls, other):
if isinstance(other, cls):
if isinstance(other, cls) and cls.n == other.n:
return other
elif isinstance(other, MemValue):
return cls.conv(other.read())
@@ -246,14 +246,20 @@ class cbits(bits):
assert n == res.n
assert n == other.size
cls.conv_cint_vec(cint(other, size=other.size), res)
@classmethod
def conv(cls, other):
if isinstance(other, cbits) and cls.n != None and \
cls.n // cls.unit == other.n // cls.unit:
return other
else:
return super(cbits, cls).conv(other)
types = {}
def load_int(self, value):
if self.n <= 64:
tmp = regint(value)
elif value == self.long_one():
tmp = cint(1, size=self.n)
else:
raise CompilerError('loading long integers to cbits not supported')
n_limbs = math.ceil(self.n / self.unit)
tmp = regint(size=n_limbs)
for i in range(n_limbs):
tmp[i].load_int(value % 2 ** self.unit)
value >>= self.unit
self.load_other(tmp)
def store_in_dynamic_mem(self, address):
inst.stmsdci(self, cbits.conv(address))
@@ -1163,14 +1169,14 @@ class cbitfix(object):
@classmethod
def _new(cls, value):
res = cls()
if cls.k < value.unit:
bits = value.bit_decompose(cls.k)
sign = bits[-1]
value += (sign << (cls.k)) * -1
res.v = value
return res
def output(self):
v = self.v
if self.k < v.unit:
bits = self.v.bit_decompose(self.k)
sign = bits[-1]
v += (sign << (self.k)) * -1
inst.print_float_plainb(v, cbits.get_type(32)(-self.f), cbits(0),
cbits(0), cbits(0))

View File

@@ -403,6 +403,20 @@ class Merger:
add_edge(last_input[t][1], n)
last_input[t][0] = n
def keep_text_order(inst, n):
if inst.get_players() is None:
# switch
for x in list(last_input.keys()):
if isinstance(x, int):
add_edge(last_input[x][0], n)
del last_input[x]
keep_merged_order(instr, n, None)
elif last_input[None][0] is not None:
keep_merged_order(instr, n, None)
else:
for player in inst.get_players():
keep_merged_order(instr, n, player)
for n,instr in enumerate(block.instructions):
outputs,inputs = instr.get_def(), instr.get_used()
@@ -427,7 +441,7 @@ class Merger:
# will be merged
if isinstance(instr, TextInputInstruction):
keep_merged_order(instr, n, TextInputInstruction)
keep_text_order(instr, n)
elif isinstance(instr, RawInputInstruction):
keep_merged_order(instr, n, RawInputInstruction)
@@ -479,10 +493,6 @@ class Merger:
last_print_str = n
elif isinstance(instr, PublicFileIOInstruction):
keep_order(instr, n, instr.__class__)
elif isinstance(instr, startprivateoutput_class):
keep_order(instr, n, startprivateoutput_class, 2)
elif isinstance(instr, stopprivateoutput_class):
keep_order(instr, n, stopprivateoutput_class, 2)
elif isinstance(instr, prep_class):
keep_order(instr, n, instr.args[0])
elif isinstance(instr, StackInstruction):

View File

@@ -421,6 +421,10 @@ class use_matmul(base.Instruction):
code = base.opcodes['USE_MATMUL']
arg_format = ['int','int','int','int']
@classmethod
def get_usage(cls, args):
return {('matmul', tuple(arg.i for arg in args[:3])): args[3].i}
class run_tape(base.Instruction):
""" Start tape/bytecode file in another thread.
@@ -1229,15 +1233,20 @@ class inverse(base.DataInstruction):
@base.gf2n
@base.vectorize
class inputmask(base.Instruction):
r""" Load secret $s_i$ with the next input mask for player $p$ and
write the mask on player $p$'s private output. """
""" Store fresh random input mask(s) in secret register (vector) and clear
register (vector) of the relevant player.
:param: mask (sint)
:param: mask (cint, player only)
:param: player (int)
"""
__slots__ = []
code = base.opcodes['INPUTMASK']
arg_format = ['sw', 'p']
arg_format = ['sw', 'cw', 'p']
field_type = 'modp'
def add_usage(self, req_node):
req_node.increment((self.field_type, 'input', self.args[1]), \
req_node.increment((self.field_type, 'input', self.args[2]), \
self.get_size())
@base.vectorize
@@ -1293,10 +1302,8 @@ class asm_input(base.TextInputInstruction):
arg_format = tools.cycle(['sw', 'p'])
field_type = 'modp'
def add_usage(self, req_node):
for player in self.args[1::2]:
req_node.increment((self.field_type, 'input', player), \
self.get_size())
def get_players(self):
return self.args[1::2]
@base.vectorize
class inputfix(base.TextInputInstruction):
@@ -1305,10 +1312,8 @@ class inputfix(base.TextInputInstruction):
arg_format = tools.cycle(['sw', 'int', 'p'])
field_type = 'modp'
def add_usage(self, req_node):
for player in self.args[2::3]:
req_node.increment((self.field_type, 'input', player), \
self.get_size())
def get_players(self):
return self.args[2::3]
@base.vectorize
class inputfloat(base.TextInputInstruction):
@@ -1322,7 +1327,7 @@ class inputfloat(base.TextInputInstruction):
req_node.increment((self.field_type, 'input', player), \
4 * self.get_size())
class inputmixed_base(base.TextInputInstruction):
class inputmixed_base(base.TextInputInstruction, base.DynFormatInstruction):
__slots__ = []
field_type = 'modp'
# the following has to match TYPE: (N_DEST, N_PARAM)
@@ -1341,22 +1346,30 @@ class inputmixed_base(base.TextInputInstruction):
type_id = self.type_ids[name]
super(inputmixed_base, self).__init__(type_id, *args)
@property
def arg_format(self):
for i in self.bases():
t = self.args[i]
yield 'int'
@classmethod
def dynamic_arg_format(self, args):
yield 'int'
for i, t in self.bases(iter(args)):
for j in range(self.types[t][0]):
yield 'sw'
for j in range(self.types[t][1]):
yield 'int'
yield self.player_arg_type
yield 'int'
def bases(self):
@classmethod
def bases(self, args):
i = 0
while i < len(self.args):
yield i
i += sum(self.types[self.args[i]]) + 2
while True:
try:
t = next(args)
except StopIteration:
return
yield i, t
n = sum(self.types[t])
i += n + 2
for j in range(n + 1):
next(args)
@base.vectorize
class inputmixed(inputmixed_base):
@@ -1380,13 +1393,16 @@ class inputmixed(inputmixed_base):
player_arg_type = 'p'
def add_usage(self, req_node):
for i in self.bases():
t = self.args[i]
for i, t in self.bases(iter(self.args)):
player = self.args[i + sum(self.types[t]) + 1]
n_dest = self.types[t][0]
req_node.increment((self.field_type, 'input', player), \
n_dest * self.get_size())
def get_players(self):
for i, t in self.bases(iter(self.args)):
yield self.args[i + sum(self.types[t]) + 1]
@base.vectorize
class inputmixedreg(inputmixed_base):
""" Store private input in secret registers (vectors). The input is
@@ -1412,6 +1428,9 @@ class inputmixedreg(inputmixed_base):
# player 0 as proxy
req_node.increment((self.field_type, 'input', 0), float('inf'))
def get_players(self):
pass
@base.gf2n
@base.vectorize
class rawinput(base.RawInputInstruction, base.Mergeable):
@@ -1433,7 +1452,23 @@ class rawinput(base.RawInputInstruction, base.Mergeable):
req_node.increment((self.field_type, 'input', player), \
self.get_size())
class inputpersonal(base.Instruction, base.Mergeable):
class personal_base(base.Instruction, base.Mergeable):
__slots__ = []
field_type = 'modp'
def __init__(self, *args):
super(personal_base, self).__init__(*args)
for i in range(0, len(args), 4):
assert args[i + 2].size == args[i]
assert args[i + 3].size == args[i]
def add_usage(self, req_node):
for i in range(0, len(self.args), 4):
player = self.args[i + 1]
req_node.increment((self.field_type, 'input', player), \
self.args[i])
class inputpersonal(personal_base):
""" Private input from cint.
:param: vector size (int)
@@ -1445,19 +1480,39 @@ class inputpersonal(base.Instruction, base.Mergeable):
__slots__ = []
code = base.opcodes['INPUTPERSONAL']
arg_format = tools.cycle(['int','p','sw','c'])
field_type = 'modp'
class privateoutput(personal_base):
""" Private input from cint.
:param: vector size (int)
:param: player (int)
:param: destination (cint)
:param: source (sint)
:param: (repeat from vector size)...
"""
__slots__ = []
code = base.opcodes['PRIVATEOUTPUT']
arg_format = tools.cycle(['int','p','cw','s'])
class sendpersonal(base.Instruction, base.Mergeable):
""" Private input from cint.
:param: vector size (int)
:param: destination player (int)
:param: destination (cint)
:param: source player (int)
:param: source (cint)
:param: (repeat from vector size)...
"""
__slots__ = []
code = base.opcodes['SENDPERSONAL']
arg_format = tools.cycle(['int','p','cw','p','c'])
def __init__(self, *args):
super(inputpersonal, self).__init__(*args)
for i in range(0, len(args), 4):
super(sendpersonal, self).__init__(*args)
for i in range(0, len(args), 5):
assert args[i + 2].size == args[i]
assert args[i + 3].size == args[i]
def add_usage(self, req_node):
for i in range(0, len(self.args), 4):
player = self.args[i + 1]
req_node.increment((self.field_type, 'input', player), \
self.args[i])
assert args[i + 4].size == args[i]
@base.gf2n
@base.vectorize
@@ -1789,27 +1844,6 @@ class floatoutput(base.PublicFileIOInstruction):
code = base.opcodes['FLOATOUTPUT']
arg_format = ['p','c','c','c','c']
@base.gf2n
@base.vectorize
class startprivateoutput(base.Instruction):
r""" Initiate private output to $n$ of $s_j$ via $s_i$. """
__slots__ = []
code = base.opcodes['STARTPRIVATEOUTPUT']
arg_format = ['sw','s','p']
field_type = 'modp'
def add_usage(self, req_node):
req_node.increment((self.field_type, 'input', self.args[2]), \
self.get_size())
@base.gf2n
@base.vectorize
class stopprivateoutput(base.Instruction):
r""" Previously iniated private output to $n$ via $c_i$. """
__slots__ = []
code = base.opcodes['STOPPRIVATEOUTPUT']
arg_format = ['cw','c','p']
@base.vectorize
class rand(base.Instruction):
""" Store insecure random value of specified length in clear integer
@@ -2210,7 +2244,8 @@ class mulrs(base.VarArgsInstruction, base.DataInstruction):
@base.gf2n
@base.vectorize
class dotprods(base.VarArgsInstruction, base.DataInstruction):
class dotprods(base.VarArgsInstruction, base.DataInstruction,
base.DynFormatInstruction):
""" Dot product of secret registers (vectors).
Note that the vectorized version works element-wise.
@@ -2238,31 +2273,29 @@ class dotprods(base.VarArgsInstruction, base.DataInstruction):
flat_args += [x, y]
base.Instruction.__init__(self, *flat_args)
@property
def arg_format(self):
@classmethod
def dynamic_arg_format(self, args):
field = 'g' if self.is_gf2n() else ''
for i in self.bases():
yield 'int'
yield 'int'
for i, n in self.bases(args):
yield 's' + field + 'w'
for j in range(self.args[i] - 2):
for j in range(n - 2):
yield 's' + field
yield 'int'
gf2n_arg_format = arg_format
def bases(self):
i = 0
while i < len(self.args):
yield i
i += self.args[i]
@property
def gf2n_arg_format(self):
return self.arg_format()
def get_repeat(self):
return sum(self.args[i] // 2 for i in self.bases()) * self.get_size()
return sum(self.args[i] // 2
for i, n in self.bases(iter(self.args))) * self.get_size()
def get_def(self):
return [self.args[i + 1] for i in self.bases()]
return [self.args[i + 1] for i, n in self.bases(iter(self.args))]
def get_used(self):
for i in self.bases():
for i, n in self.bases(iter(self.args)):
for reg in self.args[i + 2:i + self.args[i]]:
yield reg

View File

@@ -105,6 +105,7 @@ opcodes = dict(
MATMULSM = 0xAB,
CONV2DS = 0xAC,
CHECK = 0xAF,
PRIVATEOUTPUT = 0xAD,
# Data access
TRIPLE = 0x50,
BIT = 0x51,
@@ -128,6 +129,7 @@ opcodes = dict(
INPUTMIXEDREG = 0xF3,
RAWINPUT = 0xF4,
INPUTPERSONAL = 0xF5,
SENDPERSONAL = 0xF6,
STARTINPUT = 0x61,
STOPINPUT = 0x62,
READSOCKETC = 0x63,
@@ -364,6 +366,7 @@ def gf2n(instruction):
arg_format = copy.deepcopy(instruction_cls.arg_format)
reformat(arg_format)
@classmethod
def is_gf2n(self):
return True
@@ -505,8 +508,12 @@ def cisc(function):
for arg in self.args:
try:
new_regs.append(type(arg)(size=size))
except:
except TypeError:
break
except:
print([call[0][0].size for call in self.calls])
raise
assert len(new_regs) > 1
base = 0
for call in self.calls:
for new_reg, reg in zip(new_regs[1:], call[0][1:]):
@@ -854,6 +861,7 @@ class Instruction(object):
def is_vec(self):
return False
@classmethod
def is_gf2n(self):
return False
@@ -902,6 +910,10 @@ class Instruction(object):
new_args.append(arg)
return new_args
@staticmethod
def get_usage(args):
return {}
# String version of instruction attempting to replicate encoded version
def __str__(self):
@@ -949,9 +961,18 @@ class ParsedInstruction:
if name == 'cisc':
arg_format = itertools.chain(['str'], itertools.repeat('int'))
else:
arg_format = itertools.repeat('int')
self.args = [ArgFormats[next(arg_format)](f)
for i in range(n_args)]
def arg_iter():
i = 0
while True:
try:
yield self.args[i].i
except AttributeError:
yield None
i += 1
arg_format = t.dynamic_arg_format(arg_iter())
self.args = []
for i in range(n_args):
self.args.append(ArgFormats[next(arg_format)](f))
def __str__(self):
name = self.type.__name__
@@ -963,6 +984,9 @@ class ParsedInstruction:
res += ', '.join(str(arg) for arg in self.args)
return res
def get_usage(self):
return self.type.get_usage(self.args)
class VarArgsInstruction(Instruction):
def has_var_args(self):
return True
@@ -974,6 +998,26 @@ class VectorInstruction(Instruction):
def get_code(self):
return super(VectorInstruction, self).get_code(len(self.args[0]))
class DynFormatInstruction(Instruction):
__slots__ = []
@property
def arg_format(self):
return self.dynamic_arg_format(iter(self.args))
@classmethod
def bases(self, args):
i = 0
while True:
try:
n = next(args)
except StopIteration:
return
yield i, n
i += n
for j in range(n - 1):
next(args)
###
### Basic arithmetic
###
@@ -1072,6 +1116,11 @@ class TextInputInstruction(VarArgsInstruction, DoNotEliminateInstruction):
""" Input from text file or stdin """
__slots__ = []
def add_usage(self, req_node):
for player in self.get_players():
req_node.increment((self.field_type, 'input', player), \
self.get_size())
###
### Data access instructions
###

View File

@@ -223,7 +223,7 @@ def crash(condition=None):
if isinstance(condition, localint):
# allow crash on local values
condition = condition._v
if condition == None:
if condition is None:
condition = regint(1)
instructions.crash(regint.conv(condition))
@@ -284,8 +284,8 @@ def get_arg():
def make_array(l):
if isinstance(l, program.Tape.Register):
res = Array(1, type(l))
res[0] = l
res = Array(len(l), type(l))
res[:] = l
else:
l = list(l)
res = Array(len(l), type(l[0]) if l else cint)
@@ -1032,6 +1032,7 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
state = tuplify(initializer())
k = 0
block = get_block()
assert not isinstance(n_loops, int) or n_loops > 0
pre = copy.copy(loop_body.__globals__)
while (not util.is_constant(n_loops) or k < n_loops) \
and (len(get_block()) < budget or k == 0) \
@@ -1211,7 +1212,13 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
if t != regint:
raise CompilerError('Not implemented for other than regint')
args = Matrix(n_threads, 2 + thread_mem_req.get(regint, 0), 'ci')
state = tuple(initializer())
state = initializer()
if len(state) == 0:
state_type = cint
elif isinstance(state, (tuple, list)):
state_type = type(state[0])
else:
state_type = type(state)
def f(inc):
base = args[get_arg()][0]
if not util.is_constant(thread_rounds):
@@ -1224,8 +1231,7 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
if thread_mem_req:
thread_mem = Array(thread_mem_req[regint], regint, \
args[get_arg()].address + 2)
mem_state = Array(len(state), type(state[0]) \
if state else cint, args[get_arg()][1])
mem_state = Array(len(state), state_type, args[get_arg()][1])
@map_reduce_single(n_parallel, thread_rounds + inc, \
initializer, reducer, mem_state)
def f(i):
@@ -1257,14 +1263,14 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
threads = prog.run_tapes(thread_args)
for thread in threads:
prog.join_tape(thread)
if state:
if len(state):
if thread_rounds:
for i in range(n_threads - remainder):
state = reducer(Array(len(state), type(state[0]), \
state = reducer(Array(len(state), state_type, \
args[remainder + i][1]), state)
if remainder:
for i in range(remainder):
state = reducer(Array(len(state), type(state[0]).reg_type, \
state = reducer(Array(len(state), state_type, \
args[i][1]), state)
def returner():
return untuplify(state)
@@ -1300,6 +1306,39 @@ def map_sum_opt(n_threads, n_loops, types):
"""
return map_sum(n_threads, None, n_loops, len(types), types)
def map_sum_simple(n_threads, n_loops, type, size):
""" Vectorized multi-threaded sum reduction. The following computes a
100 sums of ten squares in three threads::
@map_sum_simple(3, 10, sint, 100)
def summer(i):
return sint(regint.inc(100, i, 0)) ** 2
result = summer()
:param n_threads: number of threads (int)
:param n_loops: number of loop runs (regint/cint/int)
:param type: return type, must match the return statement
in the loop
:param size: vector size, must match the return statement
in the loop
"""
initializer = lambda: type(0, size=size)
def summer(*args):
assert len(args) == 2
args = list(args)
for i in (0, 1):
if isinstance(args[i], tuple):
assert len(args[i]) == 1
args[i] = args[i][0]
for i in (0, 1):
assert len(args[i]) == size
if isinstance(args[i], Array):
args[i] = args[i][:]
return args[0] + args[1]
return map_reduce(n_threads, 1, n_loops, initializer, summer)
def tree_reduce_multithread(n_threads, function, vector):
inputs = vector.Array(len(vector))
inputs.assign_vector(vector)

View File

@@ -223,6 +223,7 @@ class Layer:
thetas = lambda self: ()
debug_output = False
back_batch_size = 128
print_random_update = False
@property
def shape(self):
@@ -254,6 +255,9 @@ class Layer:
def __str__(self):
return type(self).__name__ + str(self._Y.sizes)
def __repr__(self):
return '%s(%s)' % (type(self).__name__, self.Y.sizes)
class NoVariableLayer(Layer):
input_from = lambda *args, **kwargs: None
output_weights = lambda *args: None
@@ -459,6 +463,10 @@ class MultiOutput(MultiOutputBase):
self.debug = debug
self.true_X = sfix.Array(N)
def __repr__(self):
return '%s(%s, %s, approx=%s)' % \
(type(self).__name__, self.N, self.d_out, self.approx)
def _forward(self, batch):
N = len(batch)
d_out = self.X.sizes[1]
@@ -609,10 +617,11 @@ class DenseBase(Layer):
N = len(batch)
tmp = Matrix(self.d_in, self.d_out, unreduced_sfix)
A = sfix.Matrix(N, self.d_out, address=f_schur_Y.address)
B = sfix.Matrix(self.N, self.d_in, address=self.X.address)
@multithread(self.n_threads, self.d_in)
def _(base, size):
A = sfix.Matrix(self.N, self.d_out, address=f_schur_Y.address)
B = sfix.Matrix(self.N, self.d_in, address=self.X.address)
mp = B.direct_trans_mul(A, reduce=False,
indices=(regint.inc(size, base),
batch.get_vector(),
@@ -622,16 +631,24 @@ class DenseBase(Layer):
progress('nabla W (matmul)')
if self.d_in * self.d_out < 200000:
print('reduce at once')
@multithread(self.n_threads, self.d_in * self.d_out)
def _(base, size):
self.nabla_W.assign_vector(
tmp.get_vector(base, size).reduce_after_mul(), base=base)
else:
@for_range_opt(self.d_in)
def _(i):
self.nabla_W[i] = tmp[i].get_vector().reduce_after_mul()
@multithread(self.n_threads, self.d_in * self.d_out,
max_size=get_program().budget)
def _(base, size):
self.nabla_W.assign_vector(
tmp.get_vector(base, size).reduce_after_mul(), base=base)
if self.print_random_update:
print_ln('backward %s', self)
i = regint.get_random(64) % self.d_in
j = regint.get_random(64) % self.d_out
print_ln('%s at (%s, %s): before=%s after=%s A=%s B=%s',
str(self.nabla_W), i, j, tmp[i][j].v.reveal(),
self.nabla_W[i][j].reveal(),
A.get_column(j).reveal(),
B.get_column_by_row_indices(
batch.get_vector(), i).reveal())
print_ln('batch=%s B=%s', batch,
[self.X[bi][0][i].reveal() for bi in batch])
progress('nabla W')
@@ -699,6 +716,7 @@ class Dense(DenseBase):
self.d_in = d_in
self.d_out = d_out
self.d = d
self.activation = activation
self.X = MultiArray([N, d, d_in], sfix)
self.Y = MultiArray([N, d, d_out], sfix)
@@ -721,12 +739,17 @@ class Dense(DenseBase):
else:
self.f_input = self.Y
def __repr__(self):
return '%s(%s, %s, %s, activation=%s)' % \
(type(self).__name__, self.N, self.d_in,
self.d_out, repr(self.activation))
def reset(self):
d_in = self.d_in
d_out = self.d_out
r = math.sqrt(6.0 / (d_in + d_out))
print('Initializing dense weights in [%f,%f]' % (-r, r))
self.W.assign_vector(sfix.get_random(-r, r, size=self.W.total_size()))
self.W.randomize(-r, r)
self.b.assign_all(0)
def input_from(self, player, raw=False):
@@ -820,6 +843,12 @@ class Dense(DenseBase):
regint.inc(self.d_in))),
base)
if self.print_random_update:
print_ln('backward %s', self)
index = regint.get_random(64) % self.nabla_X.total_size()
print_ln('%s nabla_X at %s: %s', str(self.nabla_X),
index, self.nabla_X.to_array()[index].reveal())
progress('nabla X')
self.backward_params(f_schur_Y, batch=batch)
@@ -890,6 +919,10 @@ class Dropout(NoVariableLayer):
self.alpha = alpha
self.B = MultiArray([N, d1, d2], sint)
def __repr__(self):
return '%s(%s, %s, alpha=%s)' % \
(type(self).__name__, self.N, self.d1, self.alpha)
def forward(self, batch, training=False):
if training:
n_bits = -math.log(self.alpha, 2)
@@ -1022,6 +1055,7 @@ class MaxPool(NoVariableLayer):
def __init__(self, shape, strides=(1, 2, 2, 1), ksize=(1, 2, 2, 1),
padding='VALID'):
assert len(shape) == 4
assert min(shape) > 0, shape
for x in strides, ksize:
for i in 0, 3:
assert x[i] == 1
@@ -1033,12 +1067,18 @@ class MaxPool(NoVariableLayer):
self.Y = Tensor(output_shape, sfix)
self.strides = strides
self.ksize = ksize
self.padding = padding
self.nabla_X = Tensor(shape, sfix)
self.nabla_Y = Tensor(output_shape, sfix)
self.N = shape[0]
self.comparisons = MultiArray([self.N, self.X.sizes[3],
ksize[1] * ksize[2]], sint)
def __repr__(self):
return '%s(%s, strides=%s, ksize=%s, padding=%s)' % \
(type(self).__name__, self.X.sizes, self.strides,
self.ksize, self.padding)
def _forward(self, batch):
def process(pool, bi, k, i, j):
def m(a, b):
@@ -1165,7 +1205,7 @@ class Add(NoVariableLayer):
self.Y[batch[0]].assign_vector(tmp, base)
class FusedBatchNorm(Layer):
""" Fixed-point fused batch normalization layer.
""" Fixed-point fused batch normalization layer (inference only).
:param shape: input/output shape (tuple/list of four int)
"""
@@ -1192,6 +1232,153 @@ class FusedBatchNorm(Layer):
self.X[batch[0]][i][j].get_vector() * self.weights.get_vector()
+ self.bias.get_vector())
class BatchNorm(Layer):
""" Fixed-point batch normalization layer.
:param shape: input/output shape (tuple/list of four int)
:param approx: use approximate square root
"""
thetas = lambda self: (self.weights, self.bias)
nablas = lambda self: (self.nabla_weights, self.nabla_bias)
def __init__(self, shape, approx=True, args=None):
assert len(shape) in (2, 3, 4)
if len(shape) == 4:
shape = [shape[0], shape[1] * shape[2], shape[3]]
elif len(shape) == 2:
shape = [shape[0], 1, shape[1]]
tensors = (Tensor(shape, sfix) for i in range(4))
self.X, self.Y, self.nabla_X, self.nabla_Y = tensors
arrays = (sfix.Array(shape[2]) for i in range(4))
self.var, self.mu, self.weights, self.bias = arrays
arrays = (sfix.Array(shape[2]) for i in range(4))
self.mu_hat, self.var_hat, self.nabla_weights, self.nabla_bias = arrays
self.epsilon = 2 ** (-sfix.f + 1)
self.momentum = 0.1
if args != None:
approx = 'precisebn' not in args
self.approx = approx
if approx:
print('Approximate square root inverse in batch normalization')
self.InvertSqrt = mpc_math.InvertSqrt
else:
print('Precise square root inverse in batch normalization')
self.InvertSqrt = lambda x: 1 / mpc_math.sqrt(x)
def __repr__(self):
return '%s(%s, approx=%s)' % \
(type(self).__name__, self.X.sizes, self.approx)
def reset(self):
self.bias.assign_all(0)
self.weights.assign_all(1)
self.mu_hat.assign_all(0)
self.var_hat.assign_all(0)
def _output(self, batch, mu, var):
factor = sfix.Array(len(mu))
factor[:] = self.InvertSqrt(var[:] + self.epsilon)
@for_range_opt_multithread(self.n_threads,
[len(batch), self.X.sizes[1]])
def _(i, j):
tmp = self.weights[:] * (self.X[i][j][:] - self.mu[:]) * factor[:]
self.Y[i][j][:] = self.bias[:] + tmp
def forward(self, batch, training=False):
if training:
d = self.X.sizes[1]
d_in = self.X.sizes[2]
s = sfix.Array(d_in)
@map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in)
def _(i, j):
return (self.X[batch[i]][j].get_vector())
s.assign(_())
@multithread(self.n_threads, d_in)
def _(base, size):
self.mu.assign_vector(
s.get_vector(base, size) / (len(batch) * d), base)
@map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in)
def _(i, j):
item = self.X[batch[i]][j].get_vector()
return ((item - self.mu[:]) ** 2)
self.var.assign(_())
@multithread(self.n_threads, d_in)
def _(base, size):
self.var.assign_vector(
self.var.get_vector(base, size) / (len(batch) * d - 1),
base)
for x, y, in (self.mu_hat, self.mu), (self.var_hat, self.var):
x[:] = self.momentum * y[:] + (1 - self.momentum) * x[:]
self._output(batch, self.mu, self.var)
if self.print_random_update:
i = regint.get_random(64) % len(batch)
j = regint.get_random(64) % d
k = regint.get_random(64) % d_in
for x in self.mu, self.var:
print_ln('%s at %s: %s', str(x), k, x[k].reveal())
print_ln('%s at (%s, %s, %s): in=%s out=%s',
str(self.Y), i, j, k, self.X[i][j][k].reveal(),
self.Y[i][j][k].reveal())
else:
self._output(batch, self.mu_hat, self.var_hat)
def backward(self, batch, compute_nabla_X=True):
factor = Array.create_from(
self.InvertSqrt(self.var[:] + self.epsilon))
mynYf = self.X.same_shape()
gamnY = self.X.same_shape()
gamnYd = self.X.same_shape()
nYdf = self.X.same_shape()
d = self.X.sizes[1]
d_in = self.X.sizes[2]
@for_range_opt_multithread(self.n_threads, [len(batch), d])
def _(i, j):
tmp = self.weights[:] * self.nabla_Y[i][j][:]
gamnY[i][j] = tmp
gamnYd[i][j] = tmp * (self.X[i][j][:] - self.mu[:])
mynYf[i][j] = tmp * factor[:]
nYdf[i][j] = self.nabla_Y[i][j][:] * \
(self.X[i][j][:] - self.mu[:]) * factor[:]
@map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in)
def _(i, j):
return (self.nabla_Y[i][j][:])
self.nabla_bias.assign(_())
@map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in)
def _(i, j):
return (nYdf[i][j])
self.nabla_weights.assign(_())
factor3 = Array.create_from(factor[:] ** 3)
@map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in)
def _(i, j):
return (mynYf[i][j])
s1 = Array.create_from(_())
@multithread(self.n_threads, len(s1))
def _(base, size):
s1.assign_vector(s1.get_vector(base, size) / (len(batch) * d), base)
@map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in)
def _(i, j):
return (gamnYd[i][j][:] * factor3[:])
s2 = Array.create_from(_())
@multithread(self.n_threads, len(s2))
def _(base, size):
s2.assign_vector(
s2.get_vector(base, size) / (len(batch) * d - 1), base)
@for_range_opt_multithread(self.n_threads, [len(batch), d])
def _(i, j):
self.nabla_X[i][j][:] = mynYf[i][j][:] \
- s1[:] - (self.X[i][j][:] - self.mu[:]) * s2[:]
if self.print_random_update:
print_ln('backward %s', self)
i = regint.get_random(64) % len(batch)
j = regint.get_random(64) % d
k = regint.get_random(64) % d_in
for x in self.nabla_bias, self.nabla_weights:
print_ln('%s at %s: %s', str(x), k, x[k].reveal())
print_ln('%s at (%s, %s, %s): in=%s out=%s', str(self.Y), i, j, k,
self.nabla_Y[i][j][k].reveal(),
self.nabla_X[i][j][k].reveal())
class QuantBase(object):
bias_before_reduction = True
@@ -1298,6 +1485,8 @@ class ConvBase(BaseLayer):
self.padding.append(pad_total // 2)
elif padding == 'VALID':
self.padding = [0, 0]
elif isinstance(padding, int):
self.padding = [padding, padding]
else:
self.padding = padding
@@ -1323,6 +1512,12 @@ class ConvBase(BaseLayer):
assert(len(output_shape) == 4)
assert(len(weight_shape) == 4)
def __repr__(self):
return '%s(%s, %s, %s, %s, %s, padding=%s, tf_weight_format=%s)' % \
(type(self).__name__, self.X.sizes, self.weight_shape,
self.bias_shape, self.Y.sizes, self.stride, repr(self.padding),
self.tf_weight_format)
def input_from(self, player, raw=False):
self.input_params_from(player)
self.weights.input_from(player, budget=100000, raw=raw)
@@ -1545,20 +1740,20 @@ class FixConv2d(Conv2d, FixBase):
self.nabla_weights.assign_vector_by_indices(reduced, j, None, None, i)
if compute_nabla_X:
assert tuple(self.padding) == (0, 0)
assert tuple(self.stride) == (1, 1)
reverse_weights = MultiArray(
[n_channels_in, weights_h, weights_w, n_channels_out], sfix)
@for_range(n_channels_out)
def _(i):
@for_range_opt_multithread(self.n_threads, n_channels_in)
def _(l):
@for_range(weights_h)
def _(j):
@for_range(weights_w)
def _(k):
@for_range(n_channels_in)
def _(l):
reverse_weights[l][weights_h-j-1][k][i] = \
self.weights[i][j][weights_w-k-1][l]
addresses = regint.inc(n_channels_out,
self.weights[0][j][weights_w-k-1].get_address(l),
reduce(operator.mul, self.weights.sizes[1:]))
reverse_weights[l][weights_h-j-1][k].assign_vector(
self.weights.value_type.load_mem(addresses))
padded_w = inputs_w + 2 * padding_w
padded_h = inputs_h + 2 * padding_h
if padding_h or padding_w:
@@ -1579,14 +1774,16 @@ class FixConv2d(Conv2d, FixBase):
unreduced_sfix._new(res).reduce_after_mul(),
i, None, None, j)
if padding_h or padding_w:
@for_range(N)
@for_range_opt_multithread(self.n_threads, N)
def _(i):
@for_range(inputs_h)
def _(j):
@for_range(inputs_w)
def _(k):
jj = j + padding_w
kk = k + padding_w
self.nabla_X[i][j][k].assign_vector(
output[i][j][k].get_vector())
output[i][jj][kk].get_vector())
if self.debug_output:
@for_range(len(batch))
@@ -1806,6 +2003,7 @@ class Optimizer:
self.report_loss = report_loss
self.X_by_label = None
self.print_update_average = False
self.print_random_update = False
self.print_losses = False
self.print_loss_reduction = False
self.i_epoch = MemValue(0)
@@ -1846,6 +2044,7 @@ class Optimizer:
def batch_for(self, layer, batch):
if layer in (self.layers[0], self.layers[-1]):
assert not isinstance(layer, BatchNorm)
return batch
else:
batch = regint.Array(len(batch))
@@ -1876,6 +2075,21 @@ class Optimizer:
if i != len(self.layers) - 1 or run_last:
layer.forward(batch=self.batch_for(layer, batch),
training=training)
if self.print_random_update:
print_ln('forward layer %s', layer)
l = min(100, layer.Y[i].total_size())
i = regint.get_random(64) % len(batch)
if l < 100:
j = 0
else:
j = regint.get_random(64) % \
(layer.Y[i].total_size() - l)
print_ln('forward layer %s at (%s, %s): %s', layer, i, j,
layer.Y[i].to_array().get_vector(j, l).reveal())
i = regint.get_random(64) % layer.Y[0].total_size()
print_ln('forward layer %s vertical at %s: %s', layer, i,
[layer.Y[j].to_array()[i].reveal()
for j in range(len(batch))])
if self.time_layers:
stop_timer(100 + i)
break_point()
@@ -1979,7 +2193,11 @@ class Optimizer:
label * n)
self.forward(batch=batch, training=True)
self.backward(batch=batch)
if self.time_layers:
start_timer(1000)
self.update(i, batch=batch)
if self.time_layers:
stop_timer(1000)
loss_sum.iadd(self.layers[-1].l)
if self.print_loss_reduction:
before = self.layers[-1].average_loss(N)
@@ -2070,6 +2288,8 @@ class Optimizer:
if 'nomom' in program.args:
self.momentum = 0
self.print_losses = 'print_losses' in program.args
self.print_random_update = 'print_random_update' in program.args
Layer.print_random_update = self.print_random_update
self.time_layers = 'time_layers' in program.args
self.revealing_correctness = not 'no_acc' in program.args
self.layers[-1].compute_loss = not 'no_loss' in program.args
@@ -2099,6 +2319,16 @@ class Optimizer:
print_ln('loss %s', self.layers[-1].l.reveal())
self.output_weights()
return
if 'bench10' in program.args or 'bench1' in program.args:
n = 1 if 'bench1' in program.args else 10
print('benchmarking %s iterations' % n)
@for_range(n)
def _(i):
batch = Array.create_from(regint.inc(batch_size))
self.forward(batch=batch, training=True)
self.backward(batch=batch)
self.update(0, batch=batch)
return
@for_range(n_runs)
def _(i):
if not acc_first:
@@ -2115,6 +2345,7 @@ class Optimizer:
cfix(self.n_correct, k=63, f=31) / n_trained,
self.n_correct, n_trained)
if test_X and test_Y:
print('use test set')
n_test = len(test_Y)
n_correct, loss = self.reveal_correctness(test_X, test_Y,
acc_batch_size)
@@ -2211,7 +2442,8 @@ class Adam(Optimizer):
util.max, abs_g.get_vector())
scale = MemValue(sfix._new(library.AppRcr(
max_g.v, max_g.k, max_g.f, simplex_flag=True)))
@multithread(self.n_threads, m.total_size())
@multithread(self.n_threads, m.total_size(),
max_size=get_program().budget)
def _(base, size):
m_part = m.get_vector(base, size)
v_part = v.get_vector(base, size)
@@ -2333,20 +2565,33 @@ class SGD(Optimizer):
print_ln_if((x > limit) + (x < -limit),
'theta epoch=%s %s index=%s %s',
i_epoch.read(), str(theta), i, x)
index = regint.get_random(64) % len(a)
print_ln('%s at %s: nabla=%s update=%s theta=%s', str(theta), index,
aa[1][index], aa[0][index], aa[2][index])
if self.print_random_update:
print_ln('update')
l = min(100, nabla.total_size())
if l < 100:
index = 0
else:
index = regint.get_random(64) % (nabla.total_size() - l)
print_ln('%s at %s: nabla=%s update=%s theta=%s', str(theta),
index, nabla.to_array().get_vector(index, l).reveal(),
delta_theta.to_array().get_vector(index, l).reveal(),
theta.to_array().get_vector(index, l).reveal())
self.gamma.imul(1 - 10 ** - 6)
def apply_padding(input_shape, kernel_size, strides, padding):
if isinstance(padding, int):
input_shape = [x + 2 * padding for x in input_shape]
padding = 'valid'
if padding == 'valid':
return (input_shape[0] - kernel_size[0] + 1) // strides[0], \
res = (input_shape[0] - kernel_size[0] + 1) // strides[0], \
(input_shape[1] - kernel_size[1] + 1) // strides[1],
assert min(res) > 0, (input_shape, kernel_size, strides, padding)
return res
elif padding == 'same':
return (input_shape[1]) // strides[0], \
(input_shape[2]) // strides[1],
return (input_shape[0]) // strides[0], \
(input_shape[1]) // strides[1],
else:
raise Exception('invalid padding: ' + padding)
raise Exception('invalid padding: %s' % padding)
class keras:
class layers:
@@ -2354,7 +2599,7 @@ class keras:
Dense = lambda *args, **kwargs: ('dense', args, kwargs)
def Conv2D(filters, kernel_size, strides=(1, 1), padding='valid',
activation=None):
activation=None, input_shape=None):
return 'conv2d', {'filters': filters, 'kernel_size': kernel_size,
'strides': strides, 'padding': padding,
'activation': activation}
@@ -2369,6 +2614,13 @@ class keras:
raise Exception('rate needs to be a power of two')
return 'dropout', rate
def Activation(activation):
assert(activation == 'relu')
return activation,
def BatchNormalization():
return 'batchnorm',
class optimizers:
SGD = lambda *args, **kwargs: ('sgd', args, kwargs)
Adam = lambda *args, **kwargs: ('adam', args, kwargs)
@@ -2383,12 +2635,25 @@ class keras:
def compile(self, optimizer):
self.optimizer = optimizer
def compile_by_args(self, program):
if 'adam' in program.args:
self.optimizer = 'adam', [], {}
elif 'amsgrad' in program.args:
self.optimizer = 'adam', [], {'amsgrad': True}
else:
self.optimizer = 'sgd', [], {}
@property
def trainable_variables(self):
if self.opt == None:
raise Exception('need to run build() or fit() first')
return list(self.opt.thetas)
def summary(self):
sizes = [var.total_size() for var in self.trainable_variables]
print(sizes)
print('Trainable params:', sum(sizes))
def build(self, input_shape, batch_size=128):
data_input_shape = input_shape
if self.opt != None and \
@@ -2415,12 +2680,11 @@ class keras:
if i == len(self.layers) - 1:
if layer[2].get('activation', 'softmax') in \
('softmax', 'sigmoid'):
del layer[2]['activation']
layer[2].pop('activation', None)
layers.append(Dense(N, n_units, layer[1][0],
**layer[2]))
input_shape = layers[-1].Y.sizes
elif name == 'conv2d':
if len(layers) != 0:
input_shape = layers[-1].Y.sizes
input_shape = list(input_shape) + \
[1] * (4 - len(input_shape))
print (layer[1])
@@ -2437,9 +2701,13 @@ class keras:
output_shape = [batch_size] + list(
apply_padding(input_shape[1:3], kernel_size,
strides, padding)) + [filters]
padding = padding.upper() if isinstance(padding, str) \
else padding
layers.append(FixConv2d(input_shape, weight_shape,
(filters,), output_shape,
strides, padding.upper()))
strides, padding))
input_shape = output_shape
print('conv output shape', output_shape)
elif name == 'maxpool':
pool_size = layer[1]['pool_size']
strides = layer[1]['strides']
@@ -2450,16 +2718,23 @@ class keras:
strides = (strides, strides)
if strides == None:
strides = pool_size
layers.append(MaxPool(layers[-1].Y.sizes,
layers.append(MaxPool(input_shape,
[1] + list(strides) + [1],
[1] + list(pool_size) + [1],
padding.upper()))
padding))
input_shape = layers[-1].Y.sizes
elif name == 'dropout':
layers.append(Dropout(batch_size, reduce(
operator.mul, layers[-1].Y.sizes[1:]),
alpha=layer[1]))
input_shape = layers[-1].Y.sizes
elif name == 'flatten':
pass
elif name == 'relu':
layers.append(Relu(layers[-1].Y.sizes))
elif name == 'batchnorm':
input_shape = layers[-1].Y.sizes
layers.append(BatchNorm(layers[-1].Y.sizes))
else:
raise Exception(layer[0] + ' not supported')
if layers[-1].d_out == 1:

View File

@@ -1493,6 +1493,7 @@ class PackedIndexStructure(object):
self.l[i] = [0] * self.elements_per_block
time()
print_ln('packed ORAM init %s/%s', i, real_init_rounds)
print_ln('packed ORAM init done')
print('index initialized, size', size)
def translate_index(self, index):
""" Bit slicing *index* according parameters. Output is tuple

View File

@@ -580,10 +580,19 @@ class Program(object):
@staticmethod
def read_tapes(schedule):
m = re.search(r'([^/]*)\.mpc', schedule)
if m:
schedule = m.group(1)
if not os.path.exists(schedule):
schedule = 'Programs/Schedules/%s.sch' % schedule
lines = open(schedule).readlines()
try:
lines = open(schedule).readlines()
except FileNotFoundError:
print('%s not found, have you compiled the program?' % schedule,
file=sys.stderr)
sys.exit(1)
for tapename in lines[2].split(' '):
yield tapename.strip()

View File

@@ -1675,6 +1675,13 @@ class localint(Tape._no_truth):
__ne__ = lambda self, other: localint(self._v != other)
class personal(Tape._no_truth):
""" Value known to one player. Supports operations with public
values and personal values known to the same player. Can be used
with :py:func:`~Compiler.library.print_ln_to`.
:param player: player (int)
:param value: cleartext value (cint, cfix, cfloat) or array thereof
"""
def __init__(self, player, value):
assert value is not NotImplemented
assert not isinstance(value, _secret)
@@ -1685,8 +1692,24 @@ class personal(Tape._no_truth):
self._v = value
def binary_output(self):
""" Write binary output to
``Player-Data/Binary-Output-P<playerno>-<threadno>`` if
supported by underlying type. Player must be known at compile time."""
self._v.binary_output(self.player)
def reveal_to(self, player):
""" Pass personal value to another player. """
if isinstance(self._v, Array):
source = self._v[:]
else:
source = self._v
source = cint.conv(source)
res = cint(size=source.size)
sendpersonal(source.size, player, res, self.player, source)
if isinstance(self._v, Array):
res = Array.create_from(res)
return personal(player, res)
def bit_decompose(self, length):
return [personal(self.player, x) for x in self._v.bit_decompose(length)]
@@ -1858,8 +1881,13 @@ class _secret(_register, _secret_structure):
@vectorized_classmethod
@set_instruction_type
def get_random_input_mask_for(cls, player):
res = cls()
inputmask(res, player)
""" Secret random input mask according to security model.
:return: mask (sint), mask (personal cint)
:param size: vector size (int, default 1)
"""
res = cls(), personal(player, cls.clear_type())
inputmask(res[0], res[1]._v, player)
return res
@classmethod
@@ -2071,15 +2099,13 @@ class _secret(_register, _secret_structure):
@set_instruction_type
def reveal_to(self, player):
""" Reveal secret value to :py:obj:`player`.
Result written to ``Player-Data/Private-Output-P<player>``
:param player: int
:returns: value to be used with :py:func:`~Compiler.library.print_ln_to`
:returns: :py:class:`personal`
"""
masked = self.__class__()
res = personal(player, self.clear_type())
startprivateoutput(masked, self, player)
stopprivateoutput(res._v, masked.reveal(), player)
mask = self.get_random_input_mask_for(player)
masked = self + mask[0]
res = personal(player, masked.reveal() - mask[1])
return res
@@ -2633,21 +2659,20 @@ class sint(_secret, _int):
@vectorize
def reveal_to(self, player):
""" Reveal secret value to :py:obj:`player`.
Result potentially written to
``Player-Data/Private-Output-P<player>``, but not if
:py:obj:`player` is a :py:class:`regint`.
:param player: public integer (int/regint/cint):
:returns: value to be used with :py:func:`~Compiler.library.print_ln_to`
:param player: public integer (int/regint/cint)
:returns: :py:class:`personal`
"""
if not util.is_constant(player) or self.size > 1:
if not util.is_constant(player):
secret_mask = sint()
player_mask = cint()
inputmaskreg(secret_mask, player_mask, regint.conv(player))
return personal(player,
(self + secret_mask).reveal() - player_mask)
else:
return super(sint, self).reveal_to(player)
res = personal(player, self.clear_type())
privateoutput(self.size, player, res._v, self)
return res
def private_division(self, divisor, active=True, dividend_length=None,
divisor_length=None):
@@ -4366,12 +4391,9 @@ class sfix(_fix):
def reveal_to(self, player):
""" Reveal secret value to :py:obj:`player`.
Raw representation possibly written to
``Player-Data/Private-Output-P<player>``, but not if
:py:obj:`player` is a :py:class:`regint`.
:param player: public integer (int/regint/cint)
:returns: value to be used with :py:func:`~Compiler.library.print_ln_to`
:returns: :py:class:`personal`
"""
return personal(player, cfix._new(self.v.reveal_to(player)._v,
self.k, self.f))
@@ -5221,6 +5243,9 @@ class Array(_vectorizable):
return self.assign(value, addresses)
self._store(value, self.get_address(index))
def to_array(self):
return self
def get_sub(self, start, stop=None):
if stop is None:
stop = start
@@ -5471,6 +5496,10 @@ class Array(_vectorizable):
""" Insecure shuffle in place. """
self.assign_vector(self.get(regint.inc(len(self)).shuffle()))
def randomize(self, *args):
""" Randomize according to data type. """
self.assign_vector(self.value_type.get_random(*args, size=len(self)))
def reveal(self):
""" Reveal the whole array.
@@ -5596,6 +5625,9 @@ class SubMultiArray(_vectorizable):
def __iter__(self):
return (self[i] for i in range(len(self)))
def to_array(self):
return Array(self.total_size(), self.value_type, address=self.address)
def assign_all(self, value):
""" Assign the same value to all entries.
@@ -5958,6 +5990,7 @@ class SubMultiArray(_vectorizable):
"""
assert len(self.sizes) == 2
assert len(other.sizes) == 2
assert other.address != None
if indices is None:
assert self.sizes[1] == other.sizes[1]
indices = [regint.inc(i) for i in self.sizes + other.sizes[::-1]]
@@ -6145,6 +6178,16 @@ class SubMultiArray(_vectorizable):
n = self.sizes[0]
return self.array.get(regint.inc(n, 0, n + 1))
def randomize(self, *args):
""" Randomize according to data type. """
if self.total_size() < program.options.budget:
self.assign_vector(
self.value_type.get_random(*args, size=self.total_size()))
else:
@library.for_range(self.sizes[0])
def _(i):
self[i].randomize(*args)
def reveal_list(self):
""" Reveal as list. """
return list(self.get_vector().reveal())
@@ -6251,6 +6294,22 @@ class Matrix(MultiArray):
MultiArray.__init__(self, [rows, columns], value_type, debug=debug, \
address=address)
def get_column(self, index):
""" Get column as vector.
:param index: regint/cint/int
"""
assert self.value_type.n_elements() == 1
addresses = regint.inc(self.sizes[0], self.address + index,
self.sizes[1])
return self.value_type.load_mem(addresses)
def get_column_by_row_indices(self, rows, column):
assert self.value_type.n_elements() == 1
addresses = rows * self.sizes[1] + \
regint.inc(len(rows), self.address + column, 0)
return self.value_type.load_mem(addresses)
def set_column(self, index, vector):
""" Change column.

View File

@@ -47,11 +47,18 @@ Rq_Element FHE_PK::sample_secret_key(PRNG& G)
}
void FHE_PK::KeyGen(Rq_Element& sk, PRNG& G, int noise_boost)
{
Rq_Element a(*this);
a.randomize(G);
partial_key_gen(sk, a, G, noise_boost);
}
void FHE_PK::partial_key_gen(const Rq_Element& sk, const Rq_Element& a, PRNG& G,
int noise_boost)
{
FHE_PK& PK = *this;
// Generate the main public key
PK.a0.randomize(G);
a0 = a;
// b0=a0*s+p*e0
Rq_Element e0((*PK.params).FFTD(),evaluation,evaluation);
@@ -77,9 +84,6 @@ void FHE_PK::KeyGen(Rq_Element& sk, PRNG& G, int noise_boost)
mul(es,es,PK.pr);
add(PK.Sw_b,PK.Sw_b,es);
// Lowering level as we only decrypt at level 0
sk.lower_level();
// bs=bs-p1*s^2
Rq_Element s2;
mul(s2,sk,sk); // Mult at level 0
@@ -334,7 +338,7 @@ void FHE_SK::check(const FHE_Params& params, const FHE_PK& pk,
template<class FD>
void FHE_SK::check(const FHE_PK& pk, const FD& FieldD)
{
check(*params, pk, pr);
check(*params, pk, FieldD.get_prime());
pk.check_noise(*this);
if (decrypt(pk.encrypt(Plaintext_<FD>(FieldD)), FieldD) !=
Plaintext_<FD>(FieldD))

View File

@@ -150,6 +150,8 @@ class FHE_PK
Rq_Element sample_secret_key(PRNG& G);
void KeyGen(Rq_Element& sk, PRNG& G, int noise_boost = 1);
void partial_key_gen(const Rq_Element& sk, const Rq_Element& a, PRNG& G,
int noise_boost = 1);
void check_noise(const FHE_SK& sk) const;
void check_noise(const Rq_Element& x, bool check_modulo = false) const;

View File

@@ -3,6 +3,11 @@
#include "FHE/Ring_Element.h"
#include "Tools/Exceptions.h"
FHE_Params::FHE_Params(int n_mults) :
FFTData(n_mults + 1), Chi(0.7), sec_p(-1), matrix_dim(1)
{
}
void FHE_Params::set(const Ring& R,
const vector<bigint>& primes)
{
@@ -24,6 +29,14 @@ void FHE_Params::set_sec(int sec)
throw runtime_error("distributed decryption bound is zero");
}
void FHE_Params::set_matrix_dim(int matrix_dim)
{
assert(matrix_dim > 0);
if (FFTData[0].get_prime() != 0)
throw runtime_error("cannot change matrix dimension after parameter generation");
this->matrix_dim = matrix_dim;
}
bigint FHE_Params::Q() const
{
bigint res = FFTData[0].get_prime();
@@ -40,6 +53,7 @@ void FHE_Params::pack(octetStream& o) const
Chi.pack(o);
Bval.pack(o);
o.store(sec_p);
o.store(matrix_dim);
}
void FHE_Params::unpack(octetStream& o)
@@ -52,6 +66,7 @@ void FHE_Params::unpack(octetStream& o)
Chi.unpack(o);
Bval.unpack(o);
o.get(sec_p);
o.get(matrix_dim);
}
bool FHE_Params::operator!=(const FHE_Params& other) const

View File

@@ -26,10 +26,11 @@ class FHE_Params
// Data for distributed decryption
int sec_p;
bigint Bval;
int matrix_dim;
public:
FHE_Params(int n_mults = 1) : FFTData(n_mults + 1), Chi(0.7), sec_p(-1) {}
FHE_Params(int n_mults = 1);
int n_mults() const { return FFTData.size() - 1; }
@@ -37,6 +38,9 @@ class FHE_Params
void set(const vector<bigint>& primes);
void set_sec(int sec);
void set_matrix_dim(int matrix_dim);
int get_matrix_dim() const { return matrix_dim; }
const vector<FFT_Data>& FFTD() const { return FFTData; }
const bigint& p0() const { return FFTData[0].get_prime(); }

View File

@@ -47,7 +47,7 @@ bool same_word_length(int l1, int l2)
template <>
int generate_semi_setup(int plaintext_length, int sec,
FHE_Params& params, FFT_Data& FTD, bool round_up)
FHE_Params& params, FFT_Data& FTD, bool round_up, int n)
{
int m = 1024;
int lgp = plaintext_length;
@@ -58,7 +58,7 @@ int generate_semi_setup(int plaintext_length, int sec,
while (true)
{
tmp_params = params;
SemiHomomorphicNoiseBounds nb(p, phi_N(m), 1, sec,
SemiHomomorphicNoiseBounds nb(p, phi_N(m), n, sec,
numBits(NonInteractiveProof::slack(sec, phi_N(m))), true, tmp_params);
bigint p1 = 2 * p * m, p0 = p;
while (nb.min_p0(params.n_mults() > 0, p1) > p0)
@@ -89,14 +89,14 @@ int generate_semi_setup(int plaintext_length, int sec,
template <>
int generate_semi_setup(int plaintext_length, int sec,
FHE_Params& params, P2Data& P2D, bool round_up)
FHE_Params& params, P2Data& P2D, bool round_up, int n)
{
if (params.n_mults() > 0)
throw runtime_error("only implemented for 0-level BGV");
gf2n_short::init_field(plaintext_length);
int m;
char_2_dimension(m, plaintext_length);
SemiHomomorphicNoiseBounds nb(2, phi_N(m), 1, sec,
SemiHomomorphicNoiseBounds nb(2, phi_N(m), n, sec,
numBits(NonInteractiveProof::slack(sec, phi_N(m))), true, params);
int lgp0 = numBits(nb.min_p0(false, 0));
int extra_slack = common_semi_setup(params, m, 2, lgp0, -1, round_up);
@@ -590,6 +590,9 @@ void char_2_dimension(int& m, int& lg2)
m=5797;
lg2=40;
break;
case 16:
m = 13107;
break;
default:
throw runtime_error("field size not supported");
break;

View File

@@ -52,7 +52,7 @@ void generate_setup(int nparties, int lgp, int lg2,
// semi-homomorphic, includes slack
template <class FD>
int generate_semi_setup(int plaintext_length, int sec,
FHE_Params& params, FD& FieldD, bool round_up);
FHE_Params& params, FD& FieldD, bool round_up, int n = 1);
// field-independent semi-homomorphic setup
int common_semi_setup(FHE_Params& params, int m, bigint p, int lgp0, int lgp1,

View File

@@ -39,6 +39,7 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p,
bigint B_clean_not_top_gear = B_clean << int(ceil(sec / 2.));
B_clean = max(B_clean_not_top_gear, B_clean_top_gear);
B_scale = (c1 + c2 * V_s) * p * sqrt(phi_m / 12.0);
int matrix_dim = params.get_matrix_dim();
#ifdef NOISY
cout << "p * sqrt(phi(m) / 12): " << p * sqrt(phi_m / 12.0) << endl;
cout << "V_s: " << V_s << endl;
@@ -48,9 +49,11 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p,
cout << "log(slack): " << slack << endl;
cout << "B_clean: " << B_clean << endl;
cout << "B_scale: " << B_scale << endl;
cout << "matrix dimension: " << matrix_dim << endl;
#endif
drown = 1 + n * (bigint(1) << sec);
assert(matrix_dim > 0);
drown = 1 + matrix_dim * n * (bigint(1) << sec);
}
bigint SemiHomomorphicNoiseBounds::min_p0(const bigint& p1)

View File

@@ -50,6 +50,7 @@ void Ring_Element::prepare_push()
void Ring_Element::allocate()
{
assert(FFTD);
element.resize(FFTD->phi_m());
}

View File

@@ -109,6 +109,13 @@ void mul(Rq_Element& ans,const Rq_Element& a,const bigint& b)
}
}
void Rq_Element::add(octetStream& os)
{
Rq_Element tmp(*this);
tmp.unpack(os);
*this += tmp;
}
void Rq_Element::randomize(PRNG& G,int l)
{
set_level(l);
@@ -246,7 +253,7 @@ void Rq_Element::Scale(const bigint& p)
// Now add delta back onto a0
Rq_Element bb(b0,b1);
add(*this,*this,bb);
::add(*this,*this,bb);
// Now divide by p1 mod p0
modp p1_inv,pp;

View File

@@ -93,12 +93,14 @@ protected:
friend void mul(Rq_Element& ans,const Rq_Element& a,const Rq_Element& b);
friend void mul(Rq_Element& ans,const Rq_Element& a,const bigint& b);
void add(octetStream& os);
template<class S>
Rq_Element& operator+=(const vector<S>& other);
Rq_Element& operator+=(const Rq_Element& other) { add(*this, *this, other); return *this; }
Rq_Element& operator+=(const Rq_Element& other) { ::add(*this, *this, other); return *this; }
Rq_Element operator+(const Rq_Element& b) const { Rq_Element res(*this); add(res, *this, b); return res; }
Rq_Element operator+(const Rq_Element& b) const { Rq_Element res(*this); ::add(res, *this, b); return res; }
Rq_Element operator-(const Rq_Element& b) const { Rq_Element res(*this); sub(res, *this, b); return res; }
template <class T>
Rq_Element operator*(const T& b) const { Rq_Element res(*this); mul(res, *this, b); return res; }
@@ -176,7 +178,7 @@ Rq_Element& Rq_Element::operator+=(const vector<S>& other)
{
Rq_Element tmp = *this;
tmp.from(Iterator<S>(other), lev);
add(*this, *this, tmp);
::add(*this, *this, tmp);
return *this;
}

View File

@@ -203,7 +203,7 @@ template<class FD>
void PartSetup<FD>::secure_init(Player& P, MachineBase& machine,
int plaintext_length, int sec)
{
::secure_init(*this, P, machine, plaintext_length, sec);
::secure_init(*this, P, machine, plaintext_length, sec, params);
}
template<class FD>

View File

@@ -130,6 +130,13 @@ void Multiplier<FD>::report_size(ReportType type, MemoryUsage& res)
res += memory_usage;
}
template<class FD>
const vector<Ciphertext>& Multiplier<FD>::get_multiplicands(
const vector<vector<Ciphertext> >& others_ct, const FHE_PK&)
{
return others_ct[P.get_full_player().get_player(-P.get_offset())];
}
template class Multiplier<FFT_Data>;
template class Multiplier<P2Data>;

View File

@@ -55,6 +55,9 @@ public:
size_t report_size(ReportType type);
void report_size(ReportType type, MemoryUsage& res);
size_t report_volatile() { return volatile_capacity; }
const vector<Ciphertext>& get_multiplicands(
const vector<vector<Ciphertext>>& others_ct, const FHE_PK&);
};
#endif /* FHEOFFLINE_MULTIPLIER_H_ */

View File

@@ -9,6 +9,7 @@
#include "Math/Setup.h"
#include "FHEOffline/Proof.h"
#include "FHEOffline/PairwiseMachine.h"
#include "FHEOffline/TemiSetup.h"
#include "Tools/Commit.h"
#include "Tools/Bundle.h"
#include "Processor/OnlineOptions.h"
@@ -53,7 +54,7 @@ void PairwiseSetup<FD>::init(const Player& P, int sec, int plaintext_length,
template <class FD>
void PairwiseSetup<FD>::secure_init(Player& P, PairwiseMachine& machine, int plaintext_length, int sec)
{
::secure_init(*this, P, machine, plaintext_length, sec);
::secure_init(*this, P, machine, plaintext_length, sec, params);
alpha = FieldD;
machine.sk = FHE_SK(params, FieldD.get_prime());
for (auto& pk : machine.other_pks)
@@ -62,13 +63,14 @@ void PairwiseSetup<FD>::secure_init(Player& P, PairwiseMachine& machine, int pla
template <class T, class U>
void secure_init(T& setup, Player& P, U& machine,
int plaintext_length, int sec)
int plaintext_length, int sec, FHE_Params& params)
{
machine.sec = sec;
sec = max(sec, 40);
machine.drown_sec = sec;
string filename = PREP_DIR + T::name() + "-"
+ to_string(plaintext_length) + "-" + to_string(sec) + "-"
+ to_string(params.get_matrix_dim()) + "-"
+ OnlineOptions::singleton.prime.get_str() + "-"
+ to_string(CowGearOptions::singleton.top_gear()) + "-P"
+ to_string(P.my_num()) + "-" + to_string(P.num_players());
@@ -85,7 +87,6 @@ void secure_init(T& setup, Player& P, U& machine,
{
cout << "Finding parameters for security " << sec << " and field size ~2^"
<< plaintext_length << endl;
setup.params = setup.params.n_mults();
setup.generate(P, machine, plaintext_length, sec);
setup.check(P, machine);
octetStream os;
@@ -208,5 +209,8 @@ void PairwiseSetup<FD>::set_alphai(T alphai)
template class PairwiseSetup<FFT_Data>;
template class PairwiseSetup<P2Data>;
template void secure_init(PartSetup<FFT_Data>&, Player&, MachineBase&, int, int);
template void secure_init(PartSetup<P2Data>&, Player&, MachineBase&, int, int);
template void secure_init(PartSetup<FFT_Data>&, Player&, MachineBase&, int, int, FHE_Params& params);
template void secure_init(PartSetup<P2Data>&, Player&, MachineBase&, int, int, FHE_Params& params);
template void secure_init(TemiSetup<FFT_Data>&, Player&, MachineBase&, int, int, FHE_Params& params);
template void secure_init(TemiSetup<P2Data>&, Player&, MachineBase&, int, int, FHE_Params& params);

View File

@@ -15,7 +15,7 @@ class MachineBase;
template <class T, class U>
void secure_init(T& setup, Player& P, U& machine,
int plaintext_length, int sec);
int plaintext_length, int sec, FHE_Params& params);
template <class FD>
class PairwiseSetup

View File

@@ -18,7 +18,12 @@ void SimpleDistDecrypt<FD>::reshare(Plaintext<typename FD::T, FD, typename FD::S
EncCommitBase<typename FD::T, FD, typename FD::S>& EC)
{
(void)EC;
m = reshare(cm);
}
template <class FD>
Plaintext_<FD> SimpleDistDecrypt<FD>::reshare(const Ciphertext& cm)
{
PRNG G;
G.ReSeed();
this->f.randomize(G, Full);
@@ -27,10 +32,13 @@ void SimpleDistDecrypt<FD>::reshare(Plaintext<typename FD::T, FD, typename FD::S
this->run(cm);
// Step 4
Plaintext_<FD> m(this->f.get_field());
if (this->P.my_num()==0)
{ sub(m,this->mf,this->f); }
else
{ m=this->f; m.negate(); }
return m;
}

View File

@@ -20,6 +20,7 @@ public:
void reshare(Plaintext<typename FD::T, FD, typename FD::S>& m,
const Ciphertext& cm,
EncCommitBase<typename FD::T, FD, typename FD::S>& EC);
Plaintext_<FD> reshare(const Ciphertext& cm);
};
#endif /* FHEOFFLINE_SIMPLEDISTDECRYPT_H_ */

59
FHEOffline/TemiSetup.cpp Normal file
View File

@@ -0,0 +1,59 @@
/*
* TemiSetup.cpp
*
*/
#include "TemiSetup.h"
#include "PairwiseSetup.h"
#include "FHE/NTL-Subs.h"
#include "Protocols/HemiOptions.h"
template<class FD>
TemiSetup<FD>::TemiSetup()
{
this->params = FHE_Params(0);
this->pk = {this->params, 0};
this->sk = {this->params, 0};
this->calpha = this->params;
this->params.set_matrix_dim(
HemiOptions::singleton.plain_matmul ?
1 : OnlineOptions::singleton.batch_size);
}
template<class FD>
void TemiSetup<FD>::secure_init(Player& P, int plaintext_length)
{
MachineBase machine;
::secure_init(*this, P, machine, plaintext_length, 0, this->params);
}
template<class FD>
void TemiSetup<FD>::generate(Player& P, MachineBase&,
int plaintext_length, int sec)
{
generate_semi_setup(plaintext_length, sec, this->params, this->FieldD,
false, P.num_players());
this->sk = {this->params, this->FieldD.get_prime()};
this->pk = {this->params, this->FieldD.get_prime()};
}
template<class FD>
void TemiSetup<FD>::key_and_mac_generation(Player& P, MachineBase&, int,
true_type)
{
Rq_Element a(this->params);
GlobalPRNG GG(P);
a.randomize(GG);
SeededPRNG G;
auto sk = this->pk.sample_secret_key(G);
this->sk.assign(sk);
this->pk.partial_key_gen(sk, a, G);
TreeSum<Rq_Element> ts;
vector<Rq_Element> pks;
pks.push_back(this->pk.b());
ts.run(pks, P);
this->pk.assign(this->pk.a(), pks[0]);
}
template class TemiSetup<FFT_Data>;
template class TemiSetup<P2Data>;

34
FHEOffline/TemiSetup.h Normal file
View File

@@ -0,0 +1,34 @@
/*
* TemiSetup.h
*
*/
#ifndef FHEOFFLINE_TEMISETUP_H_
#define FHEOFFLINE_TEMISETUP_H_
#include "FHE/FHE_Keys.h"
#include "FHEOffline/SimpleMachine.h"
template<class FD>
class TemiSetup : public PartSetup<FD>
{
public:
static string name()
{
return "TemiParams";
}
static string protocol_name(int)
{
return "Temi";
}
TemiSetup();
void secure_init(Player& P, int plaintext_length);
void generate(Player& P, MachineBase&, int plaintext_length, int sec);
void key_and_mac_generation(Player& P, MachineBase&, int, true_type);
};
#endif /* FHEOFFLINE_TEMISETUP_H_ */

View File

@@ -47,11 +47,11 @@ inline void Memory<T>::check_index(Integer index) const
ss << T::type_string() << " memory overflow: " << i << "/" << vector<T>::size();
throw Processor_Error(ss.str());
}
#endif
#ifdef DEBUG_MEMORY
cout << typeid(T).name() << " at " << this << " index " << i << ": "
<< vector<T>::operator[](i) << endl;
#endif
#endif
}
template <class T>

View File

@@ -122,6 +122,7 @@ public:
static const bool dishonest_majority = false;
static const bool variable_players = false;
static const bool needs_ot = false;
static const bool has_mac = false;
static string type_string() { return "replicated secret"; }
static string phase_name() { return "Replicated computation"; }

View File

@@ -49,6 +49,7 @@ public:
static const bool dishonest_majority = T::dishonest_majority;
static const bool variable_players = T::variable_players;
static const bool needs_ot = T::needs_ot;
static const bool has_mac = T::has_mac;
static const bool expensive_triples = false;
static const int default_length = 64;

View File

@@ -55,7 +55,7 @@
X(BITDECC, PROC.bitdecc(EXTRA, C0)) \
X(SHRCBI, C0 = PC1 >> IMM) \
X(SHLCBI, C0 = PC1 << IMM) \
X(LDBITS, S0.load_clear(REG1, IMM)) \
X(LDBITS, S0.load_clear(REG1, int(IMM))) \
X(LDMSB, PROC.mem_op(SIZE, PROC.S, MMS, R0, IMM)) \
X(STMSB, PROC.mem_op(SIZE, MMS, PROC.S, IMM, R0)) \
X(LDMCB, PROC.mem_op(SIZE, PROC.C, MMC, R0, IMM)) \

View File

@@ -23,6 +23,7 @@
#include "Protocols/Shamir.hpp"
#include "Protocols/ShamirMC.hpp"
#include "Protocols/MaliciousShamirMC.hpp"
#include "Protocols/MaliciousShamirPO.hpp"
#include "Protocols/MAC_Check_Base.hpp"
#include "Protocols/Beaver.hpp"
#include "Protocols/Spdz2kPrep.hpp"

37
Machines/temi-party.cpp Normal file
View File

@@ -0,0 +1,37 @@
/*
* temi-party.cpp
*
*/
#include "Protocols/TemiShare.h"
#include "Math/gfp.h"
#include "Math/gf2n.h"
#include "FHE/P2Data.h"
#include "Tools/ezOptionParser.h"
#include "GC/SemiSecret.h"
#include "GC/SemiPrep.h"
#include "Processor/FieldMachine.hpp"
#include "Protocols/TemiPrep.hpp"
#include "Processor/Data_Files.hpp"
#include "Processor/Instruction.hpp"
#include "Processor/Machine.hpp"
#include "Protocols/SemiPrep.hpp"
#include "Protocols/SemiInput.hpp"
#include "Protocols/MAC_Check_Base.hpp"
#include "Protocols/MAC_Check.hpp"
#include "Protocols/SemiMC.hpp"
#include "Protocols/Beaver.hpp"
#include "Protocols/MalRepRingPrep.hpp"
#include "Protocols/Hemi.hpp"
#include "GC/ShareSecret.hpp"
#include "GC/SemiHonestRepPrep.h"
#include "Math/gfp.hpp"
int main(int argc, const char** argv)
{
ez::ezOptionParser opt;
HemiOptions::singleton = {opt, argc, argv};
DishonestMajorityFieldMachine<TemiShare, TemiShare, gf2n_short>(argc, argv,
opt);
}

View File

@@ -61,7 +61,7 @@ arithmetic: rep-ring rep-field shamir semi2k-party.x semi-party.x mascot sy
binary: rep-bin yao semi-bin-party.x tinier-party.x tiny-party.x ccd-party.x malicious-ccd-party.x real-bmr
all: overdrive she-offline
arithmetic: hemi-party.x soho-party.x gear
arithmetic: semi-he gear
-include $(DEPS)
include $(wildcard *.d static/*.d)
@@ -87,6 +87,7 @@ she-offline: Check-Offline.x spdz2-offline.x
overdrive: simple-offline.x pairwise-offline.x cnc-offline.x gear
gear: cowgear-party.x chaigear-party.x lowgear-party.x highgear-party.x
semi-he: hemi-party.x soho-party.x temi-party.x
rep-field: malicious-rep-field-party.x replicated-field-party.x ps-rep-field-party.x
@@ -210,6 +211,7 @@ static/spdz2k-party.x: $(patsubst %.cpp,%.o,$(wildcard Machines/SPDZ2*.cpp))
semi-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o
semi2k-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o
hemi-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT)
temi-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT)
soho-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT)
cowgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(TINIER)
chaigear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(TINIER)
@@ -217,6 +219,7 @@ lowgear-party.x: $(FHEOFFLINE) $(TINIER) Protocols/CowGearOptions.o Protocols/Lo
highgear-party.x: $(FHEOFFLINE) $(TINIER) Protocols/CowGearOptions.o Protocols/HighGearKeyGen.o
atlas-party.x: GC/AtlasSecret.o
static/hemi-party.x: $(FHEOBJS)
static/temi-party.x: $(FHEOBJS)
static/soho-party.x: $(FHEOBJS)
static/cowgear-party.x: $(FHEOBJS)
static/chaigear-party.x: $(FHEOBJS)

View File

@@ -14,11 +14,6 @@ using namespace std;
#include "Tools/random.h"
#include "field_types.h"
template<class T> class ReplicatedMC;
template<class T> class ReplicatedInput;
template<class T> class ReplicatedPrivateOutput;
template<class T> class Replicated;
template <class T, int L>
class FixedVec
{

View File

@@ -233,7 +233,7 @@ inline void Zp_Data::Mont_Mult_(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t*
if (mpn_cmp(ans+T,prA,T+1)>=0)
{ mpn_sub_fixed_n<T>(z,ans+T,prA); }
else
{ inline_mpn_copyi(z,ans+T,T); }
{ inline_mpn_copyi<T>(z,ans+T); }
#else
Mont_Mult(z, x, y, t);
#endif

View File

@@ -18,15 +18,21 @@ bool gf2n_<U>::useC;
word gf2n_short_table[256][256];
#define num_2_fields 6
#define num_2_fields 7
/* Require
* 2*(n-1)-64+t1<64
*/
int fields_2[num_2_fields][4] = {
{4,1,0,0},{8,4,3,1},{28,1,0,0},{40,20,15,10},{63,1,0,0},{128,7,2,1},
};
int fields_2[num_2_fields][4] =
{
{ 4, 1, 0, 0 },
{ 8, 4, 3, 1 },
{ 16, 5, 3, 1 },
{ 28, 1, 0, 0 },
{ 40, 20, 15, 10 },
{ 63, 1, 0, 0 },
{ 128, 7, 2, 1 },
};
template<class U>
void gf2n_<U>::init_tables()

View File

@@ -24,6 +24,12 @@ inline void inline_mpn_copyi(mp_limb_t* dest, const mp_limb_t* src, mp_size_t si
avx_memcpy(dest, src, size * sizeof(mp_limb_t));
}
template<int N>
inline void inline_mpn_copyi(mp_limb_t* dest, const mp_limb_t* src)
{
avx_memcpy<N * sizeof(mp_limb_t)>(dest, src);
}
inline void debug_print(const char* name, const mp_limb_t* x, int n)
{
(void)name, (void)x, (void)n;

View File

@@ -542,6 +542,7 @@ public:
int other_player_num() const { return P.get_player(offset); }
int num_players() const { return 2; }
int get_offset() const { return offset; }
Player& get_full_player() const { return P; }
void send(octetStream& o) const { P.send_to(P.get_player(offset), o); }
void reverse_send(octetStream& o) const { P.send_to(P.get_player(-offset), o); }

View File

@@ -206,6 +206,18 @@ void BaseOT::exec_base(bool new_receiver_inputs)
receiver_outputs[i + j].set_byte(k, receiver_keys[j][k]);
}
}
#ifdef BASE_OT_DEBUG
for (j = 0; j < 4; j++)
for (k = 0; k < AES_BLK_SIZE; k++)
{
printf("%4d-th receiver key:", i+j);
for (k = 0; k < HASHBYTES; k++) printf("%.2X", receiver_keys[j][k]);
printf("\n");
}
printf("\n");
#endif
}
}
@@ -244,12 +256,6 @@ void BaseOT::exec_base(bool new_receiver_inputs)
for (k = 0; k < HASHBYTES; k++) printf("%.2X", sender_keys[1][j][k]);
printf("\n");
}
if (ot_role & RECEIVER)
{
printf("%4d-th receiver key:", i+j);
for (k = 0; k < HASHBYTES; k++) printf("%.2X", receiver_keys[j][k]);
printf("\n");
}
}
printf("\n");

View File

@@ -25,7 +25,7 @@ void Binary_File_IO::write_to_file(const string filename,
if (start_pos != -1)
{
long write_pos = start_pos * T::size();
long write_pos = file_signature<T>().get_total_length() + start_pos * T::size();
// fill with zeros if needed
for (long i = outf.tellp(); i < write_pos; i++)
outf.put(0);
@@ -50,10 +50,13 @@ void Binary_File_IO::read_from_file(const string filename, vector< T >& buffer,
inf.open(filename, ios::in | ios::binary);
if (inf.fail()) { throw file_missing(filename, "Binary_File_IO.read_from_file expects this file to exist."); }
check_file_signature<T>(inf, filename).get_length();
auto data_start = inf.tellg();
int size_in_bytes = T::size() * buffer.size();
int n_read = 0;
char read_buffer[size_in_bytes];
inf.seekg(start_posn * T::size());
inf.seekg(start_posn * T::size(), iostream::cur);
do
{
inf.read(read_buffer + n_read, size_in_bytes - n_read);
@@ -62,7 +65,9 @@ void Binary_File_IO::read_from_file(const string filename, vector< T >& buffer,
if (inf.eof())
{
stringstream ss;
ss << "Got to EOF when reading from disk (expecting " << size_in_bytes << " bytes).";
ss << "Got to EOF when reading from disk (expecting " << size_in_bytes
<< " bytes from " << (long(data_start) + start_posn * T::size())
<< ").";
throw file_error(ss.str());
}
if (inf.fail())
@@ -74,7 +79,7 @@ void Binary_File_IO::read_from_file(const string filename, vector< T >& buffer,
}
while (n_read < size_in_bytes);
end_posn = inf.tellg() / T::size();
end_posn = (inf.tellg() - data_start) / T::size();
assert (end_posn == start_posn + int(buffer.size()));
//Check if at end of file by getting 1 more char.

View File

@@ -32,6 +32,15 @@ protected:
Buffer<typename T::clear, typename T::clear> buffer;
Timer timer;
// Send my inputs (not generally available)
virtual void send_mine() { throw not_implemented(); }
// Get share for next input of mine (not generally available)
virtual T finalize_mine() { throw not_implemented(); }
// Store share for next input from ``player`` from buffer ``o``
// in ``target`` (not generally available)
virtual void finalize_other(int, T&, octetStream&, int = -1)
{ throw not_implemented(); }
public:
vector<octetStream> os;
int values_input;
@@ -61,18 +70,12 @@ public:
/// Schedule input from other player
virtual void add_other(int player, int n_bits = -1) = 0;
/// Schedule input from all players
void add_from_all(const clear& input, int n_bits = -1);
void add_from_all(const typename T::open_type& input, int n_bits = -1);
/// Send my inputs
virtual void send_mine() = 0;
/// Run input protocol for all players
virtual void exchange();
/// Get share for next input of mine
virtual T finalize_mine() = 0;
/// Store share for next input from ``player`` from buffer ``o`` in ``target``
virtual void finalize_other(int player, T& target, octetStream& o, int n_bits = -1) = 0;
/// Get share for next input from ``player`
/// Get share for next input from ``player``
virtual T finalize(int player, int n_bits = -1);
void raw_input(SubProcessor<T>& proc, const vector<int>& args, int size);

View File

@@ -113,7 +113,7 @@ void Input<T>::add_other(int player, int)
}
template<class T>
void InputBase<T>::add_from_all(const clear& input, int n_bits)
void InputBase<T>::add_from_all(const typename T::open_type& input, int n_bits)
{
for (int i = 0; i < P->num_players(); i++)
if (i == P->my_num())

View File

@@ -106,6 +106,7 @@ enum
MATMULSM = 0xAB,
CONV2DS = 0xAC,
CHECK = 0xAF,
PRIVATEOUTPUT = 0xAD,
// Data access
TRIPLE = 0x50,
BIT = 0x51,
@@ -127,6 +128,7 @@ enum
INPUTMIXEDREG = 0xF3,
RAWINPUT = 0xF4,
INPUTPERSONAL = 0xF5,
SENDPERSONAL = 0xF6,
STARTINPUT = 0x61,
STOPINPUT = 0x62,
READSOCKETC = 0x63,

View File

@@ -200,14 +200,17 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case USE:
case USE_INP:
case USE_EDABIT:
case DIGESTC:
case INPUTMASK:
case GINPUTMASK:
get_ints(r, s, 2);
n = get_int(s);
break;
case STARTPRIVATEOUTPUT:
case GSTARTPRIVATEOUTPUT:
case STOPPRIVATEOUTPUT:
case GSTOPPRIVATEOUTPUT:
case DIGESTC:
get_ints(r, s, 2);
n = get_int(s);
break;
throw runtime_error("two-stage private output not supported any more");
case USE_MATMUL:
get_ints(r, s, 3);
n = get_int(s);
@@ -237,8 +240,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case PRINTREGB:
case GPRINTREG:
case LDINT:
case INPUTMASK:
case GINPUTMASK:
case INV2M:
case CONDPRINTSTR:
case CONDPRINTSTRB:
@@ -290,6 +291,8 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case RAWINPUT:
case GRAWINPUT:
case INPUTPERSONAL:
case SENDPERSONAL:
case PRIVATEOUTPUT:
case TRUNC_PR:
case RUN_TAPE:
num_var_args = get_int(s);
@@ -599,6 +602,7 @@ int BaseInstruction::get_reg_type() const
case PUBINPUT:
case FLOATOUTPUT:
case READSOCKETC:
case PRIVATEOUTPUT:
return CINT;
default:
if (is_gf2n_instruction())
@@ -738,10 +742,16 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
skip = 1;
break;
case INPUTPERSONAL:
case PRIVATEOUTPUT:
size_offset = -2;
offset = 2;
skip = 4;
break;
case SENDPERSONAL:
size_offset = -2;
offset = 2;
skip = 5;
break;
case READSOCKETS:
case READSOCKETC:
case READSOCKETINT:
@@ -939,13 +949,11 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
break;
case INPUTMASK:
Procp.DataF.get_input(Proc.get_Sp_ref(r[0]), Proc.temp.rrp, n);
if (n == Proc.P.my_num())
Proc.temp.rrp.output(Proc.private_output, false);
Proc.write_Cp(r[1], Proc.temp.rrp);
break;
case GINPUTMASK:
Proc2.DataF.get_input(Proc.get_S2_ref(r[0]), Proc.temp.ans2, n);
if (n == Proc.P.my_num())
Proc.temp.ans2.output(Proc.private_output, false);
Proc.write_C2(r[1], Proc.temp.ans2);
break;
case INPUT:
sint::Input::template input<IntInput<typename sint::clear>>(Proc.Procp, start, size);
@@ -974,6 +982,12 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
case INPUTPERSONAL:
Proc.Procp.input_personal(start);
return;
case SENDPERSONAL:
Proc.Procp.send_personal(start);
return;
case PRIVATEOUTPUT:
Proc.Procp.private_output(start);
return;
// Note: Fp version has different semantics for NOTC than GNOTC
case NOTC:
to_bigint(Proc.temp.aa, Proc.read_Cp(r[1]));
@@ -1202,18 +1216,6 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
Proc.binary_output.write((char*) &tmp, sizeof(double));
}
break;
case STARTPRIVATEOUTPUT:
Proc.privateOutputp.start(n,r[0],r[1]);
break;
case GSTARTPRIVATEOUTPUT:
Proc.privateOutput2.start(n,r[0],r[1]);
break;
case STOPPRIVATEOUTPUT:
Proc.privateOutputp.stop(n,r[0],r[1]);
break;
case GSTOPPRIVATEOUTPUT:
Proc.privateOutput2.stop(n,r[0],r[1]);
break;
case PREP:
Procp.DataF.get(Proc.Procp.get_S(), r, start, size);
return;

View File

@@ -97,12 +97,19 @@ Machine<sint, sgf2n>::Machine(int my_number, Names& playerNames,
// initialize persistence if necessary
for (auto& prog : progs)
{
if (prog.writes_persistance)
if (prog.writes_persistence)
{
string filename = Binary_File_IO::filename(my_number);
ifstream pers(filename);
if (pers.fail())
ofstream pers(filename, ios::binary);
try
{
check_file_signature<sint>(pers, filename);
}
catch (signature_mismatch&)
{
ofstream pers(filename, ios::binary);
file_signature<sint>().output(pers);
}
break;
}
}
@@ -418,12 +425,14 @@ void Machine<sint, sgf2n>::run()
cerr << "Full broadcast" << endl;
#endif
#ifdef CHOP_MEMORY
// Reduce memory size to speed up
unsigned max_size = 1 << 20;
if (M2.size_s() > max_size)
M2.resize_s(max_size);
if (Mp.size_s() > max_size)
Mp.resize_s(max_size);
#endif
// Write out the memory to use next time
ofstream outf(memory_filename(), ios::out | ios::binary);

View File

@@ -44,9 +44,9 @@ class Memory
static void check_index(const vector<U>& M, size_t i)
{
(void) M, (void) i;
#ifdef NO_CHECK_INDEX
#ifndef NO_CHECK_INDEX
if (i >= M.size())
throw overflow("memory", i, M.size());
throw overflow(U::type_string() + " memory", i, M.size());
#endif
}

View File

@@ -19,6 +19,9 @@ void MemoryPart<T>::minimum_size(size_t size)
{
if (size > this->size())
this->resize(size);
#ifdef DEBUG_MEMORY_SIZE
cerr << T::type_string() << " memory has now size " << this->size() << endl;
#endif
}
catch (bad_alloc&)
{
@@ -58,9 +61,9 @@ istream& operator>>(istream& s,Memory<T>& M)
int len;
s >> len;
M.resize_s(len);
M.MS.minimum_size(len);
s >> len;
M.resize_c(len);
M.MC.minimum_size(len);
s.seekg(1, istream::cur);
for (unsigned int i=0; i<M.MS.size(); i++)

View File

@@ -17,16 +17,16 @@ class PrivateOutput
typedef typename T::open_type open_type;
SubProcessor<T>& proc;
typename T::MAC_Check MC;
deque<open_type> masks;
public:
PrivateOutput(SubProcessor<T>& proc) : proc(proc) { };
PrivateOutput(SubProcessor<T>& proc);
~PrivateOutput();
void start(int player, int target, int source);
void stop(int player, int dest, int source);
T start(int player, const T& source);
typename T::clear stop(int player, const typename T::clear& masked);
void prepare_sending(const T& source, int player);
void exchange();
typename T::clear finalize(int player);
};
#endif /* PROCESSOR_PRIVATEOUTPUT_H_ */

View File

@@ -7,13 +7,21 @@
#include "Processor.h"
template<class T>
void PrivateOutput<T>::start(int player, int target, int source)
PrivateOutput<T>::PrivateOutput(SubProcessor<T>& proc) :
proc(proc), MC(proc.MC.get_alphai())
{
proc.get_S_ref(target) = start(player, proc.get_S_ref(source));
MC.init_open(proc.P);
MC.set_prep(proc.DataF);
}
template<class T>
T PrivateOutput<T>::start(int player, const T& source)
PrivateOutput<T>::~PrivateOutput()
{
MC.Check(proc.P);
}
template<class T>
void PrivateOutput<T>::prepare_sending(const T& source, int player)
{
assert (player < proc.P.num_players());
open_type mask;
@@ -24,26 +32,25 @@ T PrivateOutput<T>::start(int player, const T& source)
if (player == proc.P.my_num())
masks.push_back(mask);
return res;
MC.prepare_open(res);
}
template<class T>
void PrivateOutput<T>::stop(int player, int dest, int source)
void PrivateOutput<T>::exchange()
{
auto& value = proc.get_C_ref(dest);
value = stop(player, proc.get_C_ref(source));
if (proc.Proc)
value.output(proc.Proc->private_output, false);
MC.exchange(proc.P);
}
template<class T>
typename T::clear PrivateOutput<T>::stop(int player, const typename T::clear& source)
typename T::clear PrivateOutput<T>::finalize(int player)
{
typename T::clear value;
auto res = MC.finalize_open();
if (player == proc.P.my_num())
{
value = source - masks.front();
res -= masks.front();
masks.pop_front();
}
return value;
return res;
}

View File

@@ -71,6 +71,8 @@ public:
void conv2ds(const Instruction& instruction);
void input_personal(const vector<int>& args);
void send_personal(const vector<int>& args);
void private_output(const vector<int>& args);
CheckVector<T>& get_S()
{
@@ -110,7 +112,6 @@ public:
ifstream private_input;
ifstream public_input;
ofstream public_output;
ofstream private_output;
ofstream binary_output;
int sent, rounds;
@@ -172,9 +173,6 @@ class Processor : public ArithmeticProcessor
SubProcessor<sgf2n> Proc2;
SubProcessor<sint> Procp;
typename sgf2n::PrivateOutput privateOutput2;
typename sint::PrivateOutput privateOutputp;
unsigned int PC;
TempVars<sint, sgf2n> temp;

View File

@@ -4,9 +4,8 @@
#include "Processor/Processor.h"
#include "Processor/Program.h"
#include "GC/square64.h"
#include "SpecificPrivateOutput.h"
#include "Protocols/ReplicatedInput.hpp"
#include "Protocols/ReplicatedPrivateOutput.hpp"
#include "Processor/ProcessorBase.hpp"
#include "GC/Processor.hpp"
#include "GC/ShareThread.hpp"
@@ -63,7 +62,6 @@ Processor<sint, sgf2n>::Processor(int thread_num,Player& P,
share_thread(DataF.DataFb, P, machine.get_bit_mac_key()),
Procb(machine.bit_memories),
Proc2(*this,MC2,DataF.DataF2,P),Procp(*this,MCp,DataF.DataFp,P),
privateOutput2(Proc2),privateOutputp(Procp),
external_clients(P.my_num()),
binary_file_io(Binary_File_IO())
{
@@ -74,7 +72,6 @@ Processor<sint, sgf2n>::Processor(int thread_num,Player& P,
private_input_filename = (get_filename(PREP_DIR "Private-Input-",true));
private_input.open(private_input_filename.c_str());
public_output.open(get_filename(PREP_DIR "Public-Output-",true).c_str(), ios_base::out);
private_output.open(get_filename(PREP_DIR "Private-Output-",true).c_str(), ios_base::out);
binary_output.open(
get_parameterized_filename(P.my_num(), thread_num,
PREP_DIR "Binary-Output"), ios_base::out);
@@ -654,6 +651,37 @@ void SubProcessor<T>::input_personal(const vector<int>& args)
S[args[i + 2] + j] = input.finalize(args[i + 1]);
}
template<class T>
void SubProcessor<T>::private_output(const vector<int>& args)
{
typename T::PrivateOutput output(*this);
for (size_t i = 0; i < args.size(); i += 4)
for (int j = 0; j < args[i]; j++)
{
int player = args[i + 1];
output.prepare_sending(S.at(args[i + 3] + j), player);
}
output.exchange();
for (size_t i = 0; i < args.size(); i += 4)
for (int j = 0; j < args[i]; j++)
C.at(args[i + 2] + j) = output.finalize(args[i + 1]);
}
template<class T>
void SubProcessor<T>::send_personal(const vector<int>& args)
{
octetStreams to_send(P), to_receive(P);
for (size_t i = 0; i < args.size(); i += 5)
if (args[i + 3] == P.my_num())
for (int j = 0; j < args[i]; j++)
C[args[i + 4] + j].pack(to_send[args[i + 1]]);
P.send_receive_all(to_send, to_receive);
for (size_t i = 0; i < args.size(); i += 5)
if (args[i + 1] == P.my_num())
for (int j = 0; j < args[i]; j++)
C[args[i + 2] + j].unpack(to_receive[args[i + 3]]);
}
template<class sint, class sgf2n>
typename sint::clear Processor<sint, sgf2n>::get_inverse2(unsigned m)
{

View File

@@ -23,7 +23,7 @@ void Program::compute_constants()
max_mem[reg_type] = max(max_mem[reg_type],
p[i].get_mem(RegType(reg_type)));
}
writes_persistance |= p[i].opcode == WRITEFILESHARE;
writes_persistence |= p[i].opcode == WRITEFILESHARE;
}
}

View File

@@ -30,10 +30,10 @@ class Program
public:
bool writes_persistance;
bool writes_persistence;
Program(int nplayers) : offline_data_used(nplayers),
unknown_usage(false), writes_persistance(false)
unknown_usage(false), writes_persistence(false)
{ compute_constants(); }
// Read in a program

View File

@@ -0,0 +1,65 @@
/*
* SpecificPrivateOutput.h
*
*/
#ifndef PROCESSOR_SPECIFICPRIVATEOUTPUT_H_
#define PROCESSOR_SPECIFICPRIVATEOUTPUT_H_
template<class T>
class SpecificPrivateOutput
{
deque<T> secrets;
vector<typename T::PO*> pos;
Player& P;
vector<bool> active;
public:
SpecificPrivateOutput(SubProcessor<T>& proc) :
P(proc.P)
{
for (int i = 0; i < P.num_players(); i++)
pos.push_back(new typename T::PO(proc.P));
active.resize(P.num_players());
}
~SpecificPrivateOutput()
{
for (auto& x : pos)
delete x;
}
void prepare_sending(const T& secret, int player)
{
pos[player]->prepare_sending(secret, player);
if (P.my_num() == player)
secrets.push_back(secret);
active[player] = true;
}
void exchange()
{
for (int i = 0; i < this->P.num_players(); i++)
if (active[i])
{
if (i == this->P.my_num())
pos[i]->receive();
else
pos[i]->send(i);
}
}
typename T::clear finalize(int player)
{
if (player == this->P.my_num())
{
T secret = secrets.front();
secrets.pop_front();
return pos[player]->finalize(secret);
}
else
return {};
}
};
#endif /* PROCESSOR_SPECIFICPRIVATEOUTPUT_H_ */

View File

@@ -0,0 +1,100 @@
from Compiler.ml import keras
import Compiler.ml as tf
try:
n_epochs = int(program.args[1])
except (ValueError, IndexError):
n_epochs = 10
try:
batch_size = int(program.args[2])
except (ValueError, IndexError):
batch_size = 128
try:
n_threads = int(program.args[3])
except (ValueError, IndexError):
n_threads = 36
#Instantiation
AlexNet = []
padding = 'same'
batchnorm = 'batchnorm' in program.args
#1st Convolutional Layer
AlexNet.append(keras.layers.Conv2D(filters=96, input_shape=(32,32,3), kernel_size=(11,11), strides=(4,4), padding=9))
AlexNet.append(keras.layers.Activation('relu'))
AlexNet.append(keras.layers.MaxPooling2D(pool_size=3, strides=(2,2)))
if batchnorm:
AlexNet.append(keras.layers.BatchNormalization())
#2nd Convolutional Layer
AlexNet.append(keras.layers.Conv2D(filters=256, kernel_size=(5, 5), strides=(1,1), padding=1))
AlexNet.append(keras.layers.Activation('relu'))
if batchnorm:
AlexNet.append(keras.layers.BatchNormalization())
AlexNet.append(keras.layers.MaxPooling2D(pool_size=(2,2), strides=1))
#3rd Convolutional Layer
AlexNet.append(keras.layers.Conv2D(filters=384, kernel_size=(3,3), strides=(1,1), padding=1))
AlexNet.append(keras.layers.Activation('relu'))
#4th Convolutional Layer
AlexNet.append(keras.layers.Conv2D(filters=384, kernel_size=(3,3), strides=(1,1), padding=1))
AlexNet.append(keras.layers.Activation('relu'))
#5th Convolutional Layer
AlexNet.append(keras.layers.Conv2D(filters=256, kernel_size=(3,3), strides=(1,1), padding=1))
AlexNet.append(keras.layers.Activation('relu'))
#Passing it to a Fully Connected layer
# 1st Fully Connected Layer
AlexNet.append(keras.layers.Dense(256))
AlexNet.append(keras.layers.Activation('relu'))
if 'dropout' in program.args:
AlexNet.append(keras.layers.Dropout(0.5))
#2nd Fully Connected Layer
AlexNet.append(keras.layers.Dense(256))
AlexNet.append(keras.layers.Activation('relu'))
if 'dropout' in program.args:
AlexNet.append(keras.layers.Dropout(0.5))
#Output Layer
AlexNet.append(keras.layers.Dense(10))
tf.set_n_threads(n_threads)
program.options_from_args()
sfix.set_precision_from_args(program, adapt_ring=True)
training_samples = MultiArray([50000, 32, 32, 3], sfix)
training_labels = MultiArray([50000, 10], sint)
test_samples = MultiArray([10000, 32, 32, 3], sfix)
test_labels = MultiArray([10000, 10], sint)
if 'no_acc' not in program.args:
training_labels.input_from(0)
training_samples.input_from(0)
test_labels.input_from(0)
test_samples.input_from(0)
model = tf.keras.models.Sequential(AlexNet)
model.compile_by_args(program)
model.build(training_samples.sizes)
model.summary()
opt = model.fit(
training_samples,
training_labels,
epochs=n_epochs,
batch_size=batch_size,
validation_data=(test_samples, test_labels)
)

View File

@@ -0,0 +1,45 @@
# this trains LeNet on MNIST with a dropout layer
# see https://github.com/csiro-mlai/mnist-mpc for data preparation
program.options_from_args()
training_samples = MultiArray([50000, 32, 32, 3], sfix)
training_labels = MultiArray([50000, 10], sint)
test_samples = MultiArray([10000, 32, 32, 3], sfix)
test_labels = MultiArray([10000, 10], sint)
training_labels.input_from(0)
training_samples.input_from(0)
test_labels.input_from(0)
test_samples.input_from(0)
from Compiler import ml
tf = ml
ml.set_n_threads(36)
layers = [
tf.keras.layers.Conv2D(20, 5, 1, 'valid', activation='relu'),
tf.keras.layers.MaxPooling2D(2),
tf.keras.layers.Conv2D(50, 5, 1, 'valid', activation='relu'),
tf.keras.layers.MaxPooling2D(2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(500, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
]
model = tf.keras.models.Sequential(layers)
optim = tf.keras.optimizers.Adam(amsgrad=True)
model.compile(optimizer=optim)
opt = model.fit(
training_samples,
training_labels,
epochs=10,
batch_size=128,
validation_data=(test_samples, test_labels)
)

View File

@@ -21,7 +21,8 @@ tf = ml
layers = [
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(128),
tf.keras.layers.Activation('relu'),
tf.keras.layers.Dense(10, activation='softmax')
]

View File

@@ -20,8 +20,21 @@ tf = ml
layers = [
tf.keras.layers.Conv2D(20, 5, 1, 'valid', activation='relu'),
]
if 'batchnorm' in program.args:
layers += [tf.keras.layers.BatchNormalization()]
layers += [
tf.keras.layers.MaxPooling2D(2),
tf.keras.layers.Conv2D(50, 5, 1, 'valid', activation='relu'),
]
if 'batchnorm' in program.args:
layers += [tf.keras.layers.BatchNormalization()]
layers += [
tf.keras.layers.MaxPooling2D(2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.5),

View File

@@ -21,6 +21,8 @@ elif 'debug' in program.args:
n_test = 100
elif 'debug5000' in program.args:
N = n_test = 5000
elif 'mini' in program.args:
N = n_test = 10
else:
N = 60000
n_test = 10000
@@ -39,6 +41,7 @@ except:
batch_size = N
N = min(N, 10000)
batch_size = min(batch_size, N)
ml.Layer.back_batch_size = batch_size
try:
@@ -71,6 +74,9 @@ else:
ml.Dense(N, n_inner, n_inner, activation=activation, debug=debug_ml),
ml.Dense(N, n_inner, 10, debug=debug_ml)]
if 'batchnorm' in program.args:
layers.insert(1, ml.BatchNorm([N, n_inner]))
if 'dropout' in program.args:
for i in range(len(layers) - 1, 0, -1):
layers.insert(i, ml.Dropout(N, n_inner))

View File

@@ -53,7 +53,7 @@ except:
ml.Layer.back_batch_size = batch_size
layers = [
ml.FixConv2d([n_examples, 28, 28, 1], (20, 5, 5, 1), (20,), [n_examples, 24, 24, 20], (1, 1), 'VALID'),
ml.FixConv2d([n_examples, 28, 28, 1], (20, 5, 5, 1), (20,), [N, 24, 24, 20], (1, 1), 'VALID'),
ml.MaxPool([N, 24, 24, 20]),
ml.Relu([N, 12, 12, 20]),
ml.FixConv2d([N, 12, 12, 20], (50, 5, 5, 20), (50,), [N, 8, 8, 50], (1, 1), 'VALID'),
@@ -66,6 +66,12 @@ layers = [
layers += [ml.MultiOutput.from_args(program, n_examples, 10)]
if 'batchnorm' in program.args:
for arg in program.args:
assert not arg.startswith('dropout')
layers.insert(4, ml.BatchNorm([N, 8, 8, 50], args=program.args))
layers.insert(1, ml.BatchNorm([N, 24, 24, 20], args=program.args))
if 'dropout' in program.args or 'dropout2' in program.args:
layers.insert(8, ml.Dropout(N, 500))
elif 'dropout.25' in program.args:

View File

@@ -85,6 +85,12 @@ void Atlas<T>::exchange()
resharing.add_mine(e);
}
for (size_t i = 0; i < min(masks.size(), size_t(P.num_players())); i++)
{
int j = (base_king + i) % P.num_players();
resharing.add_sender(j);
}
resharing.exchange();
}

View File

@@ -27,7 +27,7 @@ HemiMatrixPrep<T>& Hemi<T>::get_matrix_prep(const array<int, 3>& dims,
if (matrix_preps.find(dims) == matrix_preps.end())
matrix_preps.insert({dims,
new HemiMatrixPrep<T>(dims[0], dims[1], dims[2],
dynamic_cast<HemiPrep<T>&>(processor.DataF))});
dynamic_cast<typename T::LivePrep&>(processor.DataF))});
return *matrix_preps.at(dims);
}

View File

@@ -18,17 +18,18 @@ template<class T>
class HemiMatrixPrep : public BufferPrep<ShareMatrix<T>>
{
typedef BufferPrep<ShareMatrix<T>> super;
typedef typename T::LivePrep LivePrep;
int n_rows, n_inner, n_cols;
bool swapped;
DataPositions* usage;
HemiPrep<T>* prep;
LivePrep* prep;
HemiMatrixPrep(const HemiMatrixPrep&) = delete;
public:
HemiMatrixPrep(int n_rows, int n_inner, int n_cols, HemiPrep<T>& prep) :
HemiMatrixPrep(int n_rows, int n_inner, int n_cols, LivePrep& prep) :
super(*(usage = new DataPositions)), n_rows(n_rows), n_inner(n_inner),
n_cols(n_cols), prep(&prep)
{

View File

@@ -87,11 +87,10 @@ void HemiMatrixPrep<T>::buffer_triples()
assert(prep);
auto& multipliers = prep->get_multipliers();
assert(prep->pairwise_machine);
auto& FTD = prep->pairwise_machine->setup_p.FieldD;
auto& pk = prep->pairwise_machine->pk;
auto& FTD = prep->get_FTD();
auto& pk = prep->get_pk();
int n_matrices = FTD.num_slots() / n_rows;
#ifdef VERBOSE
#ifdef VERBOSE_HE
fprintf(stderr, "creating %d %dx%d * %dx%d triples\n", n_matrices, n_rows, n_inner,
n_inner, n_cols);
fflush(stderr);
@@ -103,20 +102,23 @@ void HemiMatrixPrep<T>::buffer_triples()
AddableVector<ValueMatrix<gfpvar>> C(n_matrices);
MatrixRandMultJob job(C, A, B);
if (BaseMachine::thread_num == 0 and BaseMachine::has_singleton())
if (T::local_mul)
{
auto& queues = BaseMachine::s().queues;
int start = queues.distribute(job, n_matrices);
job.begin = start;
job.end = n_matrices;
matrix_rand_mult(job);
queues.wrap_up(job);
}
else
{
job.begin = 0;
job.end = n_matrices;
matrix_rand_mult(job);
if (BaseMachine::thread_num == 0 and BaseMachine::has_singleton())
{
auto& queues = BaseMachine::s().queues;
int start = queues.distribute(job, n_matrices);
job.begin = start;
job.end = n_matrices;
matrix_rand_mult(job);
queues.wrap_up(job);
}
else
{
job.begin = 0;
job.end = n_matrices;
matrix_rand_mult(job);
}
}
#ifdef VERBOSE_HE
@@ -130,26 +132,35 @@ void HemiMatrixPrep<T>::buffer_triples()
assert(prep->proc);
auto& P = prep->proc->P;
Bundle<octetStream> bundle(P);
bundle.mine.store(diag.ciphertexts);
P.unchecked_broadcast(bundle);
vector<vector<Ciphertext>> others_ct;
for (auto& os : bundle)
if (T::local_mul or OnlineOptions::singleton.direct)
{
others_ct.push_back({});
os.get(others_ct.back(), Ciphertext(pk));
Bundle<octetStream> bundle(P);
bundle.mine.store(diag.ciphertexts);
P.unchecked_broadcast(bundle);
for (auto& os : bundle)
{
others_ct.push_back({});
os.get(others_ct.back(), Ciphertext(pk));
}
}
else
{
others_ct.push_back(diag.ciphertexts);
TreeSum<Ciphertext>().run(others_ct[0], P);
}
for (int j = 0; j < n_cols; j++)
for (auto m : multipliers)
{
#ifdef VERBOSE
#ifdef VERBOSE_HE
fprintf(stderr, "column %d with party offset %d at %f\n", j,
m->get_offset(), timer.elapsed());
fflush(stderr);
#endif
Ciphertext C(pk);
auto& multiplicands = others_ct[P.get_player(-m->get_offset())];
auto& multiplicands = m->get_multiplicands(others_ct, pk);
if (BaseMachine::thread_num == 0 and BaseMachine::has_singleton())
{
auto& queues = BaseMachine::s().queues;
@@ -160,7 +171,7 @@ void HemiMatrixPrep<T>::buffer_triples()
CipherPlainMultJob job(products, multiplicands, multiplicands2, true);
int start = queues.distribute(job, n_inner);
#ifdef VERBOSE_HE
fprintf(stderr, "from %d in central thread\n", start);
fprintf(stderr, "from %d in central thread at %f\n", start, timer.elapsed());
fflush(stderr);
#endif
for (int i = start; i < n_inner; i++)
@@ -185,7 +196,10 @@ void HemiMatrixPrep<T>::buffer_triples()
m->add(products[j], C, BOTH, n_inner);
}
C += diag.dediag(products, n_matrices);
if (T::local_mul)
C += diag.dediag(products, n_matrices);
else
C = diag.dediag(products, n_matrices);
for (int i = 0; i < n_matrices; i++)
if (swapped)

View File

@@ -34,6 +34,9 @@ public:
static void basic_setup(Player& P);
static void teardown();
static const FHE_PK& get_pk();
static const FD& get_FTD();
HemiPrep(SubProcessor<T>* proc, DataPositions& usage) :
BufferPrep<T>(usage),
BitPrep<T>(proc, usage), RingPrep<T>(proc, usage),

View File

@@ -34,6 +34,20 @@ void HemiPrep<T>::basic_setup(Player& P)
T::clear::template init<typename FD::T>();
}
template<class T>
const FHE_PK& HemiPrep<T>::get_pk()
{
assert(pairwise_machine);
return pairwise_machine->pk;
}
template<class T>
const typename T::clear::FD& HemiPrep<T>::get_FTD()
{
assert(pairwise_machine);
return pairwise_machine->setup<FD>().FieldD;
}
template<class T>
HemiPrep<T>::~HemiPrep()

View File

@@ -27,6 +27,7 @@ public:
typedef HemiPrep<This> LivePrep;
static const bool needs_ot = false;
static const bool local_mul = true;
static true_type triple_matmul;
HemiShare()

View File

@@ -140,12 +140,12 @@ void KeyGenProtocol<X, L>::output_to(int player, vector<open_type>& opened,
vector<share_type>& shares)
{
PrivateOutput<share_type> po(*proc);
vector<share_type> masked;
for (auto& share : shares)
masked.push_back(po.start(player, share));
MC->POpen(opened, masked, P);
po.prepare_sending(share, player);
po.exchange();
opened.resize(shares.size());
for (auto& x : opened)
x = po.stop(player, x);
x = po.finalize(player);
}
template<int L>

View File

@@ -52,6 +52,7 @@ public:
virtual ~TreeSum();
void run(vector<T>& values, const Player& P);
T run(const T& value, const Player& P);
octetStream& get_buffer() { return os; }
@@ -210,6 +211,14 @@ void TreeSum<T>::run(vector<T>& values, const Player& P)
finish(values, P);
}
template<class T>
T TreeSum<T>::run(const T& value, const Player& P)
{
vector<T> values = {value};
run(values, P);
return values[0];
}
template<class T>
size_t TreeSum<T>::report_size(ReportType type)
{
@@ -244,14 +253,6 @@ void add_openings(vector<T>& values, const Player& P, int sum_players, int last_
MC.player_timers[sender].start();
P.wait_receive(sender, oss[j]);
MC.player_timers[sender].stop();
if ((unsigned)oss[j].get_length() < values.size() * T::size())
{
stringstream ss;
ss << "Not enough information received, expected "
<< values.size() * T::size() << " bytes, got "
<< oss[j].get_length();
throw Processor_Error(ss.str());
}
MC.timers[SUM].start();
for (unsigned int i=0; i<values.size(); i++)
{

View File

@@ -127,6 +127,7 @@ void MAC_Check_<U>::Check(const Player& P)
auto& vals = this->vals;
auto& macs = this->macs;
auto& popen_cnt = this->popen_cnt;
assert(int(macs.size()) <= popen_cnt);
if (popen_cnt < 10)
{

View File

@@ -12,6 +12,8 @@ using namespace std;
#include "Networking/Player.h"
#include "Tools/PointerVector.h"
template<class T> class Preprocessing;
/**
* Abstract base class for opening protocols
*/
@@ -61,6 +63,8 @@ public:
virtual void CheckFor(const typename T::open_type& value, const vector<T>& shares, const Player& P);
virtual const Player& get_check_player(const Player& P) const { return P; }
virtual void set_prep(Preprocessing<T>&) {}
};
#endif /* PROTOCOLS_MAC_CHECK_BASE_H_ */

View File

@@ -17,6 +17,7 @@ class MalRepRingShare : public MaliciousRep3Share<SignedZ2<K>>
{
typedef SignedZ2<K> T;
typedef MaliciousRep3Share<T> super;
typedef MalRepRingShare This;
public:
const static int BIT_LENGTH = K;
@@ -26,7 +27,8 @@ public:
typedef HashMaliciousRepMC<MalRepRingShare> MAC_Check;
typedef MAC_Check Direct_MC;
typedef ReplicatedInput<MalRepRingShare> Input;
typedef ::PrivateOutput<MalRepRingShare> PrivateOutput;
typedef ReplicatedPO<This> PO;
typedef SpecificPrivateOutput<This> PrivateOutput;
typedef MalRepRingPrepWithBits<MalRepRingShare> LivePrep;
typedef MaliciousRep3Share<Z2<K + S>> prep_type;
typedef Z2<S> random_type;

View File

@@ -13,6 +13,7 @@ template<class T> class Beaver;
template<class T> class MaliciousRepPrepWithBits;
template<class T> class MaliciousRepPO;
template<class T> class MaliciousRepPrep;
template<class T> class SpecificPrivateOutput;
namespace GC
{
@@ -30,8 +31,8 @@ public:
typedef HashMaliciousRepMC<MaliciousRep3Share<T>> MAC_Check;
typedef MAC_Check Direct_MC;
typedef ReplicatedInput<MaliciousRep3Share<T>> Input;
typedef ::PrivateOutput<MaliciousRep3Share<T>> PrivateOutput;
typedef MaliciousRepPO<MaliciousRep3Share> PO;
typedef SpecificPrivateOutput<This> PrivateOutput;
typedef Rep3Share<T> Honest;
typedef MaliciousRepPrepWithBits<MaliciousRep3Share> LivePrep;
typedef MaliciousRepPrep<MaliciousRep3Share> TriplePrep;

View File

@@ -9,13 +9,14 @@
template<class T>
class MaliciousShamirPO
{
protected:
Player& P;
octetStream to_send;
vector<octetStream> to_receive;
vector<typename T::open_type> shares;
MaliciousShamirMC<T> MC;
typename T::Direct_MC MC;
public:
MaliciousShamirPO(Player& P);

View File

@@ -13,6 +13,7 @@
template<class T> class MaliciousRepPrepWithBits;
template<class T> class MaliciousRepPrep;
template<class T> class MaliciousShamirPO;
template<class T> class SpecificPrivateOutput;
namespace GC
{
@@ -23,14 +24,15 @@ template<class T>
class MaliciousShamirShare : public ShamirShare<T>
{
typedef ShamirShare<T> super;
typedef MaliciousShamirShare This;
public:
typedef Beaver<MaliciousShamirShare<T>> Protocol;
typedef MaliciousShamirMC<MaliciousShamirShare> MAC_Check;
typedef MAC_Check Direct_MC;
typedef ShamirInput<MaliciousShamirShare> Input;
typedef ::PrivateOutput<MaliciousShamirShare> PrivateOutput;
typedef MaliciousShamirPO<MaliciousShamirShare> PO;
typedef SpecificPrivateOutput<This> PrivateOutput;
typedef ShamirShare<T> Honest;
typedef MaliciousRepPrepWithBits<MaliciousShamirShare> LivePrep;
typedef MaliciousRepPrep<MaliciousShamirShare> TriplePrep;

View File

@@ -76,12 +76,6 @@ public:
return string(1, T::type_char());
}
static void read_or_generate_mac_key(string, Player&, mac_key_type& key)
{
SeededPRNG G;
key.randomize(G);
}
MamaShare()
{
}

View File

@@ -15,6 +15,7 @@ template<class T>
class PostSacriRepFieldShare : public MaliciousRep3Share<T>
{
typedef MaliciousRep3Share<T> super;
typedef PostSacriRepFieldShare This;
public:
typedef typename super::clear clear;
@@ -23,7 +24,8 @@ public:
typedef HashMaliciousRepMC<PostSacriRepFieldShare> MAC_Check;
typedef MAC_Check Direct_MC;
typedef ReplicatedInput<PostSacriRepFieldShare> Input;
typedef ::PrivateOutput<PostSacriRepFieldShare> PrivateOutput;
typedef ReplicatedPO<This> PO;
typedef SpecificPrivateOutput<This> PrivateOutput;
typedef MaliciousRepPrepWithBits<PostSacriRepFieldShare> LivePrep;
PostSacriRepFieldShare()

View File

@@ -17,6 +17,7 @@ template<int K, int S>
class PostSacriRepRingShare : public Rep3Share2<K>
{
typedef Rep3Share2<K> super;
typedef PostSacriRepRingShare This;
public:
static const int BIT_LENGTH = K;
@@ -33,7 +34,8 @@ public:
typedef HashMaliciousRepMC<PostSacriRepRingShare> MAC_Check;
typedef MAC_Check Direct_MC;
typedef ReplicatedInput<PostSacriRepRingShare> Input;
typedef ::PrivateOutput<PostSacriRepRingShare> PrivateOutput;
typedef ReplicatedPO<This> PO;
typedef SpecificPrivateOutput<This> PrivateOutput;
typedef MalRepRingPrepWithBits<PostSacriRepRingShare> LivePrep;
typedef GC::MaliciousRepSecret bit_type;

View File

@@ -42,8 +42,13 @@ public:
{
}
~ProtocolSet()
/**
* Run all protocol checks
*/
void check()
{
protocol.check();
output.Check(processor.P);
}
};
@@ -73,6 +78,15 @@ public:
*thread.protocol), input(output, prep, P)
{
}
/**
* Run all protocol checks
*/
void check()
{
protocol.check();
output.Check(protocol.P);
}
};
/**
@@ -102,6 +116,15 @@ public:
arithmetic.protocol), input(arithmetic.input)
{
}
/**
* Run all protocol checks
*/
void check()
{
arithmetic.check();
binary.check();
}
};
#endif /* PROTOCOLS_PROTOCOLSET_H_ */

View File

@@ -15,7 +15,8 @@
template<class T> class ReplicatedPrep;
template<class T> class ReplicatedRingPrep;
template<class T> class PrivateOutput;
template<class T> class ReplicatedPO;
template<class T> class SpecificPrivateOutput;
template<class T, int L>
class RepShare : public FixedVec<T, L>, public ShareInterface
@@ -99,6 +100,7 @@ template<class T>
class Rep3Share : public RepShare<T, 2>
{
typedef RepShare<T, 2> super;
typedef Rep3Share This;
public:
typedef T clear;
@@ -107,7 +109,8 @@ public:
typedef ReplicatedMC<Rep3Share> MAC_Check;
typedef MAC_Check Direct_MC;
typedef ReplicatedInput<Rep3Share> Input;
typedef ::PrivateOutput<Rep3Share> PrivateOutput;
typedef ReplicatedPO<This> PO;
typedef SpecificPrivateOutput<This> PrivateOutput;
typedef ReplicatedPrep<Rep3Share> LivePrep;
typedef ReplicatedRingPrep<Rep3Share> TriplePrep;
typedef Rep3Share Honest;

View File

@@ -24,7 +24,8 @@ public:
typedef ReplicatedMC<Rep3Share2> MAC_Check;
typedef MAC_Check Direct_MC;
typedef ReplicatedInput<Rep3Share2> Input;
typedef ::PrivateOutput<Rep3Share2> PrivateOutput;
typedef ReplicatedPO<This> PO;
typedef SpecificPrivateOutput<This> PrivateOutput;
typedef ReplicatedPrep2k<Rep3Share2> LivePrep;
typedef Rep3Share2 Honest;
typedef SignedZ2<K> clear;

View File

@@ -31,7 +31,6 @@ public:
void add_mine(const typename T::open_type& input, int n_bits = -1);
void add_other(int player, int n_bits = -1);
void send_mine();
void exchange();
T finalize_mine();

View File

@@ -64,12 +64,6 @@ void Rep4Input<T>::add_other(int player, int)
results[player].push_back(res);
}
template<class T>
void Rep4Input<T>::send_mine()
{
throw not_implemented();
}
template<class T>
void Rep4Input<T>::exchange()
{

View File

@@ -19,10 +19,6 @@ using namespace std;
template<class T> class SubProcessor;
template<class T> class ReplicatedMC;
template<class T> class ReplicatedInput;
template<class T> class ReplicatedPrivateOutput;
template<class T> class Share;
template<class T> class Rep3Share;
template<class T> class MAC_Check_Base;
template<class T> class Preprocessing;
class Instruction;
@@ -141,9 +137,6 @@ class Replicated : public ReplicatedBase, public ProtocolBase<T>
void trunc_pr(const vector<int>& regs, int size, U& proc, false_type);
public:
typedef ReplicatedMC<T> MAC_Check;
typedef ReplicatedInput<T> Input;
static const bool uses_triples = false;
Replicated(Player& P);

View File

@@ -10,6 +10,7 @@
#include "Processor/Processor.h"
#include "Processor/TruncPrTuple.h"
#include "Tools/benchmarking.h"
#include "Tools/Bundle.h"
#include "ReplicatedInput.h"
#include "Rep3Share2k.h"
@@ -162,14 +163,13 @@ void Replicated<T>::prepare_mul(const T& x,
}
template<class T>
inline void Replicated<T>::prepare_reshare(const typename T::clear& share,
void Replicated<T>::prepare_reshare(const typename T::clear& share,
int n)
{
auto add_share = share;
typename T::value_type tmp[2];
for (int i = 0; i < 2; i++)
tmp[i].randomize(shared_prngs[i], n);
add_share += tmp[0] - tmp[1];
auto add_share = share + tmp[0] - tmp[1];
add_share.pack(os[0], n);
add_shares.push_back(add_share);
}

View File

@@ -56,16 +56,24 @@ BufferPrep<T>::~BufferPrep()
<< " bit generation" << endl;
#endif
auto field_type = T::clear::field_type();
auto& my_usage = this->usage.files.at(field_type);
this->print_left("triples", triples.size() * T::default_length, type_string,
this->usage.files.at(T::clear::field_type()).at(DATA_TRIPLE)
* T::default_length);
size_t used_bits = my_usage.at(DATA_BIT);
if (not T::clear::invertible and field_type == DATA_INT and not T::has_mac)
// add dabits with computation modulo power of two but without MAC
used_bits += my_usage.at(DATA_DABIT);
this->print_left("bits", bits.size(), type_string, used_bits);
#define X(KIND, TYPE) \
this->print_left(#KIND, KIND.size(), type_string, \
this->usage.files.at(T::clear::field_type()).at(TYPE));
X(squares, DATA_SQUARE)
X(inverses, DATA_INVERSE)
X(bits, DATA_BIT)
X(dabits, DATA_DABIT)
#undef X
@@ -601,17 +609,6 @@ void buffer_bits_from_players(vector<vector<T>>& player_bits,
for (int i = 0; i < n_relevant_players; i++)
for (auto& x : player_bits[i])
x = input.finalize((base_player + i) % P.num_players(), n_bits);
#if !defined(__clang__) && (__GNUC__ == 6)
// mitigate compiler bug
Bundle<octetStream> bundle(P);
P.unchecked_broadcast(bundle);
#endif
#ifdef DEBUG_BIT_SACRIFICE
typename T::MAC_Check MC;
for (int i = 0; i < n_relevant_players; i++)
for (auto& x : player_bits[i])
assert((MC.open(x, P) == 0) or (MC.open(x, P) == 1));
#endif
}
template<class T>
@@ -1164,18 +1161,18 @@ void BufferPrep<T>::buffer_inputs_as_usual(int player, SubProcessor<T>* proc)
typename T::clear r;
r.randomize(G);
input.add_mine(r);
this->inputs[player].push_back({input.finalize_mine(), r});
this->inputs[player].push_back({input.finalize(player), r});
}
input.send_mine();
input.exchange();
}
else
{
octetStream os;
P.receive_player(player, os);
T share;
for (int i = 0; i < buffer_size; i++)
input.add_other(player);
input.exchange();
for (int i = 0; i < buffer_size; i++)
{
input.finalize_other(player, share, os);
auto share = input.finalize(player);
this->inputs[player].push_back({share, 0});
}
}

View File

@@ -1,26 +0,0 @@
/*
* ReplicatedPrivateOutput.h
*
*/
#ifndef PROTOCOLS_REPLICATEDPRIVATEOUTPUT_H_
#define PROTOCOLS_REPLICATEDPRIVATEOUTPUT_H_
template<class T>
class SubProcessor;
template<class T>
class Share;
template <class T>
class ReplicatedPrivateOutput
{
SubProcessor<T>& proc;
public:
ReplicatedPrivateOutput(SubProcessor<T>& proc);
void start(int player, int target, int source);
void stop(int player, int source);
};
#endif /* PROTOCOLS_REPLICATEDPRIVATEOUTPUT_H_ */

View File

@@ -1,30 +0,0 @@
/*
* ReplicatedPrivateOutput.cpp
*
*/
#include "ReplicatedPrivateOutput.h"
#include "Processor/Processor.h"
#include "Math/FixedVec.h"
#include "Math/Integer.h"
template<class T>
inline ReplicatedPrivateOutput<T>::ReplicatedPrivateOutput(
SubProcessor<T>& proc) :
proc(proc)
{
}
template<class T>
void ReplicatedPrivateOutput<T>::start(int player, int target,
int source)
{
(void)player, (void)target, (void)source;
throw runtime_error("not implemented, use PrivateOutput");
}
template<class T>
void ReplicatedPrivateOutput<T>::stop(int player, int source)
{
(void)player, (void)source;
}

View File

@@ -71,6 +71,12 @@ public:
proc.get_S()[info.source_base + i] >> info.m;
}
}
void buffer_random()
{
for (int i = 0; i < OnlineOptions::singleton.batch_size; i++)
this->random.push_back(G.get<T>());
}
};
#endif /* PROTOCOLS_SEMI_H_ */

View File

@@ -14,34 +14,33 @@ template<class T> class SemiMC;
* Additive secret sharing input protocol
*/
template<class T>
class SemiInput : public IndividualInput<T>
class SemiInput : public InputBase<T>
{
SeededPRNG secure_prng;
vector<SeededPRNG> send_prngs;
vector<PRNG> recv_prngs;
Player& P;
vector<PointerVector<T>> shares;
public:
SemiInput(SubProcessor<T>& proc, SemiMC<T>& MC) :
IndividualInput<T>(proc)
SemiInput(SubProcessor<T>& proc, SemiMC<T>&) :
SemiInput(&proc, proc.P)
{
(void) MC;
}
SemiInput(SubProcessor<T>* proc, Player& P) :
IndividualInput<T>(proc, P)
{
}
SemiInput(SubProcessor<T>* proc, Player& P);
SemiInput(typename T::MAC_Check& MC, Preprocessing<T>& prep, Player& P) :
SemiInput(P)
SemiInput(0, P)
{
(void) MC, (void) prep;
}
SemiInput(Player& P) :
IndividualInput<T>(0, P)
{
}
void reset(int player);
void add_mine(const typename T::clear& input, int n_bits = -1);
void add_other(int player, int n_bits = -1);
void exchange();
void finalize_other(int player, T& target, octetStream& o, int n_bits = -1);
T finalize_mine();
};
#endif /* PROTOCOLS_SEMIINPUT_H_ */

View File

@@ -11,22 +11,64 @@
#include "ShamirInput.hpp"
template<class T>
void SemiInput<T>::add_mine(const typename T::clear& input, int n_bits)
SemiInput<T>::SemiInput(SubProcessor<T>* proc, Player& P) :
InputBase<T>(proc), P(P)
{
shares.resize(P.num_players());
vector<octetStream> to_send(P.num_players()), to_receive;
for (int i = 0; i < P.num_players(); i++)
{
send_prngs.push_back({});
to_send[i].append(send_prngs.back().get_seed(), SEED_SIZE);
}
P.send_receive_all(to_send, to_receive);
recv_prngs.resize(P.num_players());
for (int i = 0; i < P.num_players(); i++)
if (i != P.my_num())
recv_prngs[i].SetSeed(to_receive[i].consume(SEED_SIZE));
this->reset_all(P);
}
template<class T>
void SemiInput<T>::reset(int player)
{
shares[player].clear();
}
template<class T>
void SemiInput<T>::add_mine(const typename T::clear& input, int)
{
auto& P = this->P;
typename T::open_type sum, share;
for (int i = 0; i < P.num_players(); i++)
{
if (i < P.num_players() - 1)
share.randomize(secure_prng, n_bits);
else
share = input - sum;
sum += share;
if (i == P.my_num())
this->shares.push_back(share);
else
share.pack(this->os[i], n_bits);
if (i != P.my_num())
sum += send_prngs[i].template get<typename T::open_type>();
}
shares[P.my_num()].push_back(input - sum);
}
template<class T>
void SemiInput<T>::add_other(int, int)
{
}
template<class T>
void SemiInput<T>::exchange()
{
}
template<class T>
void SemiInput<T>::finalize_other(int player, T& target, octetStream&,
int)
{
target = recv_prngs[player].template get<T>();
}
template<class T>
T SemiInput<T>::finalize_mine()
{
return shares[P.my_num()].next();
}
#endif

View File

@@ -27,7 +27,6 @@ class Shamir : public ProtocolBase<T>
{
typedef typename T::open_type::Scalar U;
octetStreams os;
vector<U> reconstruction;
U rec_factor;
ShamirInput<T>* resharing;

View File

@@ -69,8 +69,6 @@ int Shamir<T>::get_n_relevant_players()
template<class T>
void Shamir<T>::reset()
{
os.reset(P);
if (resharing == 0)
{
resharing = new ShamirInput<T>(0, P);
@@ -78,6 +76,9 @@ void Shamir<T>::reset()
for (int i = 0; i < P.num_players(); i++)
resharing->reset(i);
for (int i = 0; i < n_mul_players; i++)
resharing->add_sender(i);
}
template<class T>
@@ -92,37 +93,27 @@ template<class T>
void Shamir<T>::prepare_mul(const T& x, const T& y, int n)
{
(void) n;
auto add_share = x * y * rec_factor;
if (P.my_num() < n_mul_players)
resharing->add_mine(add_share);
resharing->add_mine(x * y * rec_factor);
}
template<class T>
void Shamir<T>::exchange()
{
vector<bool> senders(P.num_players(), false);
for (int i = 0; i < n_mul_players; i++)
senders[i] = true;
P.send_receive_all(senders, resharing->os, os);
assert(resharing);
resharing->exchange();
}
template<class T>
void Shamir<T>::start_exchange()
{
if (P.my_num() < n_mul_players)
for (int offset = 1; offset < P.num_players(); offset++)
P.send_relative(offset, resharing->os[P.get_player(offset)]);
resharing->start_exchange();
}
template<class T>
void Shamir<T>::stop_exchange()
{
for (int offset = 1; offset < P.num_players(); offset++)
{
int receive_from = P.get_player(-offset);
if (receive_from < n_mul_players)
P.receive_player(receive_from, os[receive_from]);
}
resharing->stop_exchange();
}
template<class T>
@@ -136,15 +127,8 @@ template<class T>
T Shamir<T>::finalize(int n_relevant_players)
{
ShamirShare<U> res = U(0);
if (P.my_num() < n_relevant_players)
res = resharing->finalize_mine();
for (int i = 0; i < n_relevant_players; i++)
if (i != P.my_num())
{
T tmp;
resharing->finalize_other(i, tmp, os[i]);
res += tmp;
}
res += resharing->finalize(i);
return res;
}
@@ -259,7 +243,7 @@ vector<T> Shamir<T>::get_randoms(PRNG& G, int t)
input.reset_all(P);
int buffer_size = OnlineOptions::singleton.batch_size;
for (int i = 0; i < buffer_size; i += hyper.size())
input.add_mine(G.get<U>());
input.add_from_all(G.get<U>());
input.exchange();
vector<U> inputs;
vector<T> random;

View File

@@ -21,10 +21,11 @@ class IndividualInput : public PrepLessInput<T>
protected:
Player& P;
octetStreams os;
vector<bool> senders;
public:
IndividualInput(SubProcessor<T>* proc, Player& P) :
PrepLessInput<T>(proc), P(P)
PrepLessInput<T>(proc), P(P), senders(P.num_players())
{
this->reset_all(P);
}
@@ -34,10 +35,14 @@ public:
}
void reset(int player);
void add_sender(int player);
void add_other(int player, int n_bits = -1);
void send_mine();
void exchange();
void finalize_other(int player, T& target, octetStream& o, int n_bits = -1);
void start_exchange();
void stop_exchange();
};
/**

View File

@@ -20,6 +20,8 @@ void IndividualInput<U>::reset(int player)
this->i_share = 0;
os.reset(P);
}
senders[player] = false;
}
template<class T>
@@ -68,12 +70,20 @@ void ShamirInput<T>::add_mine(const typename T::open_type& input, int n_bits)
else
x.pack(this->os[i]);
}
this->senders[P.my_num()] = true;
}
template<class U>
void IndividualInput<U>::add_sender(int player)
{
senders[player] = true;
}
template<class U>
void IndividualInput<U>::add_other(int player, int)
{
(void) player;
add_sender(player);
}
template<class U>
@@ -87,7 +97,26 @@ void IndividualInput<U>::send_mine()
template<class T>
void IndividualInput<T>::exchange()
{
P.send_receive_all(os, InputBase<T>::os);
P.send_receive_all(senders, os, InputBase<T>::os);
}
template<class T>
void IndividualInput<T>::start_exchange()
{
if (senders[P.my_num()])
for (int offset = 1; offset < P.num_players(); offset++)
P.send_relative(offset, os[P.get_player(offset)]);
}
template<class T>
void IndividualInput<T>::stop_exchange()
{
for (int offset = 1; offset < P.num_players(); offset++)
{
int receive_from = P.get_player(-offset);
if (senders[receive_from])
P.receive_player(receive_from, InputBase<T>::os[receive_from]);
}
}
template<class T>

Some files were not shown because too many files have changed in this diff Show More