mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-08 05:03:59 -05:00
Semi-honest computation based on threshold semi-homomorphic encryption.
This commit is contained in:
13
CHANGELOG.md
13
CHANGELOG.md
@@ -1,6 +1,17 @@
|
||||
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
|
||||
|
||||
## 0.2.9 (Jan 11, 2021)
|
||||
## 0.3.0 (Feb 17, 2022)
|
||||
|
||||
- Semi-honest computation based on threshold semi-homomorphic encryption
|
||||
- Batch normalization backward propagation
|
||||
- AlexNet for CIFAR-10
|
||||
- Specific private output protocols
|
||||
- Semi-honest additive secret sharing without communication
|
||||
- Sending of personal values
|
||||
- Allow overwriting of persistence files
|
||||
- Protocol signature in persistence files
|
||||
|
||||
## 0.2.9 (Jan 11, 2022)
|
||||
|
||||
- Disassembler
|
||||
- Run-time parameter for probabilistic truncation error
|
||||
|
||||
@@ -497,7 +497,7 @@ class movsb(NonVectorInstruction):
|
||||
code = opcodes['MOVSB']
|
||||
arg_format = ['sbw','sb']
|
||||
|
||||
class trans(base.VarArgsInstruction):
|
||||
class trans(base.VarArgsInstruction, base.DynFormatInstruction):
|
||||
""" Secret bit register vector transpose. The first destination vector
|
||||
will contain the least significant bits of all source vectors etc.
|
||||
|
||||
@@ -511,10 +511,22 @@ class trans(base.VarArgsInstruction):
|
||||
code = opcodes['TRANS']
|
||||
is_vec = lambda self: True
|
||||
def __init__(self, *args):
|
||||
self.arg_format = ['int'] + ['sbw'] * args[0] + \
|
||||
['sb'] * (len(args) - 1 - args[0])
|
||||
super(trans, self).__init__(*args)
|
||||
|
||||
@classmethod
|
||||
def dynamic_arg_format(cls, args):
|
||||
yield 'int'
|
||||
n = next(args)
|
||||
for i in range(n):
|
||||
yield 'sbw'
|
||||
next(args)
|
||||
while True:
|
||||
try:
|
||||
yield 'sb'
|
||||
next(args)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
class bitb(NonVectorInstruction):
|
||||
""" Copy fresh secret random bit to secret bit register.
|
||||
|
||||
@@ -560,7 +572,7 @@ class inputb(base.DoNotEliminateInstruction, base.VarArgsInstruction):
|
||||
req_node.increment(('bit', 'input', self.args[i]), self.args[i + 1])
|
||||
|
||||
class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction,
|
||||
base.Mergeable):
|
||||
base.Mergeable, base.DynFormatInstruction):
|
||||
""" Copy private input to secret bit registers bit by bit. The input is
|
||||
read as floating-point number, multiplied by a power of two, rounded to an
|
||||
integer, and then decomposed into bits.
|
||||
@@ -577,11 +589,18 @@ class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction,
|
||||
code = opcodes['INPUTBVEC']
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.arg_format = []
|
||||
for x in self.get_arg_tuples(args):
|
||||
self.arg_format += ['int', 'int', 'p'] + ['sbw'] * (x[0] - 3)
|
||||
super(inputbvec, self).__init__(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def dynamic_arg_format(cls, args):
|
||||
yield 'int'
|
||||
for i, n in cls.bases(args):
|
||||
yield 'int'
|
||||
yield 'p'
|
||||
for j in range(n - 3):
|
||||
yield 'sbw'
|
||||
yield 'int'
|
||||
|
||||
@staticmethod
|
||||
def get_arg_tuples(args):
|
||||
i = 0
|
||||
@@ -590,10 +609,6 @@ class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction,
|
||||
i += args[i]
|
||||
assert i == len(args)
|
||||
|
||||
def merge(self, other):
|
||||
self.args += other.args
|
||||
self.arg_format += other.arg_format
|
||||
|
||||
def add_usage(self, req_node):
|
||||
for x in self.get_arg_tuples(self.args):
|
||||
req_node.increment(('bit', 'input', x[2]), x[0] - 3)
|
||||
|
||||
@@ -41,7 +41,7 @@ class bits(Tape.Register, _structure, _bit):
|
||||
return cls.types[length]
|
||||
@classmethod
|
||||
def conv(cls, other):
|
||||
if isinstance(other, cls):
|
||||
if isinstance(other, cls) and cls.n == other.n:
|
||||
return other
|
||||
elif isinstance(other, MemValue):
|
||||
return cls.conv(other.read())
|
||||
@@ -246,14 +246,20 @@ class cbits(bits):
|
||||
assert n == res.n
|
||||
assert n == other.size
|
||||
cls.conv_cint_vec(cint(other, size=other.size), res)
|
||||
@classmethod
|
||||
def conv(cls, other):
|
||||
if isinstance(other, cbits) and cls.n != None and \
|
||||
cls.n // cls.unit == other.n // cls.unit:
|
||||
return other
|
||||
else:
|
||||
return super(cbits, cls).conv(other)
|
||||
types = {}
|
||||
def load_int(self, value):
|
||||
if self.n <= 64:
|
||||
tmp = regint(value)
|
||||
elif value == self.long_one():
|
||||
tmp = cint(1, size=self.n)
|
||||
else:
|
||||
raise CompilerError('loading long integers to cbits not supported')
|
||||
n_limbs = math.ceil(self.n / self.unit)
|
||||
tmp = regint(size=n_limbs)
|
||||
for i in range(n_limbs):
|
||||
tmp[i].load_int(value % 2 ** self.unit)
|
||||
value >>= self.unit
|
||||
self.load_other(tmp)
|
||||
def store_in_dynamic_mem(self, address):
|
||||
inst.stmsdci(self, cbits.conv(address))
|
||||
@@ -1163,14 +1169,14 @@ class cbitfix(object):
|
||||
@classmethod
|
||||
def _new(cls, value):
|
||||
res = cls()
|
||||
if cls.k < value.unit:
|
||||
bits = value.bit_decompose(cls.k)
|
||||
sign = bits[-1]
|
||||
value += (sign << (cls.k)) * -1
|
||||
res.v = value
|
||||
return res
|
||||
def output(self):
|
||||
v = self.v
|
||||
if self.k < v.unit:
|
||||
bits = self.v.bit_decompose(self.k)
|
||||
sign = bits[-1]
|
||||
v += (sign << (self.k)) * -1
|
||||
inst.print_float_plainb(v, cbits.get_type(32)(-self.f), cbits(0),
|
||||
cbits(0), cbits(0))
|
||||
|
||||
|
||||
@@ -403,6 +403,20 @@ class Merger:
|
||||
add_edge(last_input[t][1], n)
|
||||
last_input[t][0] = n
|
||||
|
||||
def keep_text_order(inst, n):
|
||||
if inst.get_players() is None:
|
||||
# switch
|
||||
for x in list(last_input.keys()):
|
||||
if isinstance(x, int):
|
||||
add_edge(last_input[x][0], n)
|
||||
del last_input[x]
|
||||
keep_merged_order(instr, n, None)
|
||||
elif last_input[None][0] is not None:
|
||||
keep_merged_order(instr, n, None)
|
||||
else:
|
||||
for player in inst.get_players():
|
||||
keep_merged_order(instr, n, player)
|
||||
|
||||
for n,instr in enumerate(block.instructions):
|
||||
outputs,inputs = instr.get_def(), instr.get_used()
|
||||
|
||||
@@ -427,7 +441,7 @@ class Merger:
|
||||
|
||||
# will be merged
|
||||
if isinstance(instr, TextInputInstruction):
|
||||
keep_merged_order(instr, n, TextInputInstruction)
|
||||
keep_text_order(instr, n)
|
||||
elif isinstance(instr, RawInputInstruction):
|
||||
keep_merged_order(instr, n, RawInputInstruction)
|
||||
|
||||
@@ -479,10 +493,6 @@ class Merger:
|
||||
last_print_str = n
|
||||
elif isinstance(instr, PublicFileIOInstruction):
|
||||
keep_order(instr, n, instr.__class__)
|
||||
elif isinstance(instr, startprivateoutput_class):
|
||||
keep_order(instr, n, startprivateoutput_class, 2)
|
||||
elif isinstance(instr, stopprivateoutput_class):
|
||||
keep_order(instr, n, stopprivateoutput_class, 2)
|
||||
elif isinstance(instr, prep_class):
|
||||
keep_order(instr, n, instr.args[0])
|
||||
elif isinstance(instr, StackInstruction):
|
||||
|
||||
@@ -421,6 +421,10 @@ class use_matmul(base.Instruction):
|
||||
code = base.opcodes['USE_MATMUL']
|
||||
arg_format = ['int','int','int','int']
|
||||
|
||||
@classmethod
|
||||
def get_usage(cls, args):
|
||||
return {('matmul', tuple(arg.i for arg in args[:3])): args[3].i}
|
||||
|
||||
class run_tape(base.Instruction):
|
||||
""" Start tape/bytecode file in another thread.
|
||||
|
||||
@@ -1229,15 +1233,20 @@ class inverse(base.DataInstruction):
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
class inputmask(base.Instruction):
|
||||
r""" Load secret $s_i$ with the next input mask for player $p$ and
|
||||
write the mask on player $p$'s private output. """
|
||||
""" Store fresh random input mask(s) in secret register (vector) and clear
|
||||
register (vector) of the relevant player.
|
||||
|
||||
:param: mask (sint)
|
||||
:param: mask (cint, player only)
|
||||
:param: player (int)
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['INPUTMASK']
|
||||
arg_format = ['sw', 'p']
|
||||
arg_format = ['sw', 'cw', 'p']
|
||||
field_type = 'modp'
|
||||
|
||||
def add_usage(self, req_node):
|
||||
req_node.increment((self.field_type, 'input', self.args[1]), \
|
||||
req_node.increment((self.field_type, 'input', self.args[2]), \
|
||||
self.get_size())
|
||||
|
||||
@base.vectorize
|
||||
@@ -1293,10 +1302,8 @@ class asm_input(base.TextInputInstruction):
|
||||
arg_format = tools.cycle(['sw', 'p'])
|
||||
field_type = 'modp'
|
||||
|
||||
def add_usage(self, req_node):
|
||||
for player in self.args[1::2]:
|
||||
req_node.increment((self.field_type, 'input', player), \
|
||||
self.get_size())
|
||||
def get_players(self):
|
||||
return self.args[1::2]
|
||||
|
||||
@base.vectorize
|
||||
class inputfix(base.TextInputInstruction):
|
||||
@@ -1305,10 +1312,8 @@ class inputfix(base.TextInputInstruction):
|
||||
arg_format = tools.cycle(['sw', 'int', 'p'])
|
||||
field_type = 'modp'
|
||||
|
||||
def add_usage(self, req_node):
|
||||
for player in self.args[2::3]:
|
||||
req_node.increment((self.field_type, 'input', player), \
|
||||
self.get_size())
|
||||
def get_players(self):
|
||||
return self.args[2::3]
|
||||
|
||||
@base.vectorize
|
||||
class inputfloat(base.TextInputInstruction):
|
||||
@@ -1322,7 +1327,7 @@ class inputfloat(base.TextInputInstruction):
|
||||
req_node.increment((self.field_type, 'input', player), \
|
||||
4 * self.get_size())
|
||||
|
||||
class inputmixed_base(base.TextInputInstruction):
|
||||
class inputmixed_base(base.TextInputInstruction, base.DynFormatInstruction):
|
||||
__slots__ = []
|
||||
field_type = 'modp'
|
||||
# the following has to match TYPE: (N_DEST, N_PARAM)
|
||||
@@ -1341,22 +1346,30 @@ class inputmixed_base(base.TextInputInstruction):
|
||||
type_id = self.type_ids[name]
|
||||
super(inputmixed_base, self).__init__(type_id, *args)
|
||||
|
||||
@property
|
||||
def arg_format(self):
|
||||
for i in self.bases():
|
||||
t = self.args[i]
|
||||
yield 'int'
|
||||
@classmethod
|
||||
def dynamic_arg_format(self, args):
|
||||
yield 'int'
|
||||
for i, t in self.bases(iter(args)):
|
||||
for j in range(self.types[t][0]):
|
||||
yield 'sw'
|
||||
for j in range(self.types[t][1]):
|
||||
yield 'int'
|
||||
yield self.player_arg_type
|
||||
yield 'int'
|
||||
|
||||
def bases(self):
|
||||
@classmethod
|
||||
def bases(self, args):
|
||||
i = 0
|
||||
while i < len(self.args):
|
||||
yield i
|
||||
i += sum(self.types[self.args[i]]) + 2
|
||||
while True:
|
||||
try:
|
||||
t = next(args)
|
||||
except StopIteration:
|
||||
return
|
||||
yield i, t
|
||||
n = sum(self.types[t])
|
||||
i += n + 2
|
||||
for j in range(n + 1):
|
||||
next(args)
|
||||
|
||||
@base.vectorize
|
||||
class inputmixed(inputmixed_base):
|
||||
@@ -1380,13 +1393,16 @@ class inputmixed(inputmixed_base):
|
||||
player_arg_type = 'p'
|
||||
|
||||
def add_usage(self, req_node):
|
||||
for i in self.bases():
|
||||
t = self.args[i]
|
||||
for i, t in self.bases(iter(self.args)):
|
||||
player = self.args[i + sum(self.types[t]) + 1]
|
||||
n_dest = self.types[t][0]
|
||||
req_node.increment((self.field_type, 'input', player), \
|
||||
n_dest * self.get_size())
|
||||
|
||||
def get_players(self):
|
||||
for i, t in self.bases(iter(self.args)):
|
||||
yield self.args[i + sum(self.types[t]) + 1]
|
||||
|
||||
@base.vectorize
|
||||
class inputmixedreg(inputmixed_base):
|
||||
""" Store private input in secret registers (vectors). The input is
|
||||
@@ -1412,6 +1428,9 @@ class inputmixedreg(inputmixed_base):
|
||||
# player 0 as proxy
|
||||
req_node.increment((self.field_type, 'input', 0), float('inf'))
|
||||
|
||||
def get_players(self):
|
||||
pass
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
class rawinput(base.RawInputInstruction, base.Mergeable):
|
||||
@@ -1433,7 +1452,23 @@ class rawinput(base.RawInputInstruction, base.Mergeable):
|
||||
req_node.increment((self.field_type, 'input', player), \
|
||||
self.get_size())
|
||||
|
||||
class inputpersonal(base.Instruction, base.Mergeable):
|
||||
class personal_base(base.Instruction, base.Mergeable):
|
||||
__slots__ = []
|
||||
field_type = 'modp'
|
||||
|
||||
def __init__(self, *args):
|
||||
super(personal_base, self).__init__(*args)
|
||||
for i in range(0, len(args), 4):
|
||||
assert args[i + 2].size == args[i]
|
||||
assert args[i + 3].size == args[i]
|
||||
|
||||
def add_usage(self, req_node):
|
||||
for i in range(0, len(self.args), 4):
|
||||
player = self.args[i + 1]
|
||||
req_node.increment((self.field_type, 'input', player), \
|
||||
self.args[i])
|
||||
|
||||
class inputpersonal(personal_base):
|
||||
""" Private input from cint.
|
||||
|
||||
:param: vector size (int)
|
||||
@@ -1445,19 +1480,39 @@ class inputpersonal(base.Instruction, base.Mergeable):
|
||||
__slots__ = []
|
||||
code = base.opcodes['INPUTPERSONAL']
|
||||
arg_format = tools.cycle(['int','p','sw','c'])
|
||||
field_type = 'modp'
|
||||
|
||||
class privateoutput(personal_base):
|
||||
""" Private input from cint.
|
||||
|
||||
:param: vector size (int)
|
||||
:param: player (int)
|
||||
:param: destination (cint)
|
||||
:param: source (sint)
|
||||
:param: (repeat from vector size)...
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['PRIVATEOUTPUT']
|
||||
arg_format = tools.cycle(['int','p','cw','s'])
|
||||
|
||||
class sendpersonal(base.Instruction, base.Mergeable):
|
||||
""" Private input from cint.
|
||||
|
||||
:param: vector size (int)
|
||||
:param: destination player (int)
|
||||
:param: destination (cint)
|
||||
:param: source player (int)
|
||||
:param: source (cint)
|
||||
:param: (repeat from vector size)...
|
||||
"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['SENDPERSONAL']
|
||||
arg_format = tools.cycle(['int','p','cw','p','c'])
|
||||
|
||||
def __init__(self, *args):
|
||||
super(inputpersonal, self).__init__(*args)
|
||||
for i in range(0, len(args), 4):
|
||||
super(sendpersonal, self).__init__(*args)
|
||||
for i in range(0, len(args), 5):
|
||||
assert args[i + 2].size == args[i]
|
||||
assert args[i + 3].size == args[i]
|
||||
|
||||
def add_usage(self, req_node):
|
||||
for i in range(0, len(self.args), 4):
|
||||
player = self.args[i + 1]
|
||||
req_node.increment((self.field_type, 'input', player), \
|
||||
self.args[i])
|
||||
assert args[i + 4].size == args[i]
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
@@ -1789,27 +1844,6 @@ class floatoutput(base.PublicFileIOInstruction):
|
||||
code = base.opcodes['FLOATOUTPUT']
|
||||
arg_format = ['p','c','c','c','c']
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
class startprivateoutput(base.Instruction):
|
||||
r""" Initiate private output to $n$ of $s_j$ via $s_i$. """
|
||||
__slots__ = []
|
||||
code = base.opcodes['STARTPRIVATEOUTPUT']
|
||||
arg_format = ['sw','s','p']
|
||||
field_type = 'modp'
|
||||
|
||||
def add_usage(self, req_node):
|
||||
req_node.increment((self.field_type, 'input', self.args[2]), \
|
||||
self.get_size())
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
class stopprivateoutput(base.Instruction):
|
||||
r""" Previously iniated private output to $n$ via $c_i$. """
|
||||
__slots__ = []
|
||||
code = base.opcodes['STOPPRIVATEOUTPUT']
|
||||
arg_format = ['cw','c','p']
|
||||
|
||||
@base.vectorize
|
||||
class rand(base.Instruction):
|
||||
""" Store insecure random value of specified length in clear integer
|
||||
@@ -2210,7 +2244,8 @@ class mulrs(base.VarArgsInstruction, base.DataInstruction):
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
class dotprods(base.VarArgsInstruction, base.DataInstruction):
|
||||
class dotprods(base.VarArgsInstruction, base.DataInstruction,
|
||||
base.DynFormatInstruction):
|
||||
""" Dot product of secret registers (vectors).
|
||||
Note that the vectorized version works element-wise.
|
||||
|
||||
@@ -2238,31 +2273,29 @@ class dotprods(base.VarArgsInstruction, base.DataInstruction):
|
||||
flat_args += [x, y]
|
||||
base.Instruction.__init__(self, *flat_args)
|
||||
|
||||
@property
|
||||
def arg_format(self):
|
||||
@classmethod
|
||||
def dynamic_arg_format(self, args):
|
||||
field = 'g' if self.is_gf2n() else ''
|
||||
for i in self.bases():
|
||||
yield 'int'
|
||||
yield 'int'
|
||||
for i, n in self.bases(args):
|
||||
yield 's' + field + 'w'
|
||||
for j in range(self.args[i] - 2):
|
||||
for j in range(n - 2):
|
||||
yield 's' + field
|
||||
yield 'int'
|
||||
|
||||
gf2n_arg_format = arg_format
|
||||
|
||||
def bases(self):
|
||||
i = 0
|
||||
while i < len(self.args):
|
||||
yield i
|
||||
i += self.args[i]
|
||||
@property
|
||||
def gf2n_arg_format(self):
|
||||
return self.arg_format()
|
||||
|
||||
def get_repeat(self):
|
||||
return sum(self.args[i] // 2 for i in self.bases()) * self.get_size()
|
||||
return sum(self.args[i] // 2
|
||||
for i, n in self.bases(iter(self.args))) * self.get_size()
|
||||
|
||||
def get_def(self):
|
||||
return [self.args[i + 1] for i in self.bases()]
|
||||
return [self.args[i + 1] for i, n in self.bases(iter(self.args))]
|
||||
|
||||
def get_used(self):
|
||||
for i in self.bases():
|
||||
for i, n in self.bases(iter(self.args)):
|
||||
for reg in self.args[i + 2:i + self.args[i]]:
|
||||
yield reg
|
||||
|
||||
|
||||
@@ -105,6 +105,7 @@ opcodes = dict(
|
||||
MATMULSM = 0xAB,
|
||||
CONV2DS = 0xAC,
|
||||
CHECK = 0xAF,
|
||||
PRIVATEOUTPUT = 0xAD,
|
||||
# Data access
|
||||
TRIPLE = 0x50,
|
||||
BIT = 0x51,
|
||||
@@ -128,6 +129,7 @@ opcodes = dict(
|
||||
INPUTMIXEDREG = 0xF3,
|
||||
RAWINPUT = 0xF4,
|
||||
INPUTPERSONAL = 0xF5,
|
||||
SENDPERSONAL = 0xF6,
|
||||
STARTINPUT = 0x61,
|
||||
STOPINPUT = 0x62,
|
||||
READSOCKETC = 0x63,
|
||||
@@ -364,6 +366,7 @@ def gf2n(instruction):
|
||||
arg_format = copy.deepcopy(instruction_cls.arg_format)
|
||||
reformat(arg_format)
|
||||
|
||||
@classmethod
|
||||
def is_gf2n(self):
|
||||
return True
|
||||
|
||||
@@ -505,8 +508,12 @@ def cisc(function):
|
||||
for arg in self.args:
|
||||
try:
|
||||
new_regs.append(type(arg)(size=size))
|
||||
except:
|
||||
except TypeError:
|
||||
break
|
||||
except:
|
||||
print([call[0][0].size for call in self.calls])
|
||||
raise
|
||||
assert len(new_regs) > 1
|
||||
base = 0
|
||||
for call in self.calls:
|
||||
for new_reg, reg in zip(new_regs[1:], call[0][1:]):
|
||||
@@ -854,6 +861,7 @@ class Instruction(object):
|
||||
def is_vec(self):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def is_gf2n(self):
|
||||
return False
|
||||
|
||||
@@ -902,6 +910,10 @@ class Instruction(object):
|
||||
new_args.append(arg)
|
||||
return new_args
|
||||
|
||||
@staticmethod
|
||||
def get_usage(args):
|
||||
return {}
|
||||
|
||||
# String version of instruction attempting to replicate encoded version
|
||||
def __str__(self):
|
||||
|
||||
@@ -949,9 +961,18 @@ class ParsedInstruction:
|
||||
if name == 'cisc':
|
||||
arg_format = itertools.chain(['str'], itertools.repeat('int'))
|
||||
else:
|
||||
arg_format = itertools.repeat('int')
|
||||
self.args = [ArgFormats[next(arg_format)](f)
|
||||
for i in range(n_args)]
|
||||
def arg_iter():
|
||||
i = 0
|
||||
while True:
|
||||
try:
|
||||
yield self.args[i].i
|
||||
except AttributeError:
|
||||
yield None
|
||||
i += 1
|
||||
arg_format = t.dynamic_arg_format(arg_iter())
|
||||
self.args = []
|
||||
for i in range(n_args):
|
||||
self.args.append(ArgFormats[next(arg_format)](f))
|
||||
|
||||
def __str__(self):
|
||||
name = self.type.__name__
|
||||
@@ -963,6 +984,9 @@ class ParsedInstruction:
|
||||
res += ', '.join(str(arg) for arg in self.args)
|
||||
return res
|
||||
|
||||
def get_usage(self):
|
||||
return self.type.get_usage(self.args)
|
||||
|
||||
class VarArgsInstruction(Instruction):
|
||||
def has_var_args(self):
|
||||
return True
|
||||
@@ -974,6 +998,26 @@ class VectorInstruction(Instruction):
|
||||
def get_code(self):
|
||||
return super(VectorInstruction, self).get_code(len(self.args[0]))
|
||||
|
||||
class DynFormatInstruction(Instruction):
|
||||
__slots__ = []
|
||||
|
||||
@property
|
||||
def arg_format(self):
|
||||
return self.dynamic_arg_format(iter(self.args))
|
||||
|
||||
@classmethod
|
||||
def bases(self, args):
|
||||
i = 0
|
||||
while True:
|
||||
try:
|
||||
n = next(args)
|
||||
except StopIteration:
|
||||
return
|
||||
yield i, n
|
||||
i += n
|
||||
for j in range(n - 1):
|
||||
next(args)
|
||||
|
||||
###
|
||||
### Basic arithmetic
|
||||
###
|
||||
@@ -1072,6 +1116,11 @@ class TextInputInstruction(VarArgsInstruction, DoNotEliminateInstruction):
|
||||
""" Input from text file or stdin """
|
||||
__slots__ = []
|
||||
|
||||
def add_usage(self, req_node):
|
||||
for player in self.get_players():
|
||||
req_node.increment((self.field_type, 'input', player), \
|
||||
self.get_size())
|
||||
|
||||
###
|
||||
### Data access instructions
|
||||
###
|
||||
|
||||
@@ -223,7 +223,7 @@ def crash(condition=None):
|
||||
if isinstance(condition, localint):
|
||||
# allow crash on local values
|
||||
condition = condition._v
|
||||
if condition == None:
|
||||
if condition is None:
|
||||
condition = regint(1)
|
||||
instructions.crash(regint.conv(condition))
|
||||
|
||||
@@ -284,8 +284,8 @@ def get_arg():
|
||||
|
||||
def make_array(l):
|
||||
if isinstance(l, program.Tape.Register):
|
||||
res = Array(1, type(l))
|
||||
res[0] = l
|
||||
res = Array(len(l), type(l))
|
||||
res[:] = l
|
||||
else:
|
||||
l = list(l)
|
||||
res = Array(len(l), type(l[0]) if l else cint)
|
||||
@@ -1032,6 +1032,7 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
|
||||
state = tuplify(initializer())
|
||||
k = 0
|
||||
block = get_block()
|
||||
assert not isinstance(n_loops, int) or n_loops > 0
|
||||
pre = copy.copy(loop_body.__globals__)
|
||||
while (not util.is_constant(n_loops) or k < n_loops) \
|
||||
and (len(get_block()) < budget or k == 0) \
|
||||
@@ -1211,7 +1212,13 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
|
||||
if t != regint:
|
||||
raise CompilerError('Not implemented for other than regint')
|
||||
args = Matrix(n_threads, 2 + thread_mem_req.get(regint, 0), 'ci')
|
||||
state = tuple(initializer())
|
||||
state = initializer()
|
||||
if len(state) == 0:
|
||||
state_type = cint
|
||||
elif isinstance(state, (tuple, list)):
|
||||
state_type = type(state[0])
|
||||
else:
|
||||
state_type = type(state)
|
||||
def f(inc):
|
||||
base = args[get_arg()][0]
|
||||
if not util.is_constant(thread_rounds):
|
||||
@@ -1224,8 +1231,7 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
|
||||
if thread_mem_req:
|
||||
thread_mem = Array(thread_mem_req[regint], regint, \
|
||||
args[get_arg()].address + 2)
|
||||
mem_state = Array(len(state), type(state[0]) \
|
||||
if state else cint, args[get_arg()][1])
|
||||
mem_state = Array(len(state), state_type, args[get_arg()][1])
|
||||
@map_reduce_single(n_parallel, thread_rounds + inc, \
|
||||
initializer, reducer, mem_state)
|
||||
def f(i):
|
||||
@@ -1257,14 +1263,14 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
|
||||
threads = prog.run_tapes(thread_args)
|
||||
for thread in threads:
|
||||
prog.join_tape(thread)
|
||||
if state:
|
||||
if len(state):
|
||||
if thread_rounds:
|
||||
for i in range(n_threads - remainder):
|
||||
state = reducer(Array(len(state), type(state[0]), \
|
||||
state = reducer(Array(len(state), state_type, \
|
||||
args[remainder + i][1]), state)
|
||||
if remainder:
|
||||
for i in range(remainder):
|
||||
state = reducer(Array(len(state), type(state[0]).reg_type, \
|
||||
state = reducer(Array(len(state), state_type, \
|
||||
args[i][1]), state)
|
||||
def returner():
|
||||
return untuplify(state)
|
||||
@@ -1300,6 +1306,39 @@ def map_sum_opt(n_threads, n_loops, types):
|
||||
"""
|
||||
return map_sum(n_threads, None, n_loops, len(types), types)
|
||||
|
||||
def map_sum_simple(n_threads, n_loops, type, size):
|
||||
""" Vectorized multi-threaded sum reduction. The following computes a
|
||||
100 sums of ten squares in three threads::
|
||||
|
||||
@map_sum_simple(3, 10, sint, 100)
|
||||
def summer(i):
|
||||
return sint(regint.inc(100, i, 0)) ** 2
|
||||
|
||||
result = summer()
|
||||
|
||||
:param n_threads: number of threads (int)
|
||||
:param n_loops: number of loop runs (regint/cint/int)
|
||||
:param type: return type, must match the return statement
|
||||
in the loop
|
||||
:param size: vector size, must match the return statement
|
||||
in the loop
|
||||
|
||||
"""
|
||||
initializer = lambda: type(0, size=size)
|
||||
def summer(*args):
|
||||
assert len(args) == 2
|
||||
args = list(args)
|
||||
for i in (0, 1):
|
||||
if isinstance(args[i], tuple):
|
||||
assert len(args[i]) == 1
|
||||
args[i] = args[i][0]
|
||||
for i in (0, 1):
|
||||
assert len(args[i]) == size
|
||||
if isinstance(args[i], Array):
|
||||
args[i] = args[i][:]
|
||||
return args[0] + args[1]
|
||||
return map_reduce(n_threads, 1, n_loops, initializer, summer)
|
||||
|
||||
def tree_reduce_multithread(n_threads, function, vector):
|
||||
inputs = vector.Array(len(vector))
|
||||
inputs.assign_vector(vector)
|
||||
|
||||
351
Compiler/ml.py
351
Compiler/ml.py
@@ -223,6 +223,7 @@ class Layer:
|
||||
thetas = lambda self: ()
|
||||
debug_output = False
|
||||
back_batch_size = 128
|
||||
print_random_update = False
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
@@ -254,6 +255,9 @@ class Layer:
|
||||
def __str__(self):
|
||||
return type(self).__name__ + str(self._Y.sizes)
|
||||
|
||||
def __repr__(self):
|
||||
return '%s(%s)' % (type(self).__name__, self.Y.sizes)
|
||||
|
||||
class NoVariableLayer(Layer):
|
||||
input_from = lambda *args, **kwargs: None
|
||||
output_weights = lambda *args: None
|
||||
@@ -459,6 +463,10 @@ class MultiOutput(MultiOutputBase):
|
||||
self.debug = debug
|
||||
self.true_X = sfix.Array(N)
|
||||
|
||||
def __repr__(self):
|
||||
return '%s(%s, %s, approx=%s)' % \
|
||||
(type(self).__name__, self.N, self.d_out, self.approx)
|
||||
|
||||
def _forward(self, batch):
|
||||
N = len(batch)
|
||||
d_out = self.X.sizes[1]
|
||||
@@ -609,10 +617,11 @@ class DenseBase(Layer):
|
||||
N = len(batch)
|
||||
tmp = Matrix(self.d_in, self.d_out, unreduced_sfix)
|
||||
|
||||
A = sfix.Matrix(N, self.d_out, address=f_schur_Y.address)
|
||||
B = sfix.Matrix(self.N, self.d_in, address=self.X.address)
|
||||
|
||||
@multithread(self.n_threads, self.d_in)
|
||||
def _(base, size):
|
||||
A = sfix.Matrix(self.N, self.d_out, address=f_schur_Y.address)
|
||||
B = sfix.Matrix(self.N, self.d_in, address=self.X.address)
|
||||
mp = B.direct_trans_mul(A, reduce=False,
|
||||
indices=(regint.inc(size, base),
|
||||
batch.get_vector(),
|
||||
@@ -622,16 +631,24 @@ class DenseBase(Layer):
|
||||
|
||||
progress('nabla W (matmul)')
|
||||
|
||||
if self.d_in * self.d_out < 200000:
|
||||
print('reduce at once')
|
||||
@multithread(self.n_threads, self.d_in * self.d_out)
|
||||
def _(base, size):
|
||||
self.nabla_W.assign_vector(
|
||||
tmp.get_vector(base, size).reduce_after_mul(), base=base)
|
||||
else:
|
||||
@for_range_opt(self.d_in)
|
||||
def _(i):
|
||||
self.nabla_W[i] = tmp[i].get_vector().reduce_after_mul()
|
||||
@multithread(self.n_threads, self.d_in * self.d_out,
|
||||
max_size=get_program().budget)
|
||||
def _(base, size):
|
||||
self.nabla_W.assign_vector(
|
||||
tmp.get_vector(base, size).reduce_after_mul(), base=base)
|
||||
|
||||
if self.print_random_update:
|
||||
print_ln('backward %s', self)
|
||||
i = regint.get_random(64) % self.d_in
|
||||
j = regint.get_random(64) % self.d_out
|
||||
print_ln('%s at (%s, %s): before=%s after=%s A=%s B=%s',
|
||||
str(self.nabla_W), i, j, tmp[i][j].v.reveal(),
|
||||
self.nabla_W[i][j].reveal(),
|
||||
A.get_column(j).reveal(),
|
||||
B.get_column_by_row_indices(
|
||||
batch.get_vector(), i).reveal())
|
||||
print_ln('batch=%s B=%s', batch,
|
||||
[self.X[bi][0][i].reveal() for bi in batch])
|
||||
|
||||
progress('nabla W')
|
||||
|
||||
@@ -699,6 +716,7 @@ class Dense(DenseBase):
|
||||
self.d_in = d_in
|
||||
self.d_out = d_out
|
||||
self.d = d
|
||||
self.activation = activation
|
||||
|
||||
self.X = MultiArray([N, d, d_in], sfix)
|
||||
self.Y = MultiArray([N, d, d_out], sfix)
|
||||
@@ -721,12 +739,17 @@ class Dense(DenseBase):
|
||||
else:
|
||||
self.f_input = self.Y
|
||||
|
||||
def __repr__(self):
|
||||
return '%s(%s, %s, %s, activation=%s)' % \
|
||||
(type(self).__name__, self.N, self.d_in,
|
||||
self.d_out, repr(self.activation))
|
||||
|
||||
def reset(self):
|
||||
d_in = self.d_in
|
||||
d_out = self.d_out
|
||||
r = math.sqrt(6.0 / (d_in + d_out))
|
||||
print('Initializing dense weights in [%f,%f]' % (-r, r))
|
||||
self.W.assign_vector(sfix.get_random(-r, r, size=self.W.total_size()))
|
||||
self.W.randomize(-r, r)
|
||||
self.b.assign_all(0)
|
||||
|
||||
def input_from(self, player, raw=False):
|
||||
@@ -820,6 +843,12 @@ class Dense(DenseBase):
|
||||
regint.inc(self.d_in))),
|
||||
base)
|
||||
|
||||
if self.print_random_update:
|
||||
print_ln('backward %s', self)
|
||||
index = regint.get_random(64) % self.nabla_X.total_size()
|
||||
print_ln('%s nabla_X at %s: %s', str(self.nabla_X),
|
||||
index, self.nabla_X.to_array()[index].reveal())
|
||||
|
||||
progress('nabla X')
|
||||
|
||||
self.backward_params(f_schur_Y, batch=batch)
|
||||
@@ -890,6 +919,10 @@ class Dropout(NoVariableLayer):
|
||||
self.alpha = alpha
|
||||
self.B = MultiArray([N, d1, d2], sint)
|
||||
|
||||
def __repr__(self):
|
||||
return '%s(%s, %s, alpha=%s)' % \
|
||||
(type(self).__name__, self.N, self.d1, self.alpha)
|
||||
|
||||
def forward(self, batch, training=False):
|
||||
if training:
|
||||
n_bits = -math.log(self.alpha, 2)
|
||||
@@ -1022,6 +1055,7 @@ class MaxPool(NoVariableLayer):
|
||||
def __init__(self, shape, strides=(1, 2, 2, 1), ksize=(1, 2, 2, 1),
|
||||
padding='VALID'):
|
||||
assert len(shape) == 4
|
||||
assert min(shape) > 0, shape
|
||||
for x in strides, ksize:
|
||||
for i in 0, 3:
|
||||
assert x[i] == 1
|
||||
@@ -1033,12 +1067,18 @@ class MaxPool(NoVariableLayer):
|
||||
self.Y = Tensor(output_shape, sfix)
|
||||
self.strides = strides
|
||||
self.ksize = ksize
|
||||
self.padding = padding
|
||||
self.nabla_X = Tensor(shape, sfix)
|
||||
self.nabla_Y = Tensor(output_shape, sfix)
|
||||
self.N = shape[0]
|
||||
self.comparisons = MultiArray([self.N, self.X.sizes[3],
|
||||
ksize[1] * ksize[2]], sint)
|
||||
|
||||
def __repr__(self):
|
||||
return '%s(%s, strides=%s, ksize=%s, padding=%s)' % \
|
||||
(type(self).__name__, self.X.sizes, self.strides,
|
||||
self.ksize, self.padding)
|
||||
|
||||
def _forward(self, batch):
|
||||
def process(pool, bi, k, i, j):
|
||||
def m(a, b):
|
||||
@@ -1165,7 +1205,7 @@ class Add(NoVariableLayer):
|
||||
self.Y[batch[0]].assign_vector(tmp, base)
|
||||
|
||||
class FusedBatchNorm(Layer):
|
||||
""" Fixed-point fused batch normalization layer.
|
||||
""" Fixed-point fused batch normalization layer (inference only).
|
||||
|
||||
:param shape: input/output shape (tuple/list of four int)
|
||||
"""
|
||||
@@ -1192,6 +1232,153 @@ class FusedBatchNorm(Layer):
|
||||
self.X[batch[0]][i][j].get_vector() * self.weights.get_vector()
|
||||
+ self.bias.get_vector())
|
||||
|
||||
class BatchNorm(Layer):
|
||||
""" Fixed-point batch normalization layer.
|
||||
|
||||
:param shape: input/output shape (tuple/list of four int)
|
||||
:param approx: use approximate square root
|
||||
|
||||
"""
|
||||
thetas = lambda self: (self.weights, self.bias)
|
||||
nablas = lambda self: (self.nabla_weights, self.nabla_bias)
|
||||
|
||||
def __init__(self, shape, approx=True, args=None):
|
||||
assert len(shape) in (2, 3, 4)
|
||||
if len(shape) == 4:
|
||||
shape = [shape[0], shape[1] * shape[2], shape[3]]
|
||||
elif len(shape) == 2:
|
||||
shape = [shape[0], 1, shape[1]]
|
||||
tensors = (Tensor(shape, sfix) for i in range(4))
|
||||
self.X, self.Y, self.nabla_X, self.nabla_Y = tensors
|
||||
arrays = (sfix.Array(shape[2]) for i in range(4))
|
||||
self.var, self.mu, self.weights, self.bias = arrays
|
||||
arrays = (sfix.Array(shape[2]) for i in range(4))
|
||||
self.mu_hat, self.var_hat, self.nabla_weights, self.nabla_bias = arrays
|
||||
self.epsilon = 2 ** (-sfix.f + 1)
|
||||
self.momentum = 0.1
|
||||
if args != None:
|
||||
approx = 'precisebn' not in args
|
||||
self.approx = approx
|
||||
if approx:
|
||||
print('Approximate square root inverse in batch normalization')
|
||||
self.InvertSqrt = mpc_math.InvertSqrt
|
||||
else:
|
||||
print('Precise square root inverse in batch normalization')
|
||||
self.InvertSqrt = lambda x: 1 / mpc_math.sqrt(x)
|
||||
|
||||
def __repr__(self):
|
||||
return '%s(%s, approx=%s)' % \
|
||||
(type(self).__name__, self.X.sizes, self.approx)
|
||||
|
||||
def reset(self):
|
||||
self.bias.assign_all(0)
|
||||
self.weights.assign_all(1)
|
||||
self.mu_hat.assign_all(0)
|
||||
self.var_hat.assign_all(0)
|
||||
|
||||
def _output(self, batch, mu, var):
|
||||
factor = sfix.Array(len(mu))
|
||||
factor[:] = self.InvertSqrt(var[:] + self.epsilon)
|
||||
@for_range_opt_multithread(self.n_threads,
|
||||
[len(batch), self.X.sizes[1]])
|
||||
def _(i, j):
|
||||
tmp = self.weights[:] * (self.X[i][j][:] - self.mu[:]) * factor[:]
|
||||
self.Y[i][j][:] = self.bias[:] + tmp
|
||||
|
||||
def forward(self, batch, training=False):
|
||||
if training:
|
||||
d = self.X.sizes[1]
|
||||
d_in = self.X.sizes[2]
|
||||
s = sfix.Array(d_in)
|
||||
@map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in)
|
||||
def _(i, j):
|
||||
return (self.X[batch[i]][j].get_vector())
|
||||
s.assign(_())
|
||||
@multithread(self.n_threads, d_in)
|
||||
def _(base, size):
|
||||
self.mu.assign_vector(
|
||||
s.get_vector(base, size) / (len(batch) * d), base)
|
||||
@map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in)
|
||||
def _(i, j):
|
||||
item = self.X[batch[i]][j].get_vector()
|
||||
return ((item - self.mu[:]) ** 2)
|
||||
self.var.assign(_())
|
||||
@multithread(self.n_threads, d_in)
|
||||
def _(base, size):
|
||||
self.var.assign_vector(
|
||||
self.var.get_vector(base, size) / (len(batch) * d - 1),
|
||||
base)
|
||||
for x, y, in (self.mu_hat, self.mu), (self.var_hat, self.var):
|
||||
x[:] = self.momentum * y[:] + (1 - self.momentum) * x[:]
|
||||
self._output(batch, self.mu, self.var)
|
||||
if self.print_random_update:
|
||||
i = regint.get_random(64) % len(batch)
|
||||
j = regint.get_random(64) % d
|
||||
k = regint.get_random(64) % d_in
|
||||
for x in self.mu, self.var:
|
||||
print_ln('%s at %s: %s', str(x), k, x[k].reveal())
|
||||
print_ln('%s at (%s, %s, %s): in=%s out=%s',
|
||||
str(self.Y), i, j, k, self.X[i][j][k].reveal(),
|
||||
self.Y[i][j][k].reveal())
|
||||
else:
|
||||
self._output(batch, self.mu_hat, self.var_hat)
|
||||
|
||||
def backward(self, batch, compute_nabla_X=True):
|
||||
factor = Array.create_from(
|
||||
self.InvertSqrt(self.var[:] + self.epsilon))
|
||||
mynYf = self.X.same_shape()
|
||||
gamnY = self.X.same_shape()
|
||||
gamnYd = self.X.same_shape()
|
||||
nYdf = self.X.same_shape()
|
||||
d = self.X.sizes[1]
|
||||
d_in = self.X.sizes[2]
|
||||
@for_range_opt_multithread(self.n_threads, [len(batch), d])
|
||||
def _(i, j):
|
||||
tmp = self.weights[:] * self.nabla_Y[i][j][:]
|
||||
gamnY[i][j] = tmp
|
||||
gamnYd[i][j] = tmp * (self.X[i][j][:] - self.mu[:])
|
||||
mynYf[i][j] = tmp * factor[:]
|
||||
nYdf[i][j] = self.nabla_Y[i][j][:] * \
|
||||
(self.X[i][j][:] - self.mu[:]) * factor[:]
|
||||
@map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in)
|
||||
def _(i, j):
|
||||
return (self.nabla_Y[i][j][:])
|
||||
self.nabla_bias.assign(_())
|
||||
@map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in)
|
||||
def _(i, j):
|
||||
return (nYdf[i][j])
|
||||
self.nabla_weights.assign(_())
|
||||
factor3 = Array.create_from(factor[:] ** 3)
|
||||
@map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in)
|
||||
def _(i, j):
|
||||
return (mynYf[i][j])
|
||||
s1 = Array.create_from(_())
|
||||
@multithread(self.n_threads, len(s1))
|
||||
def _(base, size):
|
||||
s1.assign_vector(s1.get_vector(base, size) / (len(batch) * d), base)
|
||||
@map_sum_simple(self.n_threads, [len(batch), d], sfix, d_in)
|
||||
def _(i, j):
|
||||
return (gamnYd[i][j][:] * factor3[:])
|
||||
s2 = Array.create_from(_())
|
||||
@multithread(self.n_threads, len(s2))
|
||||
def _(base, size):
|
||||
s2.assign_vector(
|
||||
s2.get_vector(base, size) / (len(batch) * d - 1), base)
|
||||
@for_range_opt_multithread(self.n_threads, [len(batch), d])
|
||||
def _(i, j):
|
||||
self.nabla_X[i][j][:] = mynYf[i][j][:] \
|
||||
- s1[:] - (self.X[i][j][:] - self.mu[:]) * s2[:]
|
||||
if self.print_random_update:
|
||||
print_ln('backward %s', self)
|
||||
i = regint.get_random(64) % len(batch)
|
||||
j = regint.get_random(64) % d
|
||||
k = regint.get_random(64) % d_in
|
||||
for x in self.nabla_bias, self.nabla_weights:
|
||||
print_ln('%s at %s: %s', str(x), k, x[k].reveal())
|
||||
print_ln('%s at (%s, %s, %s): in=%s out=%s', str(self.Y), i, j, k,
|
||||
self.nabla_Y[i][j][k].reveal(),
|
||||
self.nabla_X[i][j][k].reveal())
|
||||
|
||||
class QuantBase(object):
|
||||
bias_before_reduction = True
|
||||
|
||||
@@ -1298,6 +1485,8 @@ class ConvBase(BaseLayer):
|
||||
self.padding.append(pad_total // 2)
|
||||
elif padding == 'VALID':
|
||||
self.padding = [0, 0]
|
||||
elif isinstance(padding, int):
|
||||
self.padding = [padding, padding]
|
||||
else:
|
||||
self.padding = padding
|
||||
|
||||
@@ -1323,6 +1512,12 @@ class ConvBase(BaseLayer):
|
||||
assert(len(output_shape) == 4)
|
||||
assert(len(weight_shape) == 4)
|
||||
|
||||
def __repr__(self):
|
||||
return '%s(%s, %s, %s, %s, %s, padding=%s, tf_weight_format=%s)' % \
|
||||
(type(self).__name__, self.X.sizes, self.weight_shape,
|
||||
self.bias_shape, self.Y.sizes, self.stride, repr(self.padding),
|
||||
self.tf_weight_format)
|
||||
|
||||
def input_from(self, player, raw=False):
|
||||
self.input_params_from(player)
|
||||
self.weights.input_from(player, budget=100000, raw=raw)
|
||||
@@ -1545,20 +1740,20 @@ class FixConv2d(Conv2d, FixBase):
|
||||
self.nabla_weights.assign_vector_by_indices(reduced, j, None, None, i)
|
||||
|
||||
if compute_nabla_X:
|
||||
assert tuple(self.padding) == (0, 0)
|
||||
assert tuple(self.stride) == (1, 1)
|
||||
reverse_weights = MultiArray(
|
||||
[n_channels_in, weights_h, weights_w, n_channels_out], sfix)
|
||||
@for_range(n_channels_out)
|
||||
def _(i):
|
||||
@for_range_opt_multithread(self.n_threads, n_channels_in)
|
||||
def _(l):
|
||||
@for_range(weights_h)
|
||||
def _(j):
|
||||
@for_range(weights_w)
|
||||
def _(k):
|
||||
@for_range(n_channels_in)
|
||||
def _(l):
|
||||
reverse_weights[l][weights_h-j-1][k][i] = \
|
||||
self.weights[i][j][weights_w-k-1][l]
|
||||
addresses = regint.inc(n_channels_out,
|
||||
self.weights[0][j][weights_w-k-1].get_address(l),
|
||||
reduce(operator.mul, self.weights.sizes[1:]))
|
||||
reverse_weights[l][weights_h-j-1][k].assign_vector(
|
||||
self.weights.value_type.load_mem(addresses))
|
||||
padded_w = inputs_w + 2 * padding_w
|
||||
padded_h = inputs_h + 2 * padding_h
|
||||
if padding_h or padding_w:
|
||||
@@ -1579,14 +1774,16 @@ class FixConv2d(Conv2d, FixBase):
|
||||
unreduced_sfix._new(res).reduce_after_mul(),
|
||||
i, None, None, j)
|
||||
if padding_h or padding_w:
|
||||
@for_range(N)
|
||||
@for_range_opt_multithread(self.n_threads, N)
|
||||
def _(i):
|
||||
@for_range(inputs_h)
|
||||
def _(j):
|
||||
@for_range(inputs_w)
|
||||
def _(k):
|
||||
jj = j + padding_w
|
||||
kk = k + padding_w
|
||||
self.nabla_X[i][j][k].assign_vector(
|
||||
output[i][j][k].get_vector())
|
||||
output[i][jj][kk].get_vector())
|
||||
|
||||
if self.debug_output:
|
||||
@for_range(len(batch))
|
||||
@@ -1806,6 +2003,7 @@ class Optimizer:
|
||||
self.report_loss = report_loss
|
||||
self.X_by_label = None
|
||||
self.print_update_average = False
|
||||
self.print_random_update = False
|
||||
self.print_losses = False
|
||||
self.print_loss_reduction = False
|
||||
self.i_epoch = MemValue(0)
|
||||
@@ -1846,6 +2044,7 @@ class Optimizer:
|
||||
|
||||
def batch_for(self, layer, batch):
|
||||
if layer in (self.layers[0], self.layers[-1]):
|
||||
assert not isinstance(layer, BatchNorm)
|
||||
return batch
|
||||
else:
|
||||
batch = regint.Array(len(batch))
|
||||
@@ -1876,6 +2075,21 @@ class Optimizer:
|
||||
if i != len(self.layers) - 1 or run_last:
|
||||
layer.forward(batch=self.batch_for(layer, batch),
|
||||
training=training)
|
||||
if self.print_random_update:
|
||||
print_ln('forward layer %s', layer)
|
||||
l = min(100, layer.Y[i].total_size())
|
||||
i = regint.get_random(64) % len(batch)
|
||||
if l < 100:
|
||||
j = 0
|
||||
else:
|
||||
j = regint.get_random(64) % \
|
||||
(layer.Y[i].total_size() - l)
|
||||
print_ln('forward layer %s at (%s, %s): %s', layer, i, j,
|
||||
layer.Y[i].to_array().get_vector(j, l).reveal())
|
||||
i = regint.get_random(64) % layer.Y[0].total_size()
|
||||
print_ln('forward layer %s vertical at %s: %s', layer, i,
|
||||
[layer.Y[j].to_array()[i].reveal()
|
||||
for j in range(len(batch))])
|
||||
if self.time_layers:
|
||||
stop_timer(100 + i)
|
||||
break_point()
|
||||
@@ -1979,7 +2193,11 @@ class Optimizer:
|
||||
label * n)
|
||||
self.forward(batch=batch, training=True)
|
||||
self.backward(batch=batch)
|
||||
if self.time_layers:
|
||||
start_timer(1000)
|
||||
self.update(i, batch=batch)
|
||||
if self.time_layers:
|
||||
stop_timer(1000)
|
||||
loss_sum.iadd(self.layers[-1].l)
|
||||
if self.print_loss_reduction:
|
||||
before = self.layers[-1].average_loss(N)
|
||||
@@ -2070,6 +2288,8 @@ class Optimizer:
|
||||
if 'nomom' in program.args:
|
||||
self.momentum = 0
|
||||
self.print_losses = 'print_losses' in program.args
|
||||
self.print_random_update = 'print_random_update' in program.args
|
||||
Layer.print_random_update = self.print_random_update
|
||||
self.time_layers = 'time_layers' in program.args
|
||||
self.revealing_correctness = not 'no_acc' in program.args
|
||||
self.layers[-1].compute_loss = not 'no_loss' in program.args
|
||||
@@ -2099,6 +2319,16 @@ class Optimizer:
|
||||
print_ln('loss %s', self.layers[-1].l.reveal())
|
||||
self.output_weights()
|
||||
return
|
||||
if 'bench10' in program.args or 'bench1' in program.args:
|
||||
n = 1 if 'bench1' in program.args else 10
|
||||
print('benchmarking %s iterations' % n)
|
||||
@for_range(n)
|
||||
def _(i):
|
||||
batch = Array.create_from(regint.inc(batch_size))
|
||||
self.forward(batch=batch, training=True)
|
||||
self.backward(batch=batch)
|
||||
self.update(0, batch=batch)
|
||||
return
|
||||
@for_range(n_runs)
|
||||
def _(i):
|
||||
if not acc_first:
|
||||
@@ -2115,6 +2345,7 @@ class Optimizer:
|
||||
cfix(self.n_correct, k=63, f=31) / n_trained,
|
||||
self.n_correct, n_trained)
|
||||
if test_X and test_Y:
|
||||
print('use test set')
|
||||
n_test = len(test_Y)
|
||||
n_correct, loss = self.reveal_correctness(test_X, test_Y,
|
||||
acc_batch_size)
|
||||
@@ -2211,7 +2442,8 @@ class Adam(Optimizer):
|
||||
util.max, abs_g.get_vector())
|
||||
scale = MemValue(sfix._new(library.AppRcr(
|
||||
max_g.v, max_g.k, max_g.f, simplex_flag=True)))
|
||||
@multithread(self.n_threads, m.total_size())
|
||||
@multithread(self.n_threads, m.total_size(),
|
||||
max_size=get_program().budget)
|
||||
def _(base, size):
|
||||
m_part = m.get_vector(base, size)
|
||||
v_part = v.get_vector(base, size)
|
||||
@@ -2333,20 +2565,33 @@ class SGD(Optimizer):
|
||||
print_ln_if((x > limit) + (x < -limit),
|
||||
'theta epoch=%s %s index=%s %s',
|
||||
i_epoch.read(), str(theta), i, x)
|
||||
index = regint.get_random(64) % len(a)
|
||||
print_ln('%s at %s: nabla=%s update=%s theta=%s', str(theta), index,
|
||||
aa[1][index], aa[0][index], aa[2][index])
|
||||
if self.print_random_update:
|
||||
print_ln('update')
|
||||
l = min(100, nabla.total_size())
|
||||
if l < 100:
|
||||
index = 0
|
||||
else:
|
||||
index = regint.get_random(64) % (nabla.total_size() - l)
|
||||
print_ln('%s at %s: nabla=%s update=%s theta=%s', str(theta),
|
||||
index, nabla.to_array().get_vector(index, l).reveal(),
|
||||
delta_theta.to_array().get_vector(index, l).reveal(),
|
||||
theta.to_array().get_vector(index, l).reveal())
|
||||
self.gamma.imul(1 - 10 ** - 6)
|
||||
|
||||
def apply_padding(input_shape, kernel_size, strides, padding):
|
||||
if isinstance(padding, int):
|
||||
input_shape = [x + 2 * padding for x in input_shape]
|
||||
padding = 'valid'
|
||||
if padding == 'valid':
|
||||
return (input_shape[0] - kernel_size[0] + 1) // strides[0], \
|
||||
res = (input_shape[0] - kernel_size[0] + 1) // strides[0], \
|
||||
(input_shape[1] - kernel_size[1] + 1) // strides[1],
|
||||
assert min(res) > 0, (input_shape, kernel_size, strides, padding)
|
||||
return res
|
||||
elif padding == 'same':
|
||||
return (input_shape[1]) // strides[0], \
|
||||
(input_shape[2]) // strides[1],
|
||||
return (input_shape[0]) // strides[0], \
|
||||
(input_shape[1]) // strides[1],
|
||||
else:
|
||||
raise Exception('invalid padding: ' + padding)
|
||||
raise Exception('invalid padding: %s' % padding)
|
||||
|
||||
class keras:
|
||||
class layers:
|
||||
@@ -2354,7 +2599,7 @@ class keras:
|
||||
Dense = lambda *args, **kwargs: ('dense', args, kwargs)
|
||||
|
||||
def Conv2D(filters, kernel_size, strides=(1, 1), padding='valid',
|
||||
activation=None):
|
||||
activation=None, input_shape=None):
|
||||
return 'conv2d', {'filters': filters, 'kernel_size': kernel_size,
|
||||
'strides': strides, 'padding': padding,
|
||||
'activation': activation}
|
||||
@@ -2369,6 +2614,13 @@ class keras:
|
||||
raise Exception('rate needs to be a power of two')
|
||||
return 'dropout', rate
|
||||
|
||||
def Activation(activation):
|
||||
assert(activation == 'relu')
|
||||
return activation,
|
||||
|
||||
def BatchNormalization():
|
||||
return 'batchnorm',
|
||||
|
||||
class optimizers:
|
||||
SGD = lambda *args, **kwargs: ('sgd', args, kwargs)
|
||||
Adam = lambda *args, **kwargs: ('adam', args, kwargs)
|
||||
@@ -2383,12 +2635,25 @@ class keras:
|
||||
def compile(self, optimizer):
|
||||
self.optimizer = optimizer
|
||||
|
||||
def compile_by_args(self, program):
|
||||
if 'adam' in program.args:
|
||||
self.optimizer = 'adam', [], {}
|
||||
elif 'amsgrad' in program.args:
|
||||
self.optimizer = 'adam', [], {'amsgrad': True}
|
||||
else:
|
||||
self.optimizer = 'sgd', [], {}
|
||||
|
||||
@property
|
||||
def trainable_variables(self):
|
||||
if self.opt == None:
|
||||
raise Exception('need to run build() or fit() first')
|
||||
return list(self.opt.thetas)
|
||||
|
||||
def summary(self):
|
||||
sizes = [var.total_size() for var in self.trainable_variables]
|
||||
print(sizes)
|
||||
print('Trainable params:', sum(sizes))
|
||||
|
||||
def build(self, input_shape, batch_size=128):
|
||||
data_input_shape = input_shape
|
||||
if self.opt != None and \
|
||||
@@ -2415,12 +2680,11 @@ class keras:
|
||||
if i == len(self.layers) - 1:
|
||||
if layer[2].get('activation', 'softmax') in \
|
||||
('softmax', 'sigmoid'):
|
||||
del layer[2]['activation']
|
||||
layer[2].pop('activation', None)
|
||||
layers.append(Dense(N, n_units, layer[1][0],
|
||||
**layer[2]))
|
||||
input_shape = layers[-1].Y.sizes
|
||||
elif name == 'conv2d':
|
||||
if len(layers) != 0:
|
||||
input_shape = layers[-1].Y.sizes
|
||||
input_shape = list(input_shape) + \
|
||||
[1] * (4 - len(input_shape))
|
||||
print (layer[1])
|
||||
@@ -2437,9 +2701,13 @@ class keras:
|
||||
output_shape = [batch_size] + list(
|
||||
apply_padding(input_shape[1:3], kernel_size,
|
||||
strides, padding)) + [filters]
|
||||
padding = padding.upper() if isinstance(padding, str) \
|
||||
else padding
|
||||
layers.append(FixConv2d(input_shape, weight_shape,
|
||||
(filters,), output_shape,
|
||||
strides, padding.upper()))
|
||||
strides, padding))
|
||||
input_shape = output_shape
|
||||
print('conv output shape', output_shape)
|
||||
elif name == 'maxpool':
|
||||
pool_size = layer[1]['pool_size']
|
||||
strides = layer[1]['strides']
|
||||
@@ -2450,16 +2718,23 @@ class keras:
|
||||
strides = (strides, strides)
|
||||
if strides == None:
|
||||
strides = pool_size
|
||||
layers.append(MaxPool(layers[-1].Y.sizes,
|
||||
layers.append(MaxPool(input_shape,
|
||||
[1] + list(strides) + [1],
|
||||
[1] + list(pool_size) + [1],
|
||||
padding.upper()))
|
||||
padding))
|
||||
input_shape = layers[-1].Y.sizes
|
||||
elif name == 'dropout':
|
||||
layers.append(Dropout(batch_size, reduce(
|
||||
operator.mul, layers[-1].Y.sizes[1:]),
|
||||
alpha=layer[1]))
|
||||
input_shape = layers[-1].Y.sizes
|
||||
elif name == 'flatten':
|
||||
pass
|
||||
elif name == 'relu':
|
||||
layers.append(Relu(layers[-1].Y.sizes))
|
||||
elif name == 'batchnorm':
|
||||
input_shape = layers[-1].Y.sizes
|
||||
layers.append(BatchNorm(layers[-1].Y.sizes))
|
||||
else:
|
||||
raise Exception(layer[0] + ' not supported')
|
||||
if layers[-1].d_out == 1:
|
||||
|
||||
@@ -1493,6 +1493,7 @@ class PackedIndexStructure(object):
|
||||
self.l[i] = [0] * self.elements_per_block
|
||||
time()
|
||||
print_ln('packed ORAM init %s/%s', i, real_init_rounds)
|
||||
print_ln('packed ORAM init done')
|
||||
print('index initialized, size', size)
|
||||
def translate_index(self, index):
|
||||
""" Bit slicing *index* according parameters. Output is tuple
|
||||
|
||||
@@ -580,10 +580,19 @@ class Program(object):
|
||||
|
||||
@staticmethod
|
||||
def read_tapes(schedule):
|
||||
m = re.search(r'([^/]*)\.mpc', schedule)
|
||||
if m:
|
||||
schedule = m.group(1)
|
||||
if not os.path.exists(schedule):
|
||||
schedule = 'Programs/Schedules/%s.sch' % schedule
|
||||
|
||||
lines = open(schedule).readlines()
|
||||
try:
|
||||
lines = open(schedule).readlines()
|
||||
except FileNotFoundError:
|
||||
print('%s not found, have you compiled the program?' % schedule,
|
||||
file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
for tapename in lines[2].split(' '):
|
||||
yield tapename.strip()
|
||||
|
||||
|
||||
@@ -1675,6 +1675,13 @@ class localint(Tape._no_truth):
|
||||
__ne__ = lambda self, other: localint(self._v != other)
|
||||
|
||||
class personal(Tape._no_truth):
|
||||
""" Value known to one player. Supports operations with public
|
||||
values and personal values known to the same player. Can be used
|
||||
with :py:func:`~Compiler.library.print_ln_to`.
|
||||
|
||||
:param player: player (int)
|
||||
:param value: cleartext value (cint, cfix, cfloat) or array thereof
|
||||
"""
|
||||
def __init__(self, player, value):
|
||||
assert value is not NotImplemented
|
||||
assert not isinstance(value, _secret)
|
||||
@@ -1685,8 +1692,24 @@ class personal(Tape._no_truth):
|
||||
self._v = value
|
||||
|
||||
def binary_output(self):
|
||||
""" Write binary output to
|
||||
``Player-Data/Binary-Output-P<playerno>-<threadno>`` if
|
||||
supported by underlying type. Player must be known at compile time."""
|
||||
self._v.binary_output(self.player)
|
||||
|
||||
def reveal_to(self, player):
|
||||
""" Pass personal value to another player. """
|
||||
if isinstance(self._v, Array):
|
||||
source = self._v[:]
|
||||
else:
|
||||
source = self._v
|
||||
source = cint.conv(source)
|
||||
res = cint(size=source.size)
|
||||
sendpersonal(source.size, player, res, self.player, source)
|
||||
if isinstance(self._v, Array):
|
||||
res = Array.create_from(res)
|
||||
return personal(player, res)
|
||||
|
||||
def bit_decompose(self, length):
|
||||
return [personal(self.player, x) for x in self._v.bit_decompose(length)]
|
||||
|
||||
@@ -1858,8 +1881,13 @@ class _secret(_register, _secret_structure):
|
||||
@vectorized_classmethod
|
||||
@set_instruction_type
|
||||
def get_random_input_mask_for(cls, player):
|
||||
res = cls()
|
||||
inputmask(res, player)
|
||||
""" Secret random input mask according to security model.
|
||||
|
||||
:return: mask (sint), mask (personal cint)
|
||||
:param size: vector size (int, default 1)
|
||||
"""
|
||||
res = cls(), personal(player, cls.clear_type())
|
||||
inputmask(res[0], res[1]._v, player)
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@@ -2071,15 +2099,13 @@ class _secret(_register, _secret_structure):
|
||||
@set_instruction_type
|
||||
def reveal_to(self, player):
|
||||
""" Reveal secret value to :py:obj:`player`.
|
||||
Result written to ``Player-Data/Private-Output-P<player>``
|
||||
|
||||
:param player: int
|
||||
:returns: value to be used with :py:func:`~Compiler.library.print_ln_to`
|
||||
:returns: :py:class:`personal`
|
||||
"""
|
||||
masked = self.__class__()
|
||||
res = personal(player, self.clear_type())
|
||||
startprivateoutput(masked, self, player)
|
||||
stopprivateoutput(res._v, masked.reveal(), player)
|
||||
mask = self.get_random_input_mask_for(player)
|
||||
masked = self + mask[0]
|
||||
res = personal(player, masked.reveal() - mask[1])
|
||||
return res
|
||||
|
||||
|
||||
@@ -2633,21 +2659,20 @@ class sint(_secret, _int):
|
||||
@vectorize
|
||||
def reveal_to(self, player):
|
||||
""" Reveal secret value to :py:obj:`player`.
|
||||
Result potentially written to
|
||||
``Player-Data/Private-Output-P<player>``, but not if
|
||||
:py:obj:`player` is a :py:class:`regint`.
|
||||
|
||||
:param player: public integer (int/regint/cint):
|
||||
:returns: value to be used with :py:func:`~Compiler.library.print_ln_to`
|
||||
:param player: public integer (int/regint/cint)
|
||||
:returns: :py:class:`personal`
|
||||
"""
|
||||
if not util.is_constant(player) or self.size > 1:
|
||||
if not util.is_constant(player):
|
||||
secret_mask = sint()
|
||||
player_mask = cint()
|
||||
inputmaskreg(secret_mask, player_mask, regint.conv(player))
|
||||
return personal(player,
|
||||
(self + secret_mask).reveal() - player_mask)
|
||||
else:
|
||||
return super(sint, self).reveal_to(player)
|
||||
res = personal(player, self.clear_type())
|
||||
privateoutput(self.size, player, res._v, self)
|
||||
return res
|
||||
|
||||
def private_division(self, divisor, active=True, dividend_length=None,
|
||||
divisor_length=None):
|
||||
@@ -4366,12 +4391,9 @@ class sfix(_fix):
|
||||
|
||||
def reveal_to(self, player):
|
||||
""" Reveal secret value to :py:obj:`player`.
|
||||
Raw representation possibly written to
|
||||
``Player-Data/Private-Output-P<player>``, but not if
|
||||
:py:obj:`player` is a :py:class:`regint`.
|
||||
|
||||
:param player: public integer (int/regint/cint)
|
||||
:returns: value to be used with :py:func:`~Compiler.library.print_ln_to`
|
||||
:returns: :py:class:`personal`
|
||||
"""
|
||||
return personal(player, cfix._new(self.v.reveal_to(player)._v,
|
||||
self.k, self.f))
|
||||
@@ -5221,6 +5243,9 @@ class Array(_vectorizable):
|
||||
return self.assign(value, addresses)
|
||||
self._store(value, self.get_address(index))
|
||||
|
||||
def to_array(self):
|
||||
return self
|
||||
|
||||
def get_sub(self, start, stop=None):
|
||||
if stop is None:
|
||||
stop = start
|
||||
@@ -5471,6 +5496,10 @@ class Array(_vectorizable):
|
||||
""" Insecure shuffle in place. """
|
||||
self.assign_vector(self.get(regint.inc(len(self)).shuffle()))
|
||||
|
||||
def randomize(self, *args):
|
||||
""" Randomize according to data type. """
|
||||
self.assign_vector(self.value_type.get_random(*args, size=len(self)))
|
||||
|
||||
def reveal(self):
|
||||
""" Reveal the whole array.
|
||||
|
||||
@@ -5596,6 +5625,9 @@ class SubMultiArray(_vectorizable):
|
||||
def __iter__(self):
|
||||
return (self[i] for i in range(len(self)))
|
||||
|
||||
def to_array(self):
|
||||
return Array(self.total_size(), self.value_type, address=self.address)
|
||||
|
||||
def assign_all(self, value):
|
||||
""" Assign the same value to all entries.
|
||||
|
||||
@@ -5958,6 +5990,7 @@ class SubMultiArray(_vectorizable):
|
||||
"""
|
||||
assert len(self.sizes) == 2
|
||||
assert len(other.sizes) == 2
|
||||
assert other.address != None
|
||||
if indices is None:
|
||||
assert self.sizes[1] == other.sizes[1]
|
||||
indices = [regint.inc(i) for i in self.sizes + other.sizes[::-1]]
|
||||
@@ -6145,6 +6178,16 @@ class SubMultiArray(_vectorizable):
|
||||
n = self.sizes[0]
|
||||
return self.array.get(regint.inc(n, 0, n + 1))
|
||||
|
||||
def randomize(self, *args):
|
||||
""" Randomize according to data type. """
|
||||
if self.total_size() < program.options.budget:
|
||||
self.assign_vector(
|
||||
self.value_type.get_random(*args, size=self.total_size()))
|
||||
else:
|
||||
@library.for_range(self.sizes[0])
|
||||
def _(i):
|
||||
self[i].randomize(*args)
|
||||
|
||||
def reveal_list(self):
|
||||
""" Reveal as list. """
|
||||
return list(self.get_vector().reveal())
|
||||
@@ -6251,6 +6294,22 @@ class Matrix(MultiArray):
|
||||
MultiArray.__init__(self, [rows, columns], value_type, debug=debug, \
|
||||
address=address)
|
||||
|
||||
def get_column(self, index):
|
||||
""" Get column as vector.
|
||||
|
||||
:param index: regint/cint/int
|
||||
"""
|
||||
assert self.value_type.n_elements() == 1
|
||||
addresses = regint.inc(self.sizes[0], self.address + index,
|
||||
self.sizes[1])
|
||||
return self.value_type.load_mem(addresses)
|
||||
|
||||
def get_column_by_row_indices(self, rows, column):
|
||||
assert self.value_type.n_elements() == 1
|
||||
addresses = rows * self.sizes[1] + \
|
||||
regint.inc(len(rows), self.address + column, 0)
|
||||
return self.value_type.load_mem(addresses)
|
||||
|
||||
def set_column(self, index, vector):
|
||||
""" Change column.
|
||||
|
||||
|
||||
@@ -47,11 +47,18 @@ Rq_Element FHE_PK::sample_secret_key(PRNG& G)
|
||||
}
|
||||
|
||||
void FHE_PK::KeyGen(Rq_Element& sk, PRNG& G, int noise_boost)
|
||||
{
|
||||
Rq_Element a(*this);
|
||||
a.randomize(G);
|
||||
partial_key_gen(sk, a, G, noise_boost);
|
||||
}
|
||||
|
||||
void FHE_PK::partial_key_gen(const Rq_Element& sk, const Rq_Element& a, PRNG& G,
|
||||
int noise_boost)
|
||||
{
|
||||
FHE_PK& PK = *this;
|
||||
|
||||
// Generate the main public key
|
||||
PK.a0.randomize(G);
|
||||
a0 = a;
|
||||
|
||||
// b0=a0*s+p*e0
|
||||
Rq_Element e0((*PK.params).FFTD(),evaluation,evaluation);
|
||||
@@ -77,9 +84,6 @@ void FHE_PK::KeyGen(Rq_Element& sk, PRNG& G, int noise_boost)
|
||||
mul(es,es,PK.pr);
|
||||
add(PK.Sw_b,PK.Sw_b,es);
|
||||
|
||||
// Lowering level as we only decrypt at level 0
|
||||
sk.lower_level();
|
||||
|
||||
// bs=bs-p1*s^2
|
||||
Rq_Element s2;
|
||||
mul(s2,sk,sk); // Mult at level 0
|
||||
@@ -334,7 +338,7 @@ void FHE_SK::check(const FHE_Params& params, const FHE_PK& pk,
|
||||
template<class FD>
|
||||
void FHE_SK::check(const FHE_PK& pk, const FD& FieldD)
|
||||
{
|
||||
check(*params, pk, pr);
|
||||
check(*params, pk, FieldD.get_prime());
|
||||
pk.check_noise(*this);
|
||||
if (decrypt(pk.encrypt(Plaintext_<FD>(FieldD)), FieldD) !=
|
||||
Plaintext_<FD>(FieldD))
|
||||
|
||||
@@ -150,6 +150,8 @@ class FHE_PK
|
||||
|
||||
Rq_Element sample_secret_key(PRNG& G);
|
||||
void KeyGen(Rq_Element& sk, PRNG& G, int noise_boost = 1);
|
||||
void partial_key_gen(const Rq_Element& sk, const Rq_Element& a, PRNG& G,
|
||||
int noise_boost = 1);
|
||||
|
||||
void check_noise(const FHE_SK& sk) const;
|
||||
void check_noise(const Rq_Element& x, bool check_modulo = false) const;
|
||||
|
||||
@@ -3,6 +3,11 @@
|
||||
#include "FHE/Ring_Element.h"
|
||||
#include "Tools/Exceptions.h"
|
||||
|
||||
FHE_Params::FHE_Params(int n_mults) :
|
||||
FFTData(n_mults + 1), Chi(0.7), sec_p(-1), matrix_dim(1)
|
||||
{
|
||||
}
|
||||
|
||||
void FHE_Params::set(const Ring& R,
|
||||
const vector<bigint>& primes)
|
||||
{
|
||||
@@ -24,6 +29,14 @@ void FHE_Params::set_sec(int sec)
|
||||
throw runtime_error("distributed decryption bound is zero");
|
||||
}
|
||||
|
||||
void FHE_Params::set_matrix_dim(int matrix_dim)
|
||||
{
|
||||
assert(matrix_dim > 0);
|
||||
if (FFTData[0].get_prime() != 0)
|
||||
throw runtime_error("cannot change matrix dimension after parameter generation");
|
||||
this->matrix_dim = matrix_dim;
|
||||
}
|
||||
|
||||
bigint FHE_Params::Q() const
|
||||
{
|
||||
bigint res = FFTData[0].get_prime();
|
||||
@@ -40,6 +53,7 @@ void FHE_Params::pack(octetStream& o) const
|
||||
Chi.pack(o);
|
||||
Bval.pack(o);
|
||||
o.store(sec_p);
|
||||
o.store(matrix_dim);
|
||||
}
|
||||
|
||||
void FHE_Params::unpack(octetStream& o)
|
||||
@@ -52,6 +66,7 @@ void FHE_Params::unpack(octetStream& o)
|
||||
Chi.unpack(o);
|
||||
Bval.unpack(o);
|
||||
o.get(sec_p);
|
||||
o.get(matrix_dim);
|
||||
}
|
||||
|
||||
bool FHE_Params::operator!=(const FHE_Params& other) const
|
||||
|
||||
@@ -26,10 +26,11 @@ class FHE_Params
|
||||
// Data for distributed decryption
|
||||
int sec_p;
|
||||
bigint Bval;
|
||||
int matrix_dim;
|
||||
|
||||
public:
|
||||
|
||||
FHE_Params(int n_mults = 1) : FFTData(n_mults + 1), Chi(0.7), sec_p(-1) {}
|
||||
FHE_Params(int n_mults = 1);
|
||||
|
||||
int n_mults() const { return FFTData.size() - 1; }
|
||||
|
||||
@@ -37,6 +38,9 @@ class FHE_Params
|
||||
void set(const vector<bigint>& primes);
|
||||
void set_sec(int sec);
|
||||
|
||||
void set_matrix_dim(int matrix_dim);
|
||||
int get_matrix_dim() const { return matrix_dim; }
|
||||
|
||||
const vector<FFT_Data>& FFTD() const { return FFTData; }
|
||||
|
||||
const bigint& p0() const { return FFTData[0].get_prime(); }
|
||||
|
||||
@@ -47,7 +47,7 @@ bool same_word_length(int l1, int l2)
|
||||
|
||||
template <>
|
||||
int generate_semi_setup(int plaintext_length, int sec,
|
||||
FHE_Params& params, FFT_Data& FTD, bool round_up)
|
||||
FHE_Params& params, FFT_Data& FTD, bool round_up, int n)
|
||||
{
|
||||
int m = 1024;
|
||||
int lgp = plaintext_length;
|
||||
@@ -58,7 +58,7 @@ int generate_semi_setup(int plaintext_length, int sec,
|
||||
while (true)
|
||||
{
|
||||
tmp_params = params;
|
||||
SemiHomomorphicNoiseBounds nb(p, phi_N(m), 1, sec,
|
||||
SemiHomomorphicNoiseBounds nb(p, phi_N(m), n, sec,
|
||||
numBits(NonInteractiveProof::slack(sec, phi_N(m))), true, tmp_params);
|
||||
bigint p1 = 2 * p * m, p0 = p;
|
||||
while (nb.min_p0(params.n_mults() > 0, p1) > p0)
|
||||
@@ -89,14 +89,14 @@ int generate_semi_setup(int plaintext_length, int sec,
|
||||
|
||||
template <>
|
||||
int generate_semi_setup(int plaintext_length, int sec,
|
||||
FHE_Params& params, P2Data& P2D, bool round_up)
|
||||
FHE_Params& params, P2Data& P2D, bool round_up, int n)
|
||||
{
|
||||
if (params.n_mults() > 0)
|
||||
throw runtime_error("only implemented for 0-level BGV");
|
||||
gf2n_short::init_field(plaintext_length);
|
||||
int m;
|
||||
char_2_dimension(m, plaintext_length);
|
||||
SemiHomomorphicNoiseBounds nb(2, phi_N(m), 1, sec,
|
||||
SemiHomomorphicNoiseBounds nb(2, phi_N(m), n, sec,
|
||||
numBits(NonInteractiveProof::slack(sec, phi_N(m))), true, params);
|
||||
int lgp0 = numBits(nb.min_p0(false, 0));
|
||||
int extra_slack = common_semi_setup(params, m, 2, lgp0, -1, round_up);
|
||||
@@ -590,6 +590,9 @@ void char_2_dimension(int& m, int& lg2)
|
||||
m=5797;
|
||||
lg2=40;
|
||||
break;
|
||||
case 16:
|
||||
m = 13107;
|
||||
break;
|
||||
default:
|
||||
throw runtime_error("field size not supported");
|
||||
break;
|
||||
|
||||
@@ -52,7 +52,7 @@ void generate_setup(int nparties, int lgp, int lg2,
|
||||
// semi-homomorphic, includes slack
|
||||
template <class FD>
|
||||
int generate_semi_setup(int plaintext_length, int sec,
|
||||
FHE_Params& params, FD& FieldD, bool round_up);
|
||||
FHE_Params& params, FD& FieldD, bool round_up, int n = 1);
|
||||
|
||||
// field-independent semi-homomorphic setup
|
||||
int common_semi_setup(FHE_Params& params, int m, bigint p, int lgp0, int lgp1,
|
||||
|
||||
@@ -39,6 +39,7 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p,
|
||||
bigint B_clean_not_top_gear = B_clean << int(ceil(sec / 2.));
|
||||
B_clean = max(B_clean_not_top_gear, B_clean_top_gear);
|
||||
B_scale = (c1 + c2 * V_s) * p * sqrt(phi_m / 12.0);
|
||||
int matrix_dim = params.get_matrix_dim();
|
||||
#ifdef NOISY
|
||||
cout << "p * sqrt(phi(m) / 12): " << p * sqrt(phi_m / 12.0) << endl;
|
||||
cout << "V_s: " << V_s << endl;
|
||||
@@ -48,9 +49,11 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p,
|
||||
cout << "log(slack): " << slack << endl;
|
||||
cout << "B_clean: " << B_clean << endl;
|
||||
cout << "B_scale: " << B_scale << endl;
|
||||
cout << "matrix dimension: " << matrix_dim << endl;
|
||||
#endif
|
||||
|
||||
drown = 1 + n * (bigint(1) << sec);
|
||||
assert(matrix_dim > 0);
|
||||
drown = 1 + matrix_dim * n * (bigint(1) << sec);
|
||||
}
|
||||
|
||||
bigint SemiHomomorphicNoiseBounds::min_p0(const bigint& p1)
|
||||
|
||||
@@ -50,6 +50,7 @@ void Ring_Element::prepare_push()
|
||||
|
||||
void Ring_Element::allocate()
|
||||
{
|
||||
assert(FFTD);
|
||||
element.resize(FFTD->phi_m());
|
||||
}
|
||||
|
||||
|
||||
@@ -109,6 +109,13 @@ void mul(Rq_Element& ans,const Rq_Element& a,const bigint& b)
|
||||
}
|
||||
}
|
||||
|
||||
void Rq_Element::add(octetStream& os)
|
||||
{
|
||||
Rq_Element tmp(*this);
|
||||
tmp.unpack(os);
|
||||
*this += tmp;
|
||||
}
|
||||
|
||||
void Rq_Element::randomize(PRNG& G,int l)
|
||||
{
|
||||
set_level(l);
|
||||
@@ -246,7 +253,7 @@ void Rq_Element::Scale(const bigint& p)
|
||||
|
||||
// Now add delta back onto a0
|
||||
Rq_Element bb(b0,b1);
|
||||
add(*this,*this,bb);
|
||||
::add(*this,*this,bb);
|
||||
|
||||
// Now divide by p1 mod p0
|
||||
modp p1_inv,pp;
|
||||
|
||||
@@ -93,12 +93,14 @@ protected:
|
||||
friend void mul(Rq_Element& ans,const Rq_Element& a,const Rq_Element& b);
|
||||
friend void mul(Rq_Element& ans,const Rq_Element& a,const bigint& b);
|
||||
|
||||
void add(octetStream& os);
|
||||
|
||||
template<class S>
|
||||
Rq_Element& operator+=(const vector<S>& other);
|
||||
|
||||
Rq_Element& operator+=(const Rq_Element& other) { add(*this, *this, other); return *this; }
|
||||
Rq_Element& operator+=(const Rq_Element& other) { ::add(*this, *this, other); return *this; }
|
||||
|
||||
Rq_Element operator+(const Rq_Element& b) const { Rq_Element res(*this); add(res, *this, b); return res; }
|
||||
Rq_Element operator+(const Rq_Element& b) const { Rq_Element res(*this); ::add(res, *this, b); return res; }
|
||||
Rq_Element operator-(const Rq_Element& b) const { Rq_Element res(*this); sub(res, *this, b); return res; }
|
||||
template <class T>
|
||||
Rq_Element operator*(const T& b) const { Rq_Element res(*this); mul(res, *this, b); return res; }
|
||||
@@ -176,7 +178,7 @@ Rq_Element& Rq_Element::operator+=(const vector<S>& other)
|
||||
{
|
||||
Rq_Element tmp = *this;
|
||||
tmp.from(Iterator<S>(other), lev);
|
||||
add(*this, *this, tmp);
|
||||
::add(*this, *this, tmp);
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
@@ -203,7 +203,7 @@ template<class FD>
|
||||
void PartSetup<FD>::secure_init(Player& P, MachineBase& machine,
|
||||
int plaintext_length, int sec)
|
||||
{
|
||||
::secure_init(*this, P, machine, plaintext_length, sec);
|
||||
::secure_init(*this, P, machine, plaintext_length, sec, params);
|
||||
}
|
||||
|
||||
template<class FD>
|
||||
|
||||
@@ -130,6 +130,13 @@ void Multiplier<FD>::report_size(ReportType type, MemoryUsage& res)
|
||||
res += memory_usage;
|
||||
}
|
||||
|
||||
template<class FD>
|
||||
const vector<Ciphertext>& Multiplier<FD>::get_multiplicands(
|
||||
const vector<vector<Ciphertext> >& others_ct, const FHE_PK&)
|
||||
{
|
||||
return others_ct[P.get_full_player().get_player(-P.get_offset())];
|
||||
}
|
||||
|
||||
|
||||
template class Multiplier<FFT_Data>;
|
||||
template class Multiplier<P2Data>;
|
||||
|
||||
@@ -55,6 +55,9 @@ public:
|
||||
size_t report_size(ReportType type);
|
||||
void report_size(ReportType type, MemoryUsage& res);
|
||||
size_t report_volatile() { return volatile_capacity; }
|
||||
|
||||
const vector<Ciphertext>& get_multiplicands(
|
||||
const vector<vector<Ciphertext>>& others_ct, const FHE_PK&);
|
||||
};
|
||||
|
||||
#endif /* FHEOFFLINE_MULTIPLIER_H_ */
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "Math/Setup.h"
|
||||
#include "FHEOffline/Proof.h"
|
||||
#include "FHEOffline/PairwiseMachine.h"
|
||||
#include "FHEOffline/TemiSetup.h"
|
||||
#include "Tools/Commit.h"
|
||||
#include "Tools/Bundle.h"
|
||||
#include "Processor/OnlineOptions.h"
|
||||
@@ -53,7 +54,7 @@ void PairwiseSetup<FD>::init(const Player& P, int sec, int plaintext_length,
|
||||
template <class FD>
|
||||
void PairwiseSetup<FD>::secure_init(Player& P, PairwiseMachine& machine, int plaintext_length, int sec)
|
||||
{
|
||||
::secure_init(*this, P, machine, plaintext_length, sec);
|
||||
::secure_init(*this, P, machine, plaintext_length, sec, params);
|
||||
alpha = FieldD;
|
||||
machine.sk = FHE_SK(params, FieldD.get_prime());
|
||||
for (auto& pk : machine.other_pks)
|
||||
@@ -62,13 +63,14 @@ void PairwiseSetup<FD>::secure_init(Player& P, PairwiseMachine& machine, int pla
|
||||
|
||||
template <class T, class U>
|
||||
void secure_init(T& setup, Player& P, U& machine,
|
||||
int plaintext_length, int sec)
|
||||
int plaintext_length, int sec, FHE_Params& params)
|
||||
{
|
||||
machine.sec = sec;
|
||||
sec = max(sec, 40);
|
||||
machine.drown_sec = sec;
|
||||
string filename = PREP_DIR + T::name() + "-"
|
||||
+ to_string(plaintext_length) + "-" + to_string(sec) + "-"
|
||||
+ to_string(params.get_matrix_dim()) + "-"
|
||||
+ OnlineOptions::singleton.prime.get_str() + "-"
|
||||
+ to_string(CowGearOptions::singleton.top_gear()) + "-P"
|
||||
+ to_string(P.my_num()) + "-" + to_string(P.num_players());
|
||||
@@ -85,7 +87,6 @@ void secure_init(T& setup, Player& P, U& machine,
|
||||
{
|
||||
cout << "Finding parameters for security " << sec << " and field size ~2^"
|
||||
<< plaintext_length << endl;
|
||||
setup.params = setup.params.n_mults();
|
||||
setup.generate(P, machine, plaintext_length, sec);
|
||||
setup.check(P, machine);
|
||||
octetStream os;
|
||||
@@ -208,5 +209,8 @@ void PairwiseSetup<FD>::set_alphai(T alphai)
|
||||
template class PairwiseSetup<FFT_Data>;
|
||||
template class PairwiseSetup<P2Data>;
|
||||
|
||||
template void secure_init(PartSetup<FFT_Data>&, Player&, MachineBase&, int, int);
|
||||
template void secure_init(PartSetup<P2Data>&, Player&, MachineBase&, int, int);
|
||||
template void secure_init(PartSetup<FFT_Data>&, Player&, MachineBase&, int, int, FHE_Params& params);
|
||||
template void secure_init(PartSetup<P2Data>&, Player&, MachineBase&, int, int, FHE_Params& params);
|
||||
|
||||
template void secure_init(TemiSetup<FFT_Data>&, Player&, MachineBase&, int, int, FHE_Params& params);
|
||||
template void secure_init(TemiSetup<P2Data>&, Player&, MachineBase&, int, int, FHE_Params& params);
|
||||
|
||||
@@ -15,7 +15,7 @@ class MachineBase;
|
||||
|
||||
template <class T, class U>
|
||||
void secure_init(T& setup, Player& P, U& machine,
|
||||
int plaintext_length, int sec);
|
||||
int plaintext_length, int sec, FHE_Params& params);
|
||||
|
||||
template <class FD>
|
||||
class PairwiseSetup
|
||||
|
||||
@@ -18,7 +18,12 @@ void SimpleDistDecrypt<FD>::reshare(Plaintext<typename FD::T, FD, typename FD::S
|
||||
EncCommitBase<typename FD::T, FD, typename FD::S>& EC)
|
||||
{
|
||||
(void)EC;
|
||||
m = reshare(cm);
|
||||
}
|
||||
|
||||
template <class FD>
|
||||
Plaintext_<FD> SimpleDistDecrypt<FD>::reshare(const Ciphertext& cm)
|
||||
{
|
||||
PRNG G;
|
||||
G.ReSeed();
|
||||
this->f.randomize(G, Full);
|
||||
@@ -27,10 +32,13 @@ void SimpleDistDecrypt<FD>::reshare(Plaintext<typename FD::T, FD, typename FD::S
|
||||
this->run(cm);
|
||||
|
||||
// Step 4
|
||||
Plaintext_<FD> m(this->f.get_field());
|
||||
if (this->P.my_num()==0)
|
||||
{ sub(m,this->mf,this->f); }
|
||||
else
|
||||
{ m=this->f; m.negate(); }
|
||||
|
||||
return m;
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ public:
|
||||
void reshare(Plaintext<typename FD::T, FD, typename FD::S>& m,
|
||||
const Ciphertext& cm,
|
||||
EncCommitBase<typename FD::T, FD, typename FD::S>& EC);
|
||||
Plaintext_<FD> reshare(const Ciphertext& cm);
|
||||
};
|
||||
|
||||
#endif /* FHEOFFLINE_SIMPLEDISTDECRYPT_H_ */
|
||||
|
||||
59
FHEOffline/TemiSetup.cpp
Normal file
59
FHEOffline/TemiSetup.cpp
Normal file
@@ -0,0 +1,59 @@
|
||||
/*
|
||||
* TemiSetup.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "TemiSetup.h"
|
||||
#include "PairwiseSetup.h"
|
||||
#include "FHE/NTL-Subs.h"
|
||||
#include "Protocols/HemiOptions.h"
|
||||
|
||||
template<class FD>
|
||||
TemiSetup<FD>::TemiSetup()
|
||||
{
|
||||
this->params = FHE_Params(0);
|
||||
this->pk = {this->params, 0};
|
||||
this->sk = {this->params, 0};
|
||||
this->calpha = this->params;
|
||||
this->params.set_matrix_dim(
|
||||
HemiOptions::singleton.plain_matmul ?
|
||||
1 : OnlineOptions::singleton.batch_size);
|
||||
}
|
||||
|
||||
template<class FD>
|
||||
void TemiSetup<FD>::secure_init(Player& P, int plaintext_length)
|
||||
{
|
||||
MachineBase machine;
|
||||
::secure_init(*this, P, machine, plaintext_length, 0, this->params);
|
||||
}
|
||||
|
||||
template<class FD>
|
||||
void TemiSetup<FD>::generate(Player& P, MachineBase&,
|
||||
int plaintext_length, int sec)
|
||||
{
|
||||
generate_semi_setup(plaintext_length, sec, this->params, this->FieldD,
|
||||
false, P.num_players());
|
||||
this->sk = {this->params, this->FieldD.get_prime()};
|
||||
this->pk = {this->params, this->FieldD.get_prime()};
|
||||
}
|
||||
|
||||
template<class FD>
|
||||
void TemiSetup<FD>::key_and_mac_generation(Player& P, MachineBase&, int,
|
||||
true_type)
|
||||
{
|
||||
Rq_Element a(this->params);
|
||||
GlobalPRNG GG(P);
|
||||
a.randomize(GG);
|
||||
SeededPRNG G;
|
||||
auto sk = this->pk.sample_secret_key(G);
|
||||
this->sk.assign(sk);
|
||||
this->pk.partial_key_gen(sk, a, G);
|
||||
TreeSum<Rq_Element> ts;
|
||||
vector<Rq_Element> pks;
|
||||
pks.push_back(this->pk.b());
|
||||
ts.run(pks, P);
|
||||
this->pk.assign(this->pk.a(), pks[0]);
|
||||
}
|
||||
|
||||
template class TemiSetup<FFT_Data>;
|
||||
template class TemiSetup<P2Data>;
|
||||
34
FHEOffline/TemiSetup.h
Normal file
34
FHEOffline/TemiSetup.h
Normal file
@@ -0,0 +1,34 @@
|
||||
/*
|
||||
* TemiSetup.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef FHEOFFLINE_TEMISETUP_H_
|
||||
#define FHEOFFLINE_TEMISETUP_H_
|
||||
|
||||
#include "FHE/FHE_Keys.h"
|
||||
#include "FHEOffline/SimpleMachine.h"
|
||||
|
||||
template<class FD>
|
||||
class TemiSetup : public PartSetup<FD>
|
||||
{
|
||||
public:
|
||||
static string name()
|
||||
{
|
||||
return "TemiParams";
|
||||
}
|
||||
|
||||
static string protocol_name(int)
|
||||
{
|
||||
return "Temi";
|
||||
}
|
||||
|
||||
TemiSetup();
|
||||
|
||||
void secure_init(Player& P, int plaintext_length);
|
||||
void generate(Player& P, MachineBase&, int plaintext_length, int sec);
|
||||
|
||||
void key_and_mac_generation(Player& P, MachineBase&, int, true_type);
|
||||
};
|
||||
|
||||
#endif /* FHEOFFLINE_TEMISETUP_H_ */
|
||||
@@ -47,11 +47,11 @@ inline void Memory<T>::check_index(Integer index) const
|
||||
ss << T::type_string() << " memory overflow: " << i << "/" << vector<T>::size();
|
||||
throw Processor_Error(ss.str());
|
||||
}
|
||||
#endif
|
||||
#ifdef DEBUG_MEMORY
|
||||
cout << typeid(T).name() << " at " << this << " index " << i << ": "
|
||||
<< vector<T>::operator[](i) << endl;
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class T>
|
||||
|
||||
@@ -122,6 +122,7 @@ public:
|
||||
static const bool dishonest_majority = false;
|
||||
static const bool variable_players = false;
|
||||
static const bool needs_ot = false;
|
||||
static const bool has_mac = false;
|
||||
|
||||
static string type_string() { return "replicated secret"; }
|
||||
static string phase_name() { return "Replicated computation"; }
|
||||
|
||||
@@ -49,6 +49,7 @@ public:
|
||||
static const bool dishonest_majority = T::dishonest_majority;
|
||||
static const bool variable_players = T::variable_players;
|
||||
static const bool needs_ot = T::needs_ot;
|
||||
static const bool has_mac = T::has_mac;
|
||||
static const bool expensive_triples = false;
|
||||
|
||||
static const int default_length = 64;
|
||||
|
||||
@@ -55,7 +55,7 @@
|
||||
X(BITDECC, PROC.bitdecc(EXTRA, C0)) \
|
||||
X(SHRCBI, C0 = PC1 >> IMM) \
|
||||
X(SHLCBI, C0 = PC1 << IMM) \
|
||||
X(LDBITS, S0.load_clear(REG1, IMM)) \
|
||||
X(LDBITS, S0.load_clear(REG1, int(IMM))) \
|
||||
X(LDMSB, PROC.mem_op(SIZE, PROC.S, MMS, R0, IMM)) \
|
||||
X(STMSB, PROC.mem_op(SIZE, MMS, PROC.S, IMM, R0)) \
|
||||
X(LDMCB, PROC.mem_op(SIZE, PROC.C, MMC, R0, IMM)) \
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
#include "Protocols/Shamir.hpp"
|
||||
#include "Protocols/ShamirMC.hpp"
|
||||
#include "Protocols/MaliciousShamirMC.hpp"
|
||||
#include "Protocols/MaliciousShamirPO.hpp"
|
||||
#include "Protocols/MAC_Check_Base.hpp"
|
||||
#include "Protocols/Beaver.hpp"
|
||||
#include "Protocols/Spdz2kPrep.hpp"
|
||||
|
||||
37
Machines/temi-party.cpp
Normal file
37
Machines/temi-party.cpp
Normal file
@@ -0,0 +1,37 @@
|
||||
/*
|
||||
* temi-party.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "Protocols/TemiShare.h"
|
||||
#include "Math/gfp.h"
|
||||
#include "Math/gf2n.h"
|
||||
#include "FHE/P2Data.h"
|
||||
#include "Tools/ezOptionParser.h"
|
||||
#include "GC/SemiSecret.h"
|
||||
#include "GC/SemiPrep.h"
|
||||
|
||||
#include "Processor/FieldMachine.hpp"
|
||||
#include "Protocols/TemiPrep.hpp"
|
||||
#include "Processor/Data_Files.hpp"
|
||||
#include "Processor/Instruction.hpp"
|
||||
#include "Processor/Machine.hpp"
|
||||
#include "Protocols/SemiPrep.hpp"
|
||||
#include "Protocols/SemiInput.hpp"
|
||||
#include "Protocols/MAC_Check_Base.hpp"
|
||||
#include "Protocols/MAC_Check.hpp"
|
||||
#include "Protocols/SemiMC.hpp"
|
||||
#include "Protocols/Beaver.hpp"
|
||||
#include "Protocols/MalRepRingPrep.hpp"
|
||||
#include "Protocols/Hemi.hpp"
|
||||
#include "GC/ShareSecret.hpp"
|
||||
#include "GC/SemiHonestRepPrep.h"
|
||||
#include "Math/gfp.hpp"
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
ez::ezOptionParser opt;
|
||||
HemiOptions::singleton = {opt, argc, argv};
|
||||
DishonestMajorityFieldMachine<TemiShare, TemiShare, gf2n_short>(argc, argv,
|
||||
opt);
|
||||
}
|
||||
5
Makefile
5
Makefile
@@ -61,7 +61,7 @@ arithmetic: rep-ring rep-field shamir semi2k-party.x semi-party.x mascot sy
|
||||
binary: rep-bin yao semi-bin-party.x tinier-party.x tiny-party.x ccd-party.x malicious-ccd-party.x real-bmr
|
||||
|
||||
all: overdrive she-offline
|
||||
arithmetic: hemi-party.x soho-party.x gear
|
||||
arithmetic: semi-he gear
|
||||
|
||||
-include $(DEPS)
|
||||
include $(wildcard *.d static/*.d)
|
||||
@@ -87,6 +87,7 @@ she-offline: Check-Offline.x spdz2-offline.x
|
||||
|
||||
overdrive: simple-offline.x pairwise-offline.x cnc-offline.x gear
|
||||
gear: cowgear-party.x chaigear-party.x lowgear-party.x highgear-party.x
|
||||
semi-he: hemi-party.x soho-party.x temi-party.x
|
||||
|
||||
rep-field: malicious-rep-field-party.x replicated-field-party.x ps-rep-field-party.x
|
||||
|
||||
@@ -210,6 +211,7 @@ static/spdz2k-party.x: $(patsubst %.cpp,%.o,$(wildcard Machines/SPDZ2*.cpp))
|
||||
semi-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o
|
||||
semi2k-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o
|
||||
hemi-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT)
|
||||
temi-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT)
|
||||
soho-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT)
|
||||
cowgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(TINIER)
|
||||
chaigear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(TINIER)
|
||||
@@ -217,6 +219,7 @@ lowgear-party.x: $(FHEOFFLINE) $(TINIER) Protocols/CowGearOptions.o Protocols/Lo
|
||||
highgear-party.x: $(FHEOFFLINE) $(TINIER) Protocols/CowGearOptions.o Protocols/HighGearKeyGen.o
|
||||
atlas-party.x: GC/AtlasSecret.o
|
||||
static/hemi-party.x: $(FHEOBJS)
|
||||
static/temi-party.x: $(FHEOBJS)
|
||||
static/soho-party.x: $(FHEOBJS)
|
||||
static/cowgear-party.x: $(FHEOBJS)
|
||||
static/chaigear-party.x: $(FHEOBJS)
|
||||
|
||||
@@ -14,11 +14,6 @@ using namespace std;
|
||||
#include "Tools/random.h"
|
||||
#include "field_types.h"
|
||||
|
||||
template<class T> class ReplicatedMC;
|
||||
template<class T> class ReplicatedInput;
|
||||
template<class T> class ReplicatedPrivateOutput;
|
||||
template<class T> class Replicated;
|
||||
|
||||
template <class T, int L>
|
||||
class FixedVec
|
||||
{
|
||||
|
||||
@@ -233,7 +233,7 @@ inline void Zp_Data::Mont_Mult_(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t*
|
||||
if (mpn_cmp(ans+T,prA,T+1)>=0)
|
||||
{ mpn_sub_fixed_n<T>(z,ans+T,prA); }
|
||||
else
|
||||
{ inline_mpn_copyi(z,ans+T,T); }
|
||||
{ inline_mpn_copyi<T>(z,ans+T); }
|
||||
#else
|
||||
Mont_Mult(z, x, y, t);
|
||||
#endif
|
||||
|
||||
@@ -18,15 +18,21 @@ bool gf2n_<U>::useC;
|
||||
|
||||
word gf2n_short_table[256][256];
|
||||
|
||||
#define num_2_fields 6
|
||||
#define num_2_fields 7
|
||||
|
||||
/* Require
|
||||
* 2*(n-1)-64+t1<64
|
||||
*/
|
||||
int fields_2[num_2_fields][4] = {
|
||||
{4,1,0,0},{8,4,3,1},{28,1,0,0},{40,20,15,10},{63,1,0,0},{128,7,2,1},
|
||||
};
|
||||
|
||||
int fields_2[num_2_fields][4] =
|
||||
{
|
||||
{ 4, 1, 0, 0 },
|
||||
{ 8, 4, 3, 1 },
|
||||
{ 16, 5, 3, 1 },
|
||||
{ 28, 1, 0, 0 },
|
||||
{ 40, 20, 15, 10 },
|
||||
{ 63, 1, 0, 0 },
|
||||
{ 128, 7, 2, 1 },
|
||||
};
|
||||
|
||||
template<class U>
|
||||
void gf2n_<U>::init_tables()
|
||||
|
||||
@@ -24,6 +24,12 @@ inline void inline_mpn_copyi(mp_limb_t* dest, const mp_limb_t* src, mp_size_t si
|
||||
avx_memcpy(dest, src, size * sizeof(mp_limb_t));
|
||||
}
|
||||
|
||||
template<int N>
|
||||
inline void inline_mpn_copyi(mp_limb_t* dest, const mp_limb_t* src)
|
||||
{
|
||||
avx_memcpy<N * sizeof(mp_limb_t)>(dest, src);
|
||||
}
|
||||
|
||||
inline void debug_print(const char* name, const mp_limb_t* x, int n)
|
||||
{
|
||||
(void)name, (void)x, (void)n;
|
||||
|
||||
@@ -542,6 +542,7 @@ public:
|
||||
int other_player_num() const { return P.get_player(offset); }
|
||||
int num_players() const { return 2; }
|
||||
int get_offset() const { return offset; }
|
||||
Player& get_full_player() const { return P; }
|
||||
|
||||
void send(octetStream& o) const { P.send_to(P.get_player(offset), o); }
|
||||
void reverse_send(octetStream& o) const { P.send_to(P.get_player(-offset), o); }
|
||||
|
||||
@@ -206,6 +206,18 @@ void BaseOT::exec_base(bool new_receiver_inputs)
|
||||
receiver_outputs[i + j].set_byte(k, receiver_keys[j][k]);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef BASE_OT_DEBUG
|
||||
for (j = 0; j < 4; j++)
|
||||
for (k = 0; k < AES_BLK_SIZE; k++)
|
||||
{
|
||||
printf("%4d-th receiver key:", i+j);
|
||||
for (k = 0; k < HASHBYTES; k++) printf("%.2X", receiver_keys[j][k]);
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
printf("\n");
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@@ -244,12 +256,6 @@ void BaseOT::exec_base(bool new_receiver_inputs)
|
||||
for (k = 0; k < HASHBYTES; k++) printf("%.2X", sender_keys[1][j][k]);
|
||||
printf("\n");
|
||||
}
|
||||
if (ot_role & RECEIVER)
|
||||
{
|
||||
printf("%4d-th receiver key:", i+j);
|
||||
for (k = 0; k < HASHBYTES; k++) printf("%.2X", receiver_keys[j][k]);
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
printf("\n");
|
||||
|
||||
@@ -25,7 +25,7 @@ void Binary_File_IO::write_to_file(const string filename,
|
||||
|
||||
if (start_pos != -1)
|
||||
{
|
||||
long write_pos = start_pos * T::size();
|
||||
long write_pos = file_signature<T>().get_total_length() + start_pos * T::size();
|
||||
// fill with zeros if needed
|
||||
for (long i = outf.tellp(); i < write_pos; i++)
|
||||
outf.put(0);
|
||||
@@ -50,10 +50,13 @@ void Binary_File_IO::read_from_file(const string filename, vector< T >& buffer,
|
||||
inf.open(filename, ios::in | ios::binary);
|
||||
if (inf.fail()) { throw file_missing(filename, "Binary_File_IO.read_from_file expects this file to exist."); }
|
||||
|
||||
check_file_signature<T>(inf, filename).get_length();
|
||||
auto data_start = inf.tellg();
|
||||
|
||||
int size_in_bytes = T::size() * buffer.size();
|
||||
int n_read = 0;
|
||||
char read_buffer[size_in_bytes];
|
||||
inf.seekg(start_posn * T::size());
|
||||
inf.seekg(start_posn * T::size(), iostream::cur);
|
||||
do
|
||||
{
|
||||
inf.read(read_buffer + n_read, size_in_bytes - n_read);
|
||||
@@ -62,7 +65,9 @@ void Binary_File_IO::read_from_file(const string filename, vector< T >& buffer,
|
||||
if (inf.eof())
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "Got to EOF when reading from disk (expecting " << size_in_bytes << " bytes).";
|
||||
ss << "Got to EOF when reading from disk (expecting " << size_in_bytes
|
||||
<< " bytes from " << (long(data_start) + start_posn * T::size())
|
||||
<< ").";
|
||||
throw file_error(ss.str());
|
||||
}
|
||||
if (inf.fail())
|
||||
@@ -74,7 +79,7 @@ void Binary_File_IO::read_from_file(const string filename, vector< T >& buffer,
|
||||
}
|
||||
while (n_read < size_in_bytes);
|
||||
|
||||
end_posn = inf.tellg() / T::size();
|
||||
end_posn = (inf.tellg() - data_start) / T::size();
|
||||
assert (end_posn == start_posn + int(buffer.size()));
|
||||
|
||||
//Check if at end of file by getting 1 more char.
|
||||
|
||||
@@ -32,6 +32,15 @@ protected:
|
||||
Buffer<typename T::clear, typename T::clear> buffer;
|
||||
Timer timer;
|
||||
|
||||
// Send my inputs (not generally available)
|
||||
virtual void send_mine() { throw not_implemented(); }
|
||||
// Get share for next input of mine (not generally available)
|
||||
virtual T finalize_mine() { throw not_implemented(); }
|
||||
// Store share for next input from ``player`` from buffer ``o``
|
||||
// in ``target`` (not generally available)
|
||||
virtual void finalize_other(int, T&, octetStream&, int = -1)
|
||||
{ throw not_implemented(); }
|
||||
|
||||
public:
|
||||
vector<octetStream> os;
|
||||
int values_input;
|
||||
@@ -61,18 +70,12 @@ public:
|
||||
/// Schedule input from other player
|
||||
virtual void add_other(int player, int n_bits = -1) = 0;
|
||||
/// Schedule input from all players
|
||||
void add_from_all(const clear& input, int n_bits = -1);
|
||||
void add_from_all(const typename T::open_type& input, int n_bits = -1);
|
||||
|
||||
/// Send my inputs
|
||||
virtual void send_mine() = 0;
|
||||
/// Run input protocol for all players
|
||||
virtual void exchange();
|
||||
|
||||
/// Get share for next input of mine
|
||||
virtual T finalize_mine() = 0;
|
||||
/// Store share for next input from ``player`` from buffer ``o`` in ``target``
|
||||
virtual void finalize_other(int player, T& target, octetStream& o, int n_bits = -1) = 0;
|
||||
/// Get share for next input from ``player`
|
||||
/// Get share for next input from ``player``
|
||||
virtual T finalize(int player, int n_bits = -1);
|
||||
|
||||
void raw_input(SubProcessor<T>& proc, const vector<int>& args, int size);
|
||||
|
||||
@@ -113,7 +113,7 @@ void Input<T>::add_other(int player, int)
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void InputBase<T>::add_from_all(const clear& input, int n_bits)
|
||||
void InputBase<T>::add_from_all(const typename T::open_type& input, int n_bits)
|
||||
{
|
||||
for (int i = 0; i < P->num_players(); i++)
|
||||
if (i == P->my_num())
|
||||
|
||||
@@ -106,6 +106,7 @@ enum
|
||||
MATMULSM = 0xAB,
|
||||
CONV2DS = 0xAC,
|
||||
CHECK = 0xAF,
|
||||
PRIVATEOUTPUT = 0xAD,
|
||||
// Data access
|
||||
TRIPLE = 0x50,
|
||||
BIT = 0x51,
|
||||
@@ -127,6 +128,7 @@ enum
|
||||
INPUTMIXEDREG = 0xF3,
|
||||
RAWINPUT = 0xF4,
|
||||
INPUTPERSONAL = 0xF5,
|
||||
SENDPERSONAL = 0xF6,
|
||||
STARTINPUT = 0x61,
|
||||
STOPINPUT = 0x62,
|
||||
READSOCKETC = 0x63,
|
||||
|
||||
@@ -200,14 +200,17 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case USE:
|
||||
case USE_INP:
|
||||
case USE_EDABIT:
|
||||
case DIGESTC:
|
||||
case INPUTMASK:
|
||||
case GINPUTMASK:
|
||||
get_ints(r, s, 2);
|
||||
n = get_int(s);
|
||||
break;
|
||||
case STARTPRIVATEOUTPUT:
|
||||
case GSTARTPRIVATEOUTPUT:
|
||||
case STOPPRIVATEOUTPUT:
|
||||
case GSTOPPRIVATEOUTPUT:
|
||||
case DIGESTC:
|
||||
get_ints(r, s, 2);
|
||||
n = get_int(s);
|
||||
break;
|
||||
throw runtime_error("two-stage private output not supported any more");
|
||||
case USE_MATMUL:
|
||||
get_ints(r, s, 3);
|
||||
n = get_int(s);
|
||||
@@ -237,8 +240,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case PRINTREGB:
|
||||
case GPRINTREG:
|
||||
case LDINT:
|
||||
case INPUTMASK:
|
||||
case GINPUTMASK:
|
||||
case INV2M:
|
||||
case CONDPRINTSTR:
|
||||
case CONDPRINTSTRB:
|
||||
@@ -290,6 +291,8 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case RAWINPUT:
|
||||
case GRAWINPUT:
|
||||
case INPUTPERSONAL:
|
||||
case SENDPERSONAL:
|
||||
case PRIVATEOUTPUT:
|
||||
case TRUNC_PR:
|
||||
case RUN_TAPE:
|
||||
num_var_args = get_int(s);
|
||||
@@ -599,6 +602,7 @@ int BaseInstruction::get_reg_type() const
|
||||
case PUBINPUT:
|
||||
case FLOATOUTPUT:
|
||||
case READSOCKETC:
|
||||
case PRIVATEOUTPUT:
|
||||
return CINT;
|
||||
default:
|
||||
if (is_gf2n_instruction())
|
||||
@@ -738,10 +742,16 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
|
||||
skip = 1;
|
||||
break;
|
||||
case INPUTPERSONAL:
|
||||
case PRIVATEOUTPUT:
|
||||
size_offset = -2;
|
||||
offset = 2;
|
||||
skip = 4;
|
||||
break;
|
||||
case SENDPERSONAL:
|
||||
size_offset = -2;
|
||||
offset = 2;
|
||||
skip = 5;
|
||||
break;
|
||||
case READSOCKETS:
|
||||
case READSOCKETC:
|
||||
case READSOCKETINT:
|
||||
@@ -939,13 +949,11 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
break;
|
||||
case INPUTMASK:
|
||||
Procp.DataF.get_input(Proc.get_Sp_ref(r[0]), Proc.temp.rrp, n);
|
||||
if (n == Proc.P.my_num())
|
||||
Proc.temp.rrp.output(Proc.private_output, false);
|
||||
Proc.write_Cp(r[1], Proc.temp.rrp);
|
||||
break;
|
||||
case GINPUTMASK:
|
||||
Proc2.DataF.get_input(Proc.get_S2_ref(r[0]), Proc.temp.ans2, n);
|
||||
if (n == Proc.P.my_num())
|
||||
Proc.temp.ans2.output(Proc.private_output, false);
|
||||
Proc.write_C2(r[1], Proc.temp.ans2);
|
||||
break;
|
||||
case INPUT:
|
||||
sint::Input::template input<IntInput<typename sint::clear>>(Proc.Procp, start, size);
|
||||
@@ -974,6 +982,12 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
case INPUTPERSONAL:
|
||||
Proc.Procp.input_personal(start);
|
||||
return;
|
||||
case SENDPERSONAL:
|
||||
Proc.Procp.send_personal(start);
|
||||
return;
|
||||
case PRIVATEOUTPUT:
|
||||
Proc.Procp.private_output(start);
|
||||
return;
|
||||
// Note: Fp version has different semantics for NOTC than GNOTC
|
||||
case NOTC:
|
||||
to_bigint(Proc.temp.aa, Proc.read_Cp(r[1]));
|
||||
@@ -1202,18 +1216,6 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
Proc.binary_output.write((char*) &tmp, sizeof(double));
|
||||
}
|
||||
break;
|
||||
case STARTPRIVATEOUTPUT:
|
||||
Proc.privateOutputp.start(n,r[0],r[1]);
|
||||
break;
|
||||
case GSTARTPRIVATEOUTPUT:
|
||||
Proc.privateOutput2.start(n,r[0],r[1]);
|
||||
break;
|
||||
case STOPPRIVATEOUTPUT:
|
||||
Proc.privateOutputp.stop(n,r[0],r[1]);
|
||||
break;
|
||||
case GSTOPPRIVATEOUTPUT:
|
||||
Proc.privateOutput2.stop(n,r[0],r[1]);
|
||||
break;
|
||||
case PREP:
|
||||
Procp.DataF.get(Proc.Procp.get_S(), r, start, size);
|
||||
return;
|
||||
|
||||
@@ -97,12 +97,19 @@ Machine<sint, sgf2n>::Machine(int my_number, Names& playerNames,
|
||||
// initialize persistence if necessary
|
||||
for (auto& prog : progs)
|
||||
{
|
||||
if (prog.writes_persistance)
|
||||
if (prog.writes_persistence)
|
||||
{
|
||||
string filename = Binary_File_IO::filename(my_number);
|
||||
ifstream pers(filename);
|
||||
if (pers.fail())
|
||||
ofstream pers(filename, ios::binary);
|
||||
try
|
||||
{
|
||||
check_file_signature<sint>(pers, filename);
|
||||
}
|
||||
catch (signature_mismatch&)
|
||||
{
|
||||
ofstream pers(filename, ios::binary);
|
||||
file_signature<sint>().output(pers);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -418,12 +425,14 @@ void Machine<sint, sgf2n>::run()
|
||||
cerr << "Full broadcast" << endl;
|
||||
#endif
|
||||
|
||||
#ifdef CHOP_MEMORY
|
||||
// Reduce memory size to speed up
|
||||
unsigned max_size = 1 << 20;
|
||||
if (M2.size_s() > max_size)
|
||||
M2.resize_s(max_size);
|
||||
if (Mp.size_s() > max_size)
|
||||
Mp.resize_s(max_size);
|
||||
#endif
|
||||
|
||||
// Write out the memory to use next time
|
||||
ofstream outf(memory_filename(), ios::out | ios::binary);
|
||||
|
||||
@@ -44,9 +44,9 @@ class Memory
|
||||
static void check_index(const vector<U>& M, size_t i)
|
||||
{
|
||||
(void) M, (void) i;
|
||||
#ifdef NO_CHECK_INDEX
|
||||
#ifndef NO_CHECK_INDEX
|
||||
if (i >= M.size())
|
||||
throw overflow("memory", i, M.size());
|
||||
throw overflow(U::type_string() + " memory", i, M.size());
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,9 @@ void MemoryPart<T>::minimum_size(size_t size)
|
||||
{
|
||||
if (size > this->size())
|
||||
this->resize(size);
|
||||
#ifdef DEBUG_MEMORY_SIZE
|
||||
cerr << T::type_string() << " memory has now size " << this->size() << endl;
|
||||
#endif
|
||||
}
|
||||
catch (bad_alloc&)
|
||||
{
|
||||
@@ -58,9 +61,9 @@ istream& operator>>(istream& s,Memory<T>& M)
|
||||
int len;
|
||||
|
||||
s >> len;
|
||||
M.resize_s(len);
|
||||
M.MS.minimum_size(len);
|
||||
s >> len;
|
||||
M.resize_c(len);
|
||||
M.MC.minimum_size(len);
|
||||
s.seekg(1, istream::cur);
|
||||
|
||||
for (unsigned int i=0; i<M.MS.size(); i++)
|
||||
|
||||
@@ -17,16 +17,16 @@ class PrivateOutput
|
||||
typedef typename T::open_type open_type;
|
||||
|
||||
SubProcessor<T>& proc;
|
||||
typename T::MAC_Check MC;
|
||||
deque<open_type> masks;
|
||||
|
||||
public:
|
||||
PrivateOutput(SubProcessor<T>& proc) : proc(proc) { };
|
||||
PrivateOutput(SubProcessor<T>& proc);
|
||||
~PrivateOutput();
|
||||
|
||||
void start(int player, int target, int source);
|
||||
void stop(int player, int dest, int source);
|
||||
|
||||
T start(int player, const T& source);
|
||||
typename T::clear stop(int player, const typename T::clear& masked);
|
||||
void prepare_sending(const T& source, int player);
|
||||
void exchange();
|
||||
typename T::clear finalize(int player);
|
||||
};
|
||||
|
||||
#endif /* PROCESSOR_PRIVATEOUTPUT_H_ */
|
||||
|
||||
@@ -7,13 +7,21 @@
|
||||
#include "Processor.h"
|
||||
|
||||
template<class T>
|
||||
void PrivateOutput<T>::start(int player, int target, int source)
|
||||
PrivateOutput<T>::PrivateOutput(SubProcessor<T>& proc) :
|
||||
proc(proc), MC(proc.MC.get_alphai())
|
||||
{
|
||||
proc.get_S_ref(target) = start(player, proc.get_S_ref(source));
|
||||
MC.init_open(proc.P);
|
||||
MC.set_prep(proc.DataF);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
T PrivateOutput<T>::start(int player, const T& source)
|
||||
PrivateOutput<T>::~PrivateOutput()
|
||||
{
|
||||
MC.Check(proc.P);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void PrivateOutput<T>::prepare_sending(const T& source, int player)
|
||||
{
|
||||
assert (player < proc.P.num_players());
|
||||
open_type mask;
|
||||
@@ -24,26 +32,25 @@ T PrivateOutput<T>::start(int player, const T& source)
|
||||
if (player == proc.P.my_num())
|
||||
masks.push_back(mask);
|
||||
|
||||
return res;
|
||||
MC.prepare_open(res);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void PrivateOutput<T>::stop(int player, int dest, int source)
|
||||
void PrivateOutput<T>::exchange()
|
||||
{
|
||||
auto& value = proc.get_C_ref(dest);
|
||||
value = stop(player, proc.get_C_ref(source));
|
||||
if (proc.Proc)
|
||||
value.output(proc.Proc->private_output, false);
|
||||
MC.exchange(proc.P);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
typename T::clear PrivateOutput<T>::stop(int player, const typename T::clear& source)
|
||||
typename T::clear PrivateOutput<T>::finalize(int player)
|
||||
{
|
||||
typename T::clear value;
|
||||
auto res = MC.finalize_open();
|
||||
|
||||
if (player == proc.P.my_num())
|
||||
{
|
||||
value = source - masks.front();
|
||||
res -= masks.front();
|
||||
masks.pop_front();
|
||||
}
|
||||
return value;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -71,6 +71,8 @@ public:
|
||||
void conv2ds(const Instruction& instruction);
|
||||
|
||||
void input_personal(const vector<int>& args);
|
||||
void send_personal(const vector<int>& args);
|
||||
void private_output(const vector<int>& args);
|
||||
|
||||
CheckVector<T>& get_S()
|
||||
{
|
||||
@@ -110,7 +112,6 @@ public:
|
||||
ifstream private_input;
|
||||
ifstream public_input;
|
||||
ofstream public_output;
|
||||
ofstream private_output;
|
||||
ofstream binary_output;
|
||||
|
||||
int sent, rounds;
|
||||
@@ -172,9 +173,6 @@ class Processor : public ArithmeticProcessor
|
||||
SubProcessor<sgf2n> Proc2;
|
||||
SubProcessor<sint> Procp;
|
||||
|
||||
typename sgf2n::PrivateOutput privateOutput2;
|
||||
typename sint::PrivateOutput privateOutputp;
|
||||
|
||||
unsigned int PC;
|
||||
TempVars<sint, sgf2n> temp;
|
||||
|
||||
|
||||
@@ -4,9 +4,8 @@
|
||||
#include "Processor/Processor.h"
|
||||
#include "Processor/Program.h"
|
||||
#include "GC/square64.h"
|
||||
#include "SpecificPrivateOutput.h"
|
||||
|
||||
#include "Protocols/ReplicatedInput.hpp"
|
||||
#include "Protocols/ReplicatedPrivateOutput.hpp"
|
||||
#include "Processor/ProcessorBase.hpp"
|
||||
#include "GC/Processor.hpp"
|
||||
#include "GC/ShareThread.hpp"
|
||||
@@ -63,7 +62,6 @@ Processor<sint, sgf2n>::Processor(int thread_num,Player& P,
|
||||
share_thread(DataF.DataFb, P, machine.get_bit_mac_key()),
|
||||
Procb(machine.bit_memories),
|
||||
Proc2(*this,MC2,DataF.DataF2,P),Procp(*this,MCp,DataF.DataFp,P),
|
||||
privateOutput2(Proc2),privateOutputp(Procp),
|
||||
external_clients(P.my_num()),
|
||||
binary_file_io(Binary_File_IO())
|
||||
{
|
||||
@@ -74,7 +72,6 @@ Processor<sint, sgf2n>::Processor(int thread_num,Player& P,
|
||||
private_input_filename = (get_filename(PREP_DIR "Private-Input-",true));
|
||||
private_input.open(private_input_filename.c_str());
|
||||
public_output.open(get_filename(PREP_DIR "Public-Output-",true).c_str(), ios_base::out);
|
||||
private_output.open(get_filename(PREP_DIR "Private-Output-",true).c_str(), ios_base::out);
|
||||
binary_output.open(
|
||||
get_parameterized_filename(P.my_num(), thread_num,
|
||||
PREP_DIR "Binary-Output"), ios_base::out);
|
||||
@@ -654,6 +651,37 @@ void SubProcessor<T>::input_personal(const vector<int>& args)
|
||||
S[args[i + 2] + j] = input.finalize(args[i + 1]);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void SubProcessor<T>::private_output(const vector<int>& args)
|
||||
{
|
||||
typename T::PrivateOutput output(*this);
|
||||
for (size_t i = 0; i < args.size(); i += 4)
|
||||
for (int j = 0; j < args[i]; j++)
|
||||
{
|
||||
int player = args[i + 1];
|
||||
output.prepare_sending(S.at(args[i + 3] + j), player);
|
||||
}
|
||||
output.exchange();
|
||||
for (size_t i = 0; i < args.size(); i += 4)
|
||||
for (int j = 0; j < args[i]; j++)
|
||||
C.at(args[i + 2] + j) = output.finalize(args[i + 1]);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void SubProcessor<T>::send_personal(const vector<int>& args)
|
||||
{
|
||||
octetStreams to_send(P), to_receive(P);
|
||||
for (size_t i = 0; i < args.size(); i += 5)
|
||||
if (args[i + 3] == P.my_num())
|
||||
for (int j = 0; j < args[i]; j++)
|
||||
C[args[i + 4] + j].pack(to_send[args[i + 1]]);
|
||||
P.send_receive_all(to_send, to_receive);
|
||||
for (size_t i = 0; i < args.size(); i += 5)
|
||||
if (args[i + 1] == P.my_num())
|
||||
for (int j = 0; j < args[i]; j++)
|
||||
C[args[i + 2] + j].unpack(to_receive[args[i + 3]]);
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
typename sint::clear Processor<sint, sgf2n>::get_inverse2(unsigned m)
|
||||
{
|
||||
|
||||
@@ -23,7 +23,7 @@ void Program::compute_constants()
|
||||
max_mem[reg_type] = max(max_mem[reg_type],
|
||||
p[i].get_mem(RegType(reg_type)));
|
||||
}
|
||||
writes_persistance |= p[i].opcode == WRITEFILESHARE;
|
||||
writes_persistence |= p[i].opcode == WRITEFILESHARE;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -30,10 +30,10 @@ class Program
|
||||
|
||||
public:
|
||||
|
||||
bool writes_persistance;
|
||||
bool writes_persistence;
|
||||
|
||||
Program(int nplayers) : offline_data_used(nplayers),
|
||||
unknown_usage(false), writes_persistance(false)
|
||||
unknown_usage(false), writes_persistence(false)
|
||||
{ compute_constants(); }
|
||||
|
||||
// Read in a program
|
||||
|
||||
65
Processor/SpecificPrivateOutput.h
Normal file
65
Processor/SpecificPrivateOutput.h
Normal file
@@ -0,0 +1,65 @@
|
||||
/*
|
||||
* SpecificPrivateOutput.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef PROCESSOR_SPECIFICPRIVATEOUTPUT_H_
|
||||
#define PROCESSOR_SPECIFICPRIVATEOUTPUT_H_
|
||||
|
||||
template<class T>
|
||||
class SpecificPrivateOutput
|
||||
{
|
||||
deque<T> secrets;
|
||||
vector<typename T::PO*> pos;
|
||||
Player& P;
|
||||
vector<bool> active;
|
||||
|
||||
public:
|
||||
SpecificPrivateOutput(SubProcessor<T>& proc) :
|
||||
P(proc.P)
|
||||
{
|
||||
for (int i = 0; i < P.num_players(); i++)
|
||||
pos.push_back(new typename T::PO(proc.P));
|
||||
active.resize(P.num_players());
|
||||
}
|
||||
|
||||
~SpecificPrivateOutput()
|
||||
{
|
||||
for (auto& x : pos)
|
||||
delete x;
|
||||
}
|
||||
|
||||
void prepare_sending(const T& secret, int player)
|
||||
{
|
||||
pos[player]->prepare_sending(secret, player);
|
||||
if (P.my_num() == player)
|
||||
secrets.push_back(secret);
|
||||
active[player] = true;
|
||||
}
|
||||
|
||||
void exchange()
|
||||
{
|
||||
for (int i = 0; i < this->P.num_players(); i++)
|
||||
if (active[i])
|
||||
{
|
||||
if (i == this->P.my_num())
|
||||
pos[i]->receive();
|
||||
else
|
||||
pos[i]->send(i);
|
||||
}
|
||||
}
|
||||
|
||||
typename T::clear finalize(int player)
|
||||
{
|
||||
if (player == this->P.my_num())
|
||||
{
|
||||
T secret = secrets.front();
|
||||
secrets.pop_front();
|
||||
return pos[player]->finalize(secret);
|
||||
}
|
||||
else
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
#endif /* PROCESSOR_SPECIFICPRIVATEOUTPUT_H_ */
|
||||
100
Programs/Source/falcon_alex.mpc
Normal file
100
Programs/Source/falcon_alex.mpc
Normal file
@@ -0,0 +1,100 @@
|
||||
from Compiler.ml import keras
|
||||
import Compiler.ml as tf
|
||||
|
||||
try:
|
||||
n_epochs = int(program.args[1])
|
||||
except (ValueError, IndexError):
|
||||
n_epochs = 10
|
||||
|
||||
try:
|
||||
batch_size = int(program.args[2])
|
||||
except (ValueError, IndexError):
|
||||
batch_size = 128
|
||||
|
||||
try:
|
||||
n_threads = int(program.args[3])
|
||||
except (ValueError, IndexError):
|
||||
n_threads = 36
|
||||
|
||||
#Instantiation
|
||||
AlexNet = []
|
||||
|
||||
padding = 'same'
|
||||
batchnorm = 'batchnorm' in program.args
|
||||
|
||||
#1st Convolutional Layer
|
||||
AlexNet.append(keras.layers.Conv2D(filters=96, input_shape=(32,32,3), kernel_size=(11,11), strides=(4,4), padding=9))
|
||||
AlexNet.append(keras.layers.Activation('relu'))
|
||||
AlexNet.append(keras.layers.MaxPooling2D(pool_size=3, strides=(2,2)))
|
||||
if batchnorm:
|
||||
AlexNet.append(keras.layers.BatchNormalization())
|
||||
|
||||
#2nd Convolutional Layer
|
||||
AlexNet.append(keras.layers.Conv2D(filters=256, kernel_size=(5, 5), strides=(1,1), padding=1))
|
||||
AlexNet.append(keras.layers.Activation('relu'))
|
||||
if batchnorm:
|
||||
AlexNet.append(keras.layers.BatchNormalization())
|
||||
AlexNet.append(keras.layers.MaxPooling2D(pool_size=(2,2), strides=1))
|
||||
|
||||
#3rd Convolutional Layer
|
||||
AlexNet.append(keras.layers.Conv2D(filters=384, kernel_size=(3,3), strides=(1,1), padding=1))
|
||||
AlexNet.append(keras.layers.Activation('relu'))
|
||||
|
||||
#4th Convolutional Layer
|
||||
AlexNet.append(keras.layers.Conv2D(filters=384, kernel_size=(3,3), strides=(1,1), padding=1))
|
||||
AlexNet.append(keras.layers.Activation('relu'))
|
||||
|
||||
#5th Convolutional Layer
|
||||
AlexNet.append(keras.layers.Conv2D(filters=256, kernel_size=(3,3), strides=(1,1), padding=1))
|
||||
AlexNet.append(keras.layers.Activation('relu'))
|
||||
|
||||
#Passing it to a Fully Connected layer
|
||||
# 1st Fully Connected Layer
|
||||
AlexNet.append(keras.layers.Dense(256))
|
||||
AlexNet.append(keras.layers.Activation('relu'))
|
||||
|
||||
if 'dropout' in program.args:
|
||||
AlexNet.append(keras.layers.Dropout(0.5))
|
||||
|
||||
#2nd Fully Connected Layer
|
||||
AlexNet.append(keras.layers.Dense(256))
|
||||
AlexNet.append(keras.layers.Activation('relu'))
|
||||
|
||||
if 'dropout' in program.args:
|
||||
AlexNet.append(keras.layers.Dropout(0.5))
|
||||
|
||||
#Output Layer
|
||||
AlexNet.append(keras.layers.Dense(10))
|
||||
|
||||
|
||||
tf.set_n_threads(n_threads)
|
||||
program.options_from_args()
|
||||
sfix.set_precision_from_args(program, adapt_ring=True)
|
||||
|
||||
training_samples = MultiArray([50000, 32, 32, 3], sfix)
|
||||
training_labels = MultiArray([50000, 10], sint)
|
||||
|
||||
test_samples = MultiArray([10000, 32, 32, 3], sfix)
|
||||
test_labels = MultiArray([10000, 10], sint)
|
||||
|
||||
if 'no_acc' not in program.args:
|
||||
training_labels.input_from(0)
|
||||
training_samples.input_from(0)
|
||||
|
||||
test_labels.input_from(0)
|
||||
test_samples.input_from(0)
|
||||
|
||||
model = tf.keras.models.Sequential(AlexNet)
|
||||
|
||||
model.compile_by_args(program)
|
||||
|
||||
model.build(training_samples.sizes)
|
||||
model.summary()
|
||||
|
||||
opt = model.fit(
|
||||
training_samples,
|
||||
training_labels,
|
||||
epochs=n_epochs,
|
||||
batch_size=batch_size,
|
||||
validation_data=(test_samples, test_labels)
|
||||
)
|
||||
45
Programs/Source/keras_cifar_lenet.mpc
Normal file
45
Programs/Source/keras_cifar_lenet.mpc
Normal file
@@ -0,0 +1,45 @@
|
||||
# this trains LeNet on MNIST with a dropout layer
|
||||
# see https://github.com/csiro-mlai/mnist-mpc for data preparation
|
||||
|
||||
program.options_from_args()
|
||||
|
||||
training_samples = MultiArray([50000, 32, 32, 3], sfix)
|
||||
training_labels = MultiArray([50000, 10], sint)
|
||||
|
||||
test_samples = MultiArray([10000, 32, 32, 3], sfix)
|
||||
test_labels = MultiArray([10000, 10], sint)
|
||||
|
||||
training_labels.input_from(0)
|
||||
training_samples.input_from(0)
|
||||
|
||||
test_labels.input_from(0)
|
||||
test_samples.input_from(0)
|
||||
|
||||
from Compiler import ml
|
||||
tf = ml
|
||||
ml.set_n_threads(36)
|
||||
|
||||
layers = [
|
||||
tf.keras.layers.Conv2D(20, 5, 1, 'valid', activation='relu'),
|
||||
tf.keras.layers.MaxPooling2D(2),
|
||||
tf.keras.layers.Conv2D(50, 5, 1, 'valid', activation='relu'),
|
||||
tf.keras.layers.MaxPooling2D(2),
|
||||
tf.keras.layers.Flatten(),
|
||||
tf.keras.layers.Dropout(0.5),
|
||||
tf.keras.layers.Dense(500, activation='relu'),
|
||||
tf.keras.layers.Dense(10, activation='softmax')
|
||||
]
|
||||
|
||||
model = tf.keras.models.Sequential(layers)
|
||||
|
||||
optim = tf.keras.optimizers.Adam(amsgrad=True)
|
||||
|
||||
model.compile(optimizer=optim)
|
||||
|
||||
opt = model.fit(
|
||||
training_samples,
|
||||
training_labels,
|
||||
epochs=10,
|
||||
batch_size=128,
|
||||
validation_data=(test_samples, test_labels)
|
||||
)
|
||||
@@ -21,7 +21,8 @@ tf = ml
|
||||
layers = [
|
||||
tf.keras.layers.Flatten(),
|
||||
tf.keras.layers.Dense(128, activation='relu'),
|
||||
tf.keras.layers.Dense(128, activation='relu'),
|
||||
tf.keras.layers.Dense(128),
|
||||
tf.keras.layers.Activation('relu'),
|
||||
tf.keras.layers.Dense(10, activation='softmax')
|
||||
]
|
||||
|
||||
|
||||
@@ -20,8 +20,21 @@ tf = ml
|
||||
|
||||
layers = [
|
||||
tf.keras.layers.Conv2D(20, 5, 1, 'valid', activation='relu'),
|
||||
]
|
||||
|
||||
if 'batchnorm' in program.args:
|
||||
layers += [tf.keras.layers.BatchNormalization()]
|
||||
|
||||
layers += [
|
||||
tf.keras.layers.MaxPooling2D(2),
|
||||
tf.keras.layers.Conv2D(50, 5, 1, 'valid', activation='relu'),
|
||||
]
|
||||
|
||||
|
||||
if 'batchnorm' in program.args:
|
||||
layers += [tf.keras.layers.BatchNormalization()]
|
||||
|
||||
layers += [
|
||||
tf.keras.layers.MaxPooling2D(2),
|
||||
tf.keras.layers.Flatten(),
|
||||
tf.keras.layers.Dropout(0.5),
|
||||
|
||||
@@ -21,6 +21,8 @@ elif 'debug' in program.args:
|
||||
n_test = 100
|
||||
elif 'debug5000' in program.args:
|
||||
N = n_test = 5000
|
||||
elif 'mini' in program.args:
|
||||
N = n_test = 10
|
||||
else:
|
||||
N = 60000
|
||||
n_test = 10000
|
||||
@@ -39,6 +41,7 @@ except:
|
||||
batch_size = N
|
||||
|
||||
N = min(N, 10000)
|
||||
batch_size = min(batch_size, N)
|
||||
ml.Layer.back_batch_size = batch_size
|
||||
|
||||
try:
|
||||
@@ -71,6 +74,9 @@ else:
|
||||
ml.Dense(N, n_inner, n_inner, activation=activation, debug=debug_ml),
|
||||
ml.Dense(N, n_inner, 10, debug=debug_ml)]
|
||||
|
||||
if 'batchnorm' in program.args:
|
||||
layers.insert(1, ml.BatchNorm([N, n_inner]))
|
||||
|
||||
if 'dropout' in program.args:
|
||||
for i in range(len(layers) - 1, 0, -1):
|
||||
layers.insert(i, ml.Dropout(N, n_inner))
|
||||
|
||||
@@ -53,7 +53,7 @@ except:
|
||||
ml.Layer.back_batch_size = batch_size
|
||||
|
||||
layers = [
|
||||
ml.FixConv2d([n_examples, 28, 28, 1], (20, 5, 5, 1), (20,), [n_examples, 24, 24, 20], (1, 1), 'VALID'),
|
||||
ml.FixConv2d([n_examples, 28, 28, 1], (20, 5, 5, 1), (20,), [N, 24, 24, 20], (1, 1), 'VALID'),
|
||||
ml.MaxPool([N, 24, 24, 20]),
|
||||
ml.Relu([N, 12, 12, 20]),
|
||||
ml.FixConv2d([N, 12, 12, 20], (50, 5, 5, 20), (50,), [N, 8, 8, 50], (1, 1), 'VALID'),
|
||||
@@ -66,6 +66,12 @@ layers = [
|
||||
|
||||
layers += [ml.MultiOutput.from_args(program, n_examples, 10)]
|
||||
|
||||
if 'batchnorm' in program.args:
|
||||
for arg in program.args:
|
||||
assert not arg.startswith('dropout')
|
||||
layers.insert(4, ml.BatchNorm([N, 8, 8, 50], args=program.args))
|
||||
layers.insert(1, ml.BatchNorm([N, 24, 24, 20], args=program.args))
|
||||
|
||||
if 'dropout' in program.args or 'dropout2' in program.args:
|
||||
layers.insert(8, ml.Dropout(N, 500))
|
||||
elif 'dropout.25' in program.args:
|
||||
|
||||
@@ -85,6 +85,12 @@ void Atlas<T>::exchange()
|
||||
resharing.add_mine(e);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < min(masks.size(), size_t(P.num_players())); i++)
|
||||
{
|
||||
int j = (base_king + i) % P.num_players();
|
||||
resharing.add_sender(j);
|
||||
}
|
||||
|
||||
resharing.exchange();
|
||||
}
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ HemiMatrixPrep<T>& Hemi<T>::get_matrix_prep(const array<int, 3>& dims,
|
||||
if (matrix_preps.find(dims) == matrix_preps.end())
|
||||
matrix_preps.insert({dims,
|
||||
new HemiMatrixPrep<T>(dims[0], dims[1], dims[2],
|
||||
dynamic_cast<HemiPrep<T>&>(processor.DataF))});
|
||||
dynamic_cast<typename T::LivePrep&>(processor.DataF))});
|
||||
return *matrix_preps.at(dims);
|
||||
}
|
||||
|
||||
|
||||
@@ -18,17 +18,18 @@ template<class T>
|
||||
class HemiMatrixPrep : public BufferPrep<ShareMatrix<T>>
|
||||
{
|
||||
typedef BufferPrep<ShareMatrix<T>> super;
|
||||
typedef typename T::LivePrep LivePrep;
|
||||
|
||||
int n_rows, n_inner, n_cols;
|
||||
bool swapped;
|
||||
DataPositions* usage;
|
||||
|
||||
HemiPrep<T>* prep;
|
||||
LivePrep* prep;
|
||||
|
||||
HemiMatrixPrep(const HemiMatrixPrep&) = delete;
|
||||
|
||||
public:
|
||||
HemiMatrixPrep(int n_rows, int n_inner, int n_cols, HemiPrep<T>& prep) :
|
||||
HemiMatrixPrep(int n_rows, int n_inner, int n_cols, LivePrep& prep) :
|
||||
super(*(usage = new DataPositions)), n_rows(n_rows), n_inner(n_inner),
|
||||
n_cols(n_cols), prep(&prep)
|
||||
{
|
||||
|
||||
@@ -87,11 +87,10 @@ void HemiMatrixPrep<T>::buffer_triples()
|
||||
|
||||
assert(prep);
|
||||
auto& multipliers = prep->get_multipliers();
|
||||
assert(prep->pairwise_machine);
|
||||
auto& FTD = prep->pairwise_machine->setup_p.FieldD;
|
||||
auto& pk = prep->pairwise_machine->pk;
|
||||
auto& FTD = prep->get_FTD();
|
||||
auto& pk = prep->get_pk();
|
||||
int n_matrices = FTD.num_slots() / n_rows;
|
||||
#ifdef VERBOSE
|
||||
#ifdef VERBOSE_HE
|
||||
fprintf(stderr, "creating %d %dx%d * %dx%d triples\n", n_matrices, n_rows, n_inner,
|
||||
n_inner, n_cols);
|
||||
fflush(stderr);
|
||||
@@ -103,20 +102,23 @@ void HemiMatrixPrep<T>::buffer_triples()
|
||||
AddableVector<ValueMatrix<gfpvar>> C(n_matrices);
|
||||
MatrixRandMultJob job(C, A, B);
|
||||
|
||||
if (BaseMachine::thread_num == 0 and BaseMachine::has_singleton())
|
||||
if (T::local_mul)
|
||||
{
|
||||
auto& queues = BaseMachine::s().queues;
|
||||
int start = queues.distribute(job, n_matrices);
|
||||
job.begin = start;
|
||||
job.end = n_matrices;
|
||||
matrix_rand_mult(job);
|
||||
queues.wrap_up(job);
|
||||
}
|
||||
else
|
||||
{
|
||||
job.begin = 0;
|
||||
job.end = n_matrices;
|
||||
matrix_rand_mult(job);
|
||||
if (BaseMachine::thread_num == 0 and BaseMachine::has_singleton())
|
||||
{
|
||||
auto& queues = BaseMachine::s().queues;
|
||||
int start = queues.distribute(job, n_matrices);
|
||||
job.begin = start;
|
||||
job.end = n_matrices;
|
||||
matrix_rand_mult(job);
|
||||
queues.wrap_up(job);
|
||||
}
|
||||
else
|
||||
{
|
||||
job.begin = 0;
|
||||
job.end = n_matrices;
|
||||
matrix_rand_mult(job);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef VERBOSE_HE
|
||||
@@ -130,26 +132,35 @@ void HemiMatrixPrep<T>::buffer_triples()
|
||||
assert(prep->proc);
|
||||
auto& P = prep->proc->P;
|
||||
|
||||
Bundle<octetStream> bundle(P);
|
||||
bundle.mine.store(diag.ciphertexts);
|
||||
P.unchecked_broadcast(bundle);
|
||||
vector<vector<Ciphertext>> others_ct;
|
||||
for (auto& os : bundle)
|
||||
|
||||
if (T::local_mul or OnlineOptions::singleton.direct)
|
||||
{
|
||||
others_ct.push_back({});
|
||||
os.get(others_ct.back(), Ciphertext(pk));
|
||||
Bundle<octetStream> bundle(P);
|
||||
bundle.mine.store(diag.ciphertexts);
|
||||
P.unchecked_broadcast(bundle);
|
||||
for (auto& os : bundle)
|
||||
{
|
||||
others_ct.push_back({});
|
||||
os.get(others_ct.back(), Ciphertext(pk));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
others_ct.push_back(diag.ciphertexts);
|
||||
TreeSum<Ciphertext>().run(others_ct[0], P);
|
||||
}
|
||||
|
||||
for (int j = 0; j < n_cols; j++)
|
||||
for (auto m : multipliers)
|
||||
{
|
||||
#ifdef VERBOSE
|
||||
#ifdef VERBOSE_HE
|
||||
fprintf(stderr, "column %d with party offset %d at %f\n", j,
|
||||
m->get_offset(), timer.elapsed());
|
||||
fflush(stderr);
|
||||
#endif
|
||||
Ciphertext C(pk);
|
||||
auto& multiplicands = others_ct[P.get_player(-m->get_offset())];
|
||||
auto& multiplicands = m->get_multiplicands(others_ct, pk);
|
||||
if (BaseMachine::thread_num == 0 and BaseMachine::has_singleton())
|
||||
{
|
||||
auto& queues = BaseMachine::s().queues;
|
||||
@@ -160,7 +171,7 @@ void HemiMatrixPrep<T>::buffer_triples()
|
||||
CipherPlainMultJob job(products, multiplicands, multiplicands2, true);
|
||||
int start = queues.distribute(job, n_inner);
|
||||
#ifdef VERBOSE_HE
|
||||
fprintf(stderr, "from %d in central thread\n", start);
|
||||
fprintf(stderr, "from %d in central thread at %f\n", start, timer.elapsed());
|
||||
fflush(stderr);
|
||||
#endif
|
||||
for (int i = start; i < n_inner; i++)
|
||||
@@ -185,7 +196,10 @@ void HemiMatrixPrep<T>::buffer_triples()
|
||||
m->add(products[j], C, BOTH, n_inner);
|
||||
}
|
||||
|
||||
C += diag.dediag(products, n_matrices);
|
||||
if (T::local_mul)
|
||||
C += diag.dediag(products, n_matrices);
|
||||
else
|
||||
C = diag.dediag(products, n_matrices);
|
||||
|
||||
for (int i = 0; i < n_matrices; i++)
|
||||
if (swapped)
|
||||
|
||||
@@ -34,6 +34,9 @@ public:
|
||||
static void basic_setup(Player& P);
|
||||
static void teardown();
|
||||
|
||||
static const FHE_PK& get_pk();
|
||||
static const FD& get_FTD();
|
||||
|
||||
HemiPrep(SubProcessor<T>* proc, DataPositions& usage) :
|
||||
BufferPrep<T>(usage),
|
||||
BitPrep<T>(proc, usage), RingPrep<T>(proc, usage),
|
||||
|
||||
@@ -34,6 +34,20 @@ void HemiPrep<T>::basic_setup(Player& P)
|
||||
T::clear::template init<typename FD::T>();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
const FHE_PK& HemiPrep<T>::get_pk()
|
||||
{
|
||||
assert(pairwise_machine);
|
||||
return pairwise_machine->pk;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
const typename T::clear::FD& HemiPrep<T>::get_FTD()
|
||||
{
|
||||
assert(pairwise_machine);
|
||||
return pairwise_machine->setup<FD>().FieldD;
|
||||
}
|
||||
|
||||
|
||||
template<class T>
|
||||
HemiPrep<T>::~HemiPrep()
|
||||
|
||||
@@ -27,6 +27,7 @@ public:
|
||||
typedef HemiPrep<This> LivePrep;
|
||||
|
||||
static const bool needs_ot = false;
|
||||
static const bool local_mul = true;
|
||||
static true_type triple_matmul;
|
||||
|
||||
HemiShare()
|
||||
|
||||
@@ -140,12 +140,12 @@ void KeyGenProtocol<X, L>::output_to(int player, vector<open_type>& opened,
|
||||
vector<share_type>& shares)
|
||||
{
|
||||
PrivateOutput<share_type> po(*proc);
|
||||
vector<share_type> masked;
|
||||
for (auto& share : shares)
|
||||
masked.push_back(po.start(player, share));
|
||||
MC->POpen(opened, masked, P);
|
||||
po.prepare_sending(share, player);
|
||||
po.exchange();
|
||||
opened.resize(shares.size());
|
||||
for (auto& x : opened)
|
||||
x = po.stop(player, x);
|
||||
x = po.finalize(player);
|
||||
}
|
||||
|
||||
template<int L>
|
||||
|
||||
@@ -52,6 +52,7 @@ public:
|
||||
virtual ~TreeSum();
|
||||
|
||||
void run(vector<T>& values, const Player& P);
|
||||
T run(const T& value, const Player& P);
|
||||
|
||||
octetStream& get_buffer() { return os; }
|
||||
|
||||
@@ -210,6 +211,14 @@ void TreeSum<T>::run(vector<T>& values, const Player& P)
|
||||
finish(values, P);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
T TreeSum<T>::run(const T& value, const Player& P)
|
||||
{
|
||||
vector<T> values = {value};
|
||||
run(values, P);
|
||||
return values[0];
|
||||
}
|
||||
|
||||
template<class T>
|
||||
size_t TreeSum<T>::report_size(ReportType type)
|
||||
{
|
||||
@@ -244,14 +253,6 @@ void add_openings(vector<T>& values, const Player& P, int sum_players, int last_
|
||||
MC.player_timers[sender].start();
|
||||
P.wait_receive(sender, oss[j]);
|
||||
MC.player_timers[sender].stop();
|
||||
if ((unsigned)oss[j].get_length() < values.size() * T::size())
|
||||
{
|
||||
stringstream ss;
|
||||
ss << "Not enough information received, expected "
|
||||
<< values.size() * T::size() << " bytes, got "
|
||||
<< oss[j].get_length();
|
||||
throw Processor_Error(ss.str());
|
||||
}
|
||||
MC.timers[SUM].start();
|
||||
for (unsigned int i=0; i<values.size(); i++)
|
||||
{
|
||||
|
||||
@@ -127,6 +127,7 @@ void MAC_Check_<U>::Check(const Player& P)
|
||||
auto& vals = this->vals;
|
||||
auto& macs = this->macs;
|
||||
auto& popen_cnt = this->popen_cnt;
|
||||
assert(int(macs.size()) <= popen_cnt);
|
||||
|
||||
if (popen_cnt < 10)
|
||||
{
|
||||
|
||||
@@ -12,6 +12,8 @@ using namespace std;
|
||||
#include "Networking/Player.h"
|
||||
#include "Tools/PointerVector.h"
|
||||
|
||||
template<class T> class Preprocessing;
|
||||
|
||||
/**
|
||||
* Abstract base class for opening protocols
|
||||
*/
|
||||
@@ -61,6 +63,8 @@ public:
|
||||
virtual void CheckFor(const typename T::open_type& value, const vector<T>& shares, const Player& P);
|
||||
|
||||
virtual const Player& get_check_player(const Player& P) const { return P; }
|
||||
|
||||
virtual void set_prep(Preprocessing<T>&) {}
|
||||
};
|
||||
|
||||
#endif /* PROTOCOLS_MAC_CHECK_BASE_H_ */
|
||||
|
||||
@@ -17,6 +17,7 @@ class MalRepRingShare : public MaliciousRep3Share<SignedZ2<K>>
|
||||
{
|
||||
typedef SignedZ2<K> T;
|
||||
typedef MaliciousRep3Share<T> super;
|
||||
typedef MalRepRingShare This;
|
||||
|
||||
public:
|
||||
const static int BIT_LENGTH = K;
|
||||
@@ -26,7 +27,8 @@ public:
|
||||
typedef HashMaliciousRepMC<MalRepRingShare> MAC_Check;
|
||||
typedef MAC_Check Direct_MC;
|
||||
typedef ReplicatedInput<MalRepRingShare> Input;
|
||||
typedef ::PrivateOutput<MalRepRingShare> PrivateOutput;
|
||||
typedef ReplicatedPO<This> PO;
|
||||
typedef SpecificPrivateOutput<This> PrivateOutput;
|
||||
typedef MalRepRingPrepWithBits<MalRepRingShare> LivePrep;
|
||||
typedef MaliciousRep3Share<Z2<K + S>> prep_type;
|
||||
typedef Z2<S> random_type;
|
||||
|
||||
@@ -13,6 +13,7 @@ template<class T> class Beaver;
|
||||
template<class T> class MaliciousRepPrepWithBits;
|
||||
template<class T> class MaliciousRepPO;
|
||||
template<class T> class MaliciousRepPrep;
|
||||
template<class T> class SpecificPrivateOutput;
|
||||
|
||||
namespace GC
|
||||
{
|
||||
@@ -30,8 +31,8 @@ public:
|
||||
typedef HashMaliciousRepMC<MaliciousRep3Share<T>> MAC_Check;
|
||||
typedef MAC_Check Direct_MC;
|
||||
typedef ReplicatedInput<MaliciousRep3Share<T>> Input;
|
||||
typedef ::PrivateOutput<MaliciousRep3Share<T>> PrivateOutput;
|
||||
typedef MaliciousRepPO<MaliciousRep3Share> PO;
|
||||
typedef SpecificPrivateOutput<This> PrivateOutput;
|
||||
typedef Rep3Share<T> Honest;
|
||||
typedef MaliciousRepPrepWithBits<MaliciousRep3Share> LivePrep;
|
||||
typedef MaliciousRepPrep<MaliciousRep3Share> TriplePrep;
|
||||
|
||||
@@ -9,13 +9,14 @@
|
||||
template<class T>
|
||||
class MaliciousShamirPO
|
||||
{
|
||||
protected:
|
||||
Player& P;
|
||||
|
||||
octetStream to_send;
|
||||
vector<octetStream> to_receive;
|
||||
|
||||
vector<typename T::open_type> shares;
|
||||
MaliciousShamirMC<T> MC;
|
||||
typename T::Direct_MC MC;
|
||||
|
||||
public:
|
||||
MaliciousShamirPO(Player& P);
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
template<class T> class MaliciousRepPrepWithBits;
|
||||
template<class T> class MaliciousRepPrep;
|
||||
template<class T> class MaliciousShamirPO;
|
||||
template<class T> class SpecificPrivateOutput;
|
||||
|
||||
namespace GC
|
||||
{
|
||||
@@ -23,14 +24,15 @@ template<class T>
|
||||
class MaliciousShamirShare : public ShamirShare<T>
|
||||
{
|
||||
typedef ShamirShare<T> super;
|
||||
typedef MaliciousShamirShare This;
|
||||
|
||||
public:
|
||||
typedef Beaver<MaliciousShamirShare<T>> Protocol;
|
||||
typedef MaliciousShamirMC<MaliciousShamirShare> MAC_Check;
|
||||
typedef MAC_Check Direct_MC;
|
||||
typedef ShamirInput<MaliciousShamirShare> Input;
|
||||
typedef ::PrivateOutput<MaliciousShamirShare> PrivateOutput;
|
||||
typedef MaliciousShamirPO<MaliciousShamirShare> PO;
|
||||
typedef SpecificPrivateOutput<This> PrivateOutput;
|
||||
typedef ShamirShare<T> Honest;
|
||||
typedef MaliciousRepPrepWithBits<MaliciousShamirShare> LivePrep;
|
||||
typedef MaliciousRepPrep<MaliciousShamirShare> TriplePrep;
|
||||
|
||||
@@ -76,12 +76,6 @@ public:
|
||||
return string(1, T::type_char());
|
||||
}
|
||||
|
||||
static void read_or_generate_mac_key(string, Player&, mac_key_type& key)
|
||||
{
|
||||
SeededPRNG G;
|
||||
key.randomize(G);
|
||||
}
|
||||
|
||||
MamaShare()
|
||||
{
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ template<class T>
|
||||
class PostSacriRepFieldShare : public MaliciousRep3Share<T>
|
||||
{
|
||||
typedef MaliciousRep3Share<T> super;
|
||||
typedef PostSacriRepFieldShare This;
|
||||
|
||||
public:
|
||||
typedef typename super::clear clear;
|
||||
@@ -23,7 +24,8 @@ public:
|
||||
typedef HashMaliciousRepMC<PostSacriRepFieldShare> MAC_Check;
|
||||
typedef MAC_Check Direct_MC;
|
||||
typedef ReplicatedInput<PostSacriRepFieldShare> Input;
|
||||
typedef ::PrivateOutput<PostSacriRepFieldShare> PrivateOutput;
|
||||
typedef ReplicatedPO<This> PO;
|
||||
typedef SpecificPrivateOutput<This> PrivateOutput;
|
||||
typedef MaliciousRepPrepWithBits<PostSacriRepFieldShare> LivePrep;
|
||||
|
||||
PostSacriRepFieldShare()
|
||||
|
||||
@@ -17,6 +17,7 @@ template<int K, int S>
|
||||
class PostSacriRepRingShare : public Rep3Share2<K>
|
||||
{
|
||||
typedef Rep3Share2<K> super;
|
||||
typedef PostSacriRepRingShare This;
|
||||
|
||||
public:
|
||||
static const int BIT_LENGTH = K;
|
||||
@@ -33,7 +34,8 @@ public:
|
||||
typedef HashMaliciousRepMC<PostSacriRepRingShare> MAC_Check;
|
||||
typedef MAC_Check Direct_MC;
|
||||
typedef ReplicatedInput<PostSacriRepRingShare> Input;
|
||||
typedef ::PrivateOutput<PostSacriRepRingShare> PrivateOutput;
|
||||
typedef ReplicatedPO<This> PO;
|
||||
typedef SpecificPrivateOutput<This> PrivateOutput;
|
||||
typedef MalRepRingPrepWithBits<PostSacriRepRingShare> LivePrep;
|
||||
|
||||
typedef GC::MaliciousRepSecret bit_type;
|
||||
|
||||
@@ -42,8 +42,13 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
~ProtocolSet()
|
||||
/**
|
||||
* Run all protocol checks
|
||||
*/
|
||||
void check()
|
||||
{
|
||||
protocol.check();
|
||||
output.Check(processor.P);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -73,6 +78,15 @@ public:
|
||||
*thread.protocol), input(output, prep, P)
|
||||
{
|
||||
}
|
||||
|
||||
/**
|
||||
* Run all protocol checks
|
||||
*/
|
||||
void check()
|
||||
{
|
||||
protocol.check();
|
||||
output.Check(protocol.P);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -102,6 +116,15 @@ public:
|
||||
arithmetic.protocol), input(arithmetic.input)
|
||||
{
|
||||
}
|
||||
|
||||
/**
|
||||
* Run all protocol checks
|
||||
*/
|
||||
void check()
|
||||
{
|
||||
arithmetic.check();
|
||||
binary.check();
|
||||
}
|
||||
};
|
||||
|
||||
#endif /* PROTOCOLS_PROTOCOLSET_H_ */
|
||||
|
||||
@@ -15,7 +15,8 @@
|
||||
|
||||
template<class T> class ReplicatedPrep;
|
||||
template<class T> class ReplicatedRingPrep;
|
||||
template<class T> class PrivateOutput;
|
||||
template<class T> class ReplicatedPO;
|
||||
template<class T> class SpecificPrivateOutput;
|
||||
|
||||
template<class T, int L>
|
||||
class RepShare : public FixedVec<T, L>, public ShareInterface
|
||||
@@ -99,6 +100,7 @@ template<class T>
|
||||
class Rep3Share : public RepShare<T, 2>
|
||||
{
|
||||
typedef RepShare<T, 2> super;
|
||||
typedef Rep3Share This;
|
||||
|
||||
public:
|
||||
typedef T clear;
|
||||
@@ -107,7 +109,8 @@ public:
|
||||
typedef ReplicatedMC<Rep3Share> MAC_Check;
|
||||
typedef MAC_Check Direct_MC;
|
||||
typedef ReplicatedInput<Rep3Share> Input;
|
||||
typedef ::PrivateOutput<Rep3Share> PrivateOutput;
|
||||
typedef ReplicatedPO<This> PO;
|
||||
typedef SpecificPrivateOutput<This> PrivateOutput;
|
||||
typedef ReplicatedPrep<Rep3Share> LivePrep;
|
||||
typedef ReplicatedRingPrep<Rep3Share> TriplePrep;
|
||||
typedef Rep3Share Honest;
|
||||
|
||||
@@ -24,7 +24,8 @@ public:
|
||||
typedef ReplicatedMC<Rep3Share2> MAC_Check;
|
||||
typedef MAC_Check Direct_MC;
|
||||
typedef ReplicatedInput<Rep3Share2> Input;
|
||||
typedef ::PrivateOutput<Rep3Share2> PrivateOutput;
|
||||
typedef ReplicatedPO<This> PO;
|
||||
typedef SpecificPrivateOutput<This> PrivateOutput;
|
||||
typedef ReplicatedPrep2k<Rep3Share2> LivePrep;
|
||||
typedef Rep3Share2 Honest;
|
||||
typedef SignedZ2<K> clear;
|
||||
|
||||
@@ -31,7 +31,6 @@ public:
|
||||
void add_mine(const typename T::open_type& input, int n_bits = -1);
|
||||
void add_other(int player, int n_bits = -1);
|
||||
|
||||
void send_mine();
|
||||
void exchange();
|
||||
|
||||
T finalize_mine();
|
||||
|
||||
@@ -64,12 +64,6 @@ void Rep4Input<T>::add_other(int player, int)
|
||||
results[player].push_back(res);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Rep4Input<T>::send_mine()
|
||||
{
|
||||
throw not_implemented();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Rep4Input<T>::exchange()
|
||||
{
|
||||
|
||||
@@ -19,10 +19,6 @@ using namespace std;
|
||||
template<class T> class SubProcessor;
|
||||
template<class T> class ReplicatedMC;
|
||||
template<class T> class ReplicatedInput;
|
||||
template<class T> class ReplicatedPrivateOutput;
|
||||
template<class T> class Share;
|
||||
template<class T> class Rep3Share;
|
||||
template<class T> class MAC_Check_Base;
|
||||
template<class T> class Preprocessing;
|
||||
class Instruction;
|
||||
|
||||
@@ -141,9 +137,6 @@ class Replicated : public ReplicatedBase, public ProtocolBase<T>
|
||||
void trunc_pr(const vector<int>& regs, int size, U& proc, false_type);
|
||||
|
||||
public:
|
||||
typedef ReplicatedMC<T> MAC_Check;
|
||||
typedef ReplicatedInput<T> Input;
|
||||
|
||||
static const bool uses_triples = false;
|
||||
|
||||
Replicated(Player& P);
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "Processor/Processor.h"
|
||||
#include "Processor/TruncPrTuple.h"
|
||||
#include "Tools/benchmarking.h"
|
||||
#include "Tools/Bundle.h"
|
||||
|
||||
#include "ReplicatedInput.h"
|
||||
#include "Rep3Share2k.h"
|
||||
@@ -162,14 +163,13 @@ void Replicated<T>::prepare_mul(const T& x,
|
||||
}
|
||||
|
||||
template<class T>
|
||||
inline void Replicated<T>::prepare_reshare(const typename T::clear& share,
|
||||
void Replicated<T>::prepare_reshare(const typename T::clear& share,
|
||||
int n)
|
||||
{
|
||||
auto add_share = share;
|
||||
typename T::value_type tmp[2];
|
||||
for (int i = 0; i < 2; i++)
|
||||
tmp[i].randomize(shared_prngs[i], n);
|
||||
add_share += tmp[0] - tmp[1];
|
||||
auto add_share = share + tmp[0] - tmp[1];
|
||||
add_share.pack(os[0], n);
|
||||
add_shares.push_back(add_share);
|
||||
}
|
||||
|
||||
@@ -56,16 +56,24 @@ BufferPrep<T>::~BufferPrep()
|
||||
<< " bit generation" << endl;
|
||||
#endif
|
||||
|
||||
auto field_type = T::clear::field_type();
|
||||
auto& my_usage = this->usage.files.at(field_type);
|
||||
|
||||
this->print_left("triples", triples.size() * T::default_length, type_string,
|
||||
this->usage.files.at(T::clear::field_type()).at(DATA_TRIPLE)
|
||||
* T::default_length);
|
||||
|
||||
size_t used_bits = my_usage.at(DATA_BIT);
|
||||
if (not T::clear::invertible and field_type == DATA_INT and not T::has_mac)
|
||||
// add dabits with computation modulo power of two but without MAC
|
||||
used_bits += my_usage.at(DATA_DABIT);
|
||||
this->print_left("bits", bits.size(), type_string, used_bits);
|
||||
|
||||
#define X(KIND, TYPE) \
|
||||
this->print_left(#KIND, KIND.size(), type_string, \
|
||||
this->usage.files.at(T::clear::field_type()).at(TYPE));
|
||||
X(squares, DATA_SQUARE)
|
||||
X(inverses, DATA_INVERSE)
|
||||
X(bits, DATA_BIT)
|
||||
X(dabits, DATA_DABIT)
|
||||
#undef X
|
||||
|
||||
@@ -601,17 +609,6 @@ void buffer_bits_from_players(vector<vector<T>>& player_bits,
|
||||
for (int i = 0; i < n_relevant_players; i++)
|
||||
for (auto& x : player_bits[i])
|
||||
x = input.finalize((base_player + i) % P.num_players(), n_bits);
|
||||
#if !defined(__clang__) && (__GNUC__ == 6)
|
||||
// mitigate compiler bug
|
||||
Bundle<octetStream> bundle(P);
|
||||
P.unchecked_broadcast(bundle);
|
||||
#endif
|
||||
#ifdef DEBUG_BIT_SACRIFICE
|
||||
typename T::MAC_Check MC;
|
||||
for (int i = 0; i < n_relevant_players; i++)
|
||||
for (auto& x : player_bits[i])
|
||||
assert((MC.open(x, P) == 0) or (MC.open(x, P) == 1));
|
||||
#endif
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -1164,18 +1161,18 @@ void BufferPrep<T>::buffer_inputs_as_usual(int player, SubProcessor<T>* proc)
|
||||
typename T::clear r;
|
||||
r.randomize(G);
|
||||
input.add_mine(r);
|
||||
this->inputs[player].push_back({input.finalize_mine(), r});
|
||||
this->inputs[player].push_back({input.finalize(player), r});
|
||||
}
|
||||
input.send_mine();
|
||||
input.exchange();
|
||||
}
|
||||
else
|
||||
{
|
||||
octetStream os;
|
||||
P.receive_player(player, os);
|
||||
T share;
|
||||
for (int i = 0; i < buffer_size; i++)
|
||||
input.add_other(player);
|
||||
input.exchange();
|
||||
for (int i = 0; i < buffer_size; i++)
|
||||
{
|
||||
input.finalize_other(player, share, os);
|
||||
auto share = input.finalize(player);
|
||||
this->inputs[player].push_back({share, 0});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
/*
|
||||
* ReplicatedPrivateOutput.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef PROTOCOLS_REPLICATEDPRIVATEOUTPUT_H_
|
||||
#define PROTOCOLS_REPLICATEDPRIVATEOUTPUT_H_
|
||||
|
||||
template<class T>
|
||||
class SubProcessor;
|
||||
template<class T>
|
||||
class Share;
|
||||
|
||||
template <class T>
|
||||
class ReplicatedPrivateOutput
|
||||
{
|
||||
SubProcessor<T>& proc;
|
||||
|
||||
public:
|
||||
ReplicatedPrivateOutput(SubProcessor<T>& proc);
|
||||
|
||||
void start(int player, int target, int source);
|
||||
void stop(int player, int source);
|
||||
};
|
||||
|
||||
#endif /* PROTOCOLS_REPLICATEDPRIVATEOUTPUT_H_ */
|
||||
@@ -1,30 +0,0 @@
|
||||
/*
|
||||
* ReplicatedPrivateOutput.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "ReplicatedPrivateOutput.h"
|
||||
#include "Processor/Processor.h"
|
||||
#include "Math/FixedVec.h"
|
||||
#include "Math/Integer.h"
|
||||
|
||||
template<class T>
|
||||
inline ReplicatedPrivateOutput<T>::ReplicatedPrivateOutput(
|
||||
SubProcessor<T>& proc) :
|
||||
proc(proc)
|
||||
{
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void ReplicatedPrivateOutput<T>::start(int player, int target,
|
||||
int source)
|
||||
{
|
||||
(void)player, (void)target, (void)source;
|
||||
throw runtime_error("not implemented, use PrivateOutput");
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void ReplicatedPrivateOutput<T>::stop(int player, int source)
|
||||
{
|
||||
(void)player, (void)source;
|
||||
}
|
||||
@@ -71,6 +71,12 @@ public:
|
||||
proc.get_S()[info.source_base + i] >> info.m;
|
||||
}
|
||||
}
|
||||
|
||||
void buffer_random()
|
||||
{
|
||||
for (int i = 0; i < OnlineOptions::singleton.batch_size; i++)
|
||||
this->random.push_back(G.get<T>());
|
||||
}
|
||||
};
|
||||
|
||||
#endif /* PROTOCOLS_SEMI_H_ */
|
||||
|
||||
@@ -14,34 +14,33 @@ template<class T> class SemiMC;
|
||||
* Additive secret sharing input protocol
|
||||
*/
|
||||
template<class T>
|
||||
class SemiInput : public IndividualInput<T>
|
||||
class SemiInput : public InputBase<T>
|
||||
{
|
||||
SeededPRNG secure_prng;
|
||||
vector<SeededPRNG> send_prngs;
|
||||
vector<PRNG> recv_prngs;
|
||||
Player& P;
|
||||
vector<PointerVector<T>> shares;
|
||||
|
||||
public:
|
||||
SemiInput(SubProcessor<T>& proc, SemiMC<T>& MC) :
|
||||
IndividualInput<T>(proc)
|
||||
SemiInput(SubProcessor<T>& proc, SemiMC<T>&) :
|
||||
SemiInput(&proc, proc.P)
|
||||
{
|
||||
(void) MC;
|
||||
}
|
||||
|
||||
SemiInput(SubProcessor<T>* proc, Player& P) :
|
||||
IndividualInput<T>(proc, P)
|
||||
{
|
||||
}
|
||||
SemiInput(SubProcessor<T>* proc, Player& P);
|
||||
|
||||
SemiInput(typename T::MAC_Check& MC, Preprocessing<T>& prep, Player& P) :
|
||||
SemiInput(P)
|
||||
SemiInput(0, P)
|
||||
{
|
||||
(void) MC, (void) prep;
|
||||
}
|
||||
|
||||
SemiInput(Player& P) :
|
||||
IndividualInput<T>(0, P)
|
||||
{
|
||||
}
|
||||
|
||||
void reset(int player);
|
||||
void add_mine(const typename T::clear& input, int n_bits = -1);
|
||||
void add_other(int player, int n_bits = -1);
|
||||
void exchange();
|
||||
void finalize_other(int player, T& target, octetStream& o, int n_bits = -1);
|
||||
T finalize_mine();
|
||||
};
|
||||
|
||||
#endif /* PROTOCOLS_SEMIINPUT_H_ */
|
||||
|
||||
@@ -11,22 +11,64 @@
|
||||
#include "ShamirInput.hpp"
|
||||
|
||||
template<class T>
|
||||
void SemiInput<T>::add_mine(const typename T::clear& input, int n_bits)
|
||||
SemiInput<T>::SemiInput(SubProcessor<T>* proc, Player& P) :
|
||||
InputBase<T>(proc), P(P)
|
||||
{
|
||||
shares.resize(P.num_players());
|
||||
vector<octetStream> to_send(P.num_players()), to_receive;
|
||||
for (int i = 0; i < P.num_players(); i++)
|
||||
{
|
||||
send_prngs.push_back({});
|
||||
to_send[i].append(send_prngs.back().get_seed(), SEED_SIZE);
|
||||
}
|
||||
P.send_receive_all(to_send, to_receive);
|
||||
recv_prngs.resize(P.num_players());
|
||||
for (int i = 0; i < P.num_players(); i++)
|
||||
if (i != P.my_num())
|
||||
recv_prngs[i].SetSeed(to_receive[i].consume(SEED_SIZE));
|
||||
this->reset_all(P);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void SemiInput<T>::reset(int player)
|
||||
{
|
||||
shares[player].clear();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void SemiInput<T>::add_mine(const typename T::clear& input, int)
|
||||
{
|
||||
auto& P = this->P;
|
||||
typename T::open_type sum, share;
|
||||
for (int i = 0; i < P.num_players(); i++)
|
||||
{
|
||||
if (i < P.num_players() - 1)
|
||||
share.randomize(secure_prng, n_bits);
|
||||
else
|
||||
share = input - sum;
|
||||
sum += share;
|
||||
if (i == P.my_num())
|
||||
this->shares.push_back(share);
|
||||
else
|
||||
share.pack(this->os[i], n_bits);
|
||||
if (i != P.my_num())
|
||||
sum += send_prngs[i].template get<typename T::open_type>();
|
||||
}
|
||||
shares[P.my_num()].push_back(input - sum);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void SemiInput<T>::add_other(int, int)
|
||||
{
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void SemiInput<T>::exchange()
|
||||
{
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void SemiInput<T>::finalize_other(int player, T& target, octetStream&,
|
||||
int)
|
||||
{
|
||||
target = recv_prngs[player].template get<T>();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
T SemiInput<T>::finalize_mine()
|
||||
{
|
||||
return shares[P.my_num()].next();
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -27,7 +27,6 @@ class Shamir : public ProtocolBase<T>
|
||||
{
|
||||
typedef typename T::open_type::Scalar U;
|
||||
|
||||
octetStreams os;
|
||||
vector<U> reconstruction;
|
||||
U rec_factor;
|
||||
ShamirInput<T>* resharing;
|
||||
|
||||
@@ -69,8 +69,6 @@ int Shamir<T>::get_n_relevant_players()
|
||||
template<class T>
|
||||
void Shamir<T>::reset()
|
||||
{
|
||||
os.reset(P);
|
||||
|
||||
if (resharing == 0)
|
||||
{
|
||||
resharing = new ShamirInput<T>(0, P);
|
||||
@@ -78,6 +76,9 @@ void Shamir<T>::reset()
|
||||
|
||||
for (int i = 0; i < P.num_players(); i++)
|
||||
resharing->reset(i);
|
||||
|
||||
for (int i = 0; i < n_mul_players; i++)
|
||||
resharing->add_sender(i);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -92,37 +93,27 @@ template<class T>
|
||||
void Shamir<T>::prepare_mul(const T& x, const T& y, int n)
|
||||
{
|
||||
(void) n;
|
||||
auto add_share = x * y * rec_factor;
|
||||
if (P.my_num() < n_mul_players)
|
||||
resharing->add_mine(add_share);
|
||||
resharing->add_mine(x * y * rec_factor);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Shamir<T>::exchange()
|
||||
{
|
||||
vector<bool> senders(P.num_players(), false);
|
||||
for (int i = 0; i < n_mul_players; i++)
|
||||
senders[i] = true;
|
||||
P.send_receive_all(senders, resharing->os, os);
|
||||
assert(resharing);
|
||||
resharing->exchange();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Shamir<T>::start_exchange()
|
||||
{
|
||||
if (P.my_num() < n_mul_players)
|
||||
for (int offset = 1; offset < P.num_players(); offset++)
|
||||
P.send_relative(offset, resharing->os[P.get_player(offset)]);
|
||||
resharing->start_exchange();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Shamir<T>::stop_exchange()
|
||||
{
|
||||
for (int offset = 1; offset < P.num_players(); offset++)
|
||||
{
|
||||
int receive_from = P.get_player(-offset);
|
||||
if (receive_from < n_mul_players)
|
||||
P.receive_player(receive_from, os[receive_from]);
|
||||
}
|
||||
resharing->stop_exchange();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -136,15 +127,8 @@ template<class T>
|
||||
T Shamir<T>::finalize(int n_relevant_players)
|
||||
{
|
||||
ShamirShare<U> res = U(0);
|
||||
if (P.my_num() < n_relevant_players)
|
||||
res = resharing->finalize_mine();
|
||||
for (int i = 0; i < n_relevant_players; i++)
|
||||
if (i != P.my_num())
|
||||
{
|
||||
T tmp;
|
||||
resharing->finalize_other(i, tmp, os[i]);
|
||||
res += tmp;
|
||||
}
|
||||
res += resharing->finalize(i);
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -259,7 +243,7 @@ vector<T> Shamir<T>::get_randoms(PRNG& G, int t)
|
||||
input.reset_all(P);
|
||||
int buffer_size = OnlineOptions::singleton.batch_size;
|
||||
for (int i = 0; i < buffer_size; i += hyper.size())
|
||||
input.add_mine(G.get<U>());
|
||||
input.add_from_all(G.get<U>());
|
||||
input.exchange();
|
||||
vector<U> inputs;
|
||||
vector<T> random;
|
||||
|
||||
@@ -21,10 +21,11 @@ class IndividualInput : public PrepLessInput<T>
|
||||
protected:
|
||||
Player& P;
|
||||
octetStreams os;
|
||||
vector<bool> senders;
|
||||
|
||||
public:
|
||||
IndividualInput(SubProcessor<T>* proc, Player& P) :
|
||||
PrepLessInput<T>(proc), P(P)
|
||||
PrepLessInput<T>(proc), P(P), senders(P.num_players())
|
||||
{
|
||||
this->reset_all(P);
|
||||
}
|
||||
@@ -34,10 +35,14 @@ public:
|
||||
}
|
||||
|
||||
void reset(int player);
|
||||
void add_sender(int player);
|
||||
void add_other(int player, int n_bits = -1);
|
||||
void send_mine();
|
||||
void exchange();
|
||||
void finalize_other(int player, T& target, octetStream& o, int n_bits = -1);
|
||||
|
||||
void start_exchange();
|
||||
void stop_exchange();
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -20,6 +20,8 @@ void IndividualInput<U>::reset(int player)
|
||||
this->i_share = 0;
|
||||
os.reset(P);
|
||||
}
|
||||
|
||||
senders[player] = false;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -68,12 +70,20 @@ void ShamirInput<T>::add_mine(const typename T::open_type& input, int n_bits)
|
||||
else
|
||||
x.pack(this->os[i]);
|
||||
}
|
||||
|
||||
this->senders[P.my_num()] = true;
|
||||
}
|
||||
|
||||
template<class U>
|
||||
void IndividualInput<U>::add_sender(int player)
|
||||
{
|
||||
senders[player] = true;
|
||||
}
|
||||
|
||||
template<class U>
|
||||
void IndividualInput<U>::add_other(int player, int)
|
||||
{
|
||||
(void) player;
|
||||
add_sender(player);
|
||||
}
|
||||
|
||||
template<class U>
|
||||
@@ -87,7 +97,26 @@ void IndividualInput<U>::send_mine()
|
||||
template<class T>
|
||||
void IndividualInput<T>::exchange()
|
||||
{
|
||||
P.send_receive_all(os, InputBase<T>::os);
|
||||
P.send_receive_all(senders, os, InputBase<T>::os);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void IndividualInput<T>::start_exchange()
|
||||
{
|
||||
if (senders[P.my_num()])
|
||||
for (int offset = 1; offset < P.num_players(); offset++)
|
||||
P.send_relative(offset, os[P.get_player(offset)]);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void IndividualInput<T>::stop_exchange()
|
||||
{
|
||||
for (int offset = 1; offset < P.num_players(); offset++)
|
||||
{
|
||||
int receive_from = P.get_player(-offset);
|
||||
if (senders[receive_from])
|
||||
P.receive_player(receive_from, InputBase<T>::os[receive_from]);
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user