Semi-honest computation based on threshold semi-homomorphic encryption.

This commit is contained in:
Marcel Keller
2022-02-17 13:21:19 +11:00
parent 61d40b7d83
commit 0f7020d791
129 changed files with 1973 additions and 539 deletions

View File

@@ -1,6 +1,17 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
## 0.2.9 (Jan 11, 2021) ## 0.3.0 (Feb 17, 2022)
- Semi-honest computation based on threshold semi-homomorphic encryption
- Batch normalization backward propagation
- AlexNet for CIFAR-10
- Specific private output protocols
- Semi-honest additive secret sharing without communication
- Sending of personal values
- Allow overwriting of persistence files
- Protocol signature in persistence files
## 0.2.9 (Jan 11, 2022)
- Disassembler - Disassembler
- Run-time parameter for probabilistic truncation error - Run-time parameter for probabilistic truncation error

1
CONFIG
View File

@@ -42,6 +42,7 @@ else
AVX_OT = 1 AVX_OT = 1
endif endif
else else
ARCH =
AVX_OT = 0 AVX_OT = 0
endif endif

View File

@@ -497,7 +497,7 @@ class movsb(NonVectorInstruction):
code = opcodes['MOVSB'] code = opcodes['MOVSB']
arg_format = ['sbw','sb'] arg_format = ['sbw','sb']
class trans(base.VarArgsInstruction): class trans(base.VarArgsInstruction, base.DynFormatInstruction):
""" Secret bit register vector transpose. The first destination vector """ Secret bit register vector transpose. The first destination vector
will contain the least significant bits of all source vectors etc. will contain the least significant bits of all source vectors etc.
@@ -511,10 +511,22 @@ class trans(base.VarArgsInstruction):
code = opcodes['TRANS'] code = opcodes['TRANS']
is_vec = lambda self: True is_vec = lambda self: True
def __init__(self, *args): def __init__(self, *args):
self.arg_format = ['int'] + ['sbw'] * args[0] + \
['sb'] * (len(args) - 1 - args[0])
super(trans, self).__init__(*args) super(trans, self).__init__(*args)
@classmethod
def dynamic_arg_format(cls, args):
yield 'int'
n = next(args)
for i in range(n):
yield 'sbw'
next(args)
while True:
try:
yield 'sb'
next(args)
except StopIteration:
break
class bitb(NonVectorInstruction): class bitb(NonVectorInstruction):
""" Copy fresh secret random bit to secret bit register. """ Copy fresh secret random bit to secret bit register.
@@ -560,7 +572,7 @@ class inputb(base.DoNotEliminateInstruction, base.VarArgsInstruction):
req_node.increment(('bit', 'input', self.args[i]), self.args[i + 1]) req_node.increment(('bit', 'input', self.args[i]), self.args[i + 1])
class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction, class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction,
base.Mergeable): base.Mergeable, base.DynFormatInstruction):
""" Copy private input to secret bit registers bit by bit. The input is """ Copy private input to secret bit registers bit by bit. The input is
read as floating-point number, multiplied by a power of two, rounded to an read as floating-point number, multiplied by a power of two, rounded to an
integer, and then decomposed into bits. integer, and then decomposed into bits.
@@ -577,11 +589,18 @@ class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction,
code = opcodes['INPUTBVEC'] code = opcodes['INPUTBVEC']
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.arg_format = []
for x in self.get_arg_tuples(args):
self.arg_format += ['int', 'int', 'p'] + ['sbw'] * (x[0] - 3)
super(inputbvec, self).__init__(*args, **kwargs) super(inputbvec, self).__init__(*args, **kwargs)
@classmethod
def dynamic_arg_format(cls, args):
yield 'int'
for i, n in cls.bases(args):
yield 'int'
yield 'p'
for j in range(n - 3):
yield 'sbw'
yield 'int'
@staticmethod @staticmethod
def get_arg_tuples(args): def get_arg_tuples(args):
i = 0 i = 0
@@ -590,10 +609,6 @@ class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction,
i += args[i] i += args[i]
assert i == len(args) assert i == len(args)
def merge(self, other):
self.args += other.args
self.arg_format += other.arg_format
def add_usage(self, req_node): def add_usage(self, req_node):
for x in self.get_arg_tuples(self.args): for x in self.get_arg_tuples(self.args):
req_node.increment(('bit', 'input', x[2]), x[0] - 3) req_node.increment(('bit', 'input', x[2]), x[0] - 3)

View File

@@ -41,7 +41,7 @@ class bits(Tape.Register, _structure, _bit):
return cls.types[length] return cls.types[length]
@classmethod @classmethod
def conv(cls, other): def conv(cls, other):
if isinstance(other, cls): if isinstance(other, cls) and cls.n == other.n:
return other return other
elif isinstance(other, MemValue): elif isinstance(other, MemValue):
return cls.conv(other.read()) return cls.conv(other.read())
@@ -246,14 +246,20 @@ class cbits(bits):
assert n == res.n assert n == res.n
assert n == other.size assert n == other.size
cls.conv_cint_vec(cint(other, size=other.size), res) cls.conv_cint_vec(cint(other, size=other.size), res)
@classmethod
def conv(cls, other):
if isinstance(other, cbits) and cls.n != None and \
cls.n // cls.unit == other.n // cls.unit:
return other
else:
return super(cbits, cls).conv(other)
types = {} types = {}
def load_int(self, value): def load_int(self, value):
if self.n <= 64: n_limbs = math.ceil(self.n / self.unit)
tmp = regint(value) tmp = regint(size=n_limbs)
elif value == self.long_one(): for i in range(n_limbs):
tmp = cint(1, size=self.n) tmp[i].load_int(value % 2 ** self.unit)
else: value >>= self.unit
raise CompilerError('loading long integers to cbits not supported')
self.load_other(tmp) self.load_other(tmp)
def store_in_dynamic_mem(self, address): def store_in_dynamic_mem(self, address):
inst.stmsdci(self, cbits.conv(address)) inst.stmsdci(self, cbits.conv(address))
@@ -1163,14 +1169,14 @@ class cbitfix(object):
@classmethod @classmethod
def _new(cls, value): def _new(cls, value):
res = cls() res = cls()
if cls.k < value.unit:
bits = value.bit_decompose(cls.k)
sign = bits[-1]
value += (sign << (cls.k)) * -1
res.v = value res.v = value
return res return res
def output(self): def output(self):
v = self.v v = self.v
if self.k < v.unit:
bits = self.v.bit_decompose(self.k)
sign = bits[-1]
v += (sign << (self.k)) * -1
inst.print_float_plainb(v, cbits.get_type(32)(-self.f), cbits(0), inst.print_float_plainb(v, cbits.get_type(32)(-self.f), cbits(0),
cbits(0), cbits(0)) cbits(0), cbits(0))

View File

@@ -403,6 +403,20 @@ class Merger:
add_edge(last_input[t][1], n) add_edge(last_input[t][1], n)
last_input[t][0] = n last_input[t][0] = n
def keep_text_order(inst, n):
if inst.get_players() is None:
# switch
for x in list(last_input.keys()):
if isinstance(x, int):
add_edge(last_input[x][0], n)
del last_input[x]
keep_merged_order(instr, n, None)
elif last_input[None][0] is not None:
keep_merged_order(instr, n, None)
else:
for player in inst.get_players():
keep_merged_order(instr, n, player)
for n,instr in enumerate(block.instructions): for n,instr in enumerate(block.instructions):
outputs,inputs = instr.get_def(), instr.get_used() outputs,inputs = instr.get_def(), instr.get_used()
@@ -427,7 +441,7 @@ class Merger:
# will be merged # will be merged
if isinstance(instr, TextInputInstruction): if isinstance(instr, TextInputInstruction):
keep_merged_order(instr, n, TextInputInstruction) keep_text_order(instr, n)
elif isinstance(instr, RawInputInstruction): elif isinstance(instr, RawInputInstruction):
keep_merged_order(instr, n, RawInputInstruction) keep_merged_order(instr, n, RawInputInstruction)
@@ -479,10 +493,6 @@ class Merger:
last_print_str = n last_print_str = n
elif isinstance(instr, PublicFileIOInstruction): elif isinstance(instr, PublicFileIOInstruction):
keep_order(instr, n, instr.__class__) keep_order(instr, n, instr.__class__)
elif isinstance(instr, startprivateoutput_class):
keep_order(instr, n, startprivateoutput_class, 2)
elif isinstance(instr, stopprivateoutput_class):
keep_order(instr, n, stopprivateoutput_class, 2)
elif isinstance(instr, prep_class): elif isinstance(instr, prep_class):
keep_order(instr, n, instr.args[0]) keep_order(instr, n, instr.args[0])
elif isinstance(instr, StackInstruction): elif isinstance(instr, StackInstruction):

View File

@@ -421,6 +421,10 @@ class use_matmul(base.Instruction):
code = base.opcodes['USE_MATMUL'] code = base.opcodes['USE_MATMUL']
arg_format = ['int','int','int','int'] arg_format = ['int','int','int','int']
@classmethod
def get_usage(cls, args):
return {('matmul', tuple(arg.i for arg in args[:3])): args[3].i}
class run_tape(base.Instruction): class run_tape(base.Instruction):
""" Start tape/bytecode file in another thread. """ Start tape/bytecode file in another thread.
@@ -1229,15 +1233,20 @@ class inverse(base.DataInstruction):
@base.gf2n @base.gf2n
@base.vectorize @base.vectorize
class inputmask(base.Instruction): class inputmask(base.Instruction):
r""" Load secret $s_i$ with the next input mask for player $p$ and """ Store fresh random input mask(s) in secret register (vector) and clear
write the mask on player $p$'s private output. """ register (vector) of the relevant player.
:param: mask (sint)
:param: mask (cint, player only)
:param: player (int)
"""
__slots__ = [] __slots__ = []
code = base.opcodes['INPUTMASK'] code = base.opcodes['INPUTMASK']
arg_format = ['sw', 'p'] arg_format = ['sw', 'cw', 'p']
field_type = 'modp' field_type = 'modp'
def add_usage(self, req_node): def add_usage(self, req_node):
req_node.increment((self.field_type, 'input', self.args[1]), \ req_node.increment((self.field_type, 'input', self.args[2]), \
self.get_size()) self.get_size())
@base.vectorize @base.vectorize
@@ -1293,10 +1302,8 @@ class asm_input(base.TextInputInstruction):
arg_format = tools.cycle(['sw', 'p']) arg_format = tools.cycle(['sw', 'p'])
field_type = 'modp' field_type = 'modp'
def add_usage(self, req_node): def get_players(self):
for player in self.args[1::2]: return self.args[1::2]
req_node.increment((self.field_type, 'input', player), \
self.get_size())
@base.vectorize @base.vectorize
class inputfix(base.TextInputInstruction): class inputfix(base.TextInputInstruction):
@@ -1305,10 +1312,8 @@ class inputfix(base.TextInputInstruction):
arg_format = tools.cycle(['sw', 'int', 'p']) arg_format = tools.cycle(['sw', 'int', 'p'])
field_type = 'modp' field_type = 'modp'
def add_usage(self, req_node): def get_players(self):
for player in self.args[2::3]: return self.args[2::3]
req_node.increment((self.field_type, 'input', player), \
self.get_size())
@base.vectorize @base.vectorize
class inputfloat(base.TextInputInstruction): class inputfloat(base.TextInputInstruction):
@@ -1322,7 +1327,7 @@ class inputfloat(base.TextInputInstruction):
req_node.increment((self.field_type, 'input', player), \ req_node.increment((self.field_type, 'input', player), \
4 * self.get_size()) 4 * self.get_size())
class inputmixed_base(base.TextInputInstruction): class inputmixed_base(base.TextInputInstruction, base.DynFormatInstruction):
__slots__ = [] __slots__ = []
field_type = 'modp' field_type = 'modp'
# the following has to match TYPE: (N_DEST, N_PARAM) # the following has to match TYPE: (N_DEST, N_PARAM)
@@ -1341,22 +1346,30 @@ class inputmixed_base(base.TextInputInstruction):
type_id = self.type_ids[name] type_id = self.type_ids[name]
super(inputmixed_base, self).__init__(type_id, *args) super(inputmixed_base, self).__init__(type_id, *args)
@property @classmethod
def arg_format(self): def dynamic_arg_format(self, args):
for i in self.bases():
t = self.args[i]
yield 'int' yield 'int'
for i, t in self.bases(iter(args)):
for j in range(self.types[t][0]): for j in range(self.types[t][0]):
yield 'sw' yield 'sw'
for j in range(self.types[t][1]): for j in range(self.types[t][1]):
yield 'int' yield 'int'
yield self.player_arg_type yield self.player_arg_type
yield 'int'
def bases(self): @classmethod
def bases(self, args):
i = 0 i = 0
while i < len(self.args): while True:
yield i try:
i += sum(self.types[self.args[i]]) + 2 t = next(args)
except StopIteration:
return
yield i, t
n = sum(self.types[t])
i += n + 2
for j in range(n + 1):
next(args)
@base.vectorize @base.vectorize
class inputmixed(inputmixed_base): class inputmixed(inputmixed_base):
@@ -1380,13 +1393,16 @@ class inputmixed(inputmixed_base):
player_arg_type = 'p' player_arg_type = 'p'
def add_usage(self, req_node): def add_usage(self, req_node):
for i in self.bases(): for i, t in self.bases(iter(self.args)):
t = self.args[i]
player = self.args[i + sum(self.types[t]) + 1] player = self.args[i + sum(self.types[t]) + 1]
n_dest = self.types[t][0] n_dest = self.types[t][0]
req_node.increment((self.field_type, 'input', player), \ req_node.increment((self.field_type, 'input', player), \
n_dest * self.get_size()) n_dest * self.get_size())
def get_players(self):
for i, t in self.bases(iter(self.args)):
yield self.args[i + sum(self.types[t]) + 1]
@base.vectorize @base.vectorize
class inputmixedreg(inputmixed_base): class inputmixedreg(inputmixed_base):
""" Store private input in secret registers (vectors). The input is """ Store private input in secret registers (vectors). The input is
@@ -1412,6 +1428,9 @@ class inputmixedreg(inputmixed_base):
# player 0 as proxy # player 0 as proxy
req_node.increment((self.field_type, 'input', 0), float('inf')) req_node.increment((self.field_type, 'input', 0), float('inf'))
def get_players(self):
pass
@base.gf2n @base.gf2n
@base.vectorize @base.vectorize
class rawinput(base.RawInputInstruction, base.Mergeable): class rawinput(base.RawInputInstruction, base.Mergeable):
@@ -1433,7 +1452,23 @@ class rawinput(base.RawInputInstruction, base.Mergeable):
req_node.increment((self.field_type, 'input', player), \ req_node.increment((self.field_type, 'input', player), \
self.get_size()) self.get_size())
class inputpersonal(base.Instruction, base.Mergeable): class personal_base(base.Instruction, base.Mergeable):
__slots__ = []
field_type = 'modp'
def __init__(self, *args):
super(personal_base, self).__init__(*args)
for i in range(0, len(args), 4):
assert args[i + 2].size == args[i]
assert args[i + 3].size == args[i]
def add_usage(self, req_node):
for i in range(0, len(self.args), 4):
player = self.args[i + 1]
req_node.increment((self.field_type, 'input', player), \
self.args[i])
class inputpersonal(personal_base):
""" Private input from cint. """ Private input from cint.
:param: vector size (int) :param: vector size (int)
@@ -1445,19 +1480,39 @@ class inputpersonal(base.Instruction, base.Mergeable):
__slots__ = [] __slots__ = []
code = base.opcodes['INPUTPERSONAL'] code = base.opcodes['INPUTPERSONAL']
arg_format = tools.cycle(['int','p','sw','c']) arg_format = tools.cycle(['int','p','sw','c'])
field_type = 'modp'
class privateoutput(personal_base):
""" Private input from cint.
:param: vector size (int)
:param: player (int)
:param: destination (cint)
:param: source (sint)
:param: (repeat from vector size)...
"""
__slots__ = []
code = base.opcodes['PRIVATEOUTPUT']
arg_format = tools.cycle(['int','p','cw','s'])
class sendpersonal(base.Instruction, base.Mergeable):
""" Private input from cint.
:param: vector size (int)
:param: destination player (int)
:param: destination (cint)
:param: source player (int)
:param: source (cint)
:param: (repeat from vector size)...
"""
__slots__ = []
code = base.opcodes['SENDPERSONAL']
arg_format = tools.cycle(['int','p','cw','p','c'])
def __init__(self, *args): def __init__(self, *args):
super(inputpersonal, self).__init__(*args) super(sendpersonal, self).__init__(*args)
for i in range(0, len(args), 4): for i in range(0, len(args), 5):
assert args[i + 2].size == args[i] assert args[i + 2].size == args[i]
assert args[i + 3].size == args[i] assert args[i + 4].size == args[i]
def add_usage(self, req_node):
for i in range(0, len(self.args), 4):
player = self.args[i + 1]
req_node.increment((self.field_type, 'input', player), \
self.args[i])
@base.gf2n @base.gf2n
@base.vectorize @base.vectorize
@@ -1789,27 +1844,6 @@ class floatoutput(base.PublicFileIOInstruction):
code = base.opcodes['FLOATOUTPUT'] code = base.opcodes['FLOATOUTPUT']
arg_format = ['p','c','c','c','c'] arg_format = ['p','c','c','c','c']
@base.gf2n
@base.vectorize
class startprivateoutput(base.Instruction):
r""" Initiate private output to $n$ of $s_j$ via $s_i$. """
__slots__ = []
code = base.opcodes['STARTPRIVATEOUTPUT']
arg_format = ['sw','s','p']
field_type = 'modp'
def add_usage(self, req_node):
req_node.increment((self.field_type, 'input', self.args[2]), \
self.get_size())
@base.gf2n
@base.vectorize
class stopprivateoutput(base.Instruction):
r""" Previously iniated private output to $n$ via $c_i$. """
__slots__ = []
code = base.opcodes['STOPPRIVATEOUTPUT']
arg_format = ['cw','c','p']
@base.vectorize @base.vectorize
class rand(base.Instruction): class rand(base.Instruction):
""" Store insecure random value of specified length in clear integer """ Store insecure random value of specified length in clear integer
@@ -2210,7 +2244,8 @@ class mulrs(base.VarArgsInstruction, base.DataInstruction):
@base.gf2n @base.gf2n
@base.vectorize @base.vectorize
class dotprods(base.VarArgsInstruction, base.DataInstruction): class dotprods(base.VarArgsInstruction, base.DataInstruction,
base.DynFormatInstruction):
""" Dot product of secret registers (vectors). """ Dot product of secret registers (vectors).
Note that the vectorized version works element-wise. Note that the vectorized version works element-wise.
@@ -2238,31 +2273,29 @@ class dotprods(base.VarArgsInstruction, base.DataInstruction):
flat_args += [x, y] flat_args += [x, y]
base.Instruction.__init__(self, *flat_args) base.Instruction.__init__(self, *flat_args)
@property @classmethod
def arg_format(self): def dynamic_arg_format(self, args):
field = 'g' if self.is_gf2n() else '' field = 'g' if self.is_gf2n() else ''
for i in self.bases():
yield 'int' yield 'int'
for i, n in self.bases(args):
yield 's' + field + 'w' yield 's' + field + 'w'
for j in range(self.args[i] - 2): for j in range(n - 2):
yield 's' + field yield 's' + field
yield 'int'
gf2n_arg_format = arg_format @property
def gf2n_arg_format(self):
def bases(self): return self.arg_format()
i = 0
while i < len(self.args):
yield i
i += self.args[i]
def get_repeat(self): def get_repeat(self):
return sum(self.args[i] // 2 for i in self.bases()) * self.get_size() return sum(self.args[i] // 2
for i, n in self.bases(iter(self.args))) * self.get_size()
def get_def(self): def get_def(self):
return [self.args[i + 1] for i in self.bases()] return [self.args[i + 1] for i, n in self.bases(iter(self.args))]
def get_used(self): def get_used(self):
for i in self.bases(): for i, n in self.bases(iter(self.args)):
for reg in self.args[i + 2:i + self.args[i]]: for reg in self.args[i + 2:i + self.args[i]]:
yield reg yield reg

View File

@@ -105,6 +105,7 @@ opcodes = dict(
MATMULSM = 0xAB, MATMULSM = 0xAB,
CONV2DS = 0xAC, CONV2DS = 0xAC,
CHECK = 0xAF, CHECK = 0xAF,
PRIVATEOUTPUT = 0xAD,
# Data access # Data access
TRIPLE = 0x50, TRIPLE = 0x50,
BIT = 0x51, BIT = 0x51,
@@ -128,6 +129,7 @@ opcodes = dict(
INPUTMIXEDREG = 0xF3, INPUTMIXEDREG = 0xF3,
RAWINPUT = 0xF4, RAWINPUT = 0xF4,
INPUTPERSONAL = 0xF5, INPUTPERSONAL = 0xF5,
SENDPERSONAL = 0xF6,
STARTINPUT = 0x61, STARTINPUT = 0x61,
STOPINPUT = 0x62, STOPINPUT = 0x62,
READSOCKETC = 0x63, READSOCKETC = 0x63,
@@ -364,6 +366,7 @@ def gf2n(instruction):
arg_format = copy.deepcopy(instruction_cls.arg_format) arg_format = copy.deepcopy(instruction_cls.arg_format)
reformat(arg_format) reformat(arg_format)
@classmethod
def is_gf2n(self): def is_gf2n(self):
return True return True
@@ -505,8 +508,12 @@ def cisc(function):
for arg in self.args: for arg in self.args:
try: try:
new_regs.append(type(arg)(size=size)) new_regs.append(type(arg)(size=size))
except: except TypeError:
break break
except:
print([call[0][0].size for call in self.calls])
raise
assert len(new_regs) > 1
base = 0 base = 0
for call in self.calls: for call in self.calls:
for new_reg, reg in zip(new_regs[1:], call[0][1:]): for new_reg, reg in zip(new_regs[1:], call[0][1:]):
@@ -854,6 +861,7 @@ class Instruction(object):
def is_vec(self): def is_vec(self):
return False return False
@classmethod
def is_gf2n(self): def is_gf2n(self):
return False return False
@@ -902,6 +910,10 @@ class Instruction(object):
new_args.append(arg) new_args.append(arg)
return new_args return new_args
@staticmethod
def get_usage(args):
return {}
# String version of instruction attempting to replicate encoded version # String version of instruction attempting to replicate encoded version
def __str__(self): def __str__(self):
@@ -949,9 +961,18 @@ class ParsedInstruction:
if name == 'cisc': if name == 'cisc':
arg_format = itertools.chain(['str'], itertools.repeat('int')) arg_format = itertools.chain(['str'], itertools.repeat('int'))
else: else:
arg_format = itertools.repeat('int') def arg_iter():
self.args = [ArgFormats[next(arg_format)](f) i = 0
for i in range(n_args)] while True:
try:
yield self.args[i].i
except AttributeError:
yield None
i += 1
arg_format = t.dynamic_arg_format(arg_iter())
self.args = []
for i in range(n_args):
self.args.append(ArgFormats[next(arg_format)](f))
def __str__(self): def __str__(self):
name = self.type.__name__ name = self.type.__name__
@@ -963,6 +984,9 @@ class ParsedInstruction:
res += ', '.join(str(arg) for arg in self.args) res += ', '.join(str(arg) for arg in self.args)
return res return res
def get_usage(self):
return self.type.get_usage(self.args)
class VarArgsInstruction(Instruction): class VarArgsInstruction(Instruction):
def has_var_args(self): def has_var_args(self):
return True return True
@@ -974,6 +998,26 @@ class VectorInstruction(Instruction):
def get_code(self): def get_code(self):
return super(VectorInstruction, self).get_code(len(self.args[0])) return super(VectorInstruction, self).get_code(len(self.args[0]))
class DynFormatInstruction(Instruction):
__slots__ = []
@property
def arg_format(self):
return self.dynamic_arg_format(iter(self.args))
@classmethod
def bases(self, args):
i = 0
while True:
try:
n = next(args)
except StopIteration:
return
yield i, n
i += n
for j in range(n - 1):
next(args)
### ###
### Basic arithmetic ### Basic arithmetic
### ###
@@ -1072,6 +1116,11 @@ class TextInputInstruction(VarArgsInstruction, DoNotEliminateInstruction):
""" Input from text file or stdin """ """ Input from text file or stdin """
__slots__ = [] __slots__ = []
def add_usage(self, req_node):
for player in self.get_players():
req_node.increment((self.field_type, 'input', player), \
self.get_size())
### ###
### Data access instructions ### Data access instructions
### ###

View File

@@ -223,7 +223,7 @@ def crash(condition=None):
if isinstance(condition, localint): if isinstance(condition, localint):
# allow crash on local values # allow crash on local values
condition = condition._v condition = condition._v
if condition == None: if condition is None:
condition = regint(1) condition = regint(1)
instructions.crash(regint.conv(condition)) instructions.crash(regint.conv(condition))
@@ -284,8 +284,8 @@ def get_arg():
def make_array(l): def make_array(l):
if isinstance(l, program.Tape.Register): if isinstance(l, program.Tape.Register):
res = Array(1, type(l)) res = Array(len(l), type(l))
res[0] = l res[:] = l
else: else:
l = list(l) l = list(l)
res = Array(len(l), type(l[0]) if l else cint) res = Array(len(l), type(l[0]) if l else cint)
@@ -1032,6 +1032,7 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
state = tuplify(initializer()) state = tuplify(initializer())
k = 0 k = 0
block = get_block() block = get_block()
assert not isinstance(n_loops, int) or n_loops > 0
pre = copy.copy(loop_body.__globals__) pre = copy.copy(loop_body.__globals__)
while (not util.is_constant(n_loops) or k < n_loops) \ while (not util.is_constant(n_loops) or k < n_loops) \
and (len(get_block()) < budget or k == 0) \ and (len(get_block()) < budget or k == 0) \
@@ -1211,7 +1212,13 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
if t != regint: if t != regint:
raise CompilerError('Not implemented for other than regint') raise CompilerError('Not implemented for other than regint')
args = Matrix(n_threads, 2 + thread_mem_req.get(regint, 0), 'ci') args = Matrix(n_threads, 2 + thread_mem_req.get(regint, 0), 'ci')
state = tuple(initializer()) state = initializer()
if len(state) == 0:
state_type = cint
elif isinstance(state, (tuple, list)):
state_type = type(state[0])
else:
state_type = type(state)
def f(inc): def f(inc):
base = args[get_arg()][0] base = args[get_arg()][0]
if not util.is_constant(thread_rounds): if not util.is_constant(thread_rounds):
@@ -1224,8 +1231,7 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
if thread_mem_req: if thread_mem_req:
thread_mem = Array(thread_mem_req[regint], regint, \ thread_mem = Array(thread_mem_req[regint], regint, \
args[get_arg()].address + 2) args[get_arg()].address + 2)
mem_state = Array(len(state), type(state[0]) \ mem_state = Array(len(state), state_type, args[get_arg()][1])
if state else cint, args[get_arg()][1])
@map_reduce_single(n_parallel, thread_rounds + inc, \ @map_reduce_single(n_parallel, thread_rounds + inc, \
initializer, reducer, mem_state) initializer, reducer, mem_state)
def f(i): def f(i):
@@ -1257,14 +1263,14 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
threads = prog.run_tapes(thread_args) threads = prog.run_tapes(thread_args)
for thread in threads: for thread in threads:
prog.join_tape(thread) prog.join_tape(thread)
if state: if len(state):
if thread_rounds: if thread_rounds:
for i in range(n_threads - remainder): for i in range(n_threads - remainder):
state = reducer(Array(len(state), type(state[0]), \ state = reducer(Array(len(state), state_type, \
args[remainder + i][1]), state) args[remainder + i][1]), state)
if remainder: if remainder:
for i in range(remainder): for i in range(remainder):
state = reducer(Array(len(state), type(state[0]).reg_type, \ state = reducer(Array(len(state), state_type, \
args[i][1]), state) args[i][1]), state)
def returner(): def returner():
return untuplify(state) return untuplify(state)
@@ -1300,6 +1306,39 @@ def map_sum_opt(n_threads, n_loops, types):
""" """
return map_sum(n_threads, None, n_loops, len(types), types) return map_sum(n_threads, None, n_loops, len(types), types)
def map_sum_simple(n_threads, n_loops, type, size):
""" Vectorized multi-threaded sum reduction. The following computes a
100 sums of ten squares in three threads::
@map_sum_simple(3, 10, sint, 100)
def summer(i):
return sint(regint.inc(100, i, 0)) ** 2
result = summer()
:param n_threads: number of threads (int)
:param n_loops: number of loop runs (regint/cint/int)
:param type: return type, must match the return statement
in the loop
:param size: vector size, must match the return statement
in the loop
"""
initializer = lambda: type(0, size=size)
def summer(*args):
assert len(args) == 2
args = list(args)
for i in (0, 1):
if isinstance(args[i], tuple):
assert len(args[i]) == 1
args[i] = args[i][0]
for i in (0, 1):
assert len(args[i]) == size
if isinstance(args[i], Array):
args[i] = args[i][:]
return args[0] + args[1]
return map_reduce(n_threads, 1, n_loops, initializer, summer)
def tree_reduce_multithread(n_threads, function, vector): def tree_reduce_multithread(n_threads, function, vector):
inputs = vector.Array(len(vector)) inputs = vector.Array(len(vector))
inputs.assign_vector(vector) inputs.assign_vector(vector)

View File

@@ -223,6 +223,7 @@ class Layer:
thetas = lambda self: () thetas = lambda self: ()
debug_output = False debug_output = False
back_batch_size = 128 back_batch_size = 128
print_random_update = False
@property @property
def shape(self): def shape(self):
@@ -254,6 +255,9 @@ class Layer:
def __str__(self): def __str__(self):
return type(self).__name__ + str(self._Y.sizes) return type(self).__name__ + str(self._Y.sizes)
def __repr__(self):
return '%s(%s)' % (type(self).__name__, self.Y.sizes)
class NoVariableLayer(Layer): class NoVariableLayer(Layer):
input_from = lambda *args, **kwargs: None input_from = lambda *args, **kwargs: None
output_weights = lambda *args: None output_weights = lambda *args: None
@@ -459,6 +463,10 @@ class MultiOutput(MultiOutputBase):
self.debug = debug self.debug = debug
self.true_X = sfix.Array(N) self.true_X = sfix.Array(N)
def __repr__(self):
return '%s(%s, %s, approx=%s)' % \
(type(self).__name__, self.N, self.d_out, self.approx)
def _forward(self, batch): def _forward(self, batch):
N = len(batch) N = len(batch)
d_out = self.X.sizes[1] d_out = self.X.sizes[1]
@@ -609,10 +617,11 @@ class DenseBase(Layer):
N = len(batch) N = len(batch)
tmp = Matrix(self.d_in, self.d_out, unreduced_sfix) tmp = Matrix(self.d_in, self.d_out, unreduced_sfix)
A = sfix.Matrix(N, self.d_out, address=f_schur_Y.address)
B = sfix.Matrix(self.N, self.d_in, address=self.X.address)
@multithread(self.n_threads, self.d_in) @multithread(self.n_threads, self.d_in)
def _(base, size): def _(base, size):
A = sfix.Matrix(self.N, self.d_out, address=f_schur_Y.address)
B = sfix.Matrix(self.N, self.d_in, address=self.X.address)
mp = B.direct_trans_mul(A, reduce=False, mp = B.direct_trans_mul(A, reduce=False,
indices=(regint.inc(size, base), indices=(regint.inc(size, base),
batch.get_vector(), batch.get_vector(),
@@ -622,16 +631,24 @@ class DenseBase(Layer):
progress('nabla W (matmul)') progress('nabla W (matmul)')
if self.d_in * self.d_out < 200000: @multithread(self.n_threads, self.d_in * self.d_out,
print('reduce at once') max_size=get_program().budget)
@multithread(self.n_threads, self.d_in * self.d_out)
def _(base, size): def _(base, size):
self.nabla_W.assign_vector( self.nabla_W.assign_vector(
tmp.get_vector(base, size).reduce_after_mul(), base=base) tmp.get_vector(base, size).reduce_after_mul(), base=base)
else:
@for_range_opt(self.d_in) if self.print_random_update:
def _(i): print_ln('backward %s', self)
self.nabla_W[i] = tmp[i].get_vector().reduce_after_mul() i = regint.get_random(64) % self.d_in
j = regint.get_random(64) % self.d_out
print_ln('%s at (%s, %s): before=%s after=%s A=%s B=%s',
str(self.nabla_W), i, j, tmp[i][j].v.reveal(),
self.nabla_W[i][j].reveal(),
A.get_column(j).reveal(),
B.get_column_by_row_indices(
batch.get_vector(), i).reveal())
print_ln('batch=%s B=%s', batch,
[self.X[bi][0][i].reveal() for bi in batch])
progress('nabla W') progress('nabla W')
@@ -699,6 +716,7 @@ class Dense(DenseBase):
self.d_in = d_in self.d_in = d_in
self.d_out = d_out self.d_out = d_out
self.d = d self.d = d
self.activation = activation
self.X = MultiArray([N, d, d_in], sfix) self.X = MultiArray([N, d, d_in], sfix)
self.Y = MultiArray([N, d, d_out], sfix) self.Y = MultiArray([N, d, d_out], sfix)
@@ -721,12 +739,17 @@ class Dense(DenseBase):
else: else:
self.f_input = self.Y self.f_input = self.Y
def __repr__(self):
return '%s(%s, %s, %s, activation=%s)' % \
(type(self).__name__, self.N, self.d_in,
self.d_out, repr(self.activation))
def reset(self): def reset(self):
d_in = self.d_in d_in = self.d_in
d_out = self.d_out d_out = self.d_out
r = math.sqrt(6.0 / (d_in + d_out)) r = math.sqrt(6.0 / (d_in + d_out))
print('Initializing dense weights in [%f,%f]' % (-r, r)) print('Initializing dense weights in [%f,%f]' % (-r, r))
self.W.assign_vector(sfix.get_random(-r, r, size=self.W.total_size())) self.W.randomize(-r, r)
self.b.assign_all(0) self.b.assign_all(0)
def input_from(self, player, raw=False): def input_from(self, player, raw=False):
@@ -820,6 +843,12 @@ class Dense(DenseBase):
regint.inc(self.d_in))), regint.inc(self.d_in))),
base) base)
if self.print_random_update:
print_ln('backward %s', self)
index = regint.get_random(64) % self.nabla_X.total_size()
print_ln('%s nabla_X at %s: %s', str(self.nabla_X),
index, self.nabla_X.to_array()[index].reveal())
progress('nabla X') progress('nabla X')
self.backward_params(f_schur_Y, batch=batch) self.backward_params(f_schur_Y, batch=batch)
@@ -890,6 +919,10 @@ class Dropout(NoVariableLayer):
self.alpha = alpha self.alpha = alpha
self.B = MultiArray([N, d1, d2], sint) self.B = MultiArray([N, d1, d2], sint)
def __repr__(self):
return '%s(%s, %s, alpha=%s)' % \
(type(self).__name__, self.N, self.d1, self.alpha)
def forward(self, batch, training=False): def forward(self, batch, training=False):
if training: if training:
n_bits = -math.log(self.alpha, 2) n_bits = -math.log(self.alpha, 2)
@@ -1022,6 +1055,7 @@ class MaxPool(NoVariableLayer):
def __init__(self, shape, strides=(1, 2, 2, 1), ksize=(1, 2, 2, 1), def __init__(self, shape, strides=(1, 2, 2, 1), ksize=(1, 2, 2, 1),
padding='VALID'): padding='VALID'):
assert len(shape) == 4 assert len(shape) == 4
assert min(shape) > 0, shape
for x in strides, ksize: for x in strides, ksize:
for i in 0, 3: for i in 0, 3:
assert x[i] == 1 assert x[i] == 1
@@ -1033,12 +1067,18 @@ class MaxPool(NoVariableLayer):
self.Y = Tensor(output_shape, sfix) self.Y = Tensor(output_shape, sfix)
self.strides = strides self.strides = strides
self.ksize = ksize self.ksize = ksize
self.padding = padding
self.nabla_X = Tensor(shape, sfix) self.nabla_X = Tensor(shape, sfix)
self.nabla_Y = Tensor(output_shape, sfix) self.nabla_Y = Tensor(output_shape, sfix)
self.N = shape[0] self.N = shape[0]
self.comparisons = MultiArray([self.N, self.X.sizes[3], self.comparisons = MultiArray([self.N, self.X.sizes[3],
ksize[1] * ksize[2]], sint) ksize[1] * ksize[2]], sint)
def __repr__(self):
return '%s(%s, strides=%s, ksize=%s, padding=%s)' % \
(type(self).__name__, self.X.sizes, self.strides,
self.ksize, self.padding)
def _forward(self, batch): def _forward(self, batch):
def process(pool, bi, k, i, j): def process(pool, bi, k, i, j):
def m(a, b): def m(a, b):
@@ -1165,7 +1205,7 @@ class Add(NoVariableLayer):
self.Y[batch[0]].assign_vector(tmp, base) self.Y[batch[0]].assign_vector(tmp, base)
class FusedBatchNorm(Layer): class FusedBatchNorm(Layer):
""" Fixed-point fused batch normalization layer. """ Fixed-point fused batch normalization layer (inference only).
:param shape: input/output shape (tuple/list of four int) :param shape: input/output shape (tuple/list of four int)
""" """
@@ -1192,6 +1232,153 @@ class FusedBatchNorm(Layer):
self.X[batch[0]][i][j].get_vector() * self.weights.get_vector() self.X[batch[0]][i][j].get_vector() * self.weights.get_vector()
+ self.bias.get_vector()) + self.bias.get_vector())
class BatchNorm(Layer):
""" Fixed-point batch normalization layer.
:param shape: input/output shape (tuple/list of four int)
:param approx: use approximate square root
"""
thetas = lambda self: (self.weights, self.bias)
nablas = lambda self: (self.nabla_weights, self.nabla_bias)
def __init__(self, shape, approx=True, args=None):
assert len(shape) in (2, 3, 4)
if len(shape) == 4:
shape = [shape[0], shape[1] * shape[2], shape[3]]
elif len(shape) == 2:
shape = [shape[0], 1, shape[1]]
tensors = (Tensor(shape, sfix) for i in range(4))
self.X, self.Y, self.nabla_X, self.nabla_Y = tensors
arrays = (sfix.Array(shape[2]) for i in range(4))
self.var, self.mu, self.weights, self.bias = arrays
arrays = (sfix.Array(shape[2]) for i in range(4))
self.mu_hat, self.var_hat, self.nabla_weights, self.nabla_bias = arrays
self.epsilon = 2 ** (-sfix.f + 1)
self.momentum = 0.1
if args != None:
approx = 'precisebn' not in args
self.approx = approx
if approx:
print('Approximate square root inverse in batch normalization')
self.InvertSqrt = mpc_math.InvertSqrt
else:
print('Precise square root inverse in batch normalization')
self.InvertSqrt = lambda x: 1 / mpc_math.sqrt(x)
def __repr__(self):
return '%s(%s, approx=%s)' % \
(type(self).__name__, self.X.sizes, self.approx)
def reset(self):
self.bias.assign_all(0)
self.weights.assign_all(1)
self.mu_hat.assign_all(0)
self.var_hat.assign_all(0)
def _output(self, batch, mu, var):
factor = sfix.Array(len(mu))
factor[:] = self.InvertSqrt(var[:] + self.epsilon)
@for_range_opt_multithread(self.n_threads,
[len(batch), self.X.sizes[1]])
def _(i, j):
tmp = self.weights[:] * (self.X[i][j][:] - self.mu[:]) * factor[:]
self.Y[i][j][:] = self.bias[:] + tmp
def forward(self, batch, training=False):
if training:
d = self.X.sizes[1]
d_in = self.X.sizes[2]
s = sfix.Array(d_in)
@map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in)
def _(i, j):
return (self.X[batch[i]][j].get_vector())
s.assign(_())
@multithread(self.n_threads, d_in)
def _(base, size):
self.mu.assign_vector(
s.get_vector(base, size) / (len(batch) * d), base)
@map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in)
def _(i, j):
item = self.X[batch[i]][j].get_vector()
return ((item - self.mu[:]) ** 2)
self.var.assign(_())
@multithread(self.n_threads, d_in)
def _(base, size):
self.var.assign_vector(
self.var.get_vector(base, size) / (len(batch) * d - 1),
base)
for x, y, in (self.mu_hat, self.mu), (self.var_hat, self.var):
x[:] = self.momentum * y[:] + (1 - self.momentum) * x[:]
self._output(batch, self.mu, self.var)
if self.print_random_update:
i = regint.get_random(64) % len(batch)
j = regint.get_random(64) % d
k = regint.get_random(64) % d_in
for x in self.mu, self.var:
print_ln('%s at %s: %s', str(x), k, x[k].reveal())
print_ln('%s at (%s, %s, %s): in=%s out=%s',
str(self.Y), i, j, k, self.X[i][j][k].reveal(),
self.Y[i][j][k].reveal())
else:
self._output(batch, self.mu_hat, self.var_hat)
def backward(self, batch, compute_nabla_X=True):
factor = Array.create_from(
self.InvertSqrt(self.var[:] + self.epsilon))
mynYf = self.X.same_shape()
gamnY = self.X.same_shape()
gamnYd = self.X.same_shape()
nYdf = self.X.same_shape()
d = self.X.sizes[1]
d_in = self.X.sizes[2]
@for_range_opt_multithread(self.n_threads, [len(batch), d])
def _(i, j):
tmp = self.weights[:] * self.nabla_Y[i][j][:]
gamnY[i][j] = tmp
gamnYd[i][j] = tmp * (self.X[i][j][:] - self.mu[:])
mynYf[i][j] = tmp * factor[:]
nYdf[i][j] = self.nabla_Y[i][j][:] * \
(self.X[i][j][:] - self.mu[:]) * factor[:]
@map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in)
def _(i, j):
return (self.nabla_Y[i][j][:])
self.nabla_bias.assign(_())
@map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in)
def _(i, j):
return (nYdf[i][j])
self.nabla_weights.assign(_())
factor3 = Array.create_from(factor[:] ** 3)
@map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in)
def _(i, j):
return (mynYf[i][j])
s1 = Array.create_from(_())
@multithread(self.n_threads, len(s1))
def _(base, size):
s1.assign_vector(s1.get_vector(base, size) / (len(batch) * d), base)
@map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in)
def _(i, j):
return (gamnYd[i][j][:] * factor3[:])
s2 = Array.create_from(_())
@multithread(self.n_threads, len(s2))
def _(base, size):
s2.assign_vector(
s2.get_vector(base, size) / (len(batch) * d - 1), base)
@for_range_opt_multithread(self.n_threads, [len(batch), d])
def _(i, j):
self.nabla_X[i][j][:] = mynYf[i][j][:] \
- s1[:] - (self.X[i][j][:] - self.mu[:]) * s2[:]
if self.print_random_update:
print_ln('backward %s', self)
i = regint.get_random(64) % len(batch)
j = regint.get_random(64) % d
k = regint.get_random(64) % d_in
for x in self.nabla_bias, self.nabla_weights:
print_ln('%s at %s: %s', str(x), k, x[k].reveal())
print_ln('%s at (%s, %s, %s): in=%s out=%s', str(self.Y), i, j, k,
self.nabla_Y[i][j][k].reveal(),
self.nabla_X[i][j][k].reveal())
class QuantBase(object): class QuantBase(object):
bias_before_reduction = True bias_before_reduction = True
@@ -1298,6 +1485,8 @@ class ConvBase(BaseLayer):
self.padding.append(pad_total // 2) self.padding.append(pad_total // 2)
elif padding == 'VALID': elif padding == 'VALID':
self.padding = [0, 0] self.padding = [0, 0]
elif isinstance(padding, int):
self.padding = [padding, padding]
else: else:
self.padding = padding self.padding = padding
@@ -1323,6 +1512,12 @@ class ConvBase(BaseLayer):
assert(len(output_shape) == 4) assert(len(output_shape) == 4)
assert(len(weight_shape) == 4) assert(len(weight_shape) == 4)
def __repr__(self):
return '%s(%s, %s, %s, %s, %s, padding=%s, tf_weight_format=%s)' % \
(type(self).__name__, self.X.sizes, self.weight_shape,
self.bias_shape, self.Y.sizes, self.stride, repr(self.padding),
self.tf_weight_format)
def input_from(self, player, raw=False): def input_from(self, player, raw=False):
self.input_params_from(player) self.input_params_from(player)
self.weights.input_from(player, budget=100000, raw=raw) self.weights.input_from(player, budget=100000, raw=raw)
@@ -1545,20 +1740,20 @@ class FixConv2d(Conv2d, FixBase):
self.nabla_weights.assign_vector_by_indices(reduced, j, None, None, i) self.nabla_weights.assign_vector_by_indices(reduced, j, None, None, i)
if compute_nabla_X: if compute_nabla_X:
assert tuple(self.padding) == (0, 0)
assert tuple(self.stride) == (1, 1) assert tuple(self.stride) == (1, 1)
reverse_weights = MultiArray( reverse_weights = MultiArray(
[n_channels_in, weights_h, weights_w, n_channels_out], sfix) [n_channels_in, weights_h, weights_w, n_channels_out], sfix)
@for_range(n_channels_out) @for_range_opt_multithread(self.n_threads, n_channels_in)
def _(i): def _(l):
@for_range(weights_h) @for_range(weights_h)
def _(j): def _(j):
@for_range(weights_w) @for_range(weights_w)
def _(k): def _(k):
@for_range(n_channels_in) addresses = regint.inc(n_channels_out,
def _(l): self.weights[0][j][weights_w-k-1].get_address(l),
reverse_weights[l][weights_h-j-1][k][i] = \ reduce(operator.mul, self.weights.sizes[1:]))
self.weights[i][j][weights_w-k-1][l] reverse_weights[l][weights_h-j-1][k].assign_vector(
self.weights.value_type.load_mem(addresses))
padded_w = inputs_w + 2 * padding_w padded_w = inputs_w + 2 * padding_w
padded_h = inputs_h + 2 * padding_h padded_h = inputs_h + 2 * padding_h
if padding_h or padding_w: if padding_h or padding_w:
@@ -1579,14 +1774,16 @@ class FixConv2d(Conv2d, FixBase):
unreduced_sfix._new(res).reduce_after_mul(), unreduced_sfix._new(res).reduce_after_mul(),
i, None, None, j) i, None, None, j)
if padding_h or padding_w: if padding_h or padding_w:
@for_range(N) @for_range_opt_multithread(self.n_threads, N)
def _(i): def _(i):
@for_range(inputs_h) @for_range(inputs_h)
def _(j): def _(j):
@for_range(inputs_w) @for_range(inputs_w)
def _(k): def _(k):
jj = j + padding_w
kk = k + padding_w
self.nabla_X[i][j][k].assign_vector( self.nabla_X[i][j][k].assign_vector(
output[i][j][k].get_vector()) output[i][jj][kk].get_vector())
if self.debug_output: if self.debug_output:
@for_range(len(batch)) @for_range(len(batch))
@@ -1806,6 +2003,7 @@ class Optimizer:
self.report_loss = report_loss self.report_loss = report_loss
self.X_by_label = None self.X_by_label = None
self.print_update_average = False self.print_update_average = False
self.print_random_update = False
self.print_losses = False self.print_losses = False
self.print_loss_reduction = False self.print_loss_reduction = False
self.i_epoch = MemValue(0) self.i_epoch = MemValue(0)
@@ -1846,6 +2044,7 @@ class Optimizer:
def batch_for(self, layer, batch): def batch_for(self, layer, batch):
if layer in (self.layers[0], self.layers[-1]): if layer in (self.layers[0], self.layers[-1]):
assert not isinstance(layer, BatchNorm)
return batch return batch
else: else:
batch = regint.Array(len(batch)) batch = regint.Array(len(batch))
@@ -1876,6 +2075,21 @@ class Optimizer:
if i != len(self.layers) - 1 or run_last: if i != len(self.layers) - 1 or run_last:
layer.forward(batch=self.batch_for(layer, batch), layer.forward(batch=self.batch_for(layer, batch),
training=training) training=training)
if self.print_random_update:
print_ln('forward layer %s', layer)
l = min(100, layer.Y[i].total_size())
i = regint.get_random(64) % len(batch)
if l < 100:
j = 0
else:
j = regint.get_random(64) % \
(layer.Y[i].total_size() - l)
print_ln('forward layer %s at (%s, %s): %s', layer, i, j,
layer.Y[i].to_array().get_vector(j, l).reveal())
i = regint.get_random(64) % layer.Y[0].total_size()
print_ln('forward layer %s vertical at %s: %s', layer, i,
[layer.Y[j].to_array()[i].reveal()
for j in range(len(batch))])
if self.time_layers: if self.time_layers:
stop_timer(100 + i) stop_timer(100 + i)
break_point() break_point()
@@ -1979,7 +2193,11 @@ class Optimizer:
label * n) label * n)
self.forward(batch=batch, training=True) self.forward(batch=batch, training=True)
self.backward(batch=batch) self.backward(batch=batch)
if self.time_layers:
start_timer(1000)
self.update(i, batch=batch) self.update(i, batch=batch)
if self.time_layers:
stop_timer(1000)
loss_sum.iadd(self.layers[-1].l) loss_sum.iadd(self.layers[-1].l)
if self.print_loss_reduction: if self.print_loss_reduction:
before = self.layers[-1].average_loss(N) before = self.layers[-1].average_loss(N)
@@ -2070,6 +2288,8 @@ class Optimizer:
if 'nomom' in program.args: if 'nomom' in program.args:
self.momentum = 0 self.momentum = 0
self.print_losses = 'print_losses' in program.args self.print_losses = 'print_losses' in program.args
self.print_random_update = 'print_random_update' in program.args
Layer.print_random_update = self.print_random_update
self.time_layers = 'time_layers' in program.args self.time_layers = 'time_layers' in program.args
self.revealing_correctness = not 'no_acc' in program.args self.revealing_correctness = not 'no_acc' in program.args
self.layers[-1].compute_loss = not 'no_loss' in program.args self.layers[-1].compute_loss = not 'no_loss' in program.args
@@ -2099,6 +2319,16 @@ class Optimizer:
print_ln('loss %s', self.layers[-1].l.reveal()) print_ln('loss %s', self.layers[-1].l.reveal())
self.output_weights() self.output_weights()
return return
if 'bench10' in program.args or 'bench1' in program.args:
n = 1 if 'bench1' in program.args else 10
print('benchmarking %s iterations' % n)
@for_range(n)
def _(i):
batch = Array.create_from(regint.inc(batch_size))
self.forward(batch=batch, training=True)
self.backward(batch=batch)
self.update(0, batch=batch)
return
@for_range(n_runs) @for_range(n_runs)
def _(i): def _(i):
if not acc_first: if not acc_first:
@@ -2115,6 +2345,7 @@ class Optimizer:
cfix(self.n_correct, k=63, f=31) / n_trained, cfix(self.n_correct, k=63, f=31) / n_trained,
self.n_correct, n_trained) self.n_correct, n_trained)
if test_X and test_Y: if test_X and test_Y:
print('use test set')
n_test = len(test_Y) n_test = len(test_Y)
n_correct, loss = self.reveal_correctness(test_X, test_Y, n_correct, loss = self.reveal_correctness(test_X, test_Y,
acc_batch_size) acc_batch_size)
@@ -2211,7 +2442,8 @@ class Adam(Optimizer):
util.max, abs_g.get_vector()) util.max, abs_g.get_vector())
scale = MemValue(sfix._new(library.AppRcr( scale = MemValue(sfix._new(library.AppRcr(
max_g.v, max_g.k, max_g.f, simplex_flag=True))) max_g.v, max_g.k, max_g.f, simplex_flag=True)))
@multithread(self.n_threads, m.total_size()) @multithread(self.n_threads, m.total_size(),
max_size=get_program().budget)
def _(base, size): def _(base, size):
m_part = m.get_vector(base, size) m_part = m.get_vector(base, size)
v_part = v.get_vector(base, size) v_part = v.get_vector(base, size)
@@ -2333,20 +2565,33 @@ class SGD(Optimizer):
print_ln_if((x > limit) + (x < -limit), print_ln_if((x > limit) + (x < -limit),
'theta epoch=%s %s index=%s %s', 'theta epoch=%s %s index=%s %s',
i_epoch.read(), str(theta), i, x) i_epoch.read(), str(theta), i, x)
index = regint.get_random(64) % len(a) if self.print_random_update:
print_ln('%s at %s: nabla=%s update=%s theta=%s', str(theta), index, print_ln('update')
aa[1][index], aa[0][index], aa[2][index]) l = min(100, nabla.total_size())
if l < 100:
index = 0
else:
index = regint.get_random(64) % (nabla.total_size() - l)
print_ln('%s at %s: nabla=%s update=%s theta=%s', str(theta),
index, nabla.to_array().get_vector(index, l).reveal(),
delta_theta.to_array().get_vector(index, l).reveal(),
theta.to_array().get_vector(index, l).reveal())
self.gamma.imul(1 - 10 ** - 6) self.gamma.imul(1 - 10 ** - 6)
def apply_padding(input_shape, kernel_size, strides, padding): def apply_padding(input_shape, kernel_size, strides, padding):
if isinstance(padding, int):
input_shape = [x + 2 * padding for x in input_shape]
padding = 'valid'
if padding == 'valid': if padding == 'valid':
return (input_shape[0] - kernel_size[0] + 1) // strides[0], \ res = (input_shape[0] - kernel_size[0] + 1) // strides[0], \
(input_shape[1] - kernel_size[1] + 1) // strides[1], (input_shape[1] - kernel_size[1] + 1) // strides[1],
assert min(res) > 0, (input_shape, kernel_size, strides, padding)
return res
elif padding == 'same': elif padding == 'same':
return (input_shape[1]) // strides[0], \ return (input_shape[0]) // strides[0], \
(input_shape[2]) // strides[1], (input_shape[1]) // strides[1],
else: else:
raise Exception('invalid padding: ' + padding) raise Exception('invalid padding: %s' % padding)
class keras: class keras:
class layers: class layers:
@@ -2354,7 +2599,7 @@ class keras:
Dense = lambda *args, **kwargs: ('dense', args, kwargs) Dense = lambda *args, **kwargs: ('dense', args, kwargs)
def Conv2D(filters, kernel_size, strides=(1, 1), padding='valid', def Conv2D(filters, kernel_size, strides=(1, 1), padding='valid',
activation=None): activation=None, input_shape=None):
return 'conv2d', {'filters': filters, 'kernel_size': kernel_size, return 'conv2d', {'filters': filters, 'kernel_size': kernel_size,
'strides': strides, 'padding': padding, 'strides': strides, 'padding': padding,
'activation': activation} 'activation': activation}
@@ -2369,6 +2614,13 @@ class keras:
raise Exception('rate needs to be a power of two') raise Exception('rate needs to be a power of two')
return 'dropout', rate return 'dropout', rate
def Activation(activation):
assert(activation == 'relu')
return activation,
def BatchNormalization():
return 'batchnorm',
class optimizers: class optimizers:
SGD = lambda *args, **kwargs: ('sgd', args, kwargs) SGD = lambda *args, **kwargs: ('sgd', args, kwargs)
Adam = lambda *args, **kwargs: ('adam', args, kwargs) Adam = lambda *args, **kwargs: ('adam', args, kwargs)
@@ -2383,12 +2635,25 @@ class keras:
def compile(self, optimizer): def compile(self, optimizer):
self.optimizer = optimizer self.optimizer = optimizer
def compile_by_args(self, program):
if 'adam' in program.args:
self.optimizer = 'adam', [], {}
elif 'amsgrad' in program.args:
self.optimizer = 'adam', [], {'amsgrad': True}
else:
self.optimizer = 'sgd', [], {}
@property @property
def trainable_variables(self): def trainable_variables(self):
if self.opt == None: if self.opt == None:
raise Exception('need to run build() or fit() first') raise Exception('need to run build() or fit() first')
return list(self.opt.thetas) return list(self.opt.thetas)
def summary(self):
sizes = [var.total_size() for var in self.trainable_variables]
print(sizes)
print('Trainable params:', sum(sizes))
def build(self, input_shape, batch_size=128): def build(self, input_shape, batch_size=128):
data_input_shape = input_shape data_input_shape = input_shape
if self.opt != None and \ if self.opt != None and \
@@ -2415,12 +2680,11 @@ class keras:
if i == len(self.layers) - 1: if i == len(self.layers) - 1:
if layer[2].get('activation', 'softmax') in \ if layer[2].get('activation', 'softmax') in \
('softmax', 'sigmoid'): ('softmax', 'sigmoid'):
del layer[2]['activation'] layer[2].pop('activation', None)
layers.append(Dense(N, n_units, layer[1][0], layers.append(Dense(N, n_units, layer[1][0],
**layer[2])) **layer[2]))
elif name == 'conv2d':
if len(layers) != 0:
input_shape = layers[-1].Y.sizes input_shape = layers[-1].Y.sizes
elif name == 'conv2d':
input_shape = list(input_shape) + \ input_shape = list(input_shape) + \
[1] * (4 - len(input_shape)) [1] * (4 - len(input_shape))
print (layer[1]) print (layer[1])
@@ -2437,9 +2701,13 @@ class keras:
output_shape = [batch_size] + list( output_shape = [batch_size] + list(
apply_padding(input_shape[1:3], kernel_size, apply_padding(input_shape[1:3], kernel_size,
strides, padding)) + [filters] strides, padding)) + [filters]
padding = padding.upper() if isinstance(padding, str) \
else padding
layers.append(FixConv2d(input_shape, weight_shape, layers.append(FixConv2d(input_shape, weight_shape,
(filters,), output_shape, (filters,), output_shape,
strides, padding.upper())) strides, padding))
input_shape = output_shape
print('conv output shape', output_shape)
elif name == 'maxpool': elif name == 'maxpool':
pool_size = layer[1]['pool_size'] pool_size = layer[1]['pool_size']
strides = layer[1]['strides'] strides = layer[1]['strides']
@@ -2450,16 +2718,23 @@ class keras:
strides = (strides, strides) strides = (strides, strides)
if strides == None: if strides == None:
strides = pool_size strides = pool_size
layers.append(MaxPool(layers[-1].Y.sizes, layers.append(MaxPool(input_shape,
[1] + list(strides) + [1], [1] + list(strides) + [1],
[1] + list(pool_size) + [1], [1] + list(pool_size) + [1],
padding.upper())) padding))
input_shape = layers[-1].Y.sizes
elif name == 'dropout': elif name == 'dropout':
layers.append(Dropout(batch_size, reduce( layers.append(Dropout(batch_size, reduce(
operator.mul, layers[-1].Y.sizes[1:]), operator.mul, layers[-1].Y.sizes[1:]),
alpha=layer[1])) alpha=layer[1]))
input_shape = layers[-1].Y.sizes
elif name == 'flatten': elif name == 'flatten':
pass pass
elif name == 'relu':
layers.append(Relu(layers[-1].Y.sizes))
elif name == 'batchnorm':
input_shape = layers[-1].Y.sizes
layers.append(BatchNorm(layers[-1].Y.sizes))
else: else:
raise Exception(layer[0] + ' not supported') raise Exception(layer[0] + ' not supported')
if layers[-1].d_out == 1: if layers[-1].d_out == 1:

View File

@@ -1493,6 +1493,7 @@ class PackedIndexStructure(object):
self.l[i] = [0] * self.elements_per_block self.l[i] = [0] * self.elements_per_block
time() time()
print_ln('packed ORAM init %s/%s', i, real_init_rounds) print_ln('packed ORAM init %s/%s', i, real_init_rounds)
print_ln('packed ORAM init done')
print('index initialized, size', size) print('index initialized, size', size)
def translate_index(self, index): def translate_index(self, index):
""" Bit slicing *index* according parameters. Output is tuple """ Bit slicing *index* according parameters. Output is tuple

View File

@@ -580,10 +580,19 @@ class Program(object):
@staticmethod @staticmethod
def read_tapes(schedule): def read_tapes(schedule):
m = re.search(r'([^/]*)\.mpc', schedule)
if m:
schedule = m.group(1)
if not os.path.exists(schedule): if not os.path.exists(schedule):
schedule = 'Programs/Schedules/%s.sch' % schedule schedule = 'Programs/Schedules/%s.sch' % schedule
try:
lines = open(schedule).readlines() lines = open(schedule).readlines()
except FileNotFoundError:
print('%s not found, have you compiled the program?' % schedule,
file=sys.stderr)
sys.exit(1)
for tapename in lines[2].split(' '): for tapename in lines[2].split(' '):
yield tapename.strip() yield tapename.strip()

View File

@@ -1675,6 +1675,13 @@ class localint(Tape._no_truth):
__ne__ = lambda self, other: localint(self._v != other) __ne__ = lambda self, other: localint(self._v != other)
class personal(Tape._no_truth): class personal(Tape._no_truth):
""" Value known to one player. Supports operations with public
values and personal values known to the same player. Can be used
with :py:func:`~Compiler.library.print_ln_to`.
:param player: player (int)
:param value: cleartext value (cint, cfix, cfloat) or array thereof
"""
def __init__(self, player, value): def __init__(self, player, value):
assert value is not NotImplemented assert value is not NotImplemented
assert not isinstance(value, _secret) assert not isinstance(value, _secret)
@@ -1685,8 +1692,24 @@ class personal(Tape._no_truth):
self._v = value self._v = value
def binary_output(self): def binary_output(self):
""" Write binary output to
``Player-Data/Binary-Output-P<playerno>-<threadno>`` if
supported by underlying type. Player must be known at compile time."""
self._v.binary_output(self.player) self._v.binary_output(self.player)
def reveal_to(self, player):
""" Pass personal value to another player. """
if isinstance(self._v, Array):
source = self._v[:]
else:
source = self._v
source = cint.conv(source)
res = cint(size=source.size)
sendpersonal(source.size, player, res, self.player, source)
if isinstance(self._v, Array):
res = Array.create_from(res)
return personal(player, res)
def bit_decompose(self, length): def bit_decompose(self, length):
return [personal(self.player, x) for x in self._v.bit_decompose(length)] return [personal(self.player, x) for x in self._v.bit_decompose(length)]
@@ -1858,8 +1881,13 @@ class _secret(_register, _secret_structure):
@vectorized_classmethod @vectorized_classmethod
@set_instruction_type @set_instruction_type
def get_random_input_mask_for(cls, player): def get_random_input_mask_for(cls, player):
res = cls() """ Secret random input mask according to security model.
inputmask(res, player)
:return: mask (sint), mask (personal cint)
:param size: vector size (int, default 1)
"""
res = cls(), personal(player, cls.clear_type())
inputmask(res[0], res[1]._v, player)
return res return res
@classmethod @classmethod
@@ -2071,15 +2099,13 @@ class _secret(_register, _secret_structure):
@set_instruction_type @set_instruction_type
def reveal_to(self, player): def reveal_to(self, player):
""" Reveal secret value to :py:obj:`player`. """ Reveal secret value to :py:obj:`player`.
Result written to ``Player-Data/Private-Output-P<player>``
:param player: int :param player: int
:returns: value to be used with :py:func:`~Compiler.library.print_ln_to` :returns: :py:class:`personal`
""" """
masked = self.__class__() mask = self.get_random_input_mask_for(player)
res = personal(player, self.clear_type()) masked = self + mask[0]
startprivateoutput(masked, self, player) res = personal(player, masked.reveal() - mask[1])
stopprivateoutput(res._v, masked.reveal(), player)
return res return res
@@ -2633,21 +2659,20 @@ class sint(_secret, _int):
@vectorize @vectorize
def reveal_to(self, player): def reveal_to(self, player):
""" Reveal secret value to :py:obj:`player`. """ Reveal secret value to :py:obj:`player`.
Result potentially written to
``Player-Data/Private-Output-P<player>``, but not if
:py:obj:`player` is a :py:class:`regint`.
:param player: public integer (int/regint/cint): :param player: public integer (int/regint/cint)
:returns: value to be used with :py:func:`~Compiler.library.print_ln_to` :returns: :py:class:`personal`
""" """
if not util.is_constant(player) or self.size > 1: if not util.is_constant(player):
secret_mask = sint() secret_mask = sint()
player_mask = cint() player_mask = cint()
inputmaskreg(secret_mask, player_mask, regint.conv(player)) inputmaskreg(secret_mask, player_mask, regint.conv(player))
return personal(player, return personal(player,
(self + secret_mask).reveal() - player_mask) (self + secret_mask).reveal() - player_mask)
else: else:
return super(sint, self).reveal_to(player) res = personal(player, self.clear_type())
privateoutput(self.size, player, res._v, self)
return res
def private_division(self, divisor, active=True, dividend_length=None, def private_division(self, divisor, active=True, dividend_length=None,
divisor_length=None): divisor_length=None):
@@ -4366,12 +4391,9 @@ class sfix(_fix):
def reveal_to(self, player): def reveal_to(self, player):
""" Reveal secret value to :py:obj:`player`. """ Reveal secret value to :py:obj:`player`.
Raw representation possibly written to
``Player-Data/Private-Output-P<player>``, but not if
:py:obj:`player` is a :py:class:`regint`.
:param player: public integer (int/regint/cint) :param player: public integer (int/regint/cint)
:returns: value to be used with :py:func:`~Compiler.library.print_ln_to` :returns: :py:class:`personal`
""" """
return personal(player, cfix._new(self.v.reveal_to(player)._v, return personal(player, cfix._new(self.v.reveal_to(player)._v,
self.k, self.f)) self.k, self.f))
@@ -5221,6 +5243,9 @@ class Array(_vectorizable):
return self.assign(value, addresses) return self.assign(value, addresses)
self._store(value, self.get_address(index)) self._store(value, self.get_address(index))
def to_array(self):
return self
def get_sub(self, start, stop=None): def get_sub(self, start, stop=None):
if stop is None: if stop is None:
stop = start stop = start
@@ -5471,6 +5496,10 @@ class Array(_vectorizable):
""" Insecure shuffle in place. """ """ Insecure shuffle in place. """
self.assign_vector(self.get(regint.inc(len(self)).shuffle())) self.assign_vector(self.get(regint.inc(len(self)).shuffle()))
def randomize(self, *args):
""" Randomize according to data type. """
self.assign_vector(self.value_type.get_random(*args, size=len(self)))
def reveal(self): def reveal(self):
""" Reveal the whole array. """ Reveal the whole array.
@@ -5596,6 +5625,9 @@ class SubMultiArray(_vectorizable):
def __iter__(self): def __iter__(self):
return (self[i] for i in range(len(self))) return (self[i] for i in range(len(self)))
def to_array(self):
return Array(self.total_size(), self.value_type, address=self.address)
def assign_all(self, value): def assign_all(self, value):
""" Assign the same value to all entries. """ Assign the same value to all entries.
@@ -5958,6 +5990,7 @@ class SubMultiArray(_vectorizable):
""" """
assert len(self.sizes) == 2 assert len(self.sizes) == 2
assert len(other.sizes) == 2 assert len(other.sizes) == 2
assert other.address != None
if indices is None: if indices is None:
assert self.sizes[1] == other.sizes[1] assert self.sizes[1] == other.sizes[1]
indices = [regint.inc(i) for i in self.sizes + other.sizes[::-1]] indices = [regint.inc(i) for i in self.sizes + other.sizes[::-1]]
@@ -6145,6 +6178,16 @@ class SubMultiArray(_vectorizable):
n = self.sizes[0] n = self.sizes[0]
return self.array.get(regint.inc(n, 0, n + 1)) return self.array.get(regint.inc(n, 0, n + 1))
def randomize(self, *args):
""" Randomize according to data type. """
if self.total_size() < program.options.budget:
self.assign_vector(
self.value_type.get_random(*args, size=self.total_size()))
else:
@library.for_range(self.sizes[0])
def _(i):
self[i].randomize(*args)
def reveal_list(self): def reveal_list(self):
""" Reveal as list. """ """ Reveal as list. """
return list(self.get_vector().reveal()) return list(self.get_vector().reveal())
@@ -6251,6 +6294,22 @@ class Matrix(MultiArray):
MultiArray.__init__(self, [rows, columns], value_type, debug=debug, \ MultiArray.__init__(self, [rows, columns], value_type, debug=debug, \
address=address) address=address)
def get_column(self, index):
""" Get column as vector.
:param index: regint/cint/int
"""
assert self.value_type.n_elements() == 1
addresses = regint.inc(self.sizes[0], self.address + index,
self.sizes[1])
return self.value_type.load_mem(addresses)
def get_column_by_row_indices(self, rows, column):
assert self.value_type.n_elements() == 1
addresses = rows * self.sizes[1] + \
regint.inc(len(rows), self.address + column, 0)
return self.value_type.load_mem(addresses)
def set_column(self, index, vector): def set_column(self, index, vector):
""" Change column. """ Change column.

View File

@@ -47,11 +47,18 @@ Rq_Element FHE_PK::sample_secret_key(PRNG& G)
} }
void FHE_PK::KeyGen(Rq_Element& sk, PRNG& G, int noise_boost) void FHE_PK::KeyGen(Rq_Element& sk, PRNG& G, int noise_boost)
{
Rq_Element a(*this);
a.randomize(G);
partial_key_gen(sk, a, G, noise_boost);
}
void FHE_PK::partial_key_gen(const Rq_Element& sk, const Rq_Element& a, PRNG& G,
int noise_boost)
{ {
FHE_PK& PK = *this; FHE_PK& PK = *this;
// Generate the main public key a0 = a;
PK.a0.randomize(G);
// b0=a0*s+p*e0 // b0=a0*s+p*e0
Rq_Element e0((*PK.params).FFTD(),evaluation,evaluation); Rq_Element e0((*PK.params).FFTD(),evaluation,evaluation);
@@ -77,9 +84,6 @@ void FHE_PK::KeyGen(Rq_Element& sk, PRNG& G, int noise_boost)
mul(es,es,PK.pr); mul(es,es,PK.pr);
add(PK.Sw_b,PK.Sw_b,es); add(PK.Sw_b,PK.Sw_b,es);
// Lowering level as we only decrypt at level 0
sk.lower_level();
// bs=bs-p1*s^2 // bs=bs-p1*s^2
Rq_Element s2; Rq_Element s2;
mul(s2,sk,sk); // Mult at level 0 mul(s2,sk,sk); // Mult at level 0
@@ -334,7 +338,7 @@ void FHE_SK::check(const FHE_Params& params, const FHE_PK& pk,
template<class FD> template<class FD>
void FHE_SK::check(const FHE_PK& pk, const FD& FieldD) void FHE_SK::check(const FHE_PK& pk, const FD& FieldD)
{ {
check(*params, pk, pr); check(*params, pk, FieldD.get_prime());
pk.check_noise(*this); pk.check_noise(*this);
if (decrypt(pk.encrypt(Plaintext_<FD>(FieldD)), FieldD) != if (decrypt(pk.encrypt(Plaintext_<FD>(FieldD)), FieldD) !=
Plaintext_<FD>(FieldD)) Plaintext_<FD>(FieldD))

View File

@@ -150,6 +150,8 @@ class FHE_PK
Rq_Element sample_secret_key(PRNG& G); Rq_Element sample_secret_key(PRNG& G);
void KeyGen(Rq_Element& sk, PRNG& G, int noise_boost = 1); void KeyGen(Rq_Element& sk, PRNG& G, int noise_boost = 1);
void partial_key_gen(const Rq_Element& sk, const Rq_Element& a, PRNG& G,
int noise_boost = 1);
void check_noise(const FHE_SK& sk) const; void check_noise(const FHE_SK& sk) const;
void check_noise(const Rq_Element& x, bool check_modulo = false) const; void check_noise(const Rq_Element& x, bool check_modulo = false) const;

View File

@@ -3,6 +3,11 @@
#include "FHE/Ring_Element.h" #include "FHE/Ring_Element.h"
#include "Tools/Exceptions.h" #include "Tools/Exceptions.h"
FHE_Params::FHE_Params(int n_mults) :
FFTData(n_mults + 1), Chi(0.7), sec_p(-1), matrix_dim(1)
{
}
void FHE_Params::set(const Ring& R, void FHE_Params::set(const Ring& R,
const vector<bigint>& primes) const vector<bigint>& primes)
{ {
@@ -24,6 +29,14 @@ void FHE_Params::set_sec(int sec)
throw runtime_error("distributed decryption bound is zero"); throw runtime_error("distributed decryption bound is zero");
} }
void FHE_Params::set_matrix_dim(int matrix_dim)
{
assert(matrix_dim > 0);
if (FFTData[0].get_prime() != 0)
throw runtime_error("cannot change matrix dimension after parameter generation");
this->matrix_dim = matrix_dim;
}
bigint FHE_Params::Q() const bigint FHE_Params::Q() const
{ {
bigint res = FFTData[0].get_prime(); bigint res = FFTData[0].get_prime();
@@ -40,6 +53,7 @@ void FHE_Params::pack(octetStream& o) const
Chi.pack(o); Chi.pack(o);
Bval.pack(o); Bval.pack(o);
o.store(sec_p); o.store(sec_p);
o.store(matrix_dim);
} }
void FHE_Params::unpack(octetStream& o) void FHE_Params::unpack(octetStream& o)
@@ -52,6 +66,7 @@ void FHE_Params::unpack(octetStream& o)
Chi.unpack(o); Chi.unpack(o);
Bval.unpack(o); Bval.unpack(o);
o.get(sec_p); o.get(sec_p);
o.get(matrix_dim);
} }
bool FHE_Params::operator!=(const FHE_Params& other) const bool FHE_Params::operator!=(const FHE_Params& other) const

View File

@@ -26,10 +26,11 @@ class FHE_Params
// Data for distributed decryption // Data for distributed decryption
int sec_p; int sec_p;
bigint Bval; bigint Bval;
int matrix_dim;
public: public:
FHE_Params(int n_mults = 1) : FFTData(n_mults + 1), Chi(0.7), sec_p(-1) {} FHE_Params(int n_mults = 1);
int n_mults() const { return FFTData.size() - 1; } int n_mults() const { return FFTData.size() - 1; }
@@ -37,6 +38,9 @@ class FHE_Params
void set(const vector<bigint>& primes); void set(const vector<bigint>& primes);
void set_sec(int sec); void set_sec(int sec);
void set_matrix_dim(int matrix_dim);
int get_matrix_dim() const { return matrix_dim; }
const vector<FFT_Data>& FFTD() const { return FFTData; } const vector<FFT_Data>& FFTD() const { return FFTData; }
const bigint& p0() const { return FFTData[0].get_prime(); } const bigint& p0() const { return FFTData[0].get_prime(); }

View File

@@ -47,7 +47,7 @@ bool same_word_length(int l1, int l2)
template <> template <>
int generate_semi_setup(int plaintext_length, int sec, int generate_semi_setup(int plaintext_length, int sec,
FHE_Params& params, FFT_Data& FTD, bool round_up) FHE_Params& params, FFT_Data& FTD, bool round_up, int n)
{ {
int m = 1024; int m = 1024;
int lgp = plaintext_length; int lgp = plaintext_length;
@@ -58,7 +58,7 @@ int generate_semi_setup(int plaintext_length, int sec,
while (true) while (true)
{ {
tmp_params = params; tmp_params = params;
SemiHomomorphicNoiseBounds nb(p, phi_N(m), 1, sec, SemiHomomorphicNoiseBounds nb(p, phi_N(m), n, sec,
numBits(NonInteractiveProof::slack(sec, phi_N(m))), true, tmp_params); numBits(NonInteractiveProof::slack(sec, phi_N(m))), true, tmp_params);
bigint p1 = 2 * p * m, p0 = p; bigint p1 = 2 * p * m, p0 = p;
while (nb.min_p0(params.n_mults() > 0, p1) > p0) while (nb.min_p0(params.n_mults() > 0, p1) > p0)
@@ -89,14 +89,14 @@ int generate_semi_setup(int plaintext_length, int sec,
template <> template <>
int generate_semi_setup(int plaintext_length, int sec, int generate_semi_setup(int plaintext_length, int sec,
FHE_Params& params, P2Data& P2D, bool round_up) FHE_Params& params, P2Data& P2D, bool round_up, int n)
{ {
if (params.n_mults() > 0) if (params.n_mults() > 0)
throw runtime_error("only implemented for 0-level BGV"); throw runtime_error("only implemented for 0-level BGV");
gf2n_short::init_field(plaintext_length); gf2n_short::init_field(plaintext_length);
int m; int m;
char_2_dimension(m, plaintext_length); char_2_dimension(m, plaintext_length);
SemiHomomorphicNoiseBounds nb(2, phi_N(m), 1, sec, SemiHomomorphicNoiseBounds nb(2, phi_N(m), n, sec,
numBits(NonInteractiveProof::slack(sec, phi_N(m))), true, params); numBits(NonInteractiveProof::slack(sec, phi_N(m))), true, params);
int lgp0 = numBits(nb.min_p0(false, 0)); int lgp0 = numBits(nb.min_p0(false, 0));
int extra_slack = common_semi_setup(params, m, 2, lgp0, -1, round_up); int extra_slack = common_semi_setup(params, m, 2, lgp0, -1, round_up);
@@ -590,6 +590,9 @@ void char_2_dimension(int& m, int& lg2)
m=5797; m=5797;
lg2=40; lg2=40;
break; break;
case 16:
m = 13107;
break;
default: default:
throw runtime_error("field size not supported"); throw runtime_error("field size not supported");
break; break;

View File

@@ -52,7 +52,7 @@ void generate_setup(int nparties, int lgp, int lg2,
// semi-homomorphic, includes slack // semi-homomorphic, includes slack
template <class FD> template <class FD>
int generate_semi_setup(int plaintext_length, int sec, int generate_semi_setup(int plaintext_length, int sec,
FHE_Params& params, FD& FieldD, bool round_up); FHE_Params& params, FD& FieldD, bool round_up, int n = 1);
// field-independent semi-homomorphic setup // field-independent semi-homomorphic setup
int common_semi_setup(FHE_Params& params, int m, bigint p, int lgp0, int lgp1, int common_semi_setup(FHE_Params& params, int m, bigint p, int lgp0, int lgp1,

View File

@@ -39,6 +39,7 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p,
bigint B_clean_not_top_gear = B_clean << int(ceil(sec / 2.)); bigint B_clean_not_top_gear = B_clean << int(ceil(sec / 2.));
B_clean = max(B_clean_not_top_gear, B_clean_top_gear); B_clean = max(B_clean_not_top_gear, B_clean_top_gear);
B_scale = (c1 + c2 * V_s) * p * sqrt(phi_m / 12.0); B_scale = (c1 + c2 * V_s) * p * sqrt(phi_m / 12.0);
int matrix_dim = params.get_matrix_dim();
#ifdef NOISY #ifdef NOISY
cout << "p * sqrt(phi(m) / 12): " << p * sqrt(phi_m / 12.0) << endl; cout << "p * sqrt(phi(m) / 12): " << p * sqrt(phi_m / 12.0) << endl;
cout << "V_s: " << V_s << endl; cout << "V_s: " << V_s << endl;
@@ -48,9 +49,11 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p,
cout << "log(slack): " << slack << endl; cout << "log(slack): " << slack << endl;
cout << "B_clean: " << B_clean << endl; cout << "B_clean: " << B_clean << endl;
cout << "B_scale: " << B_scale << endl; cout << "B_scale: " << B_scale << endl;
cout << "matrix dimension: " << matrix_dim << endl;
#endif #endif
drown = 1 + n * (bigint(1) << sec); assert(matrix_dim > 0);
drown = 1 + matrix_dim * n * (bigint(1) << sec);
} }
bigint SemiHomomorphicNoiseBounds::min_p0(const bigint& p1) bigint SemiHomomorphicNoiseBounds::min_p0(const bigint& p1)

View File

@@ -50,6 +50,7 @@ void Ring_Element::prepare_push()
void Ring_Element::allocate() void Ring_Element::allocate()
{ {
assert(FFTD);
element.resize(FFTD->phi_m()); element.resize(FFTD->phi_m());
} }

View File

@@ -109,6 +109,13 @@ void mul(Rq_Element& ans,const Rq_Element& a,const bigint& b)
} }
} }
void Rq_Element::add(octetStream& os)
{
Rq_Element tmp(*this);
tmp.unpack(os);
*this += tmp;
}
void Rq_Element::randomize(PRNG& G,int l) void Rq_Element::randomize(PRNG& G,int l)
{ {
set_level(l); set_level(l);
@@ -246,7 +253,7 @@ void Rq_Element::Scale(const bigint& p)
// Now add delta back onto a0 // Now add delta back onto a0
Rq_Element bb(b0,b1); Rq_Element bb(b0,b1);
add(*this,*this,bb); ::add(*this,*this,bb);
// Now divide by p1 mod p0 // Now divide by p1 mod p0
modp p1_inv,pp; modp p1_inv,pp;

View File

@@ -93,12 +93,14 @@ protected:
friend void mul(Rq_Element& ans,const Rq_Element& a,const Rq_Element& b); friend void mul(Rq_Element& ans,const Rq_Element& a,const Rq_Element& b);
friend void mul(Rq_Element& ans,const Rq_Element& a,const bigint& b); friend void mul(Rq_Element& ans,const Rq_Element& a,const bigint& b);
void add(octetStream& os);
template<class S> template<class S>
Rq_Element& operator+=(const vector<S>& other); Rq_Element& operator+=(const vector<S>& other);
Rq_Element& operator+=(const Rq_Element& other) { add(*this, *this, other); return *this; } Rq_Element& operator+=(const Rq_Element& other) { ::add(*this, *this, other); return *this; }
Rq_Element operator+(const Rq_Element& b) const { Rq_Element res(*this); add(res, *this, b); return res; } Rq_Element operator+(const Rq_Element& b) const { Rq_Element res(*this); ::add(res, *this, b); return res; }
Rq_Element operator-(const Rq_Element& b) const { Rq_Element res(*this); sub(res, *this, b); return res; } Rq_Element operator-(const Rq_Element& b) const { Rq_Element res(*this); sub(res, *this, b); return res; }
template <class T> template <class T>
Rq_Element operator*(const T& b) const { Rq_Element res(*this); mul(res, *this, b); return res; } Rq_Element operator*(const T& b) const { Rq_Element res(*this); mul(res, *this, b); return res; }
@@ -176,7 +178,7 @@ Rq_Element& Rq_Element::operator+=(const vector<S>& other)
{ {
Rq_Element tmp = *this; Rq_Element tmp = *this;
tmp.from(Iterator<S>(other), lev); tmp.from(Iterator<S>(other), lev);
add(*this, *this, tmp); ::add(*this, *this, tmp);
return *this; return *this;
} }

View File

@@ -203,7 +203,7 @@ template<class FD>
void PartSetup<FD>::secure_init(Player& P, MachineBase& machine, void PartSetup<FD>::secure_init(Player& P, MachineBase& machine,
int plaintext_length, int sec) int plaintext_length, int sec)
{ {
::secure_init(*this, P, machine, plaintext_length, sec); ::secure_init(*this, P, machine, plaintext_length, sec, params);
} }
template<class FD> template<class FD>

View File

@@ -130,6 +130,13 @@ void Multiplier<FD>::report_size(ReportType type, MemoryUsage& res)
res += memory_usage; res += memory_usage;
} }
template<class FD>
const vector<Ciphertext>& Multiplier<FD>::get_multiplicands(
const vector<vector<Ciphertext> >& others_ct, const FHE_PK&)
{
return others_ct[P.get_full_player().get_player(-P.get_offset())];
}
template class Multiplier<FFT_Data>; template class Multiplier<FFT_Data>;
template class Multiplier<P2Data>; template class Multiplier<P2Data>;

View File

@@ -55,6 +55,9 @@ public:
size_t report_size(ReportType type); size_t report_size(ReportType type);
void report_size(ReportType type, MemoryUsage& res); void report_size(ReportType type, MemoryUsage& res);
size_t report_volatile() { return volatile_capacity; } size_t report_volatile() { return volatile_capacity; }
const vector<Ciphertext>& get_multiplicands(
const vector<vector<Ciphertext>>& others_ct, const FHE_PK&);
}; };
#endif /* FHEOFFLINE_MULTIPLIER_H_ */ #endif /* FHEOFFLINE_MULTIPLIER_H_ */

View File

@@ -9,6 +9,7 @@
#include "Math/Setup.h" #include "Math/Setup.h"
#include "FHEOffline/Proof.h" #include "FHEOffline/Proof.h"
#include "FHEOffline/PairwiseMachine.h" #include "FHEOffline/PairwiseMachine.h"
#include "FHEOffline/TemiSetup.h"
#include "Tools/Commit.h" #include "Tools/Commit.h"
#include "Tools/Bundle.h" #include "Tools/Bundle.h"
#include "Processor/OnlineOptions.h" #include "Processor/OnlineOptions.h"
@@ -53,7 +54,7 @@ void PairwiseSetup<FD>::init(const Player& P, int sec, int plaintext_length,
template <class FD> template <class FD>
void PairwiseSetup<FD>::secure_init(Player& P, PairwiseMachine& machine, int plaintext_length, int sec) void PairwiseSetup<FD>::secure_init(Player& P, PairwiseMachine& machine, int plaintext_length, int sec)
{ {
::secure_init(*this, P, machine, plaintext_length, sec); ::secure_init(*this, P, machine, plaintext_length, sec, params);
alpha = FieldD; alpha = FieldD;
machine.sk = FHE_SK(params, FieldD.get_prime()); machine.sk = FHE_SK(params, FieldD.get_prime());
for (auto& pk : machine.other_pks) for (auto& pk : machine.other_pks)
@@ -62,13 +63,14 @@ void PairwiseSetup<FD>::secure_init(Player& P, PairwiseMachine& machine, int pla
template <class T, class U> template <class T, class U>
void secure_init(T& setup, Player& P, U& machine, void secure_init(T& setup, Player& P, U& machine,
int plaintext_length, int sec) int plaintext_length, int sec, FHE_Params& params)
{ {
machine.sec = sec; machine.sec = sec;
sec = max(sec, 40); sec = max(sec, 40);
machine.drown_sec = sec; machine.drown_sec = sec;
string filename = PREP_DIR + T::name() + "-" string filename = PREP_DIR + T::name() + "-"
+ to_string(plaintext_length) + "-" + to_string(sec) + "-" + to_string(plaintext_length) + "-" + to_string(sec) + "-"
+ to_string(params.get_matrix_dim()) + "-"
+ OnlineOptions::singleton.prime.get_str() + "-" + OnlineOptions::singleton.prime.get_str() + "-"
+ to_string(CowGearOptions::singleton.top_gear()) + "-P" + to_string(CowGearOptions::singleton.top_gear()) + "-P"
+ to_string(P.my_num()) + "-" + to_string(P.num_players()); + to_string(P.my_num()) + "-" + to_string(P.num_players());
@@ -85,7 +87,6 @@ void secure_init(T& setup, Player& P, U& machine,
{ {
cout << "Finding parameters for security " << sec << " and field size ~2^" cout << "Finding parameters for security " << sec << " and field size ~2^"
<< plaintext_length << endl; << plaintext_length << endl;
setup.params = setup.params.n_mults();
setup.generate(P, machine, plaintext_length, sec); setup.generate(P, machine, plaintext_length, sec);
setup.check(P, machine); setup.check(P, machine);
octetStream os; octetStream os;
@@ -208,5 +209,8 @@ void PairwiseSetup<FD>::set_alphai(T alphai)
template class PairwiseSetup<FFT_Data>; template class PairwiseSetup<FFT_Data>;
template class PairwiseSetup<P2Data>; template class PairwiseSetup<P2Data>;
template void secure_init(PartSetup<FFT_Data>&, Player&, MachineBase&, int, int); template void secure_init(PartSetup<FFT_Data>&, Player&, MachineBase&, int, int, FHE_Params& params);
template void secure_init(PartSetup<P2Data>&, Player&, MachineBase&, int, int); template void secure_init(PartSetup<P2Data>&, Player&, MachineBase&, int, int, FHE_Params& params);
template void secure_init(TemiSetup<FFT_Data>&, Player&, MachineBase&, int, int, FHE_Params& params);
template void secure_init(TemiSetup<P2Data>&, Player&, MachineBase&, int, int, FHE_Params& params);

View File

@@ -15,7 +15,7 @@ class MachineBase;
template <class T, class U> template <class T, class U>
void secure_init(T& setup, Player& P, U& machine, void secure_init(T& setup, Player& P, U& machine,
int plaintext_length, int sec); int plaintext_length, int sec, FHE_Params& params);
template <class FD> template <class FD>
class PairwiseSetup class PairwiseSetup

View File

@@ -18,7 +18,12 @@ void SimpleDistDecrypt<FD>::reshare(Plaintext<typename FD::T, FD, typename FD::S
EncCommitBase<typename FD::T, FD, typename FD::S>& EC) EncCommitBase<typename FD::T, FD, typename FD::S>& EC)
{ {
(void)EC; (void)EC;
m = reshare(cm);
}
template <class FD>
Plaintext_<FD> SimpleDistDecrypt<FD>::reshare(const Ciphertext& cm)
{
PRNG G; PRNG G;
G.ReSeed(); G.ReSeed();
this->f.randomize(G, Full); this->f.randomize(G, Full);
@@ -27,10 +32,13 @@ void SimpleDistDecrypt<FD>::reshare(Plaintext<typename FD::T, FD, typename FD::S
this->run(cm); this->run(cm);
// Step 4 // Step 4
Plaintext_<FD> m(this->f.get_field());
if (this->P.my_num()==0) if (this->P.my_num()==0)
{ sub(m,this->mf,this->f); } { sub(m,this->mf,this->f); }
else else
{ m=this->f; m.negate(); } { m=this->f; m.negate(); }
return m;
} }

View File

@@ -20,6 +20,7 @@ public:
void reshare(Plaintext<typename FD::T, FD, typename FD::S>& m, void reshare(Plaintext<typename FD::T, FD, typename FD::S>& m,
const Ciphertext& cm, const Ciphertext& cm,
EncCommitBase<typename FD::T, FD, typename FD::S>& EC); EncCommitBase<typename FD::T, FD, typename FD::S>& EC);
Plaintext_<FD> reshare(const Ciphertext& cm);
}; };
#endif /* FHEOFFLINE_SIMPLEDISTDECRYPT_H_ */ #endif /* FHEOFFLINE_SIMPLEDISTDECRYPT_H_ */

59
FHEOffline/TemiSetup.cpp Normal file
View File

@@ -0,0 +1,59 @@
/*
* TemiSetup.cpp
*
*/
#include "TemiSetup.h"
#include "PairwiseSetup.h"
#include "FHE/NTL-Subs.h"
#include "Protocols/HemiOptions.h"
template<class FD>
TemiSetup<FD>::TemiSetup()
{
this->params = FHE_Params(0);
this->pk = {this->params, 0};
this->sk = {this->params, 0};
this->calpha = this->params;
this->params.set_matrix_dim(
HemiOptions::singleton.plain_matmul ?
1 : OnlineOptions::singleton.batch_size);
}
template<class FD>
void TemiSetup<FD>::secure_init(Player& P, int plaintext_length)
{
MachineBase machine;
::secure_init(*this, P, machine, plaintext_length, 0, this->params);
}
template<class FD>
void TemiSetup<FD>::generate(Player& P, MachineBase&,
int plaintext_length, int sec)
{
generate_semi_setup(plaintext_length, sec, this->params, this->FieldD,
false, P.num_players());
this->sk = {this->params, this->FieldD.get_prime()};
this->pk = {this->params, this->FieldD.get_prime()};
}
template<class FD>
void TemiSetup<FD>::key_and_mac_generation(Player& P, MachineBase&, int,
true_type)
{
Rq_Element a(this->params);
GlobalPRNG GG(P);
a.randomize(GG);
SeededPRNG G;
auto sk = this->pk.sample_secret_key(G);
this->sk.assign(sk);
this->pk.partial_key_gen(sk, a, G);
TreeSum<Rq_Element> ts;
vector<Rq_Element> pks;
pks.push_back(this->pk.b());
ts.run(pks, P);
this->pk.assign(this->pk.a(), pks[0]);
}
template class TemiSetup<FFT_Data>;
template class TemiSetup<P2Data>;

34
FHEOffline/TemiSetup.h Normal file
View File

@@ -0,0 +1,34 @@
/*
* TemiSetup.h
*
*/
#ifndef FHEOFFLINE_TEMISETUP_H_
#define FHEOFFLINE_TEMISETUP_H_
#include "FHE/FHE_Keys.h"
#include "FHEOffline/SimpleMachine.h"
template<class FD>
class TemiSetup : public PartSetup<FD>
{
public:
static string name()
{
return "TemiParams";
}
static string protocol_name(int)
{
return "Temi";
}
TemiSetup();
void secure_init(Player& P, int plaintext_length);
void generate(Player& P, MachineBase&, int plaintext_length, int sec);
void key_and_mac_generation(Player& P, MachineBase&, int, true_type);
};
#endif /* FHEOFFLINE_TEMISETUP_H_ */

View File

@@ -47,11 +47,11 @@ inline void Memory<T>::check_index(Integer index) const
ss << T::type_string() << " memory overflow: " << i << "/" << vector<T>::size(); ss << T::type_string() << " memory overflow: " << i << "/" << vector<T>::size();
throw Processor_Error(ss.str()); throw Processor_Error(ss.str());
} }
#endif
#ifdef DEBUG_MEMORY #ifdef DEBUG_MEMORY
cout << typeid(T).name() << " at " << this << " index " << i << ": " cout << typeid(T).name() << " at " << this << " index " << i << ": "
<< vector<T>::operator[](i) << endl; << vector<T>::operator[](i) << endl;
#endif #endif
#endif
} }
template <class T> template <class T>

View File

@@ -122,6 +122,7 @@ public:
static const bool dishonest_majority = false; static const bool dishonest_majority = false;
static const bool variable_players = false; static const bool variable_players = false;
static const bool needs_ot = false; static const bool needs_ot = false;
static const bool has_mac = false;
static string type_string() { return "replicated secret"; } static string type_string() { return "replicated secret"; }
static string phase_name() { return "Replicated computation"; } static string phase_name() { return "Replicated computation"; }

View File

@@ -49,6 +49,7 @@ public:
static const bool dishonest_majority = T::dishonest_majority; static const bool dishonest_majority = T::dishonest_majority;
static const bool variable_players = T::variable_players; static const bool variable_players = T::variable_players;
static const bool needs_ot = T::needs_ot; static const bool needs_ot = T::needs_ot;
static const bool has_mac = T::has_mac;
static const bool expensive_triples = false; static const bool expensive_triples = false;
static const int default_length = 64; static const int default_length = 64;

View File

@@ -55,7 +55,7 @@
X(BITDECC, PROC.bitdecc(EXTRA, C0)) \ X(BITDECC, PROC.bitdecc(EXTRA, C0)) \
X(SHRCBI, C0 = PC1 >> IMM) \ X(SHRCBI, C0 = PC1 >> IMM) \
X(SHLCBI, C0 = PC1 << IMM) \ X(SHLCBI, C0 = PC1 << IMM) \
X(LDBITS, S0.load_clear(REG1, IMM)) \ X(LDBITS, S0.load_clear(REG1, int(IMM))) \
X(LDMSB, PROC.mem_op(SIZE, PROC.S, MMS, R0, IMM)) \ X(LDMSB, PROC.mem_op(SIZE, PROC.S, MMS, R0, IMM)) \
X(STMSB, PROC.mem_op(SIZE, MMS, PROC.S, IMM, R0)) \ X(STMSB, PROC.mem_op(SIZE, MMS, PROC.S, IMM, R0)) \
X(LDMCB, PROC.mem_op(SIZE, PROC.C, MMC, R0, IMM)) \ X(LDMCB, PROC.mem_op(SIZE, PROC.C, MMC, R0, IMM)) \

View File

@@ -23,6 +23,7 @@
#include "Protocols/Shamir.hpp" #include "Protocols/Shamir.hpp"
#include "Protocols/ShamirMC.hpp" #include "Protocols/ShamirMC.hpp"
#include "Protocols/MaliciousShamirMC.hpp" #include "Protocols/MaliciousShamirMC.hpp"
#include "Protocols/MaliciousShamirPO.hpp"
#include "Protocols/MAC_Check_Base.hpp" #include "Protocols/MAC_Check_Base.hpp"
#include "Protocols/Beaver.hpp" #include "Protocols/Beaver.hpp"
#include "Protocols/Spdz2kPrep.hpp" #include "Protocols/Spdz2kPrep.hpp"

37
Machines/temi-party.cpp Normal file
View File

@@ -0,0 +1,37 @@
/*
* temi-party.cpp
*
*/
#include "Protocols/TemiShare.h"
#include "Math/gfp.h"
#include "Math/gf2n.h"
#include "FHE/P2Data.h"
#include "Tools/ezOptionParser.h"
#include "GC/SemiSecret.h"
#include "GC/SemiPrep.h"
#include "Processor/FieldMachine.hpp"
#include "Protocols/TemiPrep.hpp"
#include "Processor/Data_Files.hpp"
#include "Processor/Instruction.hpp"
#include "Processor/Machine.hpp"
#include "Protocols/SemiPrep.hpp"
#include "Protocols/SemiInput.hpp"
#include "Protocols/MAC_Check_Base.hpp"
#include "Protocols/MAC_Check.hpp"
#include "Protocols/SemiMC.hpp"
#include "Protocols/Beaver.hpp"
#include "Protocols/MalRepRingPrep.hpp"
#include "Protocols/Hemi.hpp"
#include "GC/ShareSecret.hpp"
#include "GC/SemiHonestRepPrep.h"
#include "Math/gfp.hpp"
int main(int argc, const char** argv)
{
ez::ezOptionParser opt;
HemiOptions::singleton = {opt, argc, argv};
DishonestMajorityFieldMachine<TemiShare, TemiShare, gf2n_short>(argc, argv,
opt);
}

View File

@@ -61,7 +61,7 @@ arithmetic: rep-ring rep-field shamir semi2k-party.x semi-party.x mascot sy
binary: rep-bin yao semi-bin-party.x tinier-party.x tiny-party.x ccd-party.x malicious-ccd-party.x real-bmr binary: rep-bin yao semi-bin-party.x tinier-party.x tiny-party.x ccd-party.x malicious-ccd-party.x real-bmr
all: overdrive she-offline all: overdrive she-offline
arithmetic: hemi-party.x soho-party.x gear arithmetic: semi-he gear
-include $(DEPS) -include $(DEPS)
include $(wildcard *.d static/*.d) include $(wildcard *.d static/*.d)
@@ -87,6 +87,7 @@ she-offline: Check-Offline.x spdz2-offline.x
overdrive: simple-offline.x pairwise-offline.x cnc-offline.x gear overdrive: simple-offline.x pairwise-offline.x cnc-offline.x gear
gear: cowgear-party.x chaigear-party.x lowgear-party.x highgear-party.x gear: cowgear-party.x chaigear-party.x lowgear-party.x highgear-party.x
semi-he: hemi-party.x soho-party.x temi-party.x
rep-field: malicious-rep-field-party.x replicated-field-party.x ps-rep-field-party.x rep-field: malicious-rep-field-party.x replicated-field-party.x ps-rep-field-party.x
@@ -210,6 +211,7 @@ static/spdz2k-party.x: $(patsubst %.cpp,%.o,$(wildcard Machines/SPDZ2*.cpp))
semi-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o semi-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o
semi2k-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o semi2k-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o
hemi-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) hemi-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT)
temi-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT)
soho-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) soho-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT)
cowgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(TINIER) cowgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(TINIER)
chaigear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(TINIER) chaigear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(TINIER)
@@ -217,6 +219,7 @@ lowgear-party.x: $(FHEOFFLINE) $(TINIER) Protocols/CowGearOptions.o Protocols/Lo
highgear-party.x: $(FHEOFFLINE) $(TINIER) Protocols/CowGearOptions.o Protocols/HighGearKeyGen.o highgear-party.x: $(FHEOFFLINE) $(TINIER) Protocols/CowGearOptions.o Protocols/HighGearKeyGen.o
atlas-party.x: GC/AtlasSecret.o atlas-party.x: GC/AtlasSecret.o
static/hemi-party.x: $(FHEOBJS) static/hemi-party.x: $(FHEOBJS)
static/temi-party.x: $(FHEOBJS)
static/soho-party.x: $(FHEOBJS) static/soho-party.x: $(FHEOBJS)
static/cowgear-party.x: $(FHEOBJS) static/cowgear-party.x: $(FHEOBJS)
static/chaigear-party.x: $(FHEOBJS) static/chaigear-party.x: $(FHEOBJS)

View File

@@ -14,11 +14,6 @@ using namespace std;
#include "Tools/random.h" #include "Tools/random.h"
#include "field_types.h" #include "field_types.h"
template<class T> class ReplicatedMC;
template<class T> class ReplicatedInput;
template<class T> class ReplicatedPrivateOutput;
template<class T> class Replicated;
template <class T, int L> template <class T, int L>
class FixedVec class FixedVec
{ {

View File

@@ -233,7 +233,7 @@ inline void Zp_Data::Mont_Mult_(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t*
if (mpn_cmp(ans+T,prA,T+1)>=0) if (mpn_cmp(ans+T,prA,T+1)>=0)
{ mpn_sub_fixed_n<T>(z,ans+T,prA); } { mpn_sub_fixed_n<T>(z,ans+T,prA); }
else else
{ inline_mpn_copyi(z,ans+T,T); } { inline_mpn_copyi<T>(z,ans+T); }
#else #else
Mont_Mult(z, x, y, t); Mont_Mult(z, x, y, t);
#endif #endif

View File

@@ -18,15 +18,21 @@ bool gf2n_<U>::useC;
word gf2n_short_table[256][256]; word gf2n_short_table[256][256];
#define num_2_fields 6 #define num_2_fields 7
/* Require /* Require
* 2*(n-1)-64+t1<64 * 2*(n-1)-64+t1<64
*/ */
int fields_2[num_2_fields][4] = { int fields_2[num_2_fields][4] =
{4,1,0,0},{8,4,3,1},{28,1,0,0},{40,20,15,10},{63,1,0,0},{128,7,2,1}, {
}; { 4, 1, 0, 0 },
{ 8, 4, 3, 1 },
{ 16, 5, 3, 1 },
{ 28, 1, 0, 0 },
{ 40, 20, 15, 10 },
{ 63, 1, 0, 0 },
{ 128, 7, 2, 1 },
};
template<class U> template<class U>
void gf2n_<U>::init_tables() void gf2n_<U>::init_tables()

View File

@@ -24,6 +24,12 @@ inline void inline_mpn_copyi(mp_limb_t* dest, const mp_limb_t* src, mp_size_t si
avx_memcpy(dest, src, size * sizeof(mp_limb_t)); avx_memcpy(dest, src, size * sizeof(mp_limb_t));
} }
template<int N>
inline void inline_mpn_copyi(mp_limb_t* dest, const mp_limb_t* src)
{
avx_memcpy<N * sizeof(mp_limb_t)>(dest, src);
}
inline void debug_print(const char* name, const mp_limb_t* x, int n) inline void debug_print(const char* name, const mp_limb_t* x, int n)
{ {
(void)name, (void)x, (void)n; (void)name, (void)x, (void)n;

View File

@@ -542,6 +542,7 @@ public:
int other_player_num() const { return P.get_player(offset); } int other_player_num() const { return P.get_player(offset); }
int num_players() const { return 2; } int num_players() const { return 2; }
int get_offset() const { return offset; } int get_offset() const { return offset; }
Player& get_full_player() const { return P; }
void send(octetStream& o) const { P.send_to(P.get_player(offset), o); } void send(octetStream& o) const { P.send_to(P.get_player(offset), o); }
void reverse_send(octetStream& o) const { P.send_to(P.get_player(-offset), o); } void reverse_send(octetStream& o) const { P.send_to(P.get_player(-offset), o); }

View File

@@ -206,6 +206,18 @@ void BaseOT::exec_base(bool new_receiver_inputs)
receiver_outputs[i + j].set_byte(k, receiver_keys[j][k]); receiver_outputs[i + j].set_byte(k, receiver_keys[j][k]);
} }
} }
#ifdef BASE_OT_DEBUG
for (j = 0; j < 4; j++)
for (k = 0; k < AES_BLK_SIZE; k++)
{
printf("%4d-th receiver key:", i+j);
for (k = 0; k < HASHBYTES; k++) printf("%.2X", receiver_keys[j][k]);
printf("\n");
}
printf("\n");
#endif
} }
} }
@@ -244,12 +256,6 @@ void BaseOT::exec_base(bool new_receiver_inputs)
for (k = 0; k < HASHBYTES; k++) printf("%.2X", sender_keys[1][j][k]); for (k = 0; k < HASHBYTES; k++) printf("%.2X", sender_keys[1][j][k]);
printf("\n"); printf("\n");
} }
if (ot_role & RECEIVER)
{
printf("%4d-th receiver key:", i+j);
for (k = 0; k < HASHBYTES; k++) printf("%.2X", receiver_keys[j][k]);
printf("\n");
}
} }
printf("\n"); printf("\n");

View File

@@ -25,7 +25,7 @@ void Binary_File_IO::write_to_file(const string filename,
if (start_pos != -1) if (start_pos != -1)
{ {
long write_pos = start_pos * T::size(); long write_pos = file_signature<T>().get_total_length() + start_pos * T::size();
// fill with zeros if needed // fill with zeros if needed
for (long i = outf.tellp(); i < write_pos; i++) for (long i = outf.tellp(); i < write_pos; i++)
outf.put(0); outf.put(0);
@@ -50,10 +50,13 @@ void Binary_File_IO::read_from_file(const string filename, vector< T >& buffer,
inf.open(filename, ios::in | ios::binary); inf.open(filename, ios::in | ios::binary);
if (inf.fail()) { throw file_missing(filename, "Binary_File_IO.read_from_file expects this file to exist."); } if (inf.fail()) { throw file_missing(filename, "Binary_File_IO.read_from_file expects this file to exist."); }
check_file_signature<T>(inf, filename).get_length();
auto data_start = inf.tellg();
int size_in_bytes = T::size() * buffer.size(); int size_in_bytes = T::size() * buffer.size();
int n_read = 0; int n_read = 0;
char read_buffer[size_in_bytes]; char read_buffer[size_in_bytes];
inf.seekg(start_posn * T::size()); inf.seekg(start_posn * T::size(), iostream::cur);
do do
{ {
inf.read(read_buffer + n_read, size_in_bytes - n_read); inf.read(read_buffer + n_read, size_in_bytes - n_read);
@@ -62,7 +65,9 @@ void Binary_File_IO::read_from_file(const string filename, vector< T >& buffer,
if (inf.eof()) if (inf.eof())
{ {
stringstream ss; stringstream ss;
ss << "Got to EOF when reading from disk (expecting " << size_in_bytes << " bytes)."; ss << "Got to EOF when reading from disk (expecting " << size_in_bytes
<< " bytes from " << (long(data_start) + start_posn * T::size())
<< ").";
throw file_error(ss.str()); throw file_error(ss.str());
} }
if (inf.fail()) if (inf.fail())
@@ -74,7 +79,7 @@ void Binary_File_IO::read_from_file(const string filename, vector< T >& buffer,
} }
while (n_read < size_in_bytes); while (n_read < size_in_bytes);
end_posn = inf.tellg() / T::size(); end_posn = (inf.tellg() - data_start) / T::size();
assert (end_posn == start_posn + int(buffer.size())); assert (end_posn == start_posn + int(buffer.size()));
//Check if at end of file by getting 1 more char. //Check if at end of file by getting 1 more char.

View File

@@ -32,6 +32,15 @@ protected:
Buffer<typename T::clear, typename T::clear> buffer; Buffer<typename T::clear, typename T::clear> buffer;
Timer timer; Timer timer;
// Send my inputs (not generally available)
virtual void send_mine() { throw not_implemented(); }
// Get share for next input of mine (not generally available)
virtual T finalize_mine() { throw not_implemented(); }
// Store share for next input from ``player`` from buffer ``o``
// in ``target`` (not generally available)
virtual void finalize_other(int, T&, octetStream&, int = -1)
{ throw not_implemented(); }
public: public:
vector<octetStream> os; vector<octetStream> os;
int values_input; int values_input;
@@ -61,18 +70,12 @@ public:
/// Schedule input from other player /// Schedule input from other player
virtual void add_other(int player, int n_bits = -1) = 0; virtual void add_other(int player, int n_bits = -1) = 0;
/// Schedule input from all players /// Schedule input from all players
void add_from_all(const clear& input, int n_bits = -1); void add_from_all(const typename T::open_type& input, int n_bits = -1);
/// Send my inputs
virtual void send_mine() = 0;
/// Run input protocol for all players /// Run input protocol for all players
virtual void exchange(); virtual void exchange();
/// Get share for next input of mine /// Get share for next input from ``player``
virtual T finalize_mine() = 0;
/// Store share for next input from ``player`` from buffer ``o`` in ``target``
virtual void finalize_other(int player, T& target, octetStream& o, int n_bits = -1) = 0;
/// Get share for next input from ``player`
virtual T finalize(int player, int n_bits = -1); virtual T finalize(int player, int n_bits = -1);
void raw_input(SubProcessor<T>& proc, const vector<int>& args, int size); void raw_input(SubProcessor<T>& proc, const vector<int>& args, int size);

View File

@@ -113,7 +113,7 @@ void Input<T>::add_other(int player, int)
} }
template<class T> template<class T>
void InputBase<T>::add_from_all(const clear& input, int n_bits) void InputBase<T>::add_from_all(const typename T::open_type& input, int n_bits)
{ {
for (int i = 0; i < P->num_players(); i++) for (int i = 0; i < P->num_players(); i++)
if (i == P->my_num()) if (i == P->my_num())

View File

@@ -106,6 +106,7 @@ enum
MATMULSM = 0xAB, MATMULSM = 0xAB,
CONV2DS = 0xAC, CONV2DS = 0xAC,
CHECK = 0xAF, CHECK = 0xAF,
PRIVATEOUTPUT = 0xAD,
// Data access // Data access
TRIPLE = 0x50, TRIPLE = 0x50,
BIT = 0x51, BIT = 0x51,
@@ -127,6 +128,7 @@ enum
INPUTMIXEDREG = 0xF3, INPUTMIXEDREG = 0xF3,
RAWINPUT = 0xF4, RAWINPUT = 0xF4,
INPUTPERSONAL = 0xF5, INPUTPERSONAL = 0xF5,
SENDPERSONAL = 0xF6,
STARTINPUT = 0x61, STARTINPUT = 0x61,
STOPINPUT = 0x62, STOPINPUT = 0x62,
READSOCKETC = 0x63, READSOCKETC = 0x63,

View File

@@ -200,14 +200,17 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case USE: case USE:
case USE_INP: case USE_INP:
case USE_EDABIT: case USE_EDABIT:
case DIGESTC:
case INPUTMASK:
case GINPUTMASK:
get_ints(r, s, 2);
n = get_int(s);
break;
case STARTPRIVATEOUTPUT: case STARTPRIVATEOUTPUT:
case GSTARTPRIVATEOUTPUT: case GSTARTPRIVATEOUTPUT:
case STOPPRIVATEOUTPUT: case STOPPRIVATEOUTPUT:
case GSTOPPRIVATEOUTPUT: case GSTOPPRIVATEOUTPUT:
case DIGESTC: throw runtime_error("two-stage private output not supported any more");
get_ints(r, s, 2);
n = get_int(s);
break;
case USE_MATMUL: case USE_MATMUL:
get_ints(r, s, 3); get_ints(r, s, 3);
n = get_int(s); n = get_int(s);
@@ -237,8 +240,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case PRINTREGB: case PRINTREGB:
case GPRINTREG: case GPRINTREG:
case LDINT: case LDINT:
case INPUTMASK:
case GINPUTMASK:
case INV2M: case INV2M:
case CONDPRINTSTR: case CONDPRINTSTR:
case CONDPRINTSTRB: case CONDPRINTSTRB:
@@ -290,6 +291,8 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case RAWINPUT: case RAWINPUT:
case GRAWINPUT: case GRAWINPUT:
case INPUTPERSONAL: case INPUTPERSONAL:
case SENDPERSONAL:
case PRIVATEOUTPUT:
case TRUNC_PR: case TRUNC_PR:
case RUN_TAPE: case RUN_TAPE:
num_var_args = get_int(s); num_var_args = get_int(s);
@@ -599,6 +602,7 @@ int BaseInstruction::get_reg_type() const
case PUBINPUT: case PUBINPUT:
case FLOATOUTPUT: case FLOATOUTPUT:
case READSOCKETC: case READSOCKETC:
case PRIVATEOUTPUT:
return CINT; return CINT;
default: default:
if (is_gf2n_instruction()) if (is_gf2n_instruction())
@@ -738,10 +742,16 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
skip = 1; skip = 1;
break; break;
case INPUTPERSONAL: case INPUTPERSONAL:
case PRIVATEOUTPUT:
size_offset = -2; size_offset = -2;
offset = 2; offset = 2;
skip = 4; skip = 4;
break; break;
case SENDPERSONAL:
size_offset = -2;
offset = 2;
skip = 5;
break;
case READSOCKETS: case READSOCKETS:
case READSOCKETC: case READSOCKETC:
case READSOCKETINT: case READSOCKETINT:
@@ -939,13 +949,11 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
break; break;
case INPUTMASK: case INPUTMASK:
Procp.DataF.get_input(Proc.get_Sp_ref(r[0]), Proc.temp.rrp, n); Procp.DataF.get_input(Proc.get_Sp_ref(r[0]), Proc.temp.rrp, n);
if (n == Proc.P.my_num()) Proc.write_Cp(r[1], Proc.temp.rrp);
Proc.temp.rrp.output(Proc.private_output, false);
break; break;
case GINPUTMASK: case GINPUTMASK:
Proc2.DataF.get_input(Proc.get_S2_ref(r[0]), Proc.temp.ans2, n); Proc2.DataF.get_input(Proc.get_S2_ref(r[0]), Proc.temp.ans2, n);
if (n == Proc.P.my_num()) Proc.write_C2(r[1], Proc.temp.ans2);
Proc.temp.ans2.output(Proc.private_output, false);
break; break;
case INPUT: case INPUT:
sint::Input::template input<IntInput<typename sint::clear>>(Proc.Procp, start, size); sint::Input::template input<IntInput<typename sint::clear>>(Proc.Procp, start, size);
@@ -974,6 +982,12 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
case INPUTPERSONAL: case INPUTPERSONAL:
Proc.Procp.input_personal(start); Proc.Procp.input_personal(start);
return; return;
case SENDPERSONAL:
Proc.Procp.send_personal(start);
return;
case PRIVATEOUTPUT:
Proc.Procp.private_output(start);
return;
// Note: Fp version has different semantics for NOTC than GNOTC // Note: Fp version has different semantics for NOTC than GNOTC
case NOTC: case NOTC:
to_bigint(Proc.temp.aa, Proc.read_Cp(r[1])); to_bigint(Proc.temp.aa, Proc.read_Cp(r[1]));
@@ -1202,18 +1216,6 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
Proc.binary_output.write((char*) &tmp, sizeof(double)); Proc.binary_output.write((char*) &tmp, sizeof(double));
} }
break; break;
case STARTPRIVATEOUTPUT:
Proc.privateOutputp.start(n,r[0],r[1]);
break;
case GSTARTPRIVATEOUTPUT:
Proc.privateOutput2.start(n,r[0],r[1]);
break;
case STOPPRIVATEOUTPUT:
Proc.privateOutputp.stop(n,r[0],r[1]);
break;
case GSTOPPRIVATEOUTPUT:
Proc.privateOutput2.stop(n,r[0],r[1]);
break;
case PREP: case PREP:
Procp.DataF.get(Proc.Procp.get_S(), r, start, size); Procp.DataF.get(Proc.Procp.get_S(), r, start, size);
return; return;

View File

@@ -97,12 +97,19 @@ Machine<sint, sgf2n>::Machine(int my_number, Names& playerNames,
// initialize persistence if necessary // initialize persistence if necessary
for (auto& prog : progs) for (auto& prog : progs)
{ {
if (prog.writes_persistance) if (prog.writes_persistence)
{ {
string filename = Binary_File_IO::filename(my_number); string filename = Binary_File_IO::filename(my_number);
ifstream pers(filename); ifstream pers(filename);
if (pers.fail()) try
{
check_file_signature<sint>(pers, filename);
}
catch (signature_mismatch&)
{
ofstream pers(filename, ios::binary); ofstream pers(filename, ios::binary);
file_signature<sint>().output(pers);
}
break; break;
} }
} }
@@ -418,12 +425,14 @@ void Machine<sint, sgf2n>::run()
cerr << "Full broadcast" << endl; cerr << "Full broadcast" << endl;
#endif #endif
#ifdef CHOP_MEMORY
// Reduce memory size to speed up // Reduce memory size to speed up
unsigned max_size = 1 << 20; unsigned max_size = 1 << 20;
if (M2.size_s() > max_size) if (M2.size_s() > max_size)
M2.resize_s(max_size); M2.resize_s(max_size);
if (Mp.size_s() > max_size) if (Mp.size_s() > max_size)
Mp.resize_s(max_size); Mp.resize_s(max_size);
#endif
// Write out the memory to use next time // Write out the memory to use next time
ofstream outf(memory_filename(), ios::out | ios::binary); ofstream outf(memory_filename(), ios::out | ios::binary);

View File

@@ -44,9 +44,9 @@ class Memory
static void check_index(const vector<U>& M, size_t i) static void check_index(const vector<U>& M, size_t i)
{ {
(void) M, (void) i; (void) M, (void) i;
#ifdef NO_CHECK_INDEX #ifndef NO_CHECK_INDEX
if (i >= M.size()) if (i >= M.size())
throw overflow("memory", i, M.size()); throw overflow(U::type_string() + " memory", i, M.size());
#endif #endif
} }

View File

@@ -19,6 +19,9 @@ void MemoryPart<T>::minimum_size(size_t size)
{ {
if (size > this->size()) if (size > this->size())
this->resize(size); this->resize(size);
#ifdef DEBUG_MEMORY_SIZE
cerr << T::type_string() << " memory has now size " << this->size() << endl;
#endif
} }
catch (bad_alloc&) catch (bad_alloc&)
{ {
@@ -58,9 +61,9 @@ istream& operator>>(istream& s,Memory<T>& M)
int len; int len;
s >> len; s >> len;
M.resize_s(len); M.MS.minimum_size(len);
s >> len; s >> len;
M.resize_c(len); M.MC.minimum_size(len);
s.seekg(1, istream::cur); s.seekg(1, istream::cur);
for (unsigned int i=0; i<M.MS.size(); i++) for (unsigned int i=0; i<M.MS.size(); i++)

View File

@@ -17,16 +17,16 @@ class PrivateOutput
typedef typename T::open_type open_type; typedef typename T::open_type open_type;
SubProcessor<T>& proc; SubProcessor<T>& proc;
typename T::MAC_Check MC;
deque<open_type> masks; deque<open_type> masks;
public: public:
PrivateOutput(SubProcessor<T>& proc) : proc(proc) { }; PrivateOutput(SubProcessor<T>& proc);
~PrivateOutput();
void start(int player, int target, int source); void prepare_sending(const T& source, int player);
void stop(int player, int dest, int source); void exchange();
typename T::clear finalize(int player);
T start(int player, const T& source);
typename T::clear stop(int player, const typename T::clear& masked);
}; };
#endif /* PROCESSOR_PRIVATEOUTPUT_H_ */ #endif /* PROCESSOR_PRIVATEOUTPUT_H_ */

View File

@@ -7,13 +7,21 @@
#include "Processor.h" #include "Processor.h"
template<class T> template<class T>
void PrivateOutput<T>::start(int player, int target, int source) PrivateOutput<T>::PrivateOutput(SubProcessor<T>& proc) :
proc(proc), MC(proc.MC.get_alphai())
{ {
proc.get_S_ref(target) = start(player, proc.get_S_ref(source)); MC.init_open(proc.P);
MC.set_prep(proc.DataF);
} }
template<class T> template<class T>
T PrivateOutput<T>::start(int player, const T& source) PrivateOutput<T>::~PrivateOutput()
{
MC.Check(proc.P);
}
template<class T>
void PrivateOutput<T>::prepare_sending(const T& source, int player)
{ {
assert (player < proc.P.num_players()); assert (player < proc.P.num_players());
open_type mask; open_type mask;
@@ -24,26 +32,25 @@ T PrivateOutput<T>::start(int player, const T& source)
if (player == proc.P.my_num()) if (player == proc.P.my_num())
masks.push_back(mask); masks.push_back(mask);
return res; MC.prepare_open(res);
} }
template<class T> template<class T>
void PrivateOutput<T>::stop(int player, int dest, int source) void PrivateOutput<T>::exchange()
{ {
auto& value = proc.get_C_ref(dest); MC.exchange(proc.P);
value = stop(player, proc.get_C_ref(source));
if (proc.Proc)
value.output(proc.Proc->private_output, false);
} }
template<class T> template<class T>
typename T::clear PrivateOutput<T>::stop(int player, const typename T::clear& source) typename T::clear PrivateOutput<T>::finalize(int player)
{ {
typename T::clear value; auto res = MC.finalize_open();
if (player == proc.P.my_num()) if (player == proc.P.my_num())
{ {
value = source - masks.front(); res -= masks.front();
masks.pop_front(); masks.pop_front();
} }
return value;
return res;
} }

View File

@@ -71,6 +71,8 @@ public:
void conv2ds(const Instruction& instruction); void conv2ds(const Instruction& instruction);
void input_personal(const vector<int>& args); void input_personal(const vector<int>& args);
void send_personal(const vector<int>& args);
void private_output(const vector<int>& args);
CheckVector<T>& get_S() CheckVector<T>& get_S()
{ {
@@ -110,7 +112,6 @@ public:
ifstream private_input; ifstream private_input;
ifstream public_input; ifstream public_input;
ofstream public_output; ofstream public_output;
ofstream private_output;
ofstream binary_output; ofstream binary_output;
int sent, rounds; int sent, rounds;
@@ -172,9 +173,6 @@ class Processor : public ArithmeticProcessor
SubProcessor<sgf2n> Proc2; SubProcessor<sgf2n> Proc2;
SubProcessor<sint> Procp; SubProcessor<sint> Procp;
typename sgf2n::PrivateOutput privateOutput2;
typename sint::PrivateOutput privateOutputp;
unsigned int PC; unsigned int PC;
TempVars<sint, sgf2n> temp; TempVars<sint, sgf2n> temp;

View File

@@ -4,9 +4,8 @@
#include "Processor/Processor.h" #include "Processor/Processor.h"
#include "Processor/Program.h" #include "Processor/Program.h"
#include "GC/square64.h" #include "GC/square64.h"
#include "SpecificPrivateOutput.h"
#include "Protocols/ReplicatedInput.hpp"
#include "Protocols/ReplicatedPrivateOutput.hpp"
#include "Processor/ProcessorBase.hpp" #include "Processor/ProcessorBase.hpp"
#include "GC/Processor.hpp" #include "GC/Processor.hpp"
#include "GC/ShareThread.hpp" #include "GC/ShareThread.hpp"
@@ -63,7 +62,6 @@ Processor<sint, sgf2n>::Processor(int thread_num,Player& P,
share_thread(DataF.DataFb, P, machine.get_bit_mac_key()), share_thread(DataF.DataFb, P, machine.get_bit_mac_key()),
Procb(machine.bit_memories), Procb(machine.bit_memories),
Proc2(*this,MC2,DataF.DataF2,P),Procp(*this,MCp,DataF.DataFp,P), Proc2(*this,MC2,DataF.DataF2,P),Procp(*this,MCp,DataF.DataFp,P),
privateOutput2(Proc2),privateOutputp(Procp),
external_clients(P.my_num()), external_clients(P.my_num()),
binary_file_io(Binary_File_IO()) binary_file_io(Binary_File_IO())
{ {
@@ -74,7 +72,6 @@ Processor<sint, sgf2n>::Processor(int thread_num,Player& P,
private_input_filename = (get_filename(PREP_DIR "Private-Input-",true)); private_input_filename = (get_filename(PREP_DIR "Private-Input-",true));
private_input.open(private_input_filename.c_str()); private_input.open(private_input_filename.c_str());
public_output.open(get_filename(PREP_DIR "Public-Output-",true).c_str(), ios_base::out); public_output.open(get_filename(PREP_DIR "Public-Output-",true).c_str(), ios_base::out);
private_output.open(get_filename(PREP_DIR "Private-Output-",true).c_str(), ios_base::out);
binary_output.open( binary_output.open(
get_parameterized_filename(P.my_num(), thread_num, get_parameterized_filename(P.my_num(), thread_num,
PREP_DIR "Binary-Output"), ios_base::out); PREP_DIR "Binary-Output"), ios_base::out);
@@ -654,6 +651,37 @@ void SubProcessor<T>::input_personal(const vector<int>& args)
S[args[i + 2] + j] = input.finalize(args[i + 1]); S[args[i + 2] + j] = input.finalize(args[i + 1]);
} }
template<class T>
void SubProcessor<T>::private_output(const vector<int>& args)
{
typename T::PrivateOutput output(*this);
for (size_t i = 0; i < args.size(); i += 4)
for (int j = 0; j < args[i]; j++)
{
int player = args[i + 1];
output.prepare_sending(S.at(args[i + 3] + j), player);
}
output.exchange();
for (size_t i = 0; i < args.size(); i += 4)
for (int j = 0; j < args[i]; j++)
C.at(args[i + 2] + j) = output.finalize(args[i + 1]);
}
template<class T>
void SubProcessor<T>::send_personal(const vector<int>& args)
{
octetStreams to_send(P), to_receive(P);
for (size_t i = 0; i < args.size(); i += 5)
if (args[i + 3] == P.my_num())
for (int j = 0; j < args[i]; j++)
C[args[i + 4] + j].pack(to_send[args[i + 1]]);
P.send_receive_all(to_send, to_receive);
for (size_t i = 0; i < args.size(); i += 5)
if (args[i + 1] == P.my_num())
for (int j = 0; j < args[i]; j++)
C[args[i + 2] + j].unpack(to_receive[args[i + 3]]);
}
template<class sint, class sgf2n> template<class sint, class sgf2n>
typename sint::clear Processor<sint, sgf2n>::get_inverse2(unsigned m) typename sint::clear Processor<sint, sgf2n>::get_inverse2(unsigned m)
{ {

View File

@@ -23,7 +23,7 @@ void Program::compute_constants()
max_mem[reg_type] = max(max_mem[reg_type], max_mem[reg_type] = max(max_mem[reg_type],
p[i].get_mem(RegType(reg_type))); p[i].get_mem(RegType(reg_type)));
} }
writes_persistance |= p[i].opcode == WRITEFILESHARE; writes_persistence |= p[i].opcode == WRITEFILESHARE;
} }
} }

View File

@@ -30,10 +30,10 @@ class Program
public: public:
bool writes_persistance; bool writes_persistence;
Program(int nplayers) : offline_data_used(nplayers), Program(int nplayers) : offline_data_used(nplayers),
unknown_usage(false), writes_persistance(false) unknown_usage(false), writes_persistence(false)
{ compute_constants(); } { compute_constants(); }
// Read in a program // Read in a program

View File

@@ -0,0 +1,65 @@
/*
* SpecificPrivateOutput.h
*
*/
#ifndef PROCESSOR_SPECIFICPRIVATEOUTPUT_H_
#define PROCESSOR_SPECIFICPRIVATEOUTPUT_H_
template<class T>
class SpecificPrivateOutput
{
deque<T> secrets;
vector<typename T::PO*> pos;
Player& P;
vector<bool> active;
public:
SpecificPrivateOutput(SubProcessor<T>& proc) :
P(proc.P)
{
for (int i = 0; i < P.num_players(); i++)
pos.push_back(new typename T::PO(proc.P));
active.resize(P.num_players());
}
~SpecificPrivateOutput()
{
for (auto& x : pos)
delete x;
}
void prepare_sending(const T& secret, int player)
{
pos[player]->prepare_sending(secret, player);
if (P.my_num() == player)
secrets.push_back(secret);
active[player] = true;
}
void exchange()
{
for (int i = 0; i < this->P.num_players(); i++)
if (active[i])
{
if (i == this->P.my_num())
pos[i]->receive();
else
pos[i]->send(i);
}
}
typename T::clear finalize(int player)
{
if (player == this->P.my_num())
{
T secret = secrets.front();
secrets.pop_front();
return pos[player]->finalize(secret);
}
else
return {};
}
};
#endif /* PROCESSOR_SPECIFICPRIVATEOUTPUT_H_ */

View File

@@ -0,0 +1,100 @@
from Compiler.ml import keras
import Compiler.ml as tf
try:
n_epochs = int(program.args[1])
except (ValueError, IndexError):
n_epochs = 10
try:
batch_size = int(program.args[2])
except (ValueError, IndexError):
batch_size = 128
try:
n_threads = int(program.args[3])
except (ValueError, IndexError):
n_threads = 36
#Instantiation
AlexNet = []
padding = 'same'
batchnorm = 'batchnorm' in program.args
#1st Convolutional Layer
AlexNet.append(keras.layers.Conv2D(filters=96, input_shape=(32,32,3), kernel_size=(11,11), strides=(4,4), padding=9))
AlexNet.append(keras.layers.Activation('relu'))
AlexNet.append(keras.layers.MaxPooling2D(pool_size=3, strides=(2,2)))
if batchnorm:
AlexNet.append(keras.layers.BatchNormalization())
#2nd Convolutional Layer
AlexNet.append(keras.layers.Conv2D(filters=256, kernel_size=(5, 5), strides=(1,1), padding=1))
AlexNet.append(keras.layers.Activation('relu'))
if batchnorm:
AlexNet.append(keras.layers.BatchNormalization())
AlexNet.append(keras.layers.MaxPooling2D(pool_size=(2,2), strides=1))
#3rd Convolutional Layer
AlexNet.append(keras.layers.Conv2D(filters=384, kernel_size=(3,3), strides=(1,1), padding=1))
AlexNet.append(keras.layers.Activation('relu'))
#4th Convolutional Layer
AlexNet.append(keras.layers.Conv2D(filters=384, kernel_size=(3,3), strides=(1,1), padding=1))
AlexNet.append(keras.layers.Activation('relu'))
#5th Convolutional Layer
AlexNet.append(keras.layers.Conv2D(filters=256, kernel_size=(3,3), strides=(1,1), padding=1))
AlexNet.append(keras.layers.Activation('relu'))
#Passing it to a Fully Connected layer
# 1st Fully Connected Layer
AlexNet.append(keras.layers.Dense(256))
AlexNet.append(keras.layers.Activation('relu'))
if 'dropout' in program.args:
AlexNet.append(keras.layers.Dropout(0.5))
#2nd Fully Connected Layer
AlexNet.append(keras.layers.Dense(256))
AlexNet.append(keras.layers.Activation('relu'))
if 'dropout' in program.args:
AlexNet.append(keras.layers.Dropout(0.5))
#Output Layer
AlexNet.append(keras.layers.Dense(10))
tf.set_n_threads(n_threads)
program.options_from_args()
sfix.set_precision_from_args(program, adapt_ring=True)
training_samples = MultiArray([50000, 32, 32, 3], sfix)
training_labels = MultiArray([50000, 10], sint)
test_samples = MultiArray([10000, 32, 32, 3], sfix)
test_labels = MultiArray([10000, 10], sint)
if 'no_acc' not in program.args:
training_labels.input_from(0)
training_samples.input_from(0)
test_labels.input_from(0)
test_samples.input_from(0)
model = tf.keras.models.Sequential(AlexNet)
model.compile_by_args(program)
model.build(training_samples.sizes)
model.summary()
opt = model.fit(
training_samples,
training_labels,
epochs=n_epochs,
batch_size=batch_size,
validation_data=(test_samples, test_labels)
)

View File

@@ -0,0 +1,45 @@
# this trains LeNet on MNIST with a dropout layer
# see https://github.com/csiro-mlai/mnist-mpc for data preparation
program.options_from_args()
training_samples = MultiArray([50000, 32, 32, 3], sfix)
training_labels = MultiArray([50000, 10], sint)
test_samples = MultiArray([10000, 32, 32, 3], sfix)
test_labels = MultiArray([10000, 10], sint)
training_labels.input_from(0)
training_samples.input_from(0)
test_labels.input_from(0)
test_samples.input_from(0)
from Compiler import ml
tf = ml
ml.set_n_threads(36)
layers = [
tf.keras.layers.Conv2D(20, 5, 1, 'valid', activation='relu'),
tf.keras.layers.MaxPooling2D(2),
tf.keras.layers.Conv2D(50, 5, 1, 'valid', activation='relu'),
tf.keras.layers.MaxPooling2D(2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(500, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
]
model = tf.keras.models.Sequential(layers)
optim = tf.keras.optimizers.Adam(amsgrad=True)
model.compile(optimizer=optim)
opt = model.fit(
training_samples,
training_labels,
epochs=10,
batch_size=128,
validation_data=(test_samples, test_labels)
)

View File

@@ -21,7 +21,8 @@ tf = ml
layers = [ layers = [
tf.keras.layers.Flatten(), tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dense(128),
tf.keras.layers.Activation('relu'),
tf.keras.layers.Dense(10, activation='softmax') tf.keras.layers.Dense(10, activation='softmax')
] ]

View File

@@ -20,8 +20,21 @@ tf = ml
layers = [ layers = [
tf.keras.layers.Conv2D(20, 5, 1, 'valid', activation='relu'), tf.keras.layers.Conv2D(20, 5, 1, 'valid', activation='relu'),
]
if 'batchnorm' in program.args:
layers += [tf.keras.layers.BatchNormalization()]
layers += [
tf.keras.layers.MaxPooling2D(2), tf.keras.layers.MaxPooling2D(2),
tf.keras.layers.Conv2D(50, 5, 1, 'valid', activation='relu'), tf.keras.layers.Conv2D(50, 5, 1, 'valid', activation='relu'),
]
if 'batchnorm' in program.args:
layers += [tf.keras.layers.BatchNormalization()]
layers += [
tf.keras.layers.MaxPooling2D(2), tf.keras.layers.MaxPooling2D(2),
tf.keras.layers.Flatten(), tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.5), tf.keras.layers.Dropout(0.5),

View File

@@ -21,6 +21,8 @@ elif 'debug' in program.args:
n_test = 100 n_test = 100
elif 'debug5000' in program.args: elif 'debug5000' in program.args:
N = n_test = 5000 N = n_test = 5000
elif 'mini' in program.args:
N = n_test = 10
else: else:
N = 60000 N = 60000
n_test = 10000 n_test = 10000
@@ -39,6 +41,7 @@ except:
batch_size = N batch_size = N
N = min(N, 10000) N = min(N, 10000)
batch_size = min(batch_size, N)
ml.Layer.back_batch_size = batch_size ml.Layer.back_batch_size = batch_size
try: try:
@@ -71,6 +74,9 @@ else:
ml.Dense(N, n_inner, n_inner, activation=activation, debug=debug_ml), ml.Dense(N, n_inner, n_inner, activation=activation, debug=debug_ml),
ml.Dense(N, n_inner, 10, debug=debug_ml)] ml.Dense(N, n_inner, 10, debug=debug_ml)]
if 'batchnorm' in program.args:
layers.insert(1, ml.BatchNorm([N, n_inner]))
if 'dropout' in program.args: if 'dropout' in program.args:
for i in range(len(layers) - 1, 0, -1): for i in range(len(layers) - 1, 0, -1):
layers.insert(i, ml.Dropout(N, n_inner)) layers.insert(i, ml.Dropout(N, n_inner))

View File

@@ -53,7 +53,7 @@ except:
ml.Layer.back_batch_size = batch_size ml.Layer.back_batch_size = batch_size
layers = [ layers = [
ml.FixConv2d([n_examples, 28, 28, 1], (20, 5, 5, 1), (20,), [n_examples, 24, 24, 20], (1, 1), 'VALID'), ml.FixConv2d([n_examples, 28, 28, 1], (20, 5, 5, 1), (20,), [N, 24, 24, 20], (1, 1), 'VALID'),
ml.MaxPool([N, 24, 24, 20]), ml.MaxPool([N, 24, 24, 20]),
ml.Relu([N, 12, 12, 20]), ml.Relu([N, 12, 12, 20]),
ml.FixConv2d([N, 12, 12, 20], (50, 5, 5, 20), (50,), [N, 8, 8, 50], (1, 1), 'VALID'), ml.FixConv2d([N, 12, 12, 20], (50, 5, 5, 20), (50,), [N, 8, 8, 50], (1, 1), 'VALID'),
@@ -66,6 +66,12 @@ layers = [
layers += [ml.MultiOutput.from_args(program, n_examples, 10)] layers += [ml.MultiOutput.from_args(program, n_examples, 10)]
if 'batchnorm' in program.args:
for arg in program.args:
assert not arg.startswith('dropout')
layers.insert(4, ml.BatchNorm([N, 8, 8, 50], args=program.args))
layers.insert(1, ml.BatchNorm([N, 24, 24, 20], args=program.args))
if 'dropout' in program.args or 'dropout2' in program.args: if 'dropout' in program.args or 'dropout2' in program.args:
layers.insert(8, ml.Dropout(N, 500)) layers.insert(8, ml.Dropout(N, 500))
elif 'dropout.25' in program.args: elif 'dropout.25' in program.args:

View File

@@ -85,6 +85,12 @@ void Atlas<T>::exchange()
resharing.add_mine(e); resharing.add_mine(e);
} }
for (size_t i = 0; i < min(masks.size(), size_t(P.num_players())); i++)
{
int j = (base_king + i) % P.num_players();
resharing.add_sender(j);
}
resharing.exchange(); resharing.exchange();
} }

View File

@@ -27,7 +27,7 @@ HemiMatrixPrep<T>& Hemi<T>::get_matrix_prep(const array<int, 3>& dims,
if (matrix_preps.find(dims) == matrix_preps.end()) if (matrix_preps.find(dims) == matrix_preps.end())
matrix_preps.insert({dims, matrix_preps.insert({dims,
new HemiMatrixPrep<T>(dims[0], dims[1], dims[2], new HemiMatrixPrep<T>(dims[0], dims[1], dims[2],
dynamic_cast<HemiPrep<T>&>(processor.DataF))}); dynamic_cast<typename T::LivePrep&>(processor.DataF))});
return *matrix_preps.at(dims); return *matrix_preps.at(dims);
} }

View File

@@ -18,17 +18,18 @@ template<class T>
class HemiMatrixPrep : public BufferPrep<ShareMatrix<T>> class HemiMatrixPrep : public BufferPrep<ShareMatrix<T>>
{ {
typedef BufferPrep<ShareMatrix<T>> super; typedef BufferPrep<ShareMatrix<T>> super;
typedef typename T::LivePrep LivePrep;
int n_rows, n_inner, n_cols; int n_rows, n_inner, n_cols;
bool swapped; bool swapped;
DataPositions* usage; DataPositions* usage;
HemiPrep<T>* prep; LivePrep* prep;
HemiMatrixPrep(const HemiMatrixPrep&) = delete; HemiMatrixPrep(const HemiMatrixPrep&) = delete;
public: public:
HemiMatrixPrep(int n_rows, int n_inner, int n_cols, HemiPrep<T>& prep) : HemiMatrixPrep(int n_rows, int n_inner, int n_cols, LivePrep& prep) :
super(*(usage = new DataPositions)), n_rows(n_rows), n_inner(n_inner), super(*(usage = new DataPositions)), n_rows(n_rows), n_inner(n_inner),
n_cols(n_cols), prep(&prep) n_cols(n_cols), prep(&prep)
{ {

View File

@@ -87,11 +87,10 @@ void HemiMatrixPrep<T>::buffer_triples()
assert(prep); assert(prep);
auto& multipliers = prep->get_multipliers(); auto& multipliers = prep->get_multipliers();
assert(prep->pairwise_machine); auto& FTD = prep->get_FTD();
auto& FTD = prep->pairwise_machine->setup_p.FieldD; auto& pk = prep->get_pk();
auto& pk = prep->pairwise_machine->pk;
int n_matrices = FTD.num_slots() / n_rows; int n_matrices = FTD.num_slots() / n_rows;
#ifdef VERBOSE #ifdef VERBOSE_HE
fprintf(stderr, "creating %d %dx%d * %dx%d triples\n", n_matrices, n_rows, n_inner, fprintf(stderr, "creating %d %dx%d * %dx%d triples\n", n_matrices, n_rows, n_inner,
n_inner, n_cols); n_inner, n_cols);
fflush(stderr); fflush(stderr);
@@ -103,6 +102,8 @@ void HemiMatrixPrep<T>::buffer_triples()
AddableVector<ValueMatrix<gfpvar>> C(n_matrices); AddableVector<ValueMatrix<gfpvar>> C(n_matrices);
MatrixRandMultJob job(C, A, B); MatrixRandMultJob job(C, A, B);
if (T::local_mul)
{
if (BaseMachine::thread_num == 0 and BaseMachine::has_singleton()) if (BaseMachine::thread_num == 0 and BaseMachine::has_singleton())
{ {
auto& queues = BaseMachine::s().queues; auto& queues = BaseMachine::s().queues;
@@ -118,6 +119,7 @@ void HemiMatrixPrep<T>::buffer_triples()
job.end = n_matrices; job.end = n_matrices;
matrix_rand_mult(job); matrix_rand_mult(job);
} }
}
#ifdef VERBOSE_HE #ifdef VERBOSE_HE
fprintf(stderr, "encrypt at %f\n", timer.elapsed()); fprintf(stderr, "encrypt at %f\n", timer.elapsed());
@@ -130,26 +132,35 @@ void HemiMatrixPrep<T>::buffer_triples()
assert(prep->proc); assert(prep->proc);
auto& P = prep->proc->P; auto& P = prep->proc->P;
vector<vector<Ciphertext>> others_ct;
if (T::local_mul or OnlineOptions::singleton.direct)
{
Bundle<octetStream> bundle(P); Bundle<octetStream> bundle(P);
bundle.mine.store(diag.ciphertexts); bundle.mine.store(diag.ciphertexts);
P.unchecked_broadcast(bundle); P.unchecked_broadcast(bundle);
vector<vector<Ciphertext>> others_ct;
for (auto& os : bundle) for (auto& os : bundle)
{ {
others_ct.push_back({}); others_ct.push_back({});
os.get(others_ct.back(), Ciphertext(pk)); os.get(others_ct.back(), Ciphertext(pk));
} }
}
else
{
others_ct.push_back(diag.ciphertexts);
TreeSum<Ciphertext>().run(others_ct[0], P);
}
for (int j = 0; j < n_cols; j++) for (int j = 0; j < n_cols; j++)
for (auto m : multipliers) for (auto m : multipliers)
{ {
#ifdef VERBOSE #ifdef VERBOSE_HE
fprintf(stderr, "column %d with party offset %d at %f\n", j, fprintf(stderr, "column %d with party offset %d at %f\n", j,
m->get_offset(), timer.elapsed()); m->get_offset(), timer.elapsed());
fflush(stderr); fflush(stderr);
#endif #endif
Ciphertext C(pk); Ciphertext C(pk);
auto& multiplicands = others_ct[P.get_player(-m->get_offset())]; auto& multiplicands = m->get_multiplicands(others_ct, pk);
if (BaseMachine::thread_num == 0 and BaseMachine::has_singleton()) if (BaseMachine::thread_num == 0 and BaseMachine::has_singleton())
{ {
auto& queues = BaseMachine::s().queues; auto& queues = BaseMachine::s().queues;
@@ -160,7 +171,7 @@ void HemiMatrixPrep<T>::buffer_triples()
CipherPlainMultJob job(products, multiplicands, multiplicands2, true); CipherPlainMultJob job(products, multiplicands, multiplicands2, true);
int start = queues.distribute(job, n_inner); int start = queues.distribute(job, n_inner);
#ifdef VERBOSE_HE #ifdef VERBOSE_HE
fprintf(stderr, "from %d in central thread\n", start); fprintf(stderr, "from %d in central thread at %f\n", start, timer.elapsed());
fflush(stderr); fflush(stderr);
#endif #endif
for (int i = start; i < n_inner; i++) for (int i = start; i < n_inner; i++)
@@ -185,7 +196,10 @@ void HemiMatrixPrep<T>::buffer_triples()
m->add(products[j], C, BOTH, n_inner); m->add(products[j], C, BOTH, n_inner);
} }
if (T::local_mul)
C += diag.dediag(products, n_matrices); C += diag.dediag(products, n_matrices);
else
C = diag.dediag(products, n_matrices);
for (int i = 0; i < n_matrices; i++) for (int i = 0; i < n_matrices; i++)
if (swapped) if (swapped)

View File

@@ -34,6 +34,9 @@ public:
static void basic_setup(Player& P); static void basic_setup(Player& P);
static void teardown(); static void teardown();
static const FHE_PK& get_pk();
static const FD& get_FTD();
HemiPrep(SubProcessor<T>* proc, DataPositions& usage) : HemiPrep(SubProcessor<T>* proc, DataPositions& usage) :
BufferPrep<T>(usage), BufferPrep<T>(usage),
BitPrep<T>(proc, usage), RingPrep<T>(proc, usage), BitPrep<T>(proc, usage), RingPrep<T>(proc, usage),

View File

@@ -34,6 +34,20 @@ void HemiPrep<T>::basic_setup(Player& P)
T::clear::template init<typename FD::T>(); T::clear::template init<typename FD::T>();
} }
template<class T>
const FHE_PK& HemiPrep<T>::get_pk()
{
assert(pairwise_machine);
return pairwise_machine->pk;
}
template<class T>
const typename T::clear::FD& HemiPrep<T>::get_FTD()
{
assert(pairwise_machine);
return pairwise_machine->setup<FD>().FieldD;
}
template<class T> template<class T>
HemiPrep<T>::~HemiPrep() HemiPrep<T>::~HemiPrep()

View File

@@ -27,6 +27,7 @@ public:
typedef HemiPrep<This> LivePrep; typedef HemiPrep<This> LivePrep;
static const bool needs_ot = false; static const bool needs_ot = false;
static const bool local_mul = true;
static true_type triple_matmul; static true_type triple_matmul;
HemiShare() HemiShare()

View File

@@ -140,12 +140,12 @@ void KeyGenProtocol<X, L>::output_to(int player, vector<open_type>& opened,
vector<share_type>& shares) vector<share_type>& shares)
{ {
PrivateOutput<share_type> po(*proc); PrivateOutput<share_type> po(*proc);
vector<share_type> masked;
for (auto& share : shares) for (auto& share : shares)
masked.push_back(po.start(player, share)); po.prepare_sending(share, player);
MC->POpen(opened, masked, P); po.exchange();
opened.resize(shares.size());
for (auto& x : opened) for (auto& x : opened)
x = po.stop(player, x); x = po.finalize(player);
} }
template<int L> template<int L>

View File

@@ -52,6 +52,7 @@ public:
virtual ~TreeSum(); virtual ~TreeSum();
void run(vector<T>& values, const Player& P); void run(vector<T>& values, const Player& P);
T run(const T& value, const Player& P);
octetStream& get_buffer() { return os; } octetStream& get_buffer() { return os; }
@@ -210,6 +211,14 @@ void TreeSum<T>::run(vector<T>& values, const Player& P)
finish(values, P); finish(values, P);
} }
template<class T>
T TreeSum<T>::run(const T& value, const Player& P)
{
vector<T> values = {value};
run(values, P);
return values[0];
}
template<class T> template<class T>
size_t TreeSum<T>::report_size(ReportType type) size_t TreeSum<T>::report_size(ReportType type)
{ {
@@ -244,14 +253,6 @@ void add_openings(vector<T>& values, const Player& P, int sum_players, int last_
MC.player_timers[sender].start(); MC.player_timers[sender].start();
P.wait_receive(sender, oss[j]); P.wait_receive(sender, oss[j]);
MC.player_timers[sender].stop(); MC.player_timers[sender].stop();
if ((unsigned)oss[j].get_length() < values.size() * T::size())
{
stringstream ss;
ss << "Not enough information received, expected "
<< values.size() * T::size() << " bytes, got "
<< oss[j].get_length();
throw Processor_Error(ss.str());
}
MC.timers[SUM].start(); MC.timers[SUM].start();
for (unsigned int i=0; i<values.size(); i++) for (unsigned int i=0; i<values.size(); i++)
{ {

View File

@@ -127,6 +127,7 @@ void MAC_Check_<U>::Check(const Player& P)
auto& vals = this->vals; auto& vals = this->vals;
auto& macs = this->macs; auto& macs = this->macs;
auto& popen_cnt = this->popen_cnt; auto& popen_cnt = this->popen_cnt;
assert(int(macs.size()) <= popen_cnt);
if (popen_cnt < 10) if (popen_cnt < 10)
{ {

View File

@@ -12,6 +12,8 @@ using namespace std;
#include "Networking/Player.h" #include "Networking/Player.h"
#include "Tools/PointerVector.h" #include "Tools/PointerVector.h"
template<class T> class Preprocessing;
/** /**
* Abstract base class for opening protocols * Abstract base class for opening protocols
*/ */
@@ -61,6 +63,8 @@ public:
virtual void CheckFor(const typename T::open_type& value, const vector<T>& shares, const Player& P); virtual void CheckFor(const typename T::open_type& value, const vector<T>& shares, const Player& P);
virtual const Player& get_check_player(const Player& P) const { return P; } virtual const Player& get_check_player(const Player& P) const { return P; }
virtual void set_prep(Preprocessing<T>&) {}
}; };
#endif /* PROTOCOLS_MAC_CHECK_BASE_H_ */ #endif /* PROTOCOLS_MAC_CHECK_BASE_H_ */

View File

@@ -17,6 +17,7 @@ class MalRepRingShare : public MaliciousRep3Share<SignedZ2<K>>
{ {
typedef SignedZ2<K> T; typedef SignedZ2<K> T;
typedef MaliciousRep3Share<T> super; typedef MaliciousRep3Share<T> super;
typedef MalRepRingShare This;
public: public:
const static int BIT_LENGTH = K; const static int BIT_LENGTH = K;
@@ -26,7 +27,8 @@ public:
typedef HashMaliciousRepMC<MalRepRingShare> MAC_Check; typedef HashMaliciousRepMC<MalRepRingShare> MAC_Check;
typedef MAC_Check Direct_MC; typedef MAC_Check Direct_MC;
typedef ReplicatedInput<MalRepRingShare> Input; typedef ReplicatedInput<MalRepRingShare> Input;
typedef ::PrivateOutput<MalRepRingShare> PrivateOutput; typedef ReplicatedPO<This> PO;
typedef SpecificPrivateOutput<This> PrivateOutput;
typedef MalRepRingPrepWithBits<MalRepRingShare> LivePrep; typedef MalRepRingPrepWithBits<MalRepRingShare> LivePrep;
typedef MaliciousRep3Share<Z2<K + S>> prep_type; typedef MaliciousRep3Share<Z2<K + S>> prep_type;
typedef Z2<S> random_type; typedef Z2<S> random_type;

View File

@@ -13,6 +13,7 @@ template<class T> class Beaver;
template<class T> class MaliciousRepPrepWithBits; template<class T> class MaliciousRepPrepWithBits;
template<class T> class MaliciousRepPO; template<class T> class MaliciousRepPO;
template<class T> class MaliciousRepPrep; template<class T> class MaliciousRepPrep;
template<class T> class SpecificPrivateOutput;
namespace GC namespace GC
{ {
@@ -30,8 +31,8 @@ public:
typedef HashMaliciousRepMC<MaliciousRep3Share<T>> MAC_Check; typedef HashMaliciousRepMC<MaliciousRep3Share<T>> MAC_Check;
typedef MAC_Check Direct_MC; typedef MAC_Check Direct_MC;
typedef ReplicatedInput<MaliciousRep3Share<T>> Input; typedef ReplicatedInput<MaliciousRep3Share<T>> Input;
typedef ::PrivateOutput<MaliciousRep3Share<T>> PrivateOutput;
typedef MaliciousRepPO<MaliciousRep3Share> PO; typedef MaliciousRepPO<MaliciousRep3Share> PO;
typedef SpecificPrivateOutput<This> PrivateOutput;
typedef Rep3Share<T> Honest; typedef Rep3Share<T> Honest;
typedef MaliciousRepPrepWithBits<MaliciousRep3Share> LivePrep; typedef MaliciousRepPrepWithBits<MaliciousRep3Share> LivePrep;
typedef MaliciousRepPrep<MaliciousRep3Share> TriplePrep; typedef MaliciousRepPrep<MaliciousRep3Share> TriplePrep;

View File

@@ -9,13 +9,14 @@
template<class T> template<class T>
class MaliciousShamirPO class MaliciousShamirPO
{ {
protected:
Player& P; Player& P;
octetStream to_send; octetStream to_send;
vector<octetStream> to_receive; vector<octetStream> to_receive;
vector<typename T::open_type> shares; vector<typename T::open_type> shares;
MaliciousShamirMC<T> MC; typename T::Direct_MC MC;
public: public:
MaliciousShamirPO(Player& P); MaliciousShamirPO(Player& P);

View File

@@ -13,6 +13,7 @@
template<class T> class MaliciousRepPrepWithBits; template<class T> class MaliciousRepPrepWithBits;
template<class T> class MaliciousRepPrep; template<class T> class MaliciousRepPrep;
template<class T> class MaliciousShamirPO; template<class T> class MaliciousShamirPO;
template<class T> class SpecificPrivateOutput;
namespace GC namespace GC
{ {
@@ -23,14 +24,15 @@ template<class T>
class MaliciousShamirShare : public ShamirShare<T> class MaliciousShamirShare : public ShamirShare<T>
{ {
typedef ShamirShare<T> super; typedef ShamirShare<T> super;
typedef MaliciousShamirShare This;
public: public:
typedef Beaver<MaliciousShamirShare<T>> Protocol; typedef Beaver<MaliciousShamirShare<T>> Protocol;
typedef MaliciousShamirMC<MaliciousShamirShare> MAC_Check; typedef MaliciousShamirMC<MaliciousShamirShare> MAC_Check;
typedef MAC_Check Direct_MC; typedef MAC_Check Direct_MC;
typedef ShamirInput<MaliciousShamirShare> Input; typedef ShamirInput<MaliciousShamirShare> Input;
typedef ::PrivateOutput<MaliciousShamirShare> PrivateOutput;
typedef MaliciousShamirPO<MaliciousShamirShare> PO; typedef MaliciousShamirPO<MaliciousShamirShare> PO;
typedef SpecificPrivateOutput<This> PrivateOutput;
typedef ShamirShare<T> Honest; typedef ShamirShare<T> Honest;
typedef MaliciousRepPrepWithBits<MaliciousShamirShare> LivePrep; typedef MaliciousRepPrepWithBits<MaliciousShamirShare> LivePrep;
typedef MaliciousRepPrep<MaliciousShamirShare> TriplePrep; typedef MaliciousRepPrep<MaliciousShamirShare> TriplePrep;

View File

@@ -76,12 +76,6 @@ public:
return string(1, T::type_char()); return string(1, T::type_char());
} }
static void read_or_generate_mac_key(string, Player&, mac_key_type& key)
{
SeededPRNG G;
key.randomize(G);
}
MamaShare() MamaShare()
{ {
} }

View File

@@ -15,6 +15,7 @@ template<class T>
class PostSacriRepFieldShare : public MaliciousRep3Share<T> class PostSacriRepFieldShare : public MaliciousRep3Share<T>
{ {
typedef MaliciousRep3Share<T> super; typedef MaliciousRep3Share<T> super;
typedef PostSacriRepFieldShare This;
public: public:
typedef typename super::clear clear; typedef typename super::clear clear;
@@ -23,7 +24,8 @@ public:
typedef HashMaliciousRepMC<PostSacriRepFieldShare> MAC_Check; typedef HashMaliciousRepMC<PostSacriRepFieldShare> MAC_Check;
typedef MAC_Check Direct_MC; typedef MAC_Check Direct_MC;
typedef ReplicatedInput<PostSacriRepFieldShare> Input; typedef ReplicatedInput<PostSacriRepFieldShare> Input;
typedef ::PrivateOutput<PostSacriRepFieldShare> PrivateOutput; typedef ReplicatedPO<This> PO;
typedef SpecificPrivateOutput<This> PrivateOutput;
typedef MaliciousRepPrepWithBits<PostSacriRepFieldShare> LivePrep; typedef MaliciousRepPrepWithBits<PostSacriRepFieldShare> LivePrep;
PostSacriRepFieldShare() PostSacriRepFieldShare()

View File

@@ -17,6 +17,7 @@ template<int K, int S>
class PostSacriRepRingShare : public Rep3Share2<K> class PostSacriRepRingShare : public Rep3Share2<K>
{ {
typedef Rep3Share2<K> super; typedef Rep3Share2<K> super;
typedef PostSacriRepRingShare This;
public: public:
static const int BIT_LENGTH = K; static const int BIT_LENGTH = K;
@@ -33,7 +34,8 @@ public:
typedef HashMaliciousRepMC<PostSacriRepRingShare> MAC_Check; typedef HashMaliciousRepMC<PostSacriRepRingShare> MAC_Check;
typedef MAC_Check Direct_MC; typedef MAC_Check Direct_MC;
typedef ReplicatedInput<PostSacriRepRingShare> Input; typedef ReplicatedInput<PostSacriRepRingShare> Input;
typedef ::PrivateOutput<PostSacriRepRingShare> PrivateOutput; typedef ReplicatedPO<This> PO;
typedef SpecificPrivateOutput<This> PrivateOutput;
typedef MalRepRingPrepWithBits<PostSacriRepRingShare> LivePrep; typedef MalRepRingPrepWithBits<PostSacriRepRingShare> LivePrep;
typedef GC::MaliciousRepSecret bit_type; typedef GC::MaliciousRepSecret bit_type;

View File

@@ -42,8 +42,13 @@ public:
{ {
} }
~ProtocolSet() /**
* Run all protocol checks
*/
void check()
{ {
protocol.check();
output.Check(processor.P);
} }
}; };
@@ -73,6 +78,15 @@ public:
*thread.protocol), input(output, prep, P) *thread.protocol), input(output, prep, P)
{ {
} }
/**
* Run all protocol checks
*/
void check()
{
protocol.check();
output.Check(protocol.P);
}
}; };
/** /**
@@ -102,6 +116,15 @@ public:
arithmetic.protocol), input(arithmetic.input) arithmetic.protocol), input(arithmetic.input)
{ {
} }
/**
* Run all protocol checks
*/
void check()
{
arithmetic.check();
binary.check();
}
}; };
#endif /* PROTOCOLS_PROTOCOLSET_H_ */ #endif /* PROTOCOLS_PROTOCOLSET_H_ */

View File

@@ -15,7 +15,8 @@
template<class T> class ReplicatedPrep; template<class T> class ReplicatedPrep;
template<class T> class ReplicatedRingPrep; template<class T> class ReplicatedRingPrep;
template<class T> class PrivateOutput; template<class T> class ReplicatedPO;
template<class T> class SpecificPrivateOutput;
template<class T, int L> template<class T, int L>
class RepShare : public FixedVec<T, L>, public ShareInterface class RepShare : public FixedVec<T, L>, public ShareInterface
@@ -99,6 +100,7 @@ template<class T>
class Rep3Share : public RepShare<T, 2> class Rep3Share : public RepShare<T, 2>
{ {
typedef RepShare<T, 2> super; typedef RepShare<T, 2> super;
typedef Rep3Share This;
public: public:
typedef T clear; typedef T clear;
@@ -107,7 +109,8 @@ public:
typedef ReplicatedMC<Rep3Share> MAC_Check; typedef ReplicatedMC<Rep3Share> MAC_Check;
typedef MAC_Check Direct_MC; typedef MAC_Check Direct_MC;
typedef ReplicatedInput<Rep3Share> Input; typedef ReplicatedInput<Rep3Share> Input;
typedef ::PrivateOutput<Rep3Share> PrivateOutput; typedef ReplicatedPO<This> PO;
typedef SpecificPrivateOutput<This> PrivateOutput;
typedef ReplicatedPrep<Rep3Share> LivePrep; typedef ReplicatedPrep<Rep3Share> LivePrep;
typedef ReplicatedRingPrep<Rep3Share> TriplePrep; typedef ReplicatedRingPrep<Rep3Share> TriplePrep;
typedef Rep3Share Honest; typedef Rep3Share Honest;

View File

@@ -24,7 +24,8 @@ public:
typedef ReplicatedMC<Rep3Share2> MAC_Check; typedef ReplicatedMC<Rep3Share2> MAC_Check;
typedef MAC_Check Direct_MC; typedef MAC_Check Direct_MC;
typedef ReplicatedInput<Rep3Share2> Input; typedef ReplicatedInput<Rep3Share2> Input;
typedef ::PrivateOutput<Rep3Share2> PrivateOutput; typedef ReplicatedPO<This> PO;
typedef SpecificPrivateOutput<This> PrivateOutput;
typedef ReplicatedPrep2k<Rep3Share2> LivePrep; typedef ReplicatedPrep2k<Rep3Share2> LivePrep;
typedef Rep3Share2 Honest; typedef Rep3Share2 Honest;
typedef SignedZ2<K> clear; typedef SignedZ2<K> clear;

View File

@@ -31,7 +31,6 @@ public:
void add_mine(const typename T::open_type& input, int n_bits = -1); void add_mine(const typename T::open_type& input, int n_bits = -1);
void add_other(int player, int n_bits = -1); void add_other(int player, int n_bits = -1);
void send_mine();
void exchange(); void exchange();
T finalize_mine(); T finalize_mine();

View File

@@ -64,12 +64,6 @@ void Rep4Input<T>::add_other(int player, int)
results[player].push_back(res); results[player].push_back(res);
} }
template<class T>
void Rep4Input<T>::send_mine()
{
throw not_implemented();
}
template<class T> template<class T>
void Rep4Input<T>::exchange() void Rep4Input<T>::exchange()
{ {

View File

@@ -19,10 +19,6 @@ using namespace std;
template<class T> class SubProcessor; template<class T> class SubProcessor;
template<class T> class ReplicatedMC; template<class T> class ReplicatedMC;
template<class T> class ReplicatedInput; template<class T> class ReplicatedInput;
template<class T> class ReplicatedPrivateOutput;
template<class T> class Share;
template<class T> class Rep3Share;
template<class T> class MAC_Check_Base;
template<class T> class Preprocessing; template<class T> class Preprocessing;
class Instruction; class Instruction;
@@ -141,9 +137,6 @@ class Replicated : public ReplicatedBase, public ProtocolBase<T>
void trunc_pr(const vector<int>& regs, int size, U& proc, false_type); void trunc_pr(const vector<int>& regs, int size, U& proc, false_type);
public: public:
typedef ReplicatedMC<T> MAC_Check;
typedef ReplicatedInput<T> Input;
static const bool uses_triples = false; static const bool uses_triples = false;
Replicated(Player& P); Replicated(Player& P);

View File

@@ -10,6 +10,7 @@
#include "Processor/Processor.h" #include "Processor/Processor.h"
#include "Processor/TruncPrTuple.h" #include "Processor/TruncPrTuple.h"
#include "Tools/benchmarking.h" #include "Tools/benchmarking.h"
#include "Tools/Bundle.h"
#include "ReplicatedInput.h" #include "ReplicatedInput.h"
#include "Rep3Share2k.h" #include "Rep3Share2k.h"
@@ -162,14 +163,13 @@ void Replicated<T>::prepare_mul(const T& x,
} }
template<class T> template<class T>
inline void Replicated<T>::prepare_reshare(const typename T::clear& share, void Replicated<T>::prepare_reshare(const typename T::clear& share,
int n) int n)
{ {
auto add_share = share;
typename T::value_type tmp[2]; typename T::value_type tmp[2];
for (int i = 0; i < 2; i++) for (int i = 0; i < 2; i++)
tmp[i].randomize(shared_prngs[i], n); tmp[i].randomize(shared_prngs[i], n);
add_share += tmp[0] - tmp[1]; auto add_share = share + tmp[0] - tmp[1];
add_share.pack(os[0], n); add_share.pack(os[0], n);
add_shares.push_back(add_share); add_shares.push_back(add_share);
} }

View File

@@ -56,16 +56,24 @@ BufferPrep<T>::~BufferPrep()
<< " bit generation" << endl; << " bit generation" << endl;
#endif #endif
auto field_type = T::clear::field_type();
auto& my_usage = this->usage.files.at(field_type);
this->print_left("triples", triples.size() * T::default_length, type_string, this->print_left("triples", triples.size() * T::default_length, type_string,
this->usage.files.at(T::clear::field_type()).at(DATA_TRIPLE) this->usage.files.at(T::clear::field_type()).at(DATA_TRIPLE)
* T::default_length); * T::default_length);
size_t used_bits = my_usage.at(DATA_BIT);
if (not T::clear::invertible and field_type == DATA_INT and not T::has_mac)
// add dabits with computation modulo power of two but without MAC
used_bits += my_usage.at(DATA_DABIT);
this->print_left("bits", bits.size(), type_string, used_bits);
#define X(KIND, TYPE) \ #define X(KIND, TYPE) \
this->print_left(#KIND, KIND.size(), type_string, \ this->print_left(#KIND, KIND.size(), type_string, \
this->usage.files.at(T::clear::field_type()).at(TYPE)); this->usage.files.at(T::clear::field_type()).at(TYPE));
X(squares, DATA_SQUARE) X(squares, DATA_SQUARE)
X(inverses, DATA_INVERSE) X(inverses, DATA_INVERSE)
X(bits, DATA_BIT)
X(dabits, DATA_DABIT) X(dabits, DATA_DABIT)
#undef X #undef X
@@ -601,17 +609,6 @@ void buffer_bits_from_players(vector<vector<T>>& player_bits,
for (int i = 0; i < n_relevant_players; i++) for (int i = 0; i < n_relevant_players; i++)
for (auto& x : player_bits[i]) for (auto& x : player_bits[i])
x = input.finalize((base_player + i) % P.num_players(), n_bits); x = input.finalize((base_player + i) % P.num_players(), n_bits);
#if !defined(__clang__) && (__GNUC__ == 6)
// mitigate compiler bug
Bundle<octetStream> bundle(P);
P.unchecked_broadcast(bundle);
#endif
#ifdef DEBUG_BIT_SACRIFICE
typename T::MAC_Check MC;
for (int i = 0; i < n_relevant_players; i++)
for (auto& x : player_bits[i])
assert((MC.open(x, P) == 0) or (MC.open(x, P) == 1));
#endif
} }
template<class T> template<class T>
@@ -1164,18 +1161,18 @@ void BufferPrep<T>::buffer_inputs_as_usual(int player, SubProcessor<T>* proc)
typename T::clear r; typename T::clear r;
r.randomize(G); r.randomize(G);
input.add_mine(r); input.add_mine(r);
this->inputs[player].push_back({input.finalize_mine(), r}); this->inputs[player].push_back({input.finalize(player), r});
} }
input.send_mine(); input.exchange();
} }
else else
{ {
octetStream os; for (int i = 0; i < buffer_size; i++)
P.receive_player(player, os); input.add_other(player);
T share; input.exchange();
for (int i = 0; i < buffer_size; i++) for (int i = 0; i < buffer_size; i++)
{ {
input.finalize_other(player, share, os); auto share = input.finalize(player);
this->inputs[player].push_back({share, 0}); this->inputs[player].push_back({share, 0});
} }
} }

View File

@@ -1,26 +0,0 @@
/*
* ReplicatedPrivateOutput.h
*
*/
#ifndef PROTOCOLS_REPLICATEDPRIVATEOUTPUT_H_
#define PROTOCOLS_REPLICATEDPRIVATEOUTPUT_H_
template<class T>
class SubProcessor;
template<class T>
class Share;
template <class T>
class ReplicatedPrivateOutput
{
SubProcessor<T>& proc;
public:
ReplicatedPrivateOutput(SubProcessor<T>& proc);
void start(int player, int target, int source);
void stop(int player, int source);
};
#endif /* PROTOCOLS_REPLICATEDPRIVATEOUTPUT_H_ */

View File

@@ -1,30 +0,0 @@
/*
* ReplicatedPrivateOutput.cpp
*
*/
#include "ReplicatedPrivateOutput.h"
#include "Processor/Processor.h"
#include "Math/FixedVec.h"
#include "Math/Integer.h"
template<class T>
inline ReplicatedPrivateOutput<T>::ReplicatedPrivateOutput(
SubProcessor<T>& proc) :
proc(proc)
{
}
template<class T>
void ReplicatedPrivateOutput<T>::start(int player, int target,
int source)
{
(void)player, (void)target, (void)source;
throw runtime_error("not implemented, use PrivateOutput");
}
template<class T>
void ReplicatedPrivateOutput<T>::stop(int player, int source)
{
(void)player, (void)source;
}

View File

@@ -71,6 +71,12 @@ public:
proc.get_S()[info.source_base + i] >> info.m; proc.get_S()[info.source_base + i] >> info.m;
} }
} }
void buffer_random()
{
for (int i = 0; i < OnlineOptions::singleton.batch_size; i++)
this->random.push_back(G.get<T>());
}
}; };
#endif /* PROTOCOLS_SEMI_H_ */ #endif /* PROTOCOLS_SEMI_H_ */

View File

@@ -14,34 +14,33 @@ template<class T> class SemiMC;
* Additive secret sharing input protocol * Additive secret sharing input protocol
*/ */
template<class T> template<class T>
class SemiInput : public IndividualInput<T> class SemiInput : public InputBase<T>
{ {
SeededPRNG secure_prng; vector<SeededPRNG> send_prngs;
vector<PRNG> recv_prngs;
Player& P;
vector<PointerVector<T>> shares;
public: public:
SemiInput(SubProcessor<T>& proc, SemiMC<T>& MC) : SemiInput(SubProcessor<T>& proc, SemiMC<T>&) :
IndividualInput<T>(proc) SemiInput(&proc, proc.P)
{ {
(void) MC;
} }
SemiInput(SubProcessor<T>* proc, Player& P) : SemiInput(SubProcessor<T>* proc, Player& P);
IndividualInput<T>(proc, P)
{
}
SemiInput(typename T::MAC_Check& MC, Preprocessing<T>& prep, Player& P) : SemiInput(typename T::MAC_Check& MC, Preprocessing<T>& prep, Player& P) :
SemiInput(P) SemiInput(0, P)
{ {
(void) MC, (void) prep; (void) MC, (void) prep;
} }
SemiInput(Player& P) : void reset(int player);
IndividualInput<T>(0, P)
{
}
void add_mine(const typename T::clear& input, int n_bits = -1); void add_mine(const typename T::clear& input, int n_bits = -1);
void add_other(int player, int n_bits = -1);
void exchange();
void finalize_other(int player, T& target, octetStream& o, int n_bits = -1);
T finalize_mine();
}; };
#endif /* PROTOCOLS_SEMIINPUT_H_ */ #endif /* PROTOCOLS_SEMIINPUT_H_ */

View File

@@ -11,22 +11,64 @@
#include "ShamirInput.hpp" #include "ShamirInput.hpp"
template<class T> template<class T>
void SemiInput<T>::add_mine(const typename T::clear& input, int n_bits) SemiInput<T>::SemiInput(SubProcessor<T>* proc, Player& P) :
InputBase<T>(proc), P(P)
{
shares.resize(P.num_players());
vector<octetStream> to_send(P.num_players()), to_receive;
for (int i = 0; i < P.num_players(); i++)
{
send_prngs.push_back({});
to_send[i].append(send_prngs.back().get_seed(), SEED_SIZE);
}
P.send_receive_all(to_send, to_receive);
recv_prngs.resize(P.num_players());
for (int i = 0; i < P.num_players(); i++)
if (i != P.my_num())
recv_prngs[i].SetSeed(to_receive[i].consume(SEED_SIZE));
this->reset_all(P);
}
template<class T>
void SemiInput<T>::reset(int player)
{
shares[player].clear();
}
template<class T>
void SemiInput<T>::add_mine(const typename T::clear& input, int)
{ {
auto& P = this->P; auto& P = this->P;
typename T::open_type sum, share; typename T::open_type sum, share;
for (int i = 0; i < P.num_players(); i++) for (int i = 0; i < P.num_players(); i++)
{ {
if (i < P.num_players() - 1) if (i != P.my_num())
share.randomize(secure_prng, n_bits); sum += send_prngs[i].template get<typename T::open_type>();
else
share = input - sum;
sum += share;
if (i == P.my_num())
this->shares.push_back(share);
else
share.pack(this->os[i], n_bits);
} }
shares[P.my_num()].push_back(input - sum);
}
template<class T>
void SemiInput<T>::add_other(int, int)
{
}
template<class T>
void SemiInput<T>::exchange()
{
}
template<class T>
void SemiInput<T>::finalize_other(int player, T& target, octetStream&,
int)
{
target = recv_prngs[player].template get<T>();
}
template<class T>
T SemiInput<T>::finalize_mine()
{
return shares[P.my_num()].next();
} }
#endif #endif

View File

@@ -27,7 +27,6 @@ class Shamir : public ProtocolBase<T>
{ {
typedef typename T::open_type::Scalar U; typedef typename T::open_type::Scalar U;
octetStreams os;
vector<U> reconstruction; vector<U> reconstruction;
U rec_factor; U rec_factor;
ShamirInput<T>* resharing; ShamirInput<T>* resharing;

View File

@@ -69,8 +69,6 @@ int Shamir<T>::get_n_relevant_players()
template<class T> template<class T>
void Shamir<T>::reset() void Shamir<T>::reset()
{ {
os.reset(P);
if (resharing == 0) if (resharing == 0)
{ {
resharing = new ShamirInput<T>(0, P); resharing = new ShamirInput<T>(0, P);
@@ -78,6 +76,9 @@ void Shamir<T>::reset()
for (int i = 0; i < P.num_players(); i++) for (int i = 0; i < P.num_players(); i++)
resharing->reset(i); resharing->reset(i);
for (int i = 0; i < n_mul_players; i++)
resharing->add_sender(i);
} }
template<class T> template<class T>
@@ -92,37 +93,27 @@ template<class T>
void Shamir<T>::prepare_mul(const T& x, const T& y, int n) void Shamir<T>::prepare_mul(const T& x, const T& y, int n)
{ {
(void) n; (void) n;
auto add_share = x * y * rec_factor;
if (P.my_num() < n_mul_players) if (P.my_num() < n_mul_players)
resharing->add_mine(add_share); resharing->add_mine(x * y * rec_factor);
} }
template<class T> template<class T>
void Shamir<T>::exchange() void Shamir<T>::exchange()
{ {
vector<bool> senders(P.num_players(), false); assert(resharing);
for (int i = 0; i < n_mul_players; i++) resharing->exchange();
senders[i] = true;
P.send_receive_all(senders, resharing->os, os);
} }
template<class T> template<class T>
void Shamir<T>::start_exchange() void Shamir<T>::start_exchange()
{ {
if (P.my_num() < n_mul_players) resharing->start_exchange();
for (int offset = 1; offset < P.num_players(); offset++)
P.send_relative(offset, resharing->os[P.get_player(offset)]);
} }
template<class T> template<class T>
void Shamir<T>::stop_exchange() void Shamir<T>::stop_exchange()
{ {
for (int offset = 1; offset < P.num_players(); offset++) resharing->stop_exchange();
{
int receive_from = P.get_player(-offset);
if (receive_from < n_mul_players)
P.receive_player(receive_from, os[receive_from]);
}
} }
template<class T> template<class T>
@@ -136,15 +127,8 @@ template<class T>
T Shamir<T>::finalize(int n_relevant_players) T Shamir<T>::finalize(int n_relevant_players)
{ {
ShamirShare<U> res = U(0); ShamirShare<U> res = U(0);
if (P.my_num() < n_relevant_players)
res = resharing->finalize_mine();
for (int i = 0; i < n_relevant_players; i++) for (int i = 0; i < n_relevant_players; i++)
if (i != P.my_num()) res += resharing->finalize(i);
{
T tmp;
resharing->finalize_other(i, tmp, os[i]);
res += tmp;
}
return res; return res;
} }
@@ -259,7 +243,7 @@ vector<T> Shamir<T>::get_randoms(PRNG& G, int t)
input.reset_all(P); input.reset_all(P);
int buffer_size = OnlineOptions::singleton.batch_size; int buffer_size = OnlineOptions::singleton.batch_size;
for (int i = 0; i < buffer_size; i += hyper.size()) for (int i = 0; i < buffer_size; i += hyper.size())
input.add_mine(G.get<U>()); input.add_from_all(G.get<U>());
input.exchange(); input.exchange();
vector<U> inputs; vector<U> inputs;
vector<T> random; vector<T> random;

View File

@@ -21,10 +21,11 @@ class IndividualInput : public PrepLessInput<T>
protected: protected:
Player& P; Player& P;
octetStreams os; octetStreams os;
vector<bool> senders;
public: public:
IndividualInput(SubProcessor<T>* proc, Player& P) : IndividualInput(SubProcessor<T>* proc, Player& P) :
PrepLessInput<T>(proc), P(P) PrepLessInput<T>(proc), P(P), senders(P.num_players())
{ {
this->reset_all(P); this->reset_all(P);
} }
@@ -34,10 +35,14 @@ public:
} }
void reset(int player); void reset(int player);
void add_sender(int player);
void add_other(int player, int n_bits = -1); void add_other(int player, int n_bits = -1);
void send_mine(); void send_mine();
void exchange(); void exchange();
void finalize_other(int player, T& target, octetStream& o, int n_bits = -1); void finalize_other(int player, T& target, octetStream& o, int n_bits = -1);
void start_exchange();
void stop_exchange();
}; };
/** /**

View File

@@ -20,6 +20,8 @@ void IndividualInput<U>::reset(int player)
this->i_share = 0; this->i_share = 0;
os.reset(P); os.reset(P);
} }
senders[player] = false;
} }
template<class T> template<class T>
@@ -68,12 +70,20 @@ void ShamirInput<T>::add_mine(const typename T::open_type& input, int n_bits)
else else
x.pack(this->os[i]); x.pack(this->os[i]);
} }
this->senders[P.my_num()] = true;
}
template<class U>
void IndividualInput<U>::add_sender(int player)
{
senders[player] = true;
} }
template<class U> template<class U>
void IndividualInput<U>::add_other(int player, int) void IndividualInput<U>::add_other(int player, int)
{ {
(void) player; add_sender(player);
} }
template<class U> template<class U>
@@ -87,7 +97,26 @@ void IndividualInput<U>::send_mine()
template<class T> template<class T>
void IndividualInput<T>::exchange() void IndividualInput<T>::exchange()
{ {
P.send_receive_all(os, InputBase<T>::os); P.send_receive_all(senders, os, InputBase<T>::os);
}
template<class T>
void IndividualInput<T>::start_exchange()
{
if (senders[P.my_num()])
for (int offset = 1; offset < P.num_players(); offset++)
P.send_relative(offset, os[P.get_player(offset)]);
}
template<class T>
void IndividualInput<T>::stop_exchange()
{
for (int offset = 1; offset < P.num_players(); offset++)
{
int receive_from = P.get_player(-offset);
if (senders[receive_from])
P.receive_player(receive_from, InputBase<T>::os[receive_from]);
}
} }
template<class T> template<class T>

Some files were not shown because too many files have changed in this diff Show More