mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-08 05:03:59 -05:00
Bristol Fashion.
This commit is contained in:
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -4,3 +4,6 @@
|
||||
[submodule "mpir"]
|
||||
path = mpir
|
||||
url = git://github.com/wbhart/mpir.git
|
||||
[submodule "Programs/Circuits"]
|
||||
path = Programs/Circuits
|
||||
url = https://github.com/mkskeller/bristol-fashion
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
|
||||
|
||||
## 0.1.6 (Apr 2, 2020)
|
||||
|
||||
- Bristol Fashion circuits
|
||||
- Semi-honest computation with somewhat homomorphic encryption
|
||||
- Use SSL for client connections
|
||||
- Client facilities for all arithmetic protocols
|
||||
|
||||
## 0.1.5 (Mar 20, 2020)
|
||||
|
||||
- Faster conversion between arithmetic and binary secret sharing using [extended daBits](https://eprint.iacr.org/2020/338)
|
||||
|
||||
@@ -67,19 +67,33 @@ class BinaryVectorInstruction(base.Instruction):
|
||||
def copy(self, size, subs):
|
||||
return type(self)(*self.get_new_args(size, subs))
|
||||
|
||||
class NonVectorInstruction(base.Instruction):
|
||||
is_vec = lambda self: False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
assert(args[0].n <= args[0].unit)
|
||||
super(NonVectorInstruction, self).__init__(*args, **kwargs)
|
||||
|
||||
class NonVectorInstruction1(base.Instruction):
|
||||
is_vec = lambda self: False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
assert(args[1].n <= args[1].unit)
|
||||
super(NonVectorInstruction1, self).__init__(*args, **kwargs)
|
||||
|
||||
class xors(BinaryVectorInstruction):
|
||||
code = opcodes['XORS']
|
||||
arg_format = tools.cycle(['int','sbw','sb','sb'])
|
||||
|
||||
class xorm(base.Instruction):
|
||||
class xorm(NonVectorInstruction):
|
||||
code = opcodes['XORM']
|
||||
arg_format = ['int','sbw','sb','cb']
|
||||
|
||||
class xorcb(base.Instruction):
|
||||
class xorcb(NonVectorInstruction):
|
||||
code = opcodes['XORCB']
|
||||
arg_format = ['cbw','cb','cb']
|
||||
|
||||
class xorcbi(base.Instruction):
|
||||
class xorcbi(NonVectorInstruction):
|
||||
code = opcodes['XORCBI']
|
||||
arg_format = ['cbw','cb','int']
|
||||
|
||||
@@ -101,47 +115,48 @@ class andm(BinaryVectorInstruction):
|
||||
code = opcodes['ANDM']
|
||||
arg_format = ['int','sbw','sb','cb']
|
||||
|
||||
class addcb(base.Instruction):
|
||||
class addcb(NonVectorInstruction):
|
||||
code = opcodes['ADDCB']
|
||||
arg_format = ['cbw','cb','cb']
|
||||
|
||||
class addcbi(base.Instruction):
|
||||
class addcbi(NonVectorInstruction):
|
||||
code = opcodes['ADDCBI']
|
||||
arg_format = ['cbw','cb','int']
|
||||
|
||||
class mulcbi(base.Instruction):
|
||||
class mulcbi(NonVectorInstruction):
|
||||
code = opcodes['MULCBI']
|
||||
arg_format = ['cbw','cb','int']
|
||||
|
||||
class bitdecs(base.VarArgsInstruction):
|
||||
class bitdecs(NonVectorInstruction, base.VarArgsInstruction):
|
||||
code = opcodes['BITDECS']
|
||||
arg_format = tools.chain(['sb'], itertools.repeat('sbw'))
|
||||
|
||||
class bitcoms(base.VarArgsInstruction):
|
||||
class bitcoms(NonVectorInstruction, base.VarArgsInstruction):
|
||||
code = opcodes['BITCOMS']
|
||||
arg_format = tools.chain(['sbw'], itertools.repeat('sb'))
|
||||
|
||||
class bitdecc(base.VarArgsInstruction):
|
||||
class bitdecc(NonVectorInstruction, base.VarArgsInstruction):
|
||||
code = opcodes['BITDECC']
|
||||
arg_format = tools.chain(['cb'], itertools.repeat('cbw'))
|
||||
|
||||
class shrcbi(base.Instruction):
|
||||
class shrcbi(NonVectorInstruction):
|
||||
code = opcodes['SHRCBI']
|
||||
arg_format = ['cbw','cb','int']
|
||||
|
||||
class shlcbi(base.Instruction):
|
||||
class shlcbi(NonVectorInstruction):
|
||||
code = opcodes['SHLCBI']
|
||||
arg_format = ['cbw','cb','int']
|
||||
|
||||
class ldbits(base.Instruction):
|
||||
class ldbits(NonVectorInstruction):
|
||||
code = opcodes['LDBITS']
|
||||
arg_format = ['sbw','i','i']
|
||||
|
||||
class ldmsb(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
|
||||
class ldmsb(base.DirectMemoryInstruction, base.ReadMemoryInstruction,
|
||||
base.VectorInstruction):
|
||||
code = opcodes['LDMSB']
|
||||
arg_format = ['sbw','int']
|
||||
|
||||
class stmsb(base.DirectMemoryWriteInstruction):
|
||||
class stmsb(base.DirectMemoryWriteInstruction, base.VectorInstruction):
|
||||
code = opcodes['STMSB']
|
||||
arg_format = ['sb','int']
|
||||
# def __init__(self, *args, **kwargs):
|
||||
@@ -149,19 +164,20 @@ class stmsb(base.DirectMemoryWriteInstruction):
|
||||
# import inspect
|
||||
# self.caller = [frame[1:] for frame in inspect.stack()[1:]]
|
||||
|
||||
class ldmcb(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
|
||||
class ldmcb(base.DirectMemoryInstruction, base.ReadMemoryInstruction,
|
||||
base.VectorInstruction):
|
||||
code = opcodes['LDMCB']
|
||||
arg_format = ['cbw','int']
|
||||
|
||||
class stmcb(base.DirectMemoryWriteInstruction):
|
||||
class stmcb(base.DirectMemoryWriteInstruction, base.VectorInstruction):
|
||||
code = opcodes['STMCB']
|
||||
arg_format = ['cb','int']
|
||||
|
||||
class ldmsbi(base.ReadMemoryInstruction):
|
||||
class ldmsbi(base.ReadMemoryInstruction, base.VectorInstruction):
|
||||
code = opcodes['LDMSBI']
|
||||
arg_format = ['sbw','ci']
|
||||
|
||||
class stmsbi(base.WriteMemoryInstruction):
|
||||
class stmsbi(base.WriteMemoryInstruction, base.VectorInstruction):
|
||||
code = opcodes['STMSBI']
|
||||
arg_format = ['sb','ci']
|
||||
|
||||
@@ -185,15 +201,15 @@ class stmsdci(base.WriteMemoryInstruction):
|
||||
code = opcodes['STMSDCI']
|
||||
arg_format = tools.cycle(['cb','cb'])
|
||||
|
||||
class convsint(base.Instruction):
|
||||
class convsint(NonVectorInstruction1):
|
||||
code = opcodes['CONVSINT']
|
||||
arg_format = ['int','sbw','ci']
|
||||
|
||||
class convcint(base.Instruction):
|
||||
class convcint(NonVectorInstruction):
|
||||
code = opcodes['CONVCINT']
|
||||
arg_format = ['cbw','ci']
|
||||
|
||||
class convcbit(base.Instruction):
|
||||
class convcbit(NonVectorInstruction1):
|
||||
code = opcodes['CONVCBIT']
|
||||
arg_format = ['ciw','cb']
|
||||
|
||||
@@ -222,18 +238,19 @@ class split(base.Instruction):
|
||||
super(split_class, self).__init__(*args, **kwargs)
|
||||
assert (len(args) - 2) % args[0] == 0
|
||||
|
||||
class movsb(base.Instruction):
|
||||
class movsb(NonVectorInstruction):
|
||||
code = opcodes['MOVSB']
|
||||
arg_format = ['sbw','sb']
|
||||
|
||||
class trans(base.VarArgsInstruction):
|
||||
code = opcodes['TRANS']
|
||||
is_vec = lambda self: True
|
||||
def __init__(self, *args):
|
||||
self.arg_format = ['int'] + ['sbw'] * args[0] + \
|
||||
['sb'] * (len(args) - 1 - args[0])
|
||||
super(trans, self).__init__(*args)
|
||||
|
||||
class bitb(base.Instruction):
|
||||
class bitb(NonVectorInstruction):
|
||||
code = opcodes['BITB']
|
||||
arg_format = ['sbw']
|
||||
|
||||
@@ -245,20 +262,22 @@ class inputb(base.DoNotEliminateInstruction, base.VarArgsInstruction):
|
||||
__slots__ = []
|
||||
code = opcodes['INPUTB']
|
||||
arg_format = tools.cycle(['p','int','int','sbw'])
|
||||
is_vec = lambda self: True
|
||||
|
||||
class print_regb(base.IOInstruction):
|
||||
class print_regb(base.VectorInstruction, base.IOInstruction):
|
||||
code = opcodes['PRINTREGB']
|
||||
arg_format = ['cb','i']
|
||||
def __init__(self, reg, comment=''):
|
||||
super(print_regb, self).__init__(reg, self.str_to_int(comment))
|
||||
|
||||
class print_reg_plainb(base.IOInstruction):
|
||||
class print_reg_plainb(NonVectorInstruction, base.IOInstruction):
|
||||
code = opcodes['PRINTREGPLAINB']
|
||||
arg_format = ['cb']
|
||||
|
||||
class print_reg_signed(base.IOInstruction):
|
||||
code = opcodes['PRINTREGSIGNED']
|
||||
arg_format = ['int','cb']
|
||||
is_vec = lambda self: True
|
||||
|
||||
class print_float_plainb(base.IOInstruction):
|
||||
__slots__ = []
|
||||
|
||||
@@ -77,6 +77,9 @@ class bits(Tape.Register, _structure, _bit):
|
||||
def n_elements():
|
||||
return 1
|
||||
@classmethod
|
||||
def mem_size(cls):
|
||||
return math.ceil(cls.n / cls.unit)
|
||||
@classmethod
|
||||
def load_mem(cls, address, mem_type=None, size=None):
|
||||
if size not in (None, 1):
|
||||
v = [cls.load_mem(address + i) for i in range(size)]
|
||||
@@ -101,9 +104,8 @@ class bits(Tape.Register, _structure, _bit):
|
||||
def copy(self):
|
||||
return type(self)(n=instructions_base.get_global_vector_size())
|
||||
def set_length(self, n):
|
||||
if n > self.max_length:
|
||||
print(self.max_length)
|
||||
raise Exception('too long: %d' % n)
|
||||
if n > self.n:
|
||||
raise Exception('too long: %d/%d' % (n, self.n))
|
||||
self.n = n
|
||||
def set_size(self, size):
|
||||
pass
|
||||
@@ -135,7 +137,7 @@ class bits(Tape.Register, _structure, _bit):
|
||||
if self.n != None:
|
||||
suffix = '%d' % self.n
|
||||
if type(self).n != None and type(self).n != self.n:
|
||||
suffice += '/%d' % type(self).n
|
||||
suffix += '/%d' % type(self).n
|
||||
else:
|
||||
suffix = 'undef'
|
||||
return '%s(%s)' % (super(bits, self).__repr__(), suffix)
|
||||
@@ -237,6 +239,7 @@ class sbits(bits):
|
||||
bitdec = inst.bitdecs
|
||||
bitcom = inst.bitcoms
|
||||
conv_regint = inst.convsint
|
||||
one_cache = {}
|
||||
@classmethod
|
||||
def conv_regint_by_bit(cls, n, res, other):
|
||||
tmp = cbits.get_type(n)()
|
||||
@@ -285,14 +288,12 @@ class sbits(bits):
|
||||
% (value, self.n))
|
||||
if self.n <= 32:
|
||||
inst.ldbits(self, self.n, value)
|
||||
elif self.n <= 64:
|
||||
self.load_other(regint(value, size=1))
|
||||
elif self.n <= 128:
|
||||
lower = sbits.get_type(64)(value % 2**64)
|
||||
upper = sbits.get_type(self.n - 64)(value >> 64)
|
||||
self.mov(self, lower + (upper << 64))
|
||||
else:
|
||||
raise NotImplementedError('more than 128 bits wanted')
|
||||
size = math.ceil(self.n / self.unit)
|
||||
tmp = regint(size=size)
|
||||
for i in range(size):
|
||||
tmp[i].load_int((value >> (i * 64)) % 2**64)
|
||||
self.load_other(tmp)
|
||||
def load_other(self, other):
|
||||
if isinstance(other, cbits) and self.n == other.n:
|
||||
inst.convcbit2s(self.n, self, other)
|
||||
@@ -393,11 +394,10 @@ class sbits(bits):
|
||||
# res = type(self)(n=self.n)
|
||||
# inst.nots(res, self)
|
||||
# return res
|
||||
if self.n == None or self.n > self.unit:
|
||||
one = self.get_type(self.n)()
|
||||
self.conv_regint_by_bit(self.n, one, regint(1, size=self.n))
|
||||
else:
|
||||
one = self.new(value=self.long_one(), n=self.n)
|
||||
key = self.n, library.get_block()
|
||||
if key not in self.one_cache:
|
||||
self.one_cache[key] = self.new(value=self.long_one(), n=self.n)
|
||||
one = self.one_cache[key]
|
||||
return self + one
|
||||
def __neg__(self):
|
||||
return self
|
||||
@@ -432,12 +432,12 @@ class sbits(bits):
|
||||
@classmethod
|
||||
def trans(cls, rows):
|
||||
rows = list(rows)
|
||||
if len(rows) == 1:
|
||||
if len(rows) == 1 and rows[0].n <= rows[0].unit:
|
||||
return rows[0].bit_decompose()
|
||||
n_columns = rows[0].n
|
||||
for row in rows:
|
||||
assert(row.n == n_columns)
|
||||
if n_columns == 1:
|
||||
if n_columns == 1 and len(rows) <= cls.unit:
|
||||
return [cls.bit_compose(rows)]
|
||||
else:
|
||||
res = [cls.new(n=len(rows)) for i in range(n_columns)]
|
||||
@@ -452,6 +452,10 @@ class sbits(bits):
|
||||
@staticmethod
|
||||
def ripple_carry_adder(*args, **kwargs):
|
||||
return sbitint.ripple_carry_adder(*args, **kwargs)
|
||||
def to_sint(self, n_bits):
|
||||
bits = sbitvec.from_vec(sbitvec([self]).v[:n_bits]).elements()[0]
|
||||
bits = sint(bits, size=n_bits)
|
||||
return sint.bit_compose(bits)
|
||||
|
||||
class sbitvec(_vec):
|
||||
@classmethod
|
||||
@@ -524,6 +528,8 @@ class sbitvec(_vec):
|
||||
return iter(self.v)
|
||||
def __len__(self):
|
||||
return len(self.v)
|
||||
def __getitem__(self, index):
|
||||
return self.v[index]
|
||||
@classmethod
|
||||
def conv(cls, other):
|
||||
return cls.from_vec(other.v)
|
||||
|
||||
@@ -210,85 +210,6 @@ class Merger:
|
||||
max_depth_of[v] = min(max_depth_of[u], max_depth_of[v])
|
||||
return max_depth_of
|
||||
|
||||
def merge_inputs(self):
|
||||
merges = defaultdict(list)
|
||||
remaining_input_nodes = []
|
||||
def do_merge(nodes):
|
||||
if len(nodes) > 1000:
|
||||
print('Merging %d inputs...' % len(nodes))
|
||||
self.do_merge(iter(nodes))
|
||||
for n in self.input_nodes:
|
||||
inst = self.instructions[n]
|
||||
merge = merges[inst.args[0],inst.__class__]
|
||||
if len(merge) == 0:
|
||||
remaining_input_nodes.append(n)
|
||||
merge.append(n)
|
||||
if len(merge) >= self.max_parallel_open:
|
||||
do_merge(merge)
|
||||
merge[:] = []
|
||||
for merge in reversed(sorted(merges.values())):
|
||||
if merge:
|
||||
do_merge(merge)
|
||||
self.input_nodes = remaining_input_nodes
|
||||
|
||||
def compute_preorder(self, merges, rev_depth_of):
|
||||
# find flexible nodes that can be on several levels
|
||||
# and find sources on level 0
|
||||
G = self.G
|
||||
merge_nodes_set = self.open_nodes
|
||||
depth_of = self.depths
|
||||
instructions = self.instructions
|
||||
flex_nodes = defaultdict(dict)
|
||||
starters = []
|
||||
for n in range(len(G)):
|
||||
if n not in merge_nodes_set and \
|
||||
depth_of[n] != rev_depth_of[n] and G[n] and G.get_attr(n,'start') == -1 and not isinstance(instructions[n], AsymmetricCommunicationInstruction):
|
||||
#print n, depth_of[n], rev_depth_of[n]
|
||||
flex_nodes[depth_of[n]].setdefault(rev_depth_of[n], set()).add(n)
|
||||
elif len(G.pred[n]) == 0 and \
|
||||
not isinstance(self.instructions[n], RawInputInstruction):
|
||||
starters.append(n)
|
||||
if n % 10000000 == 0 and n > 0:
|
||||
print("Processed %d nodes at" % n, time.asctime())
|
||||
|
||||
inputs = defaultdict(list)
|
||||
for node in self.input_nodes:
|
||||
player = self.instructions[node].args[0]
|
||||
inputs[player].append(node)
|
||||
first_inputs = [l[0] for l in inputs.values()]
|
||||
other_inputs = []
|
||||
i = 0
|
||||
while True:
|
||||
i += 1
|
||||
found = False
|
||||
for l in inputs.values():
|
||||
if i < len(l):
|
||||
other_inputs.append(l[i])
|
||||
found = True
|
||||
if not found:
|
||||
break
|
||||
other_inputs.reverse()
|
||||
|
||||
preorder = []
|
||||
# magical preorder for topological search
|
||||
max_depth = max(merges)
|
||||
if max_depth > 10000:
|
||||
print("Computing pre-ordering ...")
|
||||
for i in range(max_depth, 0, -1):
|
||||
preorder.append(G.get_attr(merges[i], 'stop'))
|
||||
for j in flex_nodes[i-1].values():
|
||||
preorder.extend(j)
|
||||
preorder.extend(flex_nodes[0].get(i, []))
|
||||
preorder.append(merges[i])
|
||||
if i % 100000 == 0 and i > 0:
|
||||
print("Done level %d at" % i, time.asctime())
|
||||
preorder.extend(other_inputs)
|
||||
preorder.extend(starters)
|
||||
preorder.extend(first_inputs)
|
||||
if max_depth > 10000:
|
||||
print("Done at", time.asctime())
|
||||
return preorder
|
||||
|
||||
def longest_paths_merge(self):
|
||||
""" Attempt to merge instructions of type instruction_type (which are given in
|
||||
merge_nodes) using longest paths algorithm.
|
||||
@@ -301,7 +222,7 @@ class Merger:
|
||||
instructions = self.instructions
|
||||
merge_nodes = self.open_nodes
|
||||
depths = self.depths
|
||||
if not merge_nodes and not self.input_nodes:
|
||||
if not merge_nodes:
|
||||
return 0
|
||||
|
||||
# merge opens at same depth
|
||||
@@ -321,8 +242,6 @@ class Merger:
|
||||
(len(merge), t.__name__, i, len(merges)))
|
||||
self.do_merge(merge)
|
||||
|
||||
self.merge_inputs()
|
||||
|
||||
preorder = None
|
||||
|
||||
if len(instructions) > 100000:
|
||||
@@ -340,7 +259,6 @@ class Merger:
|
||||
options = self.options
|
||||
open_nodes = set()
|
||||
self.open_nodes = open_nodes
|
||||
self.input_nodes = []
|
||||
colordict = defaultdict(lambda: 'gray', asm_open='red',\
|
||||
ldi='lightblue', ldm='lightblue', stm='blue',\
|
||||
mov='yellow', mulm='orange', mulc='orange',\
|
||||
@@ -507,14 +425,7 @@ class Merger:
|
||||
elif isinstance(instr, PublicFileIOInstruction):
|
||||
keep_order(instr, n, instr.__class__)
|
||||
elif isinstance(instr, RawInputInstruction):
|
||||
keep_order(instr, n, instr.__class__, 0)
|
||||
self.input_nodes.append(n)
|
||||
G.add_node(n, merges=[])
|
||||
player = instr.args[0]
|
||||
if isinstance(instr, stopinput):
|
||||
add_edge(last[startinput_class][player], n)
|
||||
elif isinstance(instr, gstopinput):
|
||||
add_edge(last[gstartinput][player], n)
|
||||
keep_order(instr, n, instr.__class__)
|
||||
elif isinstance(instr, startprivateoutput_class):
|
||||
keep_order(instr, n, startprivateoutput_class, 2)
|
||||
elif isinstance(instr, stopprivateoutput_class):
|
||||
@@ -559,18 +470,14 @@ class Merger:
|
||||
unused_result = not G.degree(i) and len(list(inst.get_def())) \
|
||||
and reduce(operator.and_, (reg.can_eliminate for reg in inst.get_def())) \
|
||||
and not isinstance(inst, (DoNotEliminateInstruction))
|
||||
stop_node = G.get_attr(i, 'stop')
|
||||
unused_startopen = stop_node != -1 and instructions[stop_node] is None
|
||||
def eliminate(i):
|
||||
G.remove_node(i)
|
||||
merge_nodes.discard(i)
|
||||
stats[type(instructions[i]).__name__] += 1
|
||||
instructions[i] = None
|
||||
if unused_result or unused_startopen:
|
||||
if unused_result:
|
||||
eliminate(i)
|
||||
count += 1
|
||||
if unused_startopen:
|
||||
open_count += len(inst.args)
|
||||
# remove unnecessary stack instructions
|
||||
# left by optimization with budget
|
||||
if isinstance(inst, popint_class) and \
|
||||
|
||||
201
Compiler/circuit.py
Normal file
201
Compiler/circuit.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
This module contains functionality using circuits in the so-called
|
||||
`Bristol Fashion`_ format. You can download a few examples including
|
||||
the ones used below into ``Programs/Circuits`` as follows::
|
||||
|
||||
make Programs/Circuits
|
||||
|
||||
.. _`Bristol Fashion`: https://homes.esat.kuleuven.be/~nsmart/MPC
|
||||
|
||||
"""
|
||||
|
||||
from Compiler.GC.types import sbitvec, sbits
|
||||
from Compiler.library import function_block
|
||||
from Compiler import util
|
||||
import itertools
|
||||
|
||||
class Circuit:
|
||||
"""
|
||||
Use a Bristol Fashion circuit in a high-level program. The
|
||||
following example adds signed 64-bit inputs from two different
|
||||
parties and prints the result::
|
||||
|
||||
from circuit import Circuit
|
||||
sb64 = sbits.get_type(64)
|
||||
adder = Circuit('adder64')
|
||||
a, b = [sbitvec(sb64.get_input_from(i)) for i in (0, 1)]
|
||||
print_ln('%s', adder(a, b).elements()[0].reveal())
|
||||
|
||||
Circuits can also be executed in parallel as the following example
|
||||
shows::
|
||||
|
||||
from circuit import Circuit
|
||||
sb128 = sbits.get_type(128)
|
||||
key = sb128(0x2b7e151628aed2a6abf7158809cf4f3c)
|
||||
plaintext = sb128(0x6bc1bee22e409f96e93d7e117393172a)
|
||||
n = 1000
|
||||
aes128 = Circuit('aes_128')
|
||||
ciphertexts = aes128(sbitvec([key] * n), sbitvec([plaintext] * n))
|
||||
ciphertexts.elements()[n - 1].reveal().print_reg()
|
||||
|
||||
This executes AES-128 1000 times in parallel and then outputs the
|
||||
last result, which should be ``0x3ad77bb40d7a3660a89ecaf32466ef97``,
|
||||
one of the test vectors for AES-128.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
self.filename = 'Programs/Circuits/%s.txt' % name
|
||||
f = open(self.filename)
|
||||
self.functions = {}
|
||||
|
||||
def __call__(self, *inputs):
|
||||
return self.run(*inputs)
|
||||
|
||||
def run(self, *inputs):
|
||||
n = inputs[0][0].n
|
||||
if n not in self.functions:
|
||||
self.functions[n] = function_block(lambda *args:
|
||||
self.compile(*args))
|
||||
flat_res = self.functions[n](*itertools.chain(*inputs))
|
||||
res = []
|
||||
i = 0
|
||||
for l in self.n_output_wires:
|
||||
v = []
|
||||
for i in range(l):
|
||||
v.append(flat_res[i])
|
||||
i += 1
|
||||
res.append(sbitvec.from_vec(v))
|
||||
return util.untuplify(res)
|
||||
|
||||
def compile(self, *all_inputs):
|
||||
f = open(self.filename)
|
||||
lines = iter(f)
|
||||
next_line = lambda: next(lines).split()
|
||||
n_gates, n_wires = (int(x) for x in next_line())
|
||||
self.n_wires = n_wires
|
||||
input_line = [int(x) for x in next_line()]
|
||||
n_inputs = input_line[0]
|
||||
n_input_wires = input_line[1:]
|
||||
assert(n_inputs == len(n_input_wires))
|
||||
inputs = []
|
||||
s = 0
|
||||
for n in n_input_wires:
|
||||
inputs.append(all_inputs[s:s + n])
|
||||
s += n
|
||||
output_line = [int(x) for x in next_line()]
|
||||
n_outputs = output_line[0]
|
||||
self.n_output_wires = output_line[1:]
|
||||
assert(n_outputs == len(self.n_output_wires))
|
||||
next(lines)
|
||||
|
||||
wires = [None] * n_wires
|
||||
self.wires = wires
|
||||
i_wire = 0
|
||||
for input, input_wires in zip(inputs, n_input_wires):
|
||||
assert(len(input) == input_wires)
|
||||
for i, reg in enumerate(input):
|
||||
wires[i_wire] = reg
|
||||
i_wire += 1
|
||||
|
||||
for i in range(n_gates):
|
||||
line = next_line()
|
||||
t = line[-1]
|
||||
if t in ('XOR', 'AND'):
|
||||
assert line[0] == '2'
|
||||
assert line[1] == '1'
|
||||
assert len(line) == 6
|
||||
ins = [wires[int(line[2 + i])] for i in range(2)]
|
||||
if t == 'XOR':
|
||||
wires[int(line[4])] = ins[0] ^ ins[1]
|
||||
else:
|
||||
wires[int(line[4])] = ins[0] & ins[1]
|
||||
elif t == 'INV':
|
||||
assert line[0] == '1'
|
||||
assert line[1] == '1'
|
||||
assert len(line) == 5
|
||||
wires[int(line[3])] = ~wires[int(line[2])]
|
||||
|
||||
return self.wires[-sum(self.n_output_wires):]
|
||||
|
||||
Keccak_f = None
|
||||
|
||||
def sha3_256(x):
|
||||
"""
|
||||
This function implements SHA3-256 for inputs of up to 1080 bits::
|
||||
|
||||
from circuit import sha3_256
|
||||
a = sbitvec.from_vec([])
|
||||
b = sbitvec(sint(0xcc), 8)
|
||||
for x in a, b:
|
||||
sha3_256(x).elements()[0].reveal().print_reg()
|
||||
|
||||
This should output the first two test vectors of SHA3-256 in
|
||||
byte-reversed order::
|
||||
|
||||
0x5375f6fb6aa989b0c287a923afe81e79ff875921cacc956666d71ebff8c6ffa7
|
||||
0x17c7e0d65c285af8406d4f21c071851a312b739a8ecdf25c1270d31c39357067
|
||||
|
||||
Note that :py:obj:`sint` to :py:obj:`sbitvec` conversion is only
|
||||
implemented for computation modulo a power of two.
|
||||
"""
|
||||
|
||||
global Keccak_f
|
||||
if Keccak_f is None:
|
||||
# only one instance
|
||||
Keccak_f = Circuit('Keccak_f')
|
||||
|
||||
# whole bytes
|
||||
assert len(x.v) % 8 == 0
|
||||
# only one block
|
||||
r = 1088
|
||||
assert len(x.v) < 1088
|
||||
if x.v:
|
||||
n = x.v[0].n
|
||||
else:
|
||||
n = 1
|
||||
d = sbitvec([sbits.get_type(8)(0x06)] * n)
|
||||
sbn = sbits.get_type(n)
|
||||
padding = [sbn(0)] * (r - 8 - len(x.v))
|
||||
P_flat = x.v + d.v + padding
|
||||
assert len(P_flat) == r
|
||||
P_flat[-1] = ~P_flat[-1]
|
||||
w = 64
|
||||
P1 = [P_flat[i * w:(i + 1) * w] for i in range(r // w)]
|
||||
|
||||
S = [[[sbn(0) for i in range(w)] for i in range(5)] for i in range(5)]
|
||||
for x in range(5):
|
||||
for y in range(5):
|
||||
if x + 5 * y < r // w:
|
||||
for i in range(w):
|
||||
S[x][y][i] ^= P1[x + 5 * y][i]
|
||||
|
||||
def flatten(S):
|
||||
res = [None] * 1600
|
||||
for y in range(5):
|
||||
for x in range(5):
|
||||
for i in range(w):
|
||||
j = (5 * y + x) * w + i // 8 * 8 + 7 - i % 8
|
||||
res[1600 - 1 - j] = S[x][y][i]
|
||||
return res
|
||||
|
||||
def unflatten(S_flat):
|
||||
res = [[[None] * w for j in range(5)] for i in range(5)]
|
||||
for y in range(5):
|
||||
for x in range(5):
|
||||
for i in range(w):
|
||||
j = (5 * y + x) * w + i // 8 * 8 + 7 - i % 8
|
||||
res[x][y][i] = S_flat[1600 - 1 -j]
|
||||
return res
|
||||
|
||||
S = unflatten(Keccak_f(flatten(S)))
|
||||
|
||||
Z = []
|
||||
while len(Z) <= 256:
|
||||
for y in range(5):
|
||||
for x in range(5):
|
||||
if x + 5 * y < r // w:
|
||||
Z += S[y][x]
|
||||
if len(Z) <= 256:
|
||||
S = unflatten(Keccak_f(flatten(S)))
|
||||
return sbitvec.from_vec(Z[:256])
|
||||
@@ -262,6 +262,7 @@ def Mod2mField(a_prime, a, k, m, kappa, signed):
|
||||
return r_dprime, r_prime, c, c_prime, u, t, c2k1
|
||||
|
||||
def MaskingBitsInRing(m, strict=False):
|
||||
program.curr_tape.require_bit_length(1)
|
||||
from Compiler.types import sint
|
||||
if program.use_edabit():
|
||||
return sint.get_edabit(m, strict)
|
||||
|
||||
@@ -9,19 +9,18 @@ import time
|
||||
import sys
|
||||
|
||||
|
||||
def run(args, options, param=-1, merge_opens=True, emulate=True, \
|
||||
reallocate=True, assemblymode=False, debug=False):
|
||||
def run(args, options, param=-1, merge_opens=True,
|
||||
reallocate=True, debug=False):
|
||||
""" Compile a file and output a Program object.
|
||||
|
||||
If merge_opens is set to True, will attempt to merge any parallelisable open
|
||||
instructions. """
|
||||
|
||||
prog = Program(args, options, param, assemblymode)
|
||||
prog = Program(args, options, param)
|
||||
instructions.program = prog
|
||||
instructions_base.program = prog
|
||||
types.program = prog
|
||||
comparison.program = prog
|
||||
prog.EMULATE = emulate
|
||||
prog.DEBUG = debug
|
||||
VARS['program'] = prog
|
||||
if options.binary:
|
||||
@@ -31,26 +30,9 @@ def run(args, options, param=-1, merge_opens=True, emulate=True, \
|
||||
|
||||
print('Compiling file', prog.infile)
|
||||
|
||||
# no longer needed, but may want to support assembly in future (?)
|
||||
if assemblymode:
|
||||
prog.restart_main_thread()
|
||||
for i in range(INIT_REG_MAX):
|
||||
VARS['c%d'%i] = prog.curr_block.new_reg('c')
|
||||
VARS['s%d'%i] = prog.curr_block.new_reg('s')
|
||||
VARS['cg%d'%i] = prog.curr_block.new_reg('cg')
|
||||
VARS['sg%d'%i] = prog.curr_block.new_reg('sg')
|
||||
if i % 10000000 == 0 and i > 0:
|
||||
print("Initialized %d register variables at" % i, time.asctime())
|
||||
|
||||
# first pass determines how many assembler registers are used
|
||||
prog.FIRST_PASS = True
|
||||
exec(compile(open(prog.infile).read(), prog.infile, 'exec'), VARS)
|
||||
|
||||
if instructions_base.Instruction.count != 0:
|
||||
print('instructions count', instructions_base.Instruction.count)
|
||||
instructions_base.Instruction.count = 0
|
||||
prog.FIRST_PASS = False
|
||||
prog.reset_values()
|
||||
# make compiler modules directly accessible
|
||||
sys.path.insert(0, 'Compiler')
|
||||
# create the tapes
|
||||
@@ -60,17 +42,14 @@ def run(args, options, param=-1, merge_opens=True, emulate=True, \
|
||||
for tape in prog.tapes:
|
||||
tape.optimize(options)
|
||||
|
||||
# check program still does the same thing after optimizations
|
||||
if emulate:
|
||||
clearmem = list(prog.mem_c)
|
||||
sharedmem = list(prog.mem_s)
|
||||
prog.emulate()
|
||||
if prog.mem_c != clearmem or prog.mem_s != sharedmem:
|
||||
print('Warning: emulated memory values changed after compiler optimization')
|
||||
# raise CompilerError('Compiler optimization caused incorrect memory write.')
|
||||
|
||||
if prog.main_thread_running:
|
||||
prog.update_req(prog.curr_tape)
|
||||
|
||||
if prog.req_num:
|
||||
print('Program requires:')
|
||||
for x in prog.req_num.pretty():
|
||||
print(x)
|
||||
|
||||
if prog.verbose:
|
||||
print('Program requires:', repr(prog.req_num))
|
||||
print('Cost:', 0 if prog.req_num is None else prog.req_num.cost())
|
||||
|
||||
@@ -1,13 +1,7 @@
|
||||
from collections import defaultdict
|
||||
|
||||
#INIT_REG_MAX = 655360
|
||||
INIT_REG_MAX = 1310720
|
||||
REG_MAX = 2 ** 32
|
||||
USER_MEM = 8192
|
||||
TMP_MEM = 8192
|
||||
TMP_MEM_BASE = USER_MEM
|
||||
TMP_REG = 3
|
||||
TMP_REG_BASE = REG_MAX - TMP_REG
|
||||
|
||||
P_VALUES = { 32: 2147565569, \
|
||||
64: 9223372036855103489, \
|
||||
|
||||
@@ -17,7 +17,7 @@ class SparseDiGraph(object):
|
||||
""" max_nodes: maximum no of nodes
|
||||
default_attributes: dict of node attributes and default values """
|
||||
if default_attributes is None:
|
||||
default_attributes = { 'merges': None, 'stop': -1, 'start': -1 }
|
||||
default_attributes = { 'merges': None }
|
||||
self.default_attributes = default_attributes
|
||||
self.attribute_pos = dict(list(zip(list(default_attributes.keys()), list(range(len(default_attributes))))))
|
||||
self.n = max_nodes
|
||||
|
||||
@@ -34,9 +34,6 @@ class ldi(base.Instruction):
|
||||
__slots__ = []
|
||||
code = base.opcodes['LDI']
|
||||
arg_format = ['cw','i']
|
||||
|
||||
def execute(self):
|
||||
self.args[0].value = self.args[1]
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
@@ -45,9 +42,6 @@ class ldsi(base.Instruction):
|
||||
__slots__ = []
|
||||
code = base.opcodes['LDSI']
|
||||
arg_format = ['sw','i']
|
||||
|
||||
def execute(self):
|
||||
self.args[0].value = self.args[1]
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
@@ -57,9 +51,6 @@ class ldmc(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
|
||||
code = base.opcodes['LDMC']
|
||||
arg_format = ['cw','int']
|
||||
|
||||
def execute(self):
|
||||
self.args[0].value = program.mem_c[self.args[1]]
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
class ldms(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
|
||||
@@ -68,9 +59,6 @@ class ldms(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
|
||||
code = base.opcodes['LDMS']
|
||||
arg_format = ['sw','int']
|
||||
|
||||
def execute(self):
|
||||
self.args[0].value = program.mem_s[self.args[1]]
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
class stmc(base.DirectMemoryWriteInstruction):
|
||||
@@ -79,9 +67,6 @@ class stmc(base.DirectMemoryWriteInstruction):
|
||||
code = base.opcodes['STMC']
|
||||
arg_format = ['c','int']
|
||||
|
||||
def execute(self):
|
||||
program.mem_c[self.args[1]] = self.args[0].value
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
class stms(base.DirectMemoryWriteInstruction):
|
||||
@@ -90,9 +75,6 @@ class stms(base.DirectMemoryWriteInstruction):
|
||||
code = base.opcodes['STMS']
|
||||
arg_format = ['s','int']
|
||||
|
||||
def execute(self):
|
||||
program.mem_s[self.args[1]] = self.args[0].value
|
||||
|
||||
@base.vectorize
|
||||
class ldmint(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
|
||||
r""" Assigns register $ci_i$ the value in memory \verb+Ci[n]+. """
|
||||
@@ -100,9 +82,6 @@ class ldmint(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
|
||||
code = base.opcodes['LDMINT']
|
||||
arg_format = ['ciw','int']
|
||||
|
||||
def execute(self):
|
||||
self.args[0].value = program.mem_i[self.args[1]]
|
||||
|
||||
@base.vectorize
|
||||
class stmint(base.DirectMemoryWriteInstruction):
|
||||
r""" Sets \verb+Ci[n]+ to be the value $ci_i$. """
|
||||
@@ -110,18 +89,12 @@ class stmint(base.DirectMemoryWriteInstruction):
|
||||
code = base.opcodes['STMINT']
|
||||
arg_format = ['ci','int']
|
||||
|
||||
def execute(self):
|
||||
program.mem_i[self.args[1]] = self.args[0].value
|
||||
|
||||
# must have seperate instructions because address is always modp
|
||||
@base.vectorize
|
||||
class ldmci(base.ReadMemoryInstruction):
|
||||
r""" Assigns register $c_i$ the value in memory \verb+C[cj]+. """
|
||||
code = base.opcodes['LDMCI']
|
||||
arg_format = ['cw','ci']
|
||||
|
||||
def execute(self):
|
||||
self.args[0].value = program.mem_c[self.args[1].value]
|
||||
|
||||
@base.vectorize
|
||||
class ldmsi(base.ReadMemoryInstruction):
|
||||
@@ -129,53 +102,35 @@ class ldmsi(base.ReadMemoryInstruction):
|
||||
code = base.opcodes['LDMSI']
|
||||
arg_format = ['sw','ci']
|
||||
|
||||
def execute(self):
|
||||
self.args[0].value = program.mem_s[self.args[1].value]
|
||||
|
||||
@base.vectorize
|
||||
class stmci(base.WriteMemoryInstruction):
|
||||
r""" Sets \verb+C[cj]+ to be the value $c_i$. """
|
||||
code = base.opcodes['STMCI']
|
||||
arg_format = ['c','ci']
|
||||
|
||||
def execute(self):
|
||||
program.mem_c[self.args[1].value] = self.args[0].value
|
||||
|
||||
@base.vectorize
|
||||
class stmsi(base.WriteMemoryInstruction):
|
||||
r""" Sets \verb+S[cj]+ to be the value $s_i$. """
|
||||
code = base.opcodes['STMSI']
|
||||
arg_format = ['s','ci']
|
||||
|
||||
def execute(self):
|
||||
program.mem_s[self.args[1].value] = self.args[0].value
|
||||
|
||||
@base.vectorize
|
||||
class ldminti(base.ReadMemoryInstruction):
|
||||
r""" Assigns register $ci_i$ the value in memory \verb+Ci[cj]+. """
|
||||
code = base.opcodes['LDMINTI']
|
||||
arg_format = ['ciw','ci']
|
||||
|
||||
def execute(self):
|
||||
self.args[0].value = program.mem_i[self.args[1].value]
|
||||
|
||||
@base.vectorize
|
||||
class stminti(base.WriteMemoryInstruction):
|
||||
r""" Sets \verb+Ci[cj]+ to be the value $ci_i$. """
|
||||
code = base.opcodes['STMINTI']
|
||||
arg_format = ['ci','ci']
|
||||
|
||||
def execute(self):
|
||||
program.mem_i[self.args[1].value] = self.args[0].value
|
||||
|
||||
@base.vectorize
|
||||
class gldmci(base.ReadMemoryInstruction):
|
||||
r""" Assigns register $c_i$ the value in memory \verb+C[cj]+. """
|
||||
code = base.opcodes['LDMCI'] + 0x100
|
||||
arg_format = ['cgw','ci']
|
||||
|
||||
def execute(self):
|
||||
self.args[0].value = program.mem_c[self.args[1].value]
|
||||
|
||||
@base.vectorize
|
||||
class gldmsi(base.ReadMemoryInstruction):
|
||||
@@ -183,27 +138,18 @@ class gldmsi(base.ReadMemoryInstruction):
|
||||
code = base.opcodes['LDMSI'] + 0x100
|
||||
arg_format = ['sgw','ci']
|
||||
|
||||
def execute(self):
|
||||
self.args[0].value = program.mem_s[self.args[1].value]
|
||||
|
||||
@base.vectorize
|
||||
class gstmci(base.WriteMemoryInstruction):
|
||||
r""" Sets \verb+C[cj]+ to be the value $c_i$. """
|
||||
code = base.opcodes['STMCI'] + 0x100
|
||||
arg_format = ['cg','ci']
|
||||
|
||||
def execute(self):
|
||||
program.mem_c[self.args[1].value] = self.args[0].value
|
||||
|
||||
@base.vectorize
|
||||
class gstmsi(base.WriteMemoryInstruction):
|
||||
r""" Sets \verb+S[cj]+ to be the value $s_i$. """
|
||||
code = base.opcodes['STMSI'] + 0x100
|
||||
arg_format = ['sg','ci']
|
||||
|
||||
def execute(self):
|
||||
program.mem_s[self.args[1].value] = self.args[0].value
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
class protectmems(base.Instruction):
|
||||
@@ -233,9 +179,6 @@ class movc(base.Instruction):
|
||||
code = base.opcodes['MOVC']
|
||||
arg_format = ['cw','c']
|
||||
|
||||
def execute(self):
|
||||
self.args[0].value = self.args[1].value
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
class movs(base.Instruction):
|
||||
@@ -244,9 +187,6 @@ class movs(base.Instruction):
|
||||
code = base.opcodes['MOVS']
|
||||
arg_format = ['sw','s']
|
||||
|
||||
def execute(self):
|
||||
self.args[0].value = self.args[1].value
|
||||
|
||||
@base.vectorize
|
||||
class movint(base.Instruction):
|
||||
r""" Assigns register $ci_i$ the value in the register $ci_j$. """
|
||||
@@ -452,9 +392,6 @@ class divc(base.InvertInstruction):
|
||||
__slots__ = []
|
||||
code = base.opcodes['DIVC']
|
||||
arg_format = ['cw','c','c']
|
||||
|
||||
def execute(self):
|
||||
self.args[0].value = self.args[1].value * pow(self.args[2].value, program.P-2, program.P) % program.P
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
@@ -464,9 +401,6 @@ class modc(base.Instruction):
|
||||
code = base.opcodes['MODC']
|
||||
arg_format = ['cw','c','c']
|
||||
|
||||
def execute(self):
|
||||
self.args[0].value = self.args[1].value % self.args[2].value
|
||||
|
||||
@base.vectorize
|
||||
class inv2m(base.InvertInstruction):
|
||||
__slots__ = []
|
||||
@@ -498,9 +432,6 @@ class andc(base.Instruction):
|
||||
__slots__ = []
|
||||
code = base.opcodes['ANDC']
|
||||
arg_format = ['cw','c','c']
|
||||
|
||||
def execute(self):
|
||||
self.args[0].value = (self.args[1].value & self.args[2].value) % program.P
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
@@ -509,9 +440,6 @@ class orc(base.Instruction):
|
||||
__slots__ = []
|
||||
code = base.opcodes['ORC']
|
||||
arg_format = ['cw','c','c']
|
||||
|
||||
def execute(self):
|
||||
self.args[0].value = (self.args[1].value | self.args[2].value) % program.P
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
@@ -520,9 +448,6 @@ class xorc(base.Instruction):
|
||||
__slots__ = []
|
||||
code = base.opcodes['XORC']
|
||||
arg_format = ['cw','c','c']
|
||||
|
||||
def execute(self):
|
||||
self.args[0].value = (self.args[1].value ^ self.args[2].value) % program.P
|
||||
|
||||
@base.vectorize
|
||||
class notc(base.Instruction):
|
||||
@@ -530,9 +455,6 @@ class notc(base.Instruction):
|
||||
__slots__ = []
|
||||
code = base.opcodes['NOTC']
|
||||
arg_format = ['cw','c', 'int']
|
||||
|
||||
def execute(self):
|
||||
self.args[0].value = (~self.args[1].value + 2 ** self.args[2]) % program.P
|
||||
|
||||
@base.vectorize
|
||||
class gnotc(base.Instruction):
|
||||
@@ -544,9 +466,6 @@ class gnotc(base.Instruction):
|
||||
def is_gf2n(self):
|
||||
return True
|
||||
|
||||
def execute(self):
|
||||
self.args[0].value = ~self.args[1].value
|
||||
|
||||
@base.vectorize
|
||||
class gbitdec(base.Instruction):
|
||||
r""" Store every $n$-th bit of $cg_i$ in $cg_j, \dots$. """
|
||||
@@ -672,8 +591,6 @@ class divci(base.InvertInstruction, base.ClearImmediate):
|
||||
r""" Clear division by immediate value $c_i=c_j/n$. """
|
||||
__slots__ = []
|
||||
code = base.opcodes['DIVCI']
|
||||
def execute(self):
|
||||
self.args[0].value = self.args[1].value * pow(self.args[2], program.P-2, program.P) % program.P
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
@@ -719,9 +636,6 @@ class shlc(base.Instruction):
|
||||
__slots__ = []
|
||||
code = base.opcodes['SHLC']
|
||||
arg_format = ['cw','c','c']
|
||||
|
||||
def execute(self):
|
||||
self.args[0].value = (self.args[1].value << self.args[2].value) % program.P
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
@@ -730,9 +644,6 @@ class shrc(base.Instruction):
|
||||
__slots__ = []
|
||||
code = base.opcodes['SHRC']
|
||||
arg_format = ['cw','c','c']
|
||||
|
||||
def execute(self):
|
||||
self.args[0].value = (self.args[1].value >> self.args[2].value) % program.P
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
@@ -764,11 +675,6 @@ class triple(base.DataInstruction):
|
||||
code = base.opcodes['TRIPLE']
|
||||
arg_format = ['sw','sw','sw']
|
||||
data_type = 'triple'
|
||||
|
||||
def execute(self):
|
||||
self.args[0].value = randint(0,program.P)
|
||||
self.args[1].value = randint(0,program.P)
|
||||
self.args[2].value = (self.args[0].value * self.args[1].value) % program.P
|
||||
|
||||
@base.vectorize
|
||||
class gbittriple(base.DataInstruction):
|
||||
@@ -804,9 +710,6 @@ class bit(base.DataInstruction):
|
||||
code = base.opcodes['BIT']
|
||||
arg_format = ['sw']
|
||||
data_type = 'bit'
|
||||
|
||||
def execute(self):
|
||||
self.args[0].value = randint(0,1)
|
||||
|
||||
@base.vectorize
|
||||
class dabit(base.DataInstruction):
|
||||
@@ -848,10 +751,6 @@ class square(base.DataInstruction):
|
||||
code = base.opcodes['SQUARE']
|
||||
arg_format = ['sw','sw']
|
||||
data_type = 'square'
|
||||
|
||||
def execute(self):
|
||||
self.args[0].value = randint(0,program.P)
|
||||
self.args[1].value = (self.args[0].value * self.args[0].value) % program.P
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
@@ -868,11 +767,6 @@ class inverse(base.DataInstruction):
|
||||
raise CompilerError('random inverse in ring not implemented')
|
||||
base.DataInstruction.__init__(self, *args, **kwargs)
|
||||
|
||||
def execute(self):
|
||||
self.args[0].value = randint(0,program.P)
|
||||
import gmpy
|
||||
self.args[1].value = int(gmpy.invert(self.args[0].value, program.P))
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
class inputmask(base.Instruction):
|
||||
@@ -920,8 +814,6 @@ class asm_input(base.TextInputInstruction):
|
||||
for player in self.args[1::2]:
|
||||
req_node.increment((self.field_type, 'input', player), \
|
||||
self.get_size())
|
||||
def execute(self):
|
||||
self.args[0].value = _python_input("Enter player %d's input:" % self.args[1]) % program.P
|
||||
|
||||
@base.vectorize
|
||||
class inputfix(base.TextInputInstruction):
|
||||
@@ -1006,46 +898,18 @@ class inputmixedreg(inputmixed_base):
|
||||
req_node.increment((self.field_type, 'input', 0), float('inf'))
|
||||
|
||||
@base.gf2n
|
||||
class startinput(base.RawInputInstruction):
|
||||
class rawinput(base.RawInputInstruction, base.Mergeable):
|
||||
r""" Receive inputs from player $p$. """
|
||||
__slots__ = []
|
||||
code = base.opcodes['STARTINPUT']
|
||||
arg_format = ['p', 'int']
|
||||
code = base.opcodes['RAWINPUT']
|
||||
arg_format = tools.cycle(['p','sw'])
|
||||
field_type = 'modp'
|
||||
|
||||
def add_usage(self, req_node):
|
||||
req_node.increment((self.field_type, 'input', self.args[0]), \
|
||||
self.args[1])
|
||||
|
||||
def merge(self, other):
|
||||
self.args[1] += other.args[1]
|
||||
|
||||
class StopInputInstruction(base.RawInputInstruction):
|
||||
__slots__ = []
|
||||
|
||||
def merge(self, other):
|
||||
if self.get_size() != other.get_size():
|
||||
raise NotImplemented()
|
||||
else:
|
||||
self.args += other.args[1:]
|
||||
|
||||
class stopinput(StopInputInstruction):
|
||||
r""" Receive inputs from player $p$ and put in registers. """
|
||||
__slots__ = []
|
||||
code = base.opcodes['STOPINPUT']
|
||||
arg_format = tools.chain(['p'], itertools.repeat('sw'))
|
||||
|
||||
def has_var_args(self):
|
||||
return True
|
||||
|
||||
class gstopinput(StopInputInstruction):
|
||||
r""" Receive inputs from player $p$ and put in registers. """
|
||||
__slots__ = []
|
||||
code = 0x100 + base.opcodes['STOPINPUT']
|
||||
arg_format = tools.chain(['p'], itertools.repeat('sgw'))
|
||||
|
||||
def has_var_args(self):
|
||||
return True
|
||||
for i in range(0, len(self.args), 2):
|
||||
player = self.args[i]
|
||||
req_node.increment((self.field_type, 'input', player), \
|
||||
self.get_size())
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
@@ -1054,9 +918,6 @@ class print_mem(base.IOInstruction):
|
||||
__slots__ = []
|
||||
code = base.opcodes['PRINTMEM']
|
||||
arg_format = ['c']
|
||||
|
||||
def execute(self):
|
||||
pass
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
@@ -1069,9 +930,6 @@ class print_reg(base.IOInstruction):
|
||||
def __init__(self, reg, comment=''):
|
||||
super(print_reg_class, self).__init__(reg, self.str_to_int(comment))
|
||||
|
||||
def execute(self):
|
||||
pass
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
class print_reg_plain(base.IOInstruction):
|
||||
@@ -1238,41 +1096,6 @@ class acceptclientconnection(base.IOInstruction):
|
||||
code = base.opcodes['ACCEPTCLIENTCONNECTION']
|
||||
arg_format = ['ciw', 'int']
|
||||
|
||||
class connectipv4(base.IOInstruction):
|
||||
"""Connect to server at IPv4 address in register \verb|cj| at given port. Write socket handle to register \verb|ci|"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['CONNECTIPV4']
|
||||
arg_format = ['ciw', 'ci', 'int']
|
||||
|
||||
class readclientpublickey(base.IOInstruction):
|
||||
"""Read a client public key as 8 32-bit ints for a specified client id"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['READCLIENTPUBLICKEY']
|
||||
arg_format = tools.chain(['ci'], itertools.repeat('ci'))
|
||||
|
||||
def has_var_args(self):
|
||||
return True
|
||||
|
||||
class initsecuresocket(base.IOInstruction):
|
||||
"""Read a client public key as 8 32-bit ints for a specified client id,
|
||||
negotiate a shared key via STS and use it for replay resistant comms"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['INITSECURESOCKET']
|
||||
arg_format = tools.chain(['ci'], itertools.repeat('ci'))
|
||||
|
||||
def has_var_args(self):
|
||||
return True
|
||||
|
||||
class respsecuresocket(base.IOInstruction):
|
||||
"""Read a client public key as 8 32-bit ints for a specified client id,
|
||||
negotiate a shared key via STS and use it for replay resistant comms"""
|
||||
__slots__ = []
|
||||
code = base.opcodes['RESPSECURESOCKET']
|
||||
arg_format = tools.chain(['ci'], itertools.repeat('ci'))
|
||||
|
||||
def has_var_args(self):
|
||||
return True
|
||||
|
||||
class writesharestofile(base.IOInstruction):
|
||||
"""Write shares to a file"""
|
||||
__slots__ = []
|
||||
@@ -1392,12 +1215,6 @@ class eqzc(base.UnaryComparisonInstruction):
|
||||
r""" Clear comparison $c_i = (c_j \stackrel{?}{==} 0)$. """
|
||||
__slots__ = []
|
||||
code = base.opcodes['EQZC']
|
||||
|
||||
def execute(self):
|
||||
if self.args[1].value == 0:
|
||||
self.args[0].value = 1
|
||||
else:
|
||||
self.args[0].value = 0
|
||||
|
||||
@base.vectorize
|
||||
class ltzc(base.UnaryComparisonInstruction):
|
||||
@@ -1435,9 +1252,6 @@ class jmp(base.JumpInstruction):
|
||||
arg_format = ['int']
|
||||
jump_arg = 0
|
||||
|
||||
def execute(self):
|
||||
pass
|
||||
|
||||
class jmpi(base.JumpInstruction):
|
||||
""" Unconditional relative jump of $c_i+1$ instructions. """
|
||||
__slots__ = []
|
||||
@@ -1457,9 +1271,6 @@ class jmpnz(base.JumpInstruction):
|
||||
code = base.opcodes['JMPNZ']
|
||||
arg_format = ['ci', 'int']
|
||||
jump_arg = 1
|
||||
|
||||
def execute(self):
|
||||
pass
|
||||
|
||||
class jmpeqz(base.JumpInstruction):
|
||||
r""" Jump $n+1$ instructions if $c_i == 0$. """
|
||||
@@ -1467,9 +1278,6 @@ class jmpeqz(base.JumpInstruction):
|
||||
code = base.opcodes['JMPEQZ']
|
||||
arg_format = ['ci', 'int']
|
||||
jump_arg = 1
|
||||
|
||||
def execute(self):
|
||||
pass
|
||||
|
||||
###
|
||||
### Conversions
|
||||
|
||||
@@ -115,6 +115,7 @@ opcodes = dict(
|
||||
INPUTFLOAT = 0xF1,
|
||||
INPUTMIXED = 0xF2,
|
||||
INPUTMIXEDREG = 0xF3,
|
||||
RAWINPUT = 0xF4,
|
||||
STARTINPUT = 0x61,
|
||||
STOPINPUT = 0x62,
|
||||
READSOCKETC = 0x63,
|
||||
@@ -688,15 +689,12 @@ class Instruction(object):
|
||||
self.args = list(args)
|
||||
if not kwargs.get('copying', False):
|
||||
self.check_args()
|
||||
if not program.FIRST_PASS:
|
||||
if kwargs.get('add_to_prog', True):
|
||||
program.curr_block.instructions.append(self)
|
||||
if program.DEBUG:
|
||||
self.caller = [frame[1:] for frame in inspect.stack()[1:]]
|
||||
else:
|
||||
self.caller = None
|
||||
if program.EMULATE:
|
||||
self.execute()
|
||||
if kwargs.get('add_to_prog', True):
|
||||
program.curr_block.instructions.append(self)
|
||||
if program.DEBUG:
|
||||
self.caller = [frame[1:] for frame in inspect.stack()[1:]]
|
||||
else:
|
||||
self.caller = None
|
||||
|
||||
Instruction.count += 1
|
||||
if Instruction.count % 100000 == 0:
|
||||
@@ -717,10 +715,6 @@ class Instruction(object):
|
||||
def get_bytes(self):
|
||||
return bytearray(self.get_encoding())
|
||||
|
||||
def execute(self):
|
||||
""" Emulate execution of this instruction """
|
||||
raise NotImplementedError('execute method must be implemented')
|
||||
|
||||
def check_args(self):
|
||||
""" Check the args match up with that specified in arg_format """
|
||||
try:
|
||||
@@ -839,21 +833,12 @@ class VectorInstruction(Instruction):
|
||||
class AddBase(Instruction):
|
||||
__slots__ = []
|
||||
|
||||
def execute(self):
|
||||
self.args[0].value = (self.args[1].value + self.args[2].value) % program.P
|
||||
|
||||
class SubBase(Instruction):
|
||||
__slots__ = []
|
||||
|
||||
def execute(self):
|
||||
self.args[0].value = (self.args[1].value - self.args[2].value) % program.P
|
||||
|
||||
class MulBase(Instruction):
|
||||
__slots__ = []
|
||||
|
||||
def execute(self):
|
||||
self.args[0].value = (self.args[1].value * self.args[2].value) % program.P
|
||||
|
||||
###
|
||||
### Basic arithmetic with immediate values
|
||||
###
|
||||
@@ -861,9 +846,6 @@ class MulBase(Instruction):
|
||||
class ImmediateBase(Instruction):
|
||||
__slots__ = ['op']
|
||||
|
||||
def execute(self):
|
||||
exec('self.args[0].value = self.args[1].value.%s(self.args[2]) %% program.P' % self.op)
|
||||
|
||||
class SharedImmediate(ImmediateBase):
|
||||
__slots__ = []
|
||||
arg_format = ['sw', 's', 'i']
|
||||
@@ -1023,10 +1005,7 @@ class CISC(Instruction):
|
||||
def __init__(self, *args):
|
||||
self.args = args
|
||||
self.check_args()
|
||||
#if EMULATE:
|
||||
# self.expand()
|
||||
if not program.FIRST_PASS:
|
||||
self.expand()
|
||||
self.expand()
|
||||
|
||||
def expand(self):
|
||||
""" Expand this into a sequence of RISC instructions. """
|
||||
|
||||
@@ -342,8 +342,10 @@ class Function:
|
||||
x.reg_type)))
|
||||
runtime_args = [None] * len(args)
|
||||
for t in sorted(type_args, key=lambda x: x.reg_type):
|
||||
for i,i_arg in enumerate(type_args[t]):
|
||||
i = 0
|
||||
for i_arg in type_args[t]:
|
||||
runtime_args[i_arg] = t.load_mem(bases[t] + i)
|
||||
i += util.mem_size(t)
|
||||
return self.function(*(list(compile_args) + runtime_args))
|
||||
self.on_first_call(wrapped_function)
|
||||
self.type_args[len(args)] = type_args
|
||||
@@ -354,10 +356,12 @@ class Function:
|
||||
for i,reg_type in enumerate(sorted(type_args,
|
||||
key=lambda x: x.reg_type)):
|
||||
store_in_mem(bases[reg_type], base + i)
|
||||
for j,i_arg in enumerate(type_args[reg_type]):
|
||||
j = 0
|
||||
for i_arg in type_args[reg_type]:
|
||||
if get_reg_type(args[i_arg]) != reg_type:
|
||||
raise CompilerError('type mismatch')
|
||||
store_in_mem(args[i_arg], bases[reg_type] + j)
|
||||
j += util.mem_size(reg_type)
|
||||
return self.on_call(base, bases)
|
||||
|
||||
class FunctionTape(Function):
|
||||
|
||||
@@ -60,27 +60,21 @@ def sigmoid_prime(x):
|
||||
return sx * (1 - sx)
|
||||
|
||||
@vectorize
|
||||
def approx_sigmoid(x):
|
||||
if approx_sigmoid.special and \
|
||||
get_program().options.ring and get_program().use_edabit():
|
||||
l = int(get_program().options.ring)
|
||||
r, r_bits = sint.get_edabit(x.k, False)
|
||||
c = ((x.v - r) << (l - x.k)).reveal() >> (l - x.k)
|
||||
c_bits = c.bit_decompose(x.k)
|
||||
lower_overflow = CarryOutRawLE(c_bits[:x.f - 1], r_bits[:x.f - 1])
|
||||
higher_bits = sbitint.bit_adder(c_bits[x.f - 1:], r_bits[x.f - 1:],
|
||||
lower_overflow)
|
||||
sign = higher_bits[-1]
|
||||
higher_bits.pop(-1)
|
||||
aa = sign & ~util.tree_reduce(operator.and_, higher_bits)
|
||||
bb = ~sign & ~util.tree_reduce(operator.and_, [~x for x in higher_bits])
|
||||
a, b = (sint.conv(x) for x in (aa, bb))
|
||||
def approx_sigmoid(x, n=3):
|
||||
if n == 5:
|
||||
cuts = [-5, -2.5, 2.5, 5]
|
||||
le = [0] + [x <= cut for cut in cuts] + [1]
|
||||
select = [le[i + 1] - le[i] for i in range(5)]
|
||||
outputs = [cfix(10 ** -4),
|
||||
0.02776 * x + 0.145,
|
||||
0.17 * x + 0.5,
|
||||
0.02776 * x + 0.85498,
|
||||
cfix(1 - 10 ** -4)]
|
||||
return sum(a * b for a, b in zip(select, outputs))
|
||||
else:
|
||||
a = x < -0.5
|
||||
b = x > 0.5
|
||||
return a.if_else(0, b.if_else(1, 0.5 + x))
|
||||
|
||||
approx_sigmoid.special = False
|
||||
return a.if_else(0, b.if_else(1, 0.5 + x))
|
||||
|
||||
def lse_0_from_e_x(x, e_x):
|
||||
return sanitize(-x, log_e(1 + e_x), x + 2 ** -x.f, 0)
|
||||
@@ -144,7 +138,7 @@ class Output(Layer):
|
||||
|
||||
def eval(self, size, base=0):
|
||||
if self.approx:
|
||||
return approx_sigmoid(self.X.get_vector(base, size))
|
||||
return approx_sigmoid(self.X.get_vector(base, size), self.approx)
|
||||
else:
|
||||
return sigmoid_from_e_x(self.X.get_vector(base, size),
|
||||
self.e_x.get_vector(base, size))
|
||||
@@ -531,7 +525,7 @@ class QuantConv2d(QuantConvBase):
|
||||
_, inputs_h, inputs_w, n_channels_in = self.input_shape
|
||||
return weights_h * weights_w * n_channels_in
|
||||
|
||||
def forward(self, batch):
|
||||
def forward(self, batch=[None]):
|
||||
assert len(batch) == 1
|
||||
assert(self.weight_shape[0] == self.output_shape[-1])
|
||||
|
||||
|
||||
@@ -39,11 +39,11 @@ class Program(object):
|
||||
|
||||
These are created by executing a file containing appropriate instructions
|
||||
and threads. """
|
||||
def __init__(self, args, options, param=-1, assemblymode=False):
|
||||
def __init__(self, args, options, param=-1):
|
||||
self.options = options
|
||||
self.verbose = options.verbose
|
||||
self.args = args
|
||||
self.init_names(args, assemblymode)
|
||||
self.init_names(args)
|
||||
self.P = P_VALUES[param]
|
||||
self.param = param
|
||||
if (param != -1) + sum(x != 0 for x in(options.ring, options.field,
|
||||
@@ -65,8 +65,6 @@ class Program(object):
|
||||
self.tape_counter = 0
|
||||
self.tapes = []
|
||||
self._curr_tape = None
|
||||
self.EMULATE = True # defaults
|
||||
self.FIRST_PASS = False
|
||||
self.DEBUG = False
|
||||
self.main_thread_running = False
|
||||
self.allocated_mem = RegType.create_dict(lambda: USER_MEM)
|
||||
@@ -102,9 +100,8 @@ class Program(object):
|
||||
self.use_dabit = options.mixed
|
||||
self._edabit = options.edabit
|
||||
self._split = False
|
||||
self._square = False
|
||||
Program.prog = self
|
||||
|
||||
self.reset_values()
|
||||
|
||||
def get_args(self):
|
||||
return self.args
|
||||
@@ -131,7 +128,7 @@ class Program(object):
|
||||
res = max(res, sum(running.values()))
|
||||
return res
|
||||
|
||||
def init_names(self, args, assemblymode):
|
||||
def init_names(self, args):
|
||||
# ignore path to file - source must be in Programs/Source
|
||||
if 'Programs' in os.listdir(os.getcwd()):
|
||||
# compile prog in ./Programs/Source directory
|
||||
@@ -153,8 +150,6 @@ class Program(object):
|
||||
|
||||
if os.path.exists(args[0]):
|
||||
self.infile = args[0]
|
||||
elif assemblymode:
|
||||
self.infile = self.programs_dir + '/Source/' + progname + '.asm'
|
||||
else:
|
||||
self.infile = self.programs_dir + '/Source/' + progname + '.mpc'
|
||||
"""
|
||||
@@ -234,40 +229,6 @@ class Program(object):
|
||||
else:
|
||||
self.req_num += tape.req_num
|
||||
|
||||
def read_memory(self, filename):
|
||||
""" Read the clear and shared memory from a file """
|
||||
f = open(filename)
|
||||
n = int(next(f))
|
||||
self.mem_c = [0]*n
|
||||
self.mem_s = [0]*n
|
||||
mem = self.mem_c
|
||||
done_c = False
|
||||
for line in f:
|
||||
line = line.split(' ')
|
||||
a = int(line[0])
|
||||
b = int(line[1])
|
||||
if a != -1:
|
||||
mem[a] = b
|
||||
elif done_c:
|
||||
break
|
||||
else:
|
||||
mem = self.mem_s
|
||||
done_c = True
|
||||
|
||||
def get_memory(self, mem_type, i):
|
||||
if mem_type == 'c':
|
||||
return self.mem_c[i]
|
||||
elif mem_type == 's':
|
||||
return self.mem_s[i]
|
||||
raise CompilerError('Invalid memory type')
|
||||
|
||||
def reset_values(self):
|
||||
""" Reset register and memory values. """
|
||||
for tape in self.tapes:
|
||||
tape.reset_registers()
|
||||
self.mem_c = list(range(USER_MEM + TMP_MEM))
|
||||
self.mem_s = list(range(USER_MEM + TMP_MEM))
|
||||
|
||||
def write_bytes(self, outfile=None):
|
||||
""" Write all non-empty threads and schedule to files. """
|
||||
# runtime doesn't support 'new-style' parallelism yet
|
||||
@@ -329,17 +290,6 @@ class Program(object):
|
||||
tape.write_str(self.options.asmoutfile + '-' + tape.name)
|
||||
tape.purge()
|
||||
|
||||
def emulate(self):
|
||||
""" Emulate execution of entire program. """
|
||||
self.reset_values()
|
||||
for sch in self.schedule:
|
||||
if sch[0] == 'start':
|
||||
for tape in sch[1]:
|
||||
self._curr_tape = tape
|
||||
for block in tape.basicblocks:
|
||||
for line in block.instructions:
|
||||
line.execute()
|
||||
|
||||
def restart_main_thread(self):
|
||||
if self.main_thread_running:
|
||||
# wait for main thread to finish
|
||||
@@ -375,6 +325,10 @@ class Program(object):
|
||||
if size == 0:
|
||||
return
|
||||
if isinstance(mem_type, type):
|
||||
try:
|
||||
size *= math.ceil(mem_type.n / mem_type.unit)
|
||||
except AttributeError:
|
||||
pass
|
||||
self.types[mem_type.reg_type] = mem_type
|
||||
mem_type = mem_type.reg_type
|
||||
elif reg_type is not None:
|
||||
@@ -447,6 +401,12 @@ class Program(object):
|
||||
assert change in (2, 3)
|
||||
self._split = change
|
||||
|
||||
def use_square(self, change=None):
|
||||
if change is None:
|
||||
return self._square
|
||||
else:
|
||||
self._square = change
|
||||
|
||||
class Tape:
|
||||
""" A tape contains a list of basic blocks, onto which instructions are added. """
|
||||
def __init__(self, name, program):
|
||||
@@ -588,7 +548,6 @@ class Tape:
|
||||
#print 'Compiling basic block', sub.name
|
||||
|
||||
def init_registers(self):
|
||||
self.reset_registers()
|
||||
self.reg_counter = RegType.create_dict(lambda: 0)
|
||||
|
||||
def init_names(self, name):
|
||||
@@ -605,7 +564,6 @@ class Tape:
|
||||
for block in self.basicblocks:
|
||||
block.purge()
|
||||
self._is_empty = (len(self.basicblocks) == 0)
|
||||
del self.reg_values
|
||||
del self.basicblocks
|
||||
del self.active_basicblock
|
||||
self.purged = True
|
||||
@@ -840,13 +798,6 @@ class Tape:
|
||||
else:
|
||||
return self.reg_counter[reg_type]
|
||||
|
||||
def reset_registers(self):
|
||||
""" Reset register values to zero. """
|
||||
self.reg_values = RegType.create_dict(lambda: [])
|
||||
|
||||
def get_value(self, reg_type, i):
|
||||
return self.reg_values[reg_type][i]
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
@@ -883,12 +834,28 @@ class Tape:
|
||||
def cost(self):
|
||||
return sum(num * COST[req[0]][req[1]] for req,num in list(self.items()) \
|
||||
if req[1] != 'input' and req[0] != 'edabit')
|
||||
def pretty(self):
|
||||
t = lambda x: 'integer' if x == 'modp' else x
|
||||
res = []
|
||||
for req, num in self.items():
|
||||
domain = t(req[0])
|
||||
n = '%12.0f' % num
|
||||
if req[1] == 'input':
|
||||
res += ['%s %s inputs from player %d' \
|
||||
% (n, domain, req[2])]
|
||||
elif domain.endswith('edabit'):
|
||||
if domain == 'sedabit':
|
||||
eda = 'strict edabits'
|
||||
else:
|
||||
eda = 'loose edabits'
|
||||
res += ['%s %s of length %d' % (n, eda, req[1])]
|
||||
elif req[0] != 'all':
|
||||
res += ['%s %s %ss' % (n, domain, req[1])]
|
||||
if self['all','round']:
|
||||
res += ['% 12.0f virtual machine rounds' % self['all','round']]
|
||||
return res
|
||||
def __str__(self):
|
||||
return ", ".join('%s inputs in %s from player %d' \
|
||||
% (num, req[0], req[2]) \
|
||||
if req[1] == 'input' \
|
||||
else '%s %ss in %s' % (num, req[1], req[0]) \
|
||||
for req,num in list(self.items()))
|
||||
return ', '.join(self.pretty())
|
||||
def __repr__(self):
|
||||
return repr(dict(self))
|
||||
|
||||
@@ -959,14 +926,12 @@ class Tape:
|
||||
"""
|
||||
Class for creating new registers. The register's index is automatically assigned
|
||||
based on the block's reg_counter dictionary.
|
||||
|
||||
The 'value' property is for emulation.
|
||||
"""
|
||||
__slots__ = ["reg_type", "program", "absolute_i", "relative_i", \
|
||||
"size", "vector", "vectorbase", "caller", \
|
||||
"can_eliminate"]
|
||||
|
||||
def __init__(self, reg_type, program, value=None, size=None, i=None):
|
||||
def __init__(self, reg_type, program, size=None, i=None):
|
||||
""" Creates a new register.
|
||||
reg_type must be one of those defined in RegType. """
|
||||
if Compiler.instructions_base.get_global_instruction_type() == 'gf2n':
|
||||
@@ -989,8 +954,6 @@ class Tape:
|
||||
else:
|
||||
self.i = float('inf')
|
||||
self.vector = []
|
||||
if value is not None:
|
||||
self.value = value
|
||||
self.can_eliminate = True
|
||||
if Program.prog.DEBUG:
|
||||
self.caller = [frame[1:] for frame in inspect.stack()[1:]]
|
||||
@@ -1010,22 +973,9 @@ class Tape:
|
||||
def set_size(self, size):
|
||||
if self.size == size:
|
||||
return
|
||||
elif not self.program.program.options.assemblymode:
|
||||
else:
|
||||
raise CompilerError('Mismatch of instruction and register size:'
|
||||
' %s != %s' % (self.size, size))
|
||||
elif self.size == 1 and self.vectorbase is self:
|
||||
if '%s%d' % (self.reg_type, self.i) in compilerLib.VARS:
|
||||
# create vector register in assembly mode
|
||||
self.size = size
|
||||
self.vector = [self]
|
||||
for i in range(1,size):
|
||||
reg = compilerLib.VARS['%s%d' % (self.reg_type, self.i + i)]
|
||||
reg.set_vectorbase(self)
|
||||
self.vector.append(reg)
|
||||
else:
|
||||
raise CompilerError('Cannot find %s in VARS' % str(self))
|
||||
else:
|
||||
raise CompilerError('Cannot reset size of vector register')
|
||||
|
||||
def set_vectorbase(self, vectorbase):
|
||||
if self.vectorbase is not self:
|
||||
@@ -1074,16 +1024,6 @@ class Tape:
|
||||
def copy(self):
|
||||
return Tape.Register(self.reg_type, Program.prog.curr_tape)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.program.reg_values[self.reg_type][self.i]
|
||||
|
||||
@value.setter
|
||||
def value(self, val):
|
||||
while (len(self.program.reg_values[self.reg_type]) <= self.i):
|
||||
self.program.reg_values[self.reg_type] += [0] * INIT_REG_MAX
|
||||
self.program.reg_values[self.reg_type][self.i] = val
|
||||
|
||||
@property
|
||||
def is_gf2n(self):
|
||||
return self.reg_type == RegType.ClearGF2N or \
|
||||
|
||||
@@ -232,7 +232,7 @@ def inputmixed(*args):
|
||||
if isinstance(args[-1], int):
|
||||
instructions.inputmixed(*args)
|
||||
else:
|
||||
instructions.inputmixedreg(*args)
|
||||
instructions.inputmixedreg(*(args[:-1] + (regint.conv(args[-1]),)))
|
||||
|
||||
class _number(object):
|
||||
""" Number functionality. """
|
||||
@@ -762,7 +762,11 @@ class cint(_clear, _int):
|
||||
def load_int(self, val):
|
||||
if val:
|
||||
# +1 for sign
|
||||
program.curr_tape.require_bit_length(1 + int(math.ceil(math.log(abs(val)))))
|
||||
bit_length = 1 + int(math.ceil(math.log(abs(val))))
|
||||
if program.options.ring:
|
||||
assert(bit_length <= int(program.options.ring))
|
||||
elif program.param != -1 or program.options.field:
|
||||
program.curr_tape.require_bit_length(bit_length)
|
||||
if self.in_immediate_range(val):
|
||||
ldi(self, val)
|
||||
else:
|
||||
@@ -783,7 +787,7 @@ class cint(_clear, _int):
|
||||
sum += sign * chunk
|
||||
|
||||
@vectorize
|
||||
def to_regint(self, n_bits=None, dest=None):
|
||||
def to_regint(self, n_bits=64, dest=None):
|
||||
""" Convert to regint.
|
||||
|
||||
:param n_bits: bit length (int)
|
||||
@@ -1146,23 +1150,6 @@ class regint(_register, _int):
|
||||
else:
|
||||
return res
|
||||
|
||||
@vectorized_classmethod
|
||||
def read_client_public_key(cls, client_id):
|
||||
""" Receive 8 register values from socket containing client public key."""
|
||||
res = [cls() for i in range(8)]
|
||||
readclientpublickey(client_id, *res)
|
||||
return res
|
||||
|
||||
@vectorized_classmethod
|
||||
def init_secure_socket(cls, client_id, w1, w2, w3, w4, w5, w6, w7, w8):
|
||||
""" Use 8 register values containing client public key."""
|
||||
initsecuresocket(client_id, w1, w2, w3, w4, w5, w6, w7, w8)
|
||||
|
||||
@vectorized_classmethod
|
||||
def resp_secure_socket(cls, client_id, w1, w2, w3, w4, w5, w6, w7, w8):
|
||||
""" Receive 8 register values from socket containing client public key."""
|
||||
respsecuresocket(client_id, w1, w2, w3, w4, w5, w6, w7, w8)
|
||||
|
||||
@vectorize
|
||||
def write_to_socket(self, client_id, message_type=ClientMessageType.NoType):
|
||||
writesocketint(client_id, message_type, self)
|
||||
@@ -1439,7 +1426,7 @@ class _secret(_register):
|
||||
def get_input_from(cls, player):
|
||||
""" Secret input from player.
|
||||
|
||||
:param: player (compile-time int) """
|
||||
:param player: public (regint/cint/int) """
|
||||
res = cls()
|
||||
asm_input(res, player)
|
||||
return res
|
||||
@@ -1648,9 +1635,12 @@ class _secret(_register):
|
||||
@vectorize
|
||||
def square(self):
|
||||
""" Secret square. """
|
||||
res = self.__class__()
|
||||
sqrs(res, self)
|
||||
return res
|
||||
if program.use_square():
|
||||
res = self.__class__()
|
||||
sqrs(res, self)
|
||||
return res
|
||||
else:
|
||||
return self * self
|
||||
|
||||
@set_instruction_type
|
||||
@vectorize
|
||||
@@ -1712,7 +1702,7 @@ class sint(_secret, _int):
|
||||
def get_input_from(cls, player):
|
||||
""" Secret input.
|
||||
|
||||
:param player: compile-time integer (int) """
|
||||
:param player: public (regint/cint/int) """
|
||||
res = cls()
|
||||
inputmixed('int', res, player)
|
||||
return res
|
||||
@@ -1757,8 +1747,7 @@ class sint(_secret, _int):
|
||||
@classmethod
|
||||
def get_raw_input_from(cls, player):
|
||||
res = cls()
|
||||
startinput(player, 1)
|
||||
stopinput(player, res)
|
||||
rawinput(player, res)
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
@@ -2056,8 +2045,7 @@ class sgf2n(_secret, _gf2n):
|
||||
@classmethod
|
||||
def get_raw_input_from(cls, player):
|
||||
res = cls()
|
||||
gstartinput(player, 1)
|
||||
gstopinput(player, res)
|
||||
grawinput(player, res)
|
||||
return res
|
||||
|
||||
def add(self, other):
|
||||
@@ -3293,7 +3281,7 @@ class sfix(_fix):
|
||||
def get_input_from(cls, player):
|
||||
""" Secret fixed-point input.
|
||||
|
||||
:param player: int """
|
||||
:param player: public (regint/cint/int) """
|
||||
v = cls.int_type()
|
||||
inputmixed('fix', v, cls.f, player)
|
||||
return cls._new(v)
|
||||
@@ -3674,7 +3662,7 @@ class sfloat(_number, _structure):
|
||||
def get_input_from(cls, player):
|
||||
""" Secret floating-point input.
|
||||
|
||||
:param player: int """
|
||||
:param player: public (regint/cint/int) """
|
||||
v = sint()
|
||||
p = sint()
|
||||
z = sint()
|
||||
@@ -4195,7 +4183,7 @@ class Array(object):
|
||||
def input_from(self, player, budget=None):
|
||||
""" Fill with inputs from player if supported by type.
|
||||
|
||||
:param player: compile-time (int) """
|
||||
:param player: public (regint/cint/int) """
|
||||
self.assign(self.value_type.get_input_from(player, size=len(self)))
|
||||
|
||||
def __add__(self, other):
|
||||
@@ -4351,7 +4339,7 @@ class SubMultiArray(object):
|
||||
def input_from(self, player, budget=None):
|
||||
""" Fill with inputs from player if supported by type.
|
||||
|
||||
:param player: compile-time (int) """
|
||||
:param player: public (regint/cint/int) """
|
||||
@library.for_range_opt(self.sizes[0], budget=budget)
|
||||
def _(i):
|
||||
self[i].input_from(player, budget=budget)
|
||||
@@ -4597,6 +4585,12 @@ class Matrix(MultiArray):
|
||||
MultiArray.__init__(self, [rows, columns], value_type, debug=debug, \
|
||||
address=address)
|
||||
|
||||
def set_column(self, index, vector):
|
||||
assert self.value_type.n_elements() == 1
|
||||
addresses = regint.inc(self.sizes[0], self.address + index,
|
||||
self.sizes[1])
|
||||
vector.store_in_mem(addresses)
|
||||
|
||||
class VectorArray(object):
|
||||
def __init__(self, length, value_type, vector_size, address=None):
|
||||
self.array = Array(length * vector_size, value_type, address)
|
||||
|
||||
@@ -182,6 +182,12 @@ def expand(x, size):
|
||||
except AttributeError:
|
||||
return x
|
||||
|
||||
def mem_size(x):
|
||||
try:
|
||||
return x.mem_size()
|
||||
except AttributeError:
|
||||
return 1
|
||||
|
||||
class set_by_id(object):
|
||||
def __init__(self, init=[]):
|
||||
self.content = {}
|
||||
|
||||
@@ -1,4 +1,32 @@
|
||||
The ExternalIO directory contains examples of managing I/O between external client processes and SPDZ parties running SPDZ engines. These instructions assume that SPDZ has been built as per the [project readme](../README.md).
|
||||
The ExternalIO directory contains an example of managing I/O between external client processes and SPDZ parties running SPDZ engines. These instructions assume that SPDZ has been built as per the [project readme](../README.md).
|
||||
|
||||
## Working Examples
|
||||
|
||||
[bankers-bonus-client.cpp](./bankers-bonus-client.cpp) acts as a
|
||||
client to [bankers_bonus.mpc](../Programs/Source/bankers_bonus.mpc)
|
||||
and demonstrates sending input and receiving output as described by
|
||||
[Damgård et al.](https://eprint.iacr.org/2015/1006) The computation
|
||||
allows up to eight clients to input a number and computes the client
|
||||
with the largest input. You can run it as follows from the main
|
||||
directory:
|
||||
```
|
||||
make bankers-bonus-client.x
|
||||
./compile.py bankers_bonus 1
|
||||
Scripts/setup-ssl.sh <nparties>
|
||||
Scripts/setup-clients.sh 3
|
||||
Scripts/<protocol>.sh &
|
||||
./bankers-bonus-client.x 0 <nparties 100 0 &
|
||||
./bankers-bonus-client.x 1 <nparties> 200 0 &
|
||||
./bankers-bonus-client.x 2 <nparties> 50 1
|
||||
```
|
||||
This should output that the winning id is 1. Note that the ids have to
|
||||
be incremental, and the client with the highest id has to input 1 as
|
||||
the last argument while the others have to input 0 there. Furthermore,
|
||||
`<nparties>` refers to the number of parties running the computation
|
||||
not the number of clients, and `<protocol>` can be the name of
|
||||
protocol script. The setup scripts generate the necessary SSL
|
||||
certificates and keys. Therefore, if you run the computation on
|
||||
different hosts, you will have to distribute the `*.pem` files.
|
||||
|
||||
## I/O MPC Instructions
|
||||
|
||||
@@ -55,49 +83,3 @@ Receive shares of private inputs from a client, blocking on client send. This is
|
||||
*message_type* - optional integer which will be sent in first 4 bytes of message, to indicate message type to client.
|
||||
|
||||
*[inputs]* - returned list of shares of private input.
|
||||
|
||||
|
||||
## Securing communications
|
||||
|
||||
Two cryptographic protocols have been implemented for use in particular applications and are included here for completeness:
|
||||
|
||||
1. Communication security using a Station to Station key agreement and libsodium Secret Box using a nonce counter for message ordering.
|
||||
2. Authenticated Diffie-Hellman without message ordering.
|
||||
|
||||
Please note these are **NOT** required to allow external client I/O. Your mileage may vary, for example in a web setting TLS may be sufficient to secure communications between processes.
|
||||
|
||||
[client-setup.cpp](../client-setup.cpp) is a utility which is run to generate the key material for both the external clients and SPDZ parties for both protocols.
|
||||
|
||||
#### MPC instructions
|
||||
|
||||
**regint.init_secure_socket**(*regint client_socket_id*, *[regint] public_signing_key*)
|
||||
|
||||
STS protocol initiator. Read a client public key for a specified client connection and negotiate a shared key via STS. All subsequent write_socket / read_socket instructions are encrypted / decrypted for replay resistant commsec.
|
||||
|
||||
*client_socket_id* - an identifier used to refer to the client socket.
|
||||
|
||||
*public_signing_key* - client public key supplied as list of 8 32-bit ints.
|
||||
|
||||
**regint.resp_secure_socket**(*regint client_socket_id*, *[regint] public_signing_key*)
|
||||
|
||||
STS protocol responder. Read a client public key for a specified client connection and negotiate a shared key via STS. All subsequent write_socket / read_socket instructions are encrypted / decrypted for replay resistant commsec.
|
||||
|
||||
*client_socket_id* - an identifier used to refer to the client socket.
|
||||
|
||||
*public_signing_key* - client public key supplied as list of 8 32-bit ints.
|
||||
|
||||
*[regint public_key]* **regint.read_client_public_key**(*regint client_socket_id*)
|
||||
|
||||
Instruction to read the client public key and run setup for the authenticated Diffie-Hellman encryption. All subsequent write_socket instructions are encrypted. Only the sint.read_from_socket instruction is encrypted.
|
||||
|
||||
*client_socket_id* - an identifier used to refer to the client socket.
|
||||
|
||||
*public_key* - client public key made available to mpc programs as list of 8 32-bit ints.
|
||||
|
||||
## Working Examples
|
||||
|
||||
See [bankers-bonus-client.cpp](./bankers-bonus-client.cpp) which acts as a client to [bankers_bonus.mpc](../Programs/Source/bankers_bonus.mpc) and demonstrates sending input and receiving output with no communications security.
|
||||
|
||||
See [bankers-bonus-commsec-client.cpp](./bankers-bonus-commsec-client.cpp) which acts as a client to [bankers_bonus_commsec.mpc](../Programs/Source/bankers_bonus_commsec.mpc) which runs the same algorithm but includes both the available crypto protocols.
|
||||
|
||||
More instructions on how to run these are provided in the *-client files.
|
||||
|
||||
@@ -16,11 +16,10 @@
|
||||
* - share of random value [r]
|
||||
* - share of winning unique id * random value [w]
|
||||
* winning unique id is valid if ∑ [y] * ∑ [r] = ∑ [w]
|
||||
*
|
||||
* No communications security is used.
|
||||
*
|
||||
* To run with 2 parties / SPDZ engines:
|
||||
* ./Scripts/setup-online.sh to create triple shares for each party (spdz engine).
|
||||
* ./Scripts/setup-clients.sh to create SSL keys and certificates for clients
|
||||
* ./compile.py bankers_bonus
|
||||
* ./Scripts/run-online.sh bankers_bonus to run the engines.
|
||||
*
|
||||
@@ -34,6 +33,7 @@
|
||||
#include "Math/gfp.h"
|
||||
#include "Math/gf2n.h"
|
||||
#include "Networking/sockets.h"
|
||||
#include "Networking/ssl_sockets.h"
|
||||
#include "Tools/int.h"
|
||||
#include "Math/Setup.h"
|
||||
#include "Protocols/fake-stuff.h"
|
||||
@@ -46,12 +46,13 @@
|
||||
// Send the private inputs masked with a random value.
|
||||
// Receive shares of a preprocessed triple from each SPDZ engine, combine and check the triples are valid.
|
||||
// Add the private input value to triple[0] and send to each spdz engine.
|
||||
void send_private_inputs(const vector<gfp>& values, vector<int>& sockets, int nparties)
|
||||
template<class T>
|
||||
void send_private_inputs(const vector<T>& values, vector<ssl_socket*>& sockets, int nparties)
|
||||
{
|
||||
int num_inputs = values.size();
|
||||
octetStream os;
|
||||
vector< vector<gfp> > triples(num_inputs, vector<gfp>(3));
|
||||
vector<gfp> triple_shares(3);
|
||||
vector< vector<T> > triples(num_inputs, vector<T>(3));
|
||||
vector<T> triple_shares(3);
|
||||
|
||||
// Receive num_inputs triples from SPDZ
|
||||
for (int j = 0; j < nparties; j++)
|
||||
@@ -59,6 +60,10 @@ void send_private_inputs(const vector<gfp>& values, vector<int>& sockets, int np
|
||||
os.reset_write_head();
|
||||
os.Receive(sockets[j]);
|
||||
|
||||
#ifdef VERBOSE_COMM
|
||||
cerr << "received " << os.get_length() << " from " << j << endl;
|
||||
#endif
|
||||
|
||||
for (int j = 0; j < num_inputs; j++)
|
||||
{
|
||||
for (int k = 0; k < 3; k++)
|
||||
@@ -72,49 +77,30 @@ void send_private_inputs(const vector<gfp>& values, vector<int>& sockets, int np
|
||||
// Check triple relations (is a party cheating?)
|
||||
for (int i = 0; i < num_inputs; i++)
|
||||
{
|
||||
if (triples[i][0] * triples[i][1] != triples[i][2])
|
||||
if (T(triples[i][0] * triples[i][1]) != triples[i][2])
|
||||
{
|
||||
cerr << triples[i][2] << " != " << triples[i][0] << " * " << triples[i][1] << endl;
|
||||
cerr << "Incorrect triple at " << i << ", aborting\n";
|
||||
exit(1);
|
||||
throw mac_fail();
|
||||
}
|
||||
}
|
||||
// Send inputs + triple[0], so SPDZ can compute shares of each value
|
||||
os.reset_write_head();
|
||||
for (int i = 0; i < num_inputs; i++)
|
||||
{
|
||||
gfp y = values[i] + triples[i][0];
|
||||
T y = values[i] + triples[i][0];
|
||||
y.pack(os);
|
||||
}
|
||||
for (int j = 0; j < nparties; j++)
|
||||
os.Send(sockets[j]);
|
||||
}
|
||||
|
||||
// Assumes that Scripts/setup-online.sh has been run to compute prime
|
||||
void initialise_fields(const string& dir_prefix)
|
||||
{
|
||||
int lg2;
|
||||
bigint p;
|
||||
|
||||
string filename = dir_prefix + "Params-Data";
|
||||
cout << "loading params from: " << filename << endl;
|
||||
|
||||
ifstream inpf(filename.c_str());
|
||||
if (inpf.fail()) { throw file_error(filename.c_str()); }
|
||||
inpf >> p;
|
||||
inpf >> lg2;
|
||||
|
||||
inpf.close();
|
||||
|
||||
gfp::init_field(p);
|
||||
gf2n::init_field(lg2);
|
||||
}
|
||||
|
||||
|
||||
// Receive shares of the result and sum together.
|
||||
// Also receive authenticating values.
|
||||
gfp receive_result(vector<int>& sockets, int nparties)
|
||||
template<class T>
|
||||
T receive_result(vector<ssl_socket*>& sockets, int nparties)
|
||||
{
|
||||
vector<gfp> output_values(3);
|
||||
vector<T> output_values(3);
|
||||
octetStream os;
|
||||
for (int i = 0; i < nparties; i++)
|
||||
{
|
||||
@@ -122,20 +108,32 @@ gfp receive_result(vector<int>& sockets, int nparties)
|
||||
os.Receive(sockets[i]);
|
||||
for (unsigned int j = 0; j < 3; j++)
|
||||
{
|
||||
gfp value;
|
||||
T value;
|
||||
value.unpack(os);
|
||||
output_values[j] += value;
|
||||
}
|
||||
}
|
||||
|
||||
if (output_values[0] * output_values[1] != output_values[2])
|
||||
if (T(output_values[0] * output_values[1]) != output_values[2])
|
||||
{
|
||||
cerr << "Unable to authenticate output value as correct, aborting." << endl;
|
||||
exit(1);
|
||||
throw mac_fail();
|
||||
}
|
||||
return output_values[0];
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void run(int salary_value, vector<ssl_socket*>& sockets, int nparties)
|
||||
{
|
||||
// Run the computation
|
||||
send_private_inputs<T>({salary_value}, sockets, nparties);
|
||||
cout << "Sent private inputs to each SPDZ engine, waiting for result..." << endl;
|
||||
|
||||
// Get the result back (client_id of winning client)
|
||||
T result = receive_result<T>(sockets, nparties);
|
||||
|
||||
cout << "Winning client id is : " << result << endl;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
@@ -162,34 +160,65 @@ int main(int argc, char** argv)
|
||||
if (argc > 6)
|
||||
port_base = atoi(argv[6]);
|
||||
|
||||
// init static gfp
|
||||
string prep_data_prefix = get_prep_dir(nparties, 128, gf2n::default_degree());
|
||||
initialise_fields(prep_data_prefix);
|
||||
bigint::init_thread();
|
||||
|
||||
// Setup connections from this client to each party socket
|
||||
vector<int> sockets(nparties);
|
||||
vector<int> plain_sockets(nparties);
|
||||
vector<ssl_socket*> sockets(nparties);
|
||||
ssl_ctx ctx("C" + to_string(my_client_id));
|
||||
ssl_service io_service;
|
||||
octetStream specification;
|
||||
for (int i = 0; i < nparties; i++)
|
||||
{
|
||||
set_up_client_socket(sockets[i], host_name.c_str(), port_base + i);
|
||||
send(sockets[i], (octet*) &my_client_id, sizeof(int));
|
||||
set_up_client_socket(plain_sockets[i], host_name.c_str(), port_base + i);
|
||||
send(plain_sockets[i], (octet*) &my_client_id, sizeof(int));
|
||||
sockets[i] = new ssl_socket(io_service, ctx, plain_sockets[i],
|
||||
"P" + to_string(i), "C" + to_string(my_client_id), true);
|
||||
if (i == 0)
|
||||
specification.Receive(sockets[0]);
|
||||
octetStream os;
|
||||
os.store(finish);
|
||||
os.Send(sockets[i]);
|
||||
}
|
||||
cout << "Finish setup socket connections to SPDZ engines." << endl;
|
||||
|
||||
// Run the commputation
|
||||
send_private_inputs({salary_value}, sockets, nparties);
|
||||
cout << "Sent private inputs to each SPDZ engine, waiting for result..." << endl;
|
||||
int type = specification.get<int>();
|
||||
switch (type)
|
||||
{
|
||||
case 'p':
|
||||
{
|
||||
gfp::init_field(specification.get<bigint>());
|
||||
cerr << "using prime " << gfp::pr() << endl;
|
||||
run<gfp>(salary_value, sockets, nparties);
|
||||
break;
|
||||
}
|
||||
case 'R':
|
||||
{
|
||||
int R = specification.get<int>();
|
||||
switch (R)
|
||||
{
|
||||
case 64:
|
||||
run<Z2<64>>(salary_value, sockets, nparties);
|
||||
break;
|
||||
case 104:
|
||||
run<Z2<104>>(salary_value, sockets, nparties);
|
||||
break;
|
||||
case 128:
|
||||
run<Z2<128>>(salary_value, sockets, nparties);
|
||||
break;
|
||||
default:
|
||||
cerr << R << "-bit ring not implemented";
|
||||
exit(1);
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
cerr << "Type " << type << " not implemented";
|
||||
exit(1);
|
||||
}
|
||||
|
||||
// Get the result back (client_id of winning client)
|
||||
gfp result = receive_result(sockets, nparties);
|
||||
|
||||
cout << "Winning client id is : " << result << endl;
|
||||
|
||||
for (unsigned int i = 0; i < sockets.size(); i++)
|
||||
close_client_socket(sockets[i]);
|
||||
for (int i = 0; i < nparties; i++)
|
||||
delete sockets[i];
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -1,407 +0,0 @@
|
||||
/*
|
||||
* Demonstrate external client inputing and receiving outputs from a SPDZ process,
|
||||
* following the protocol described in https://eprint.iacr.org/2015/1006.pdf.
|
||||
* Uses SPDZ implemented encryption for external client communication, see bankers-bonus-client.cpp
|
||||
* for a simpler client with no crypto.
|
||||
*
|
||||
* Provides a client to bankers_bonus_commsec.mpc program to calculate which banker pays for lunch based on
|
||||
* the private value annual bonus. Up to 8 clients can connect to the SPDZ engines running
|
||||
* the bankers_bonus.mpc program.
|
||||
*
|
||||
* Each connecting client:
|
||||
* - sends an increasing id to identify the client, starting with 0
|
||||
* - sends an integer (0 meaining more players will join this round or 1 meaning stop the round and calc the result).
|
||||
* - runs crypto setup to demonstrate both DH Auth Encryption and STS protocol for comms security.
|
||||
* - sends an integer input (bonus value to compare)
|
||||
*
|
||||
* The result is returned authenticated with a share of a random value:
|
||||
* - share of winning unique id [y]
|
||||
* - share of random value [r]
|
||||
* - share of winning unique id * random value [w]
|
||||
* winning unique id is valid if ∑ [y] * ∑ [r] = ∑ [w]
|
||||
*
|
||||
* To run with 2 parties (SPDZ engines) and 3 external clients:
|
||||
* ./Scripts/setup-online.sh to create triple shares for each party (spdz engine).
|
||||
* ./client-setup.x 2 -nc 3 to create the crypto key material for both parties and clients.
|
||||
* ./compile.py bankers_bonus_commsec
|
||||
* ./Scripts/run-online.sh bankers_bonus_commsec to run the engines.
|
||||
*
|
||||
* ./bankers-bonus-commsec-client.x 0 2 100 0
|
||||
* ./bankers-bonus-commsec-client.x 1 2 200 0
|
||||
* ./bankers-bonus-commsec-client.x 2 2 50 1
|
||||
*
|
||||
* Expect winner to be second client with id 1.
|
||||
* Note here client id must match id used in generating client key material, Client-Keys-C<client id>.
|
||||
*/
|
||||
|
||||
#include "Math/gfp.h"
|
||||
#include "Math/gf2n.h"
|
||||
#include "Networking/sockets.h"
|
||||
#include "Networking/STS.h"
|
||||
#include "Tools/int.h"
|
||||
#include "Math/Setup.h"
|
||||
#include "Protocols/fake-stuff.h"
|
||||
|
||||
#include <sodium.h>
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
|
||||
typedef pair< vector<octet>, vector<octet> > keypair_t; // A pair of send/recv keys for talking to SPDZ
|
||||
typedef vector< keypair_t > commsec_t; // A database of send/recv keys indexed by server
|
||||
typedef struct {
|
||||
unsigned char client_secretkey[crypto_sign_SECRETKEYBYTES];
|
||||
unsigned char client_publickey[crypto_sign_PUBLICKEYBYTES];
|
||||
vector<int> client_publickey_ints;
|
||||
vector< vector<unsigned char> >server_publickey;
|
||||
} sign_key_container_t;
|
||||
|
||||
keypair_t sts_response_role_exceptions(sign_key_container_t keys, vector<int>& sockets, int server_id)
|
||||
{
|
||||
STS ke(&keys.server_publickey[server_id][0], keys.client_publickey, keys.client_secretkey);
|
||||
sts_msg1_t m1;
|
||||
sts_msg2_t m2;
|
||||
sts_msg3_t m3;
|
||||
octetStream os;
|
||||
|
||||
os.Receive(sockets[server_id]);
|
||||
os.consume(m1.bytes, sizeof m1.bytes);
|
||||
m2 = ke.recv_msg1(m1);
|
||||
os.reset_write_head();
|
||||
os.append(m2.pubkey, sizeof m2.pubkey);
|
||||
os.append(m2.sig, sizeof m2.sig);
|
||||
os.Send(sockets[server_id]);
|
||||
os.Receive(sockets[server_id]);
|
||||
os.consume(m3.bytes, sizeof m3.bytes);
|
||||
ke.recv_msg3(m3);
|
||||
vector<unsigned char> recvKey = ke.derive_secret(crypto_secretbox_KEYBYTES);
|
||||
vector<unsigned char> sendKey = ke.derive_secret(crypto_secretbox_KEYBYTES);
|
||||
return make_pair(sendKey,recvKey);
|
||||
}
|
||||
|
||||
keypair_t sts_initiator_role_exceptions(sign_key_container_t keys, vector<int>& sockets, int server_id)
|
||||
{
|
||||
STS ke(&keys.server_publickey[server_id][0], keys.client_publickey, keys.client_secretkey);
|
||||
sts_msg1_t m1;
|
||||
sts_msg2_t m2;
|
||||
sts_msg3_t m3;
|
||||
octetStream os;
|
||||
|
||||
m1 = ke.send_msg1();
|
||||
cout << "m1: ";
|
||||
for (unsigned int j = 0; j < 32; j++)
|
||||
cout << setfill('0') << setw(2) << hex << (int) m1.bytes[j];
|
||||
cout << dec << endl;
|
||||
os.reset_write_head();
|
||||
os.append(m1.bytes, sizeof m1.bytes);
|
||||
os.Send(sockets[server_id]);
|
||||
|
||||
os.reset_write_head();
|
||||
os.Receive(sockets[server_id]);
|
||||
os.consume(m2.pubkey, sizeof m2.pubkey);
|
||||
os.consume(m2.sig, sizeof m2.sig);
|
||||
m3 = ke.recv_msg2(m2);
|
||||
|
||||
os.reset_write_head();
|
||||
os.append(m3.bytes, sizeof m3.bytes);
|
||||
os.Send(sockets[server_id]);
|
||||
|
||||
vector<unsigned char> sendKey = ke.derive_secret(crypto_secretbox_KEYBYTES);
|
||||
vector<unsigned char> recvKey = ke.derive_secret(crypto_secretbox_KEYBYTES);
|
||||
return make_pair(sendKey,recvKey);
|
||||
}
|
||||
|
||||
pair< vector<octet>, vector<octet> > sts_response_role(sign_key_container_t keys, vector<int>& sockets, int server_id)
|
||||
{
|
||||
pair< vector<octet>, vector<octet> > res;
|
||||
try {
|
||||
res = sts_response_role_exceptions(keys, sockets, server_id);
|
||||
} catch(char const *e) {
|
||||
cerr << "Error in STS: " << e << endl;
|
||||
exit(1);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
pair< vector<octet>, vector<octet> > sts_initiator_role(sign_key_container_t keys, vector<int>& sockets, int server_id)
|
||||
{
|
||||
pair< vector<octet>, vector<octet> > res;
|
||||
try {
|
||||
res = sts_initiator_role_exceptions(keys, sockets, server_id);
|
||||
} catch(char const *e) {
|
||||
cerr << "Error in STS: " << e << endl;
|
||||
exit(1);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
// Send the private inputs masked with a random value.
|
||||
// Receive shares of a preprocessed triple from each SPDZ engine, combine and check the triples are valid.
|
||||
// Add the private input value to triple[0] and send to each spdz engine.
|
||||
void send_private_inputs(const vector<gfp>& values, vector<int>& sockets, int nparties,
|
||||
commsec_t commsec, vector<octet*>& keys)
|
||||
{
|
||||
int num_inputs = values.size();
|
||||
octetStream os;
|
||||
vector< vector<gfp> > triples(num_inputs, vector<gfp>(3));
|
||||
vector<gfp> triple_shares(3);
|
||||
|
||||
// Receive num_inputs triples from SPDZ
|
||||
for (int j = 0; j < nparties; j++)
|
||||
{
|
||||
os.reset_write_head();
|
||||
os.Receive(sockets[j]);
|
||||
os.decrypt_sequence(&commsec[j].second[0],0);
|
||||
os.decrypt(keys[j]);
|
||||
|
||||
for (int j = 0; j < num_inputs; j++)
|
||||
{
|
||||
for (int k = 0; k < 3; k++)
|
||||
{
|
||||
triple_shares[k].unpack(os);
|
||||
triples[j][k] += triple_shares[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
// Check triple relations
|
||||
for (int i = 0; i < num_inputs; i++)
|
||||
{
|
||||
if (triples[i][0] * triples[i][1] != triples[i][2])
|
||||
{
|
||||
cerr << "Incorrect triple at " << i << ", aborting\n";
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
// Send inputs + triple[0], so SPDZ can compute shares of each value
|
||||
os.reset_write_head();
|
||||
for (int i = 0; i < num_inputs; i++)
|
||||
{
|
||||
gfp y = values[i] + triples[i][0];
|
||||
y.pack(os);
|
||||
}
|
||||
for (int j = 0; j < nparties; j++)
|
||||
{
|
||||
octetStream temp = os;
|
||||
temp.encrypt_sequence(&commsec[j].first[0], 0);
|
||||
temp.Send(sockets[j]);
|
||||
}
|
||||
}
|
||||
|
||||
// Send public key in clear to each SPDZ engine.
|
||||
void send_public_key(vector<int>& pubkey, int socket)
|
||||
{
|
||||
octetStream os;
|
||||
os.reset_write_head();
|
||||
|
||||
for (unsigned int i = 0; i < pubkey.size(); i++)
|
||||
{
|
||||
os.store(pubkey[i]);
|
||||
}
|
||||
|
||||
os.Send(socket);
|
||||
}
|
||||
|
||||
// Assumes that Scripts/setup-online.sh has been run to compute prime
|
||||
void initialise_fields(const string& dir_prefix)
|
||||
{
|
||||
int lg2;
|
||||
bigint p;
|
||||
|
||||
string filename = dir_prefix + "Params-Data";
|
||||
cout << "loading params from: " << filename << endl;
|
||||
|
||||
ifstream inpf(filename.c_str());
|
||||
if (inpf.fail()) { throw file_error(filename.c_str()); }
|
||||
inpf >> p;
|
||||
inpf >> lg2;
|
||||
|
||||
inpf.close();
|
||||
|
||||
gfp::init_field(p);
|
||||
gf2n::init_field(lg2);
|
||||
}
|
||||
|
||||
// Assumes that client-setup has been run to create key pairs for clients and parties
|
||||
void generate_symmetric_keys(vector<octet*>& keys, vector<int>& client_public_key_ints,
|
||||
sign_key_container_t *sts_key, const string& dir_prefix, int client_no)
|
||||
{
|
||||
unsigned char client_publickey[crypto_box_PUBLICKEYBYTES];
|
||||
unsigned char client_secretkey[crypto_box_SECRETKEYBYTES];
|
||||
unsigned char server_publickey[crypto_box_PUBLICKEYBYTES];
|
||||
unsigned char scalarmult_q[crypto_scalarmult_BYTES];
|
||||
crypto_generichash_state h;
|
||||
|
||||
// read client public/secret keys + SPDZ server public keys
|
||||
ifstream keyfile;
|
||||
stringstream client_filename;
|
||||
client_filename << dir_prefix << "Client-Keys-C" << client_no;
|
||||
keyfile.open(client_filename.str().c_str());
|
||||
if (keyfile.fail())
|
||||
throw file_error(client_filename.str());
|
||||
keyfile.read((char*)client_publickey, sizeof client_publickey);
|
||||
if (keyfile.eof())
|
||||
throw end_of_file(client_filename.str(), "client public key" );
|
||||
|
||||
// Convert client public key unsigned char to int, reverse endianness
|
||||
for(unsigned int j = 0; j < client_public_key_ints.size(); j++) {
|
||||
int keybyte = 0;
|
||||
for(unsigned int k = 0; k < 4; k++) {
|
||||
keybyte = keybyte + (((int)client_publickey[j*4+k]) << ((3-k) * 8));
|
||||
}
|
||||
client_public_key_ints[j] = keybyte;
|
||||
}
|
||||
|
||||
keyfile.read((char*)client_secretkey, sizeof client_secretkey);
|
||||
if (keyfile.eof()) {
|
||||
throw end_of_file(client_filename.str(), "client private key" );
|
||||
}
|
||||
|
||||
keyfile.read((char*)sts_key->client_publickey, crypto_sign_PUBLICKEYBYTES);
|
||||
keyfile.read((char*)sts_key->client_secretkey, crypto_sign_SECRETKEYBYTES);
|
||||
// Convert client public key unsigned char to int, reverse endianness
|
||||
sts_key->client_publickey_ints.resize(8);
|
||||
for(unsigned int j = 0; j < sts_key->client_publickey_ints.size(); j++) {
|
||||
int keybyte = 0;
|
||||
for(unsigned int k = 0; k < 4; k++) {
|
||||
keybyte = keybyte + (((int)sts_key->client_publickey[j*4+k]) << ((3-k) * 8));
|
||||
}
|
||||
sts_key->client_publickey_ints[j] = keybyte;
|
||||
}
|
||||
|
||||
for (unsigned int i = 0; i < keys.size(); i++)
|
||||
{
|
||||
keys[i] = new octet[crypto_generichash_BYTES];
|
||||
keyfile.read((char*)server_publickey, crypto_box_PUBLICKEYBYTES);
|
||||
if (keyfile.eof())
|
||||
throw end_of_file(client_filename.str(), "server public key for party " + to_string(i));
|
||||
keyfile.read((char*)(&sts_key->server_publickey[i][0]), crypto_sign_PUBLICKEYBYTES);
|
||||
if (keyfile.eof())
|
||||
throw end_of_file(client_filename.str(), "server public signing key for party " + to_string(i));
|
||||
|
||||
// Derive a shared key from this server's secret key and the client's public key
|
||||
// shared key = h(q || client_secretkey || server_publickey)
|
||||
if (crypto_scalarmult(scalarmult_q, client_secretkey, server_publickey) != 0) {
|
||||
cerr << "Scalar mult failed\n";
|
||||
exit(1);
|
||||
}
|
||||
crypto_generichash_init(&h, NULL, 0U, crypto_generichash_BYTES);
|
||||
crypto_generichash_update(&h, scalarmult_q, sizeof scalarmult_q);
|
||||
crypto_generichash_update(&h, client_publickey, sizeof client_publickey);
|
||||
crypto_generichash_update(&h, server_publickey, sizeof server_publickey);
|
||||
crypto_generichash_final(&h, keys[i], crypto_generichash_BYTES);
|
||||
}
|
||||
keyfile.close();
|
||||
|
||||
cout << "My public key is: ";
|
||||
for (unsigned int j = 0; j < 32; j++)
|
||||
cout << setfill('0') << setw(2) << hex << (int) client_publickey[j];
|
||||
cout << dec << endl;
|
||||
}
|
||||
|
||||
|
||||
// Receive shares of the result and sum together.
|
||||
// Also receive authenticating values.
|
||||
gfp receive_result(vector<int>& sockets, int nparties, commsec_t commsec, vector<octet*>& keys)
|
||||
{
|
||||
vector<gfp> output_values(3);
|
||||
octetStream os;
|
||||
for (int i = 0; i < nparties; i++)
|
||||
{
|
||||
os.reset_write_head();
|
||||
os.Receive(sockets[i]);
|
||||
|
||||
os.decrypt_sequence(&commsec[i].second[0],1);
|
||||
os.decrypt(keys[i]);
|
||||
|
||||
for (unsigned int j = 0; j < 3; j++)
|
||||
{
|
||||
gfp value;
|
||||
value.unpack(os);
|
||||
output_values[j] += value;
|
||||
}
|
||||
}
|
||||
|
||||
if (output_values[0] * output_values[1] != output_values[2])
|
||||
{
|
||||
cerr << "Unable to authenticate output value as correct, aborting." << endl;
|
||||
exit(1);
|
||||
}
|
||||
return output_values[0];
|
||||
}
|
||||
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
int my_client_id;
|
||||
int nparties;
|
||||
int salary_value;
|
||||
int finish;
|
||||
int port_base = 14000;
|
||||
sign_key_container_t sts_key;
|
||||
string host_name = "localhost";
|
||||
|
||||
if (argc < 5) {
|
||||
cout << "Usage is external-client <client identifier> <number of spdz parties> "
|
||||
<< "<salary to compare> <finish (0 false, 1 true)> <optional host name, default localhost> "
|
||||
<< "<optional spdz party port base number, default 14000>" << endl;
|
||||
exit(0);
|
||||
}
|
||||
|
||||
my_client_id = atoi(argv[1]);
|
||||
nparties = atoi(argv[2]);
|
||||
salary_value = atoi(argv[3]);
|
||||
finish = atoi(argv[4]);
|
||||
if (argc > 5)
|
||||
host_name = argv[5];
|
||||
if (argc > 6)
|
||||
port_base = atoi(argv[6]);
|
||||
|
||||
sts_key.server_publickey.resize(nparties);
|
||||
for(int i = 0 ; i < nparties; i++) {
|
||||
sts_key.server_publickey[i].resize(crypto_sign_PUBLICKEYBYTES);
|
||||
}
|
||||
|
||||
// init static gfp
|
||||
string prep_data_prefix = get_prep_dir(nparties, 128, gf2n::default_degree());
|
||||
initialise_fields(prep_data_prefix);
|
||||
bigint::init_thread();
|
||||
|
||||
// Generate session keys to decrypt data sent from each spdz engine (party)
|
||||
vector<octet*> session_keys(nparties);
|
||||
vector<int> client_public_key_ints(8);
|
||||
|
||||
generate_symmetric_keys(session_keys, client_public_key_ints, &sts_key, prep_data_prefix, my_client_id);
|
||||
|
||||
// Setup connections from this client to each party socket and send the client public keys
|
||||
vector<int> sockets(nparties);
|
||||
// vector< pair <vector<octet>, vector <octet> > > commseckey(nparties);
|
||||
commsec_t commseckey(nparties);
|
||||
for (int i = 0; i < nparties; i++)
|
||||
{
|
||||
set_up_client_socket(sockets[i], host_name.c_str(), port_base + i);
|
||||
send(sockets[i], (octet*) &my_client_id, sizeof(int));
|
||||
octetStream os;
|
||||
os.store(finish);
|
||||
os.Send(sockets[i]);
|
||||
|
||||
send_public_key(sts_key.client_publickey_ints, sockets[i]);
|
||||
send_public_key(client_public_key_ints, sockets[i]);
|
||||
commseckey[i] = sts_initiator_role(sts_key, sockets, i);
|
||||
}
|
||||
cout << "Finish setup socket connections to SPDZ engines." << endl;
|
||||
|
||||
// Send the inputs to the SPDZ Engines
|
||||
send_private_inputs({salary_value}, sockets, nparties, commseckey, session_keys);
|
||||
cout << "Sent private inputs to each SPDZ engine, waiting for result..." << endl;
|
||||
|
||||
// Get the result back
|
||||
gfp result = receive_result(sockets, nparties, commseckey, session_keys);
|
||||
|
||||
cout << "Winning client id is : " << result << endl;
|
||||
|
||||
for (unsigned int i = 0; i < sockets.size(); i++)
|
||||
close_client_socket(sockets[i]);
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -180,6 +180,13 @@ void FFT_Data::pack(octetStream& o) const
|
||||
{
|
||||
R.pack(o);
|
||||
prData.pack(o);
|
||||
o.store(root);
|
||||
o.store(twop);
|
||||
o.store(two_root);
|
||||
o.store(b);
|
||||
iphi.pack(o);
|
||||
o.store(powers);
|
||||
o.store(powers_i);
|
||||
}
|
||||
|
||||
|
||||
@@ -187,7 +194,13 @@ void FFT_Data::unpack(octetStream& o)
|
||||
{
|
||||
R.unpack(o);
|
||||
prData.unpack(o);
|
||||
init(R, prData);
|
||||
o.get(root);
|
||||
o.get(twop);
|
||||
o.get(two_root);
|
||||
o.get(b);
|
||||
iphi.unpack(o);
|
||||
o.get(powers);
|
||||
o.get(powers_i);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -345,19 +345,38 @@ void init(Ring& Rg,int m)
|
||||
Rg.pi.resize(Rg.phim); Rg.pi_inv.resize(Rg.mm);
|
||||
for (int i=0; i<Rg.mm; i++) { Rg.pi_inv[i]=-1; }
|
||||
|
||||
int k=0;
|
||||
for (int i=1; i<Rg.mm; i++)
|
||||
{ if (gcd(i,Rg.mm)==1)
|
||||
{ Rg.pi[k]=i;
|
||||
Rg.pi_inv[i]=k;
|
||||
k++;
|
||||
if (((m - 1) & m) == 0)
|
||||
{
|
||||
// m is power of two
|
||||
// no need to generate poly
|
||||
int k = 0;
|
||||
for (int i = 1; i < Rg.mm; i++)
|
||||
{
|
||||
// easy GCD
|
||||
if (i % 2 == 1)
|
||||
{
|
||||
Rg.pi[k] = i;
|
||||
Rg.pi_inv[i] = k;
|
||||
k++;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
int k=0;
|
||||
for (int i=1; i<Rg.mm; i++)
|
||||
{ if (gcd(i,Rg.mm)==1)
|
||||
{ Rg.pi[k]=i;
|
||||
Rg.pi_inv[i]=k;
|
||||
k++;
|
||||
}
|
||||
}
|
||||
|
||||
ZZX P=Cyclotomic(Rg.mm);
|
||||
Rg.poly.resize(Rg.phim+1);
|
||||
for (int i=0; i<Rg.phim+1; i++)
|
||||
{ Rg.poly[i]=to_int(coeff(P,i)); }
|
||||
ZZX P=Cyclotomic(Rg.mm);
|
||||
Rg.poly.resize(Rg.phim+1);
|
||||
for (int i=0; i<Rg.phim+1; i++)
|
||||
{ Rg.poly[i]=to_int(coeff(P,i)); }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -16,7 +16,9 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p,
|
||||
{
|
||||
if (sigma <= 0)
|
||||
this->sigma = sigma = FHE_Params().get_R();
|
||||
#ifdef VERBOSE
|
||||
cerr << "Standard deviation: " << this->sigma << endl;
|
||||
#endif
|
||||
h += extra_h * sec;
|
||||
produce_epsilon_constants();
|
||||
|
||||
|
||||
24
FHE/Ring.cpp
24
FHE/Ring.cpp
@@ -29,19 +29,27 @@ istream& operator>>(istream& s,Ring& R)
|
||||
void Ring::pack(octetStream& o) const
|
||||
{
|
||||
o.store(mm);
|
||||
o.store(phim);
|
||||
o.store(pi);
|
||||
o.store(pi_inv);
|
||||
o.store(poly);
|
||||
if (((mm - 1) & mm) != 0)
|
||||
{
|
||||
o.store(phim);
|
||||
o.store(pi);
|
||||
o.store(pi_inv);
|
||||
o.store(poly);
|
||||
}
|
||||
}
|
||||
|
||||
void Ring::unpack(octetStream& o)
|
||||
{
|
||||
o.get(mm);
|
||||
o.get(phim);
|
||||
o.get(pi);
|
||||
o.get(pi_inv);
|
||||
o.get(poly);
|
||||
if (((mm - 1) & mm) != 0)
|
||||
{
|
||||
o.get(phim);
|
||||
o.get(pi);
|
||||
o.get(pi_inv);
|
||||
o.get(poly);
|
||||
}
|
||||
else
|
||||
init(*this, mm);
|
||||
}
|
||||
|
||||
bool Ring::operator !=(const Ring& other) const
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <assert.h>
|
||||
using namespace std;
|
||||
|
||||
#include "Tools/octetStream.h"
|
||||
@@ -31,7 +32,7 @@ class Ring
|
||||
|
||||
int p(int i) const { return pi.at(i); }
|
||||
int p_inv(int i) const { return pi_inv.at(i); }
|
||||
const vector<int>& Phi() const { return poly; }
|
||||
const vector<int>& Phi() const { assert(poly.size()); return poly; }
|
||||
|
||||
friend ostream& operator<<(ostream& s,const Ring& R);
|
||||
friend istream& operator>>(istream& s,Ring& R);
|
||||
|
||||
@@ -96,6 +96,8 @@ public:
|
||||
res += x.from == from;
|
||||
return res;
|
||||
}
|
||||
|
||||
int n_interactive_inputs_from_me(int my_num);
|
||||
};
|
||||
|
||||
#endif /* GC_ARGTUPLES_H_ */
|
||||
|
||||
@@ -76,14 +76,7 @@ unsigned Instruction::get_mem(RegType reg_type) const
|
||||
inline
|
||||
void Instruction::parse(istream& s, int pos)
|
||||
{
|
||||
n = 0;
|
||||
start.resize(0);
|
||||
::memset(r, 0, sizeof(r));
|
||||
|
||||
int file_pos = s.tellg();
|
||||
opcode = ::get_int(s);
|
||||
|
||||
parse_operands(s, pos, file_pos);
|
||||
BaseInstruction::parse(s, pos);
|
||||
|
||||
switch(opcode)
|
||||
{
|
||||
|
||||
@@ -42,6 +42,7 @@ MAYBE_INLINE bool Instruction::execute(Processor<T>& processor,
|
||||
cout << endl;
|
||||
#endif
|
||||
const Instruction& instruction = *this;
|
||||
auto& Ci = processor.I;
|
||||
switch (opcode)
|
||||
{
|
||||
#define X(NAME, CODE) case NAME: CODE; return true;
|
||||
|
||||
@@ -65,6 +65,8 @@ public:
|
||||
void reset(const U& program);
|
||||
|
||||
long long get_input(const int* params, bool interactive = false);
|
||||
bigint get_long_input(const int* params, ProcessorBase& input_proc,
|
||||
bool interactive = false);
|
||||
|
||||
void bitcoms(T& x, const vector<int>& regs) { x.bitcom(S, regs); }
|
||||
void bitdecs(const vector<int>& regs, const T& x) { x.bitdec(S, regs); }
|
||||
@@ -84,6 +86,10 @@ public:
|
||||
template<class U>
|
||||
void store_clear_in_dynamic(const vector<int>& args, U& dynamic_memory);
|
||||
|
||||
template<class U>
|
||||
void mem_op(int n, Memory<U>& dest, const Memory<U>& source,
|
||||
Integer dest_address, Integer source_address);
|
||||
|
||||
void xors(const vector<int>& args);
|
||||
void andm(const ::BaseInstruction& instruction);
|
||||
void and_(const vector<int>& args, bool repeat);
|
||||
@@ -95,9 +101,9 @@ public:
|
||||
|
||||
void reveal(const ::BaseInstruction& instruction);
|
||||
|
||||
void print_reg(int reg, int n);
|
||||
void print_reg(int reg, int n, int size);
|
||||
void print_reg_plain(Clear& value);
|
||||
void print_reg_signed(unsigned n_bits, Clear& value);
|
||||
void print_reg_signed(unsigned n_bits, Integer value);
|
||||
void print_chr(int n);
|
||||
void print_str(int n);
|
||||
void print_float(const vector<int>& args);
|
||||
|
||||
@@ -67,11 +67,19 @@ void Processor<T>::reset(const U& program)
|
||||
template<class T>
|
||||
inline long long GC::Processor<T>::get_input(const int* params, bool interactive)
|
||||
{
|
||||
bigint res = ProcessorBase::get_input<FixInput>(interactive, ¶ms[1]).items[0];
|
||||
assert(params[0] <= 64);
|
||||
return get_long_input(params, *this, interactive).get_si();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
bigint GC::Processor<T>::get_long_input(const int* params,
|
||||
ProcessorBase& input_proc, bool interactive)
|
||||
{
|
||||
bigint res = input_proc.get_input<FixInput_<bigint>>(interactive,
|
||||
¶ms[1]).items[0];
|
||||
int n_bits = *params;
|
||||
check_input(res, n_bits);
|
||||
assert(n_bits <= 64);
|
||||
return res.get_si();
|
||||
return res;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -171,6 +179,17 @@ void GC::Processor<T>::store_clear_in_dynamic(const vector<int>& args,
|
||||
T::store_clear_in_dynamic(dynamic_memory, accesses);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
template<class U>
|
||||
void Processor<T>::mem_op(int n, Memory<U>& dest, const Memory<U>& source,
|
||||
Integer dest_address, Integer source_address)
|
||||
{
|
||||
for (int i = 0; i < n; i++)
|
||||
{
|
||||
dest[dest_address + i] = source[source_address + i];
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void Processor<T>::xors(const vector<int>& args)
|
||||
{
|
||||
@@ -234,12 +253,15 @@ void Processor<T>::reveal(const vector<int>& args)
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void Processor<T>::print_reg(int reg, int n)
|
||||
void Processor<T>::print_reg(int reg, int n, int size)
|
||||
{
|
||||
#ifdef DEBUG_VALUES
|
||||
cout << "print_reg " << typeid(T).name() << " " << reg << " " << &C[reg] << endl;
|
||||
#endif
|
||||
T::out << "Reg[" << reg << "] = " << hex << showbase << C[reg] << dec << " # ";
|
||||
bigint output;
|
||||
for (int i = 0; i < size; i++)
|
||||
output += bigint((unsigned long)C[reg + i].get()) << (T::default_length * i);
|
||||
T::out << "Reg[" << reg << "] = " << hex << showbase << output << dec << " # ";
|
||||
print_str(n);
|
||||
T::out << endl << flush;
|
||||
}
|
||||
@@ -251,14 +273,29 @@ void Processor<T>::print_reg_plain(Clear& value)
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void Processor<T>::print_reg_signed(unsigned n_bits, Clear& value)
|
||||
void Processor<T>::print_reg_signed(unsigned n_bits, Integer reg)
|
||||
{
|
||||
unsigned n_shift = 0;
|
||||
if (n_bits > 1)
|
||||
n_shift = sizeof(value.get()) * 8 - n_bits;
|
||||
if (n_shift > 63)
|
||||
n_shift = 0;
|
||||
T::out << dec << (value.get() << n_shift >> n_shift) << flush;
|
||||
if (n_bits <= Clear::N_BITS)
|
||||
{
|
||||
auto value = C[reg];
|
||||
unsigned n_shift = 0;
|
||||
if (n_bits > 1)
|
||||
n_shift = sizeof(value.get()) * 8 - n_bits;
|
||||
if (n_shift > 63)
|
||||
n_shift = 0;
|
||||
T::out << dec << (value.get() << n_shift >> n_shift) << flush;
|
||||
}
|
||||
else
|
||||
{
|
||||
bigint tmp = 0;
|
||||
for (int i = 0; i < DIV_CEIL(n_bits, Clear::N_BITS); i++)
|
||||
{
|
||||
tmp += bigint((unsigned long)C[reg + i].get()) << (i * Clear::N_BITS);
|
||||
}
|
||||
if (tmp >= bigint(1) << (n_bits - 1))
|
||||
tmp -= bigint(1) << n_bits;
|
||||
T::out << dec << tmp << flush;
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
|
||||
@@ -21,6 +21,8 @@
|
||||
|
||||
#include <fstream>
|
||||
|
||||
class ProcessorBase;
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
@@ -116,6 +118,10 @@ public:
|
||||
static void inputb(Processor<U>& processor, const vector<int>& args)
|
||||
{ T::inputb(processor, args); }
|
||||
template<class U>
|
||||
static void inputb(Processor<U>& processor, ProcessorBase& input_proc,
|
||||
const vector<int>& args)
|
||||
{ T::inputb(processor, input_proc, args); }
|
||||
template<class U>
|
||||
static void reveal_inst(Processor<U>& processor, const vector<int>& args)
|
||||
{ processor.reveal(args); }
|
||||
|
||||
|
||||
@@ -290,12 +290,14 @@ void Secret<T>::trans(Processor<U>& processor, int n_outputs,
|
||||
const vector<int>& args)
|
||||
{
|
||||
int n_inputs = args.size() - n_outputs;
|
||||
int dl = U::default_length;
|
||||
for (int i = 0; i < n_outputs; i++)
|
||||
{
|
||||
processor.S[args[i]].resize_regs(n_inputs);
|
||||
for (int j = 0; j < DIV_CEIL(n_inputs, dl); j++)
|
||||
processor.S[args[i] + j].resize_regs(min(dl, n_inputs - j * dl));
|
||||
for (int j = 0; j < n_inputs; j++)
|
||||
processor.S[args[i]].registers[j] =
|
||||
processor.S[args[n_outputs + j]].registers[i];
|
||||
processor.S[args[i] + j / dl].registers[j % dl] =
|
||||
processor.S[args[n_outputs + j] + i / dl].registers[i % dl];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -25,12 +25,25 @@ SemiSecret::MC* SemiSecret::new_mc(mac_key_type)
|
||||
void SemiSecret::trans(Processor<SemiSecret>& processor, int n_outputs,
|
||||
const vector<int>& args)
|
||||
{
|
||||
square64 square;
|
||||
for (size_t i = n_outputs; i < args.size(); i++)
|
||||
square.rows[i - n_outputs] = processor.S[args[i]].get();
|
||||
square.transpose(args.size() - n_outputs, n_outputs);
|
||||
for (int i = 0; i < n_outputs; i++)
|
||||
processor.S[args[i]] = square.rows[i];
|
||||
int N_BITS = default_length;
|
||||
for (int j = 0; j < DIV_CEIL(n_outputs, N_BITS); j++)
|
||||
for (int l = 0; l < DIV_CEIL(args.size() - n_outputs, N_BITS); l++)
|
||||
{
|
||||
square64 square;
|
||||
size_t input_base = n_outputs + l * N_BITS;
|
||||
for (size_t i = input_base;
|
||||
i < min(input_base + N_BITS, args.size()); i++)
|
||||
square.rows[i - input_base] = processor.S[args[i] + j].get();
|
||||
square.transpose(
|
||||
min(size_t(N_BITS), args.size() - n_outputs - l * N_BITS),
|
||||
min(N_BITS, n_outputs - j * N_BITS));
|
||||
int output_base = j * N_BITS;
|
||||
for (int i = output_base; i < min(n_outputs, output_base + N_BITS);
|
||||
i++)
|
||||
{
|
||||
processor.S[args[i] + l] = square.rows[i - output_base];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SemiSecret::load_clear(int n, const Integer& x)
|
||||
|
||||
@@ -20,6 +20,7 @@ using namespace std;
|
||||
#include "Protocols/Replicated.h"
|
||||
#include "Protocols/ReplicatedMC.h"
|
||||
#include "Processor/DummyProtocol.h"
|
||||
#include "Processor/ProcessorBase.h"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
@@ -58,7 +59,10 @@ public:
|
||||
{ and_(processor, args, false); }
|
||||
static void and_(Processor<U>& processor, const vector<int>& args, bool repeat);
|
||||
static void xors(Processor<U>& processor, const vector<int>& args);
|
||||
static void inputb(Processor<U>& processor, const vector<int>& args);
|
||||
static void inputb(Processor<U>& processor, const vector<int>& args)
|
||||
{ inputb(processor, processor, args); }
|
||||
static void inputb(Processor<U>& processor, ProcessorBase& input_processor,
|
||||
const vector<int>& args);
|
||||
static void reveal_inst(Processor<U>& processor, const vector<int>& args);
|
||||
|
||||
static void convcbit(Integer& dest, const Clear& source) { dest = source; }
|
||||
|
||||
@@ -26,7 +26,7 @@ namespace GC
|
||||
{
|
||||
|
||||
template<class U>
|
||||
const int VectorSecret<U>::default_length;
|
||||
const int ReplicatedSecret<U>::N_BITS;
|
||||
|
||||
template<class U>
|
||||
const int ReplicatedSecret<U>::default_length;
|
||||
@@ -92,6 +92,7 @@ void ShareSecret<U>::store_clear_in_dynamic(Memory<U>& mem,
|
||||
|
||||
template<class U>
|
||||
void ShareSecret<U>::inputb(Processor<U>& processor,
|
||||
ProcessorBase& input_processor,
|
||||
const vector<int>& args)
|
||||
{
|
||||
auto& party = ShareThread<U>::s();
|
||||
@@ -99,16 +100,22 @@ void ShareSecret<U>::inputb(Processor<U>& processor,
|
||||
input.reset_all(*party.P);
|
||||
|
||||
InputArgList a(args);
|
||||
bool interactive = Thread<U>::s().n_interactive_inputs_from_me(a) > 0;
|
||||
bool interactive = a.n_interactive_inputs_from_me(party.P->my_num()) > 0;
|
||||
int dl = U::default_length;
|
||||
|
||||
for (auto x : a)
|
||||
{
|
||||
if (x.from == party.P->my_num())
|
||||
{
|
||||
input.add_mine(processor.get_input(x.params, interactive), x.n_bits);
|
||||
bigint whole_input = processor.get_long_input(x.params,
|
||||
input_processor, interactive);
|
||||
for (int i = 0; i < DIV_CEIL(x.n_bits, dl); i++)
|
||||
input.add_mine(bigint(whole_input >> (i * dl)).get_si(),
|
||||
min(dl, x.n_bits - i * dl));
|
||||
}
|
||||
else
|
||||
input.add_other(x.from);
|
||||
for (int i = 0; i < DIV_CEIL(x.n_bits, dl); i++)
|
||||
input.add_other(x.from);
|
||||
}
|
||||
|
||||
if (interactive)
|
||||
@@ -120,8 +127,12 @@ void ShareSecret<U>::inputb(Processor<U>& processor,
|
||||
{
|
||||
int from = x.from;
|
||||
int n_bits = x.n_bits;
|
||||
auto& res = processor.S[x.dest];
|
||||
res = input.finalize(from, n_bits).mask(n_bits);
|
||||
for (int i = 0; i < DIV_CEIL(x.n_bits, dl); i++)
|
||||
{
|
||||
auto& res = processor.S[x.dest + i];
|
||||
int n = min(dl, n_bits - i * dl);
|
||||
res = input.finalize(from, n).mask(n);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -139,7 +150,11 @@ void ShareSecret<U>::reveal_inst(Processor<U>& processor,
|
||||
if (n > max(U::default_length, Clear::N_BITS))
|
||||
assert(U::default_length == Clear::N_BITS);
|
||||
for (int j = 0; j < DIV_CEIL(n, U::default_length); j++)
|
||||
shares.push_back(processor.S[r1 + j].mask(n));
|
||||
{
|
||||
shares.push_back(
|
||||
processor.S[r1 + j].mask(
|
||||
min(U::default_length, n - j * U::default_length)));
|
||||
}
|
||||
}
|
||||
assert(party.MC);
|
||||
PointerVector<typename U::open_type> opened;
|
||||
@@ -149,7 +164,10 @@ void ShareSecret<U>::reveal_inst(Processor<U>& processor,
|
||||
int n = args[i];
|
||||
int r0 = args[i + 1];
|
||||
for (int j = 0; j < DIV_CEIL(n, U::default_length); j++)
|
||||
processor.C[r0 + j] = opened.next().mask(n);
|
||||
{
|
||||
processor.C[r0 + j] = opened.next().mask(
|
||||
min(U::default_length, n - j * U::default_length));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -180,12 +198,22 @@ void ReplicatedSecret<U>::trans(Processor<U>& processor,
|
||||
assert(length == 2);
|
||||
for (int k = 0; k < 2; k++)
|
||||
{
|
||||
square64 square;
|
||||
for (size_t i = n_outputs; i < args.size(); i++)
|
||||
square.rows[i - n_outputs] = processor.S[args[i]][k].get();
|
||||
square.transpose(args.size() - n_outputs, n_outputs);
|
||||
for (int i = 0; i < n_outputs; i++)
|
||||
processor.S[args[i]][k] = square.rows[i];
|
||||
for (int j = 0; j < DIV_CEIL(n_outputs, N_BITS); j++)
|
||||
for (int l = 0; l < DIV_CEIL(args.size() - n_outputs, N_BITS); l++)
|
||||
{
|
||||
square64 square;
|
||||
size_t input_base = n_outputs + l * N_BITS;
|
||||
for (size_t i = input_base; i < min(input_base + N_BITS, args.size()); i++)
|
||||
square.rows[i - input_base] = processor.S[args[i] + j][k].get();
|
||||
square.transpose(
|
||||
min(size_t(N_BITS), args.size() - n_outputs - l * N_BITS),
|
||||
min(N_BITS, n_outputs - j * N_BITS));
|
||||
int output_base = j * N_BITS;
|
||||
for (int i = output_base; i < min(n_outputs, output_base + N_BITS); i++)
|
||||
{
|
||||
processor.S[args[i] + l][k] = square.rows[i - output_base];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -75,7 +75,9 @@ void ShareThread<T>::post_run()
|
||||
{
|
||||
MC->Check(*this->P);
|
||||
#ifndef INSECURE
|
||||
#ifdef VERBOSE
|
||||
cerr << "Removing used pre-processed data" << endl;
|
||||
#endif
|
||||
DataF.prune();
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -90,18 +90,23 @@ void Thread<T>::finish()
|
||||
pthread_join(thread, 0);
|
||||
}
|
||||
|
||||
|
||||
template<class T>
|
||||
int GC::Thread<T>::n_interactive_inputs_from_me(InputArgList& args)
|
||||
int Thread<T>::n_interactive_inputs_from_me(InputArgList& args)
|
||||
{
|
||||
return args.n_interactive_inputs_from_me(P->my_num());
|
||||
}
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
|
||||
inline int InputArgList::n_interactive_inputs_from_me(int my_num)
|
||||
{
|
||||
int res = 0;
|
||||
if (thread_num == 0 and master.opts.interactive)
|
||||
res = args.n_inputs_from(P->my_num());
|
||||
if (ArithmeticProcessor().use_stdin())
|
||||
res = n_inputs_from(my_num);
|
||||
if (res > 0)
|
||||
cout << "Please enter " << res << " numbers:" << endl;
|
||||
return res;
|
||||
}
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
#endif
|
||||
|
||||
@@ -218,6 +218,9 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
template<class U>
|
||||
const int VectorSecret<U>::default_length;
|
||||
|
||||
template<class T>
|
||||
inline VectorSecret<T> operator*(const BitVec& clear, const VectorSecret<T>& share)
|
||||
{
|
||||
|
||||
@@ -30,12 +30,11 @@
|
||||
|
||||
#define IMM instruction.get_n()
|
||||
#define EXTRA instruction.get_start()
|
||||
#define SIZE instruction.get_size()
|
||||
|
||||
#define MSD processor.memories.MS[IMM]
|
||||
#define MMC processor.memories.MC[IMM]
|
||||
#define MMS processor.memories.MS
|
||||
#define MMC processor.memories.MC
|
||||
#define MID MACH->MI[IMM]
|
||||
|
||||
#define MSI processor.memories.MS[PI1.get()]
|
||||
#define MII MACH->MI[PI1.get()]
|
||||
|
||||
#define BIT_INSTRUCTIONS \
|
||||
@@ -44,7 +43,6 @@
|
||||
X(XORCBI, C0.xor_(PC1, IMM)) \
|
||||
X(ANDRS, T::andrs(PROC, EXTRA)) \
|
||||
X(ANDS, T::ands(PROC, EXTRA)) \
|
||||
X(INPUTB, T::inputb(PROC, EXTRA)) \
|
||||
X(ADDCB, C0 = PC1 + PC2) \
|
||||
X(ADDCBI, C0 = PC1 + IMM) \
|
||||
X(MULCBI, C0 = PC1 * IMM) \
|
||||
@@ -54,24 +52,25 @@
|
||||
X(SHRCBI, C0 = PC1 >> IMM) \
|
||||
X(SHLCBI, C0 = PC1 << IMM) \
|
||||
X(LDBITS, S0.load_clear(REG1, IMM)) \
|
||||
X(LDMSB, S0 = MSD) \
|
||||
X(STMSB, MSD = S0) \
|
||||
X(LDMCB, C0 = MMC) \
|
||||
X(STMCB, MMC = C0) \
|
||||
X(LDMSB, PROC.mem_op(SIZE, PROC.S, MMS, R0, IMM)) \
|
||||
X(STMSB, PROC.mem_op(SIZE, MMS, PROC.S, IMM, R0)) \
|
||||
X(LDMCB, PROC.mem_op(SIZE, PROC.C, MMC, R0, IMM)) \
|
||||
X(STMCB, PROC.mem_op(SIZE, MMC, PROC.C, IMM, R0)) \
|
||||
X(LDMSBI, PROC.mem_op(SIZE, PROC.S, MMS, R0, Ci[REG1])) \
|
||||
X(STMSBI, PROC.mem_op(SIZE, MMS, PROC.S, Ci[REG1], R0)) \
|
||||
X(MOVSB, S0 = PS1) \
|
||||
X(TRANS, T::trans(PROC, IMM, EXTRA)) \
|
||||
X(BITB, PROC.random_bit(S0)) \
|
||||
X(REVEAL, T::reveal_inst(PROC, EXTRA)) \
|
||||
X(PRINTREGSIGNED, PROC.print_reg_signed(IMM, C0)) \
|
||||
X(PRINTREGB, PROC.print_reg(R0, IMM)) \
|
||||
X(PRINTREGSIGNED, PROC.print_reg_signed(IMM, R0)) \
|
||||
X(PRINTREGB, PROC.print_reg(R0, IMM, SIZE)) \
|
||||
X(PRINTREGPLAINB, PROC.print_reg_plain(C0)) \
|
||||
X(PRINTFLOATPLAINB, PROC.print_float(EXTRA)) \
|
||||
X(CONDPRINTSTRB, if(C0.get()) PROC.print_str(IMM)) \
|
||||
|
||||
#define COMBI_INSTRUCTIONS BIT_INSTRUCTIONS \
|
||||
X(INPUTB, T::inputb(PROC, Proc, EXTRA)) \
|
||||
X(ANDM, processor.andm(instruction)) \
|
||||
X(LDMSBI, S0 = processor.memories.MS[Proc.read_Ci(REG1)]) \
|
||||
X(STMSBI, processor.memories.MS[Proc.read_Ci(REG1)] = S0) \
|
||||
X(CONVSINT, S0.load_clear(IMM, Proc.read_Ci(REG1))) \
|
||||
X(CONVCINT, C0 = Proc.read_Ci(REG1)) \
|
||||
X(CONVCBIT, Proc.write_Ci(R0, PC1.get())) \
|
||||
@@ -84,8 +83,7 @@
|
||||
X(SPLIT, Proc.split(INST)) \
|
||||
|
||||
#define GC_INSTRUCTIONS \
|
||||
X(LDMSBI, S0 = MSI) \
|
||||
X(STMSBI, MSI = S0) \
|
||||
X(INPUTB, T::inputb(PROC, EXTRA)) \
|
||||
X(LDMSD, PROC.load_dynamic_direct(EXTRA, MD)) \
|
||||
X(STMSD, PROC.store_dynamic_direct(EXTRA, MD)) \
|
||||
X(LDMSDI, PROC.load_dynamic_indirect(EXTRA, MD)) \
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "Tools/cpu_support.h"
|
||||
#include <stdexcept>
|
||||
#include <iostream>
|
||||
#include <assert.h>
|
||||
using namespace std;
|
||||
|
||||
union matrix32x8
|
||||
@@ -98,6 +99,9 @@ void square64::transpose(int n_rows, int n_cols)
|
||||
print();
|
||||
#endif
|
||||
|
||||
assert(n_rows <= 64);
|
||||
assert(n_cols <= 64);
|
||||
|
||||
square64 tmp = *this;
|
||||
*this = {};
|
||||
|
||||
|
||||
@@ -3,9 +3,7 @@
|
||||
#include "Math/Setup.h"
|
||||
#include "Protocols/Share.h"
|
||||
#include "Tools/ezOptionParser.h"
|
||||
#include "Tools/Config.h"
|
||||
#include "Networking/Server.h"
|
||||
#include "GC/TinierSecret.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
|
||||
8
Makefile
8
Makefile
@@ -59,7 +59,7 @@ offline: $(OT_EXE) Check-Offline.x
|
||||
|
||||
gen_input: gen_input_f2n.x gen_input_fp.x
|
||||
|
||||
externalIO: client-setup.x bankers-bonus-client.x bankers-bonus-commsec-client.x
|
||||
externalIO: bankers-bonus-client.x
|
||||
|
||||
bmr: bmr-program-party.x bmr-program-tparty.x
|
||||
|
||||
@@ -134,9 +134,6 @@ bmr-clean:
|
||||
bankers-bonus-client.x: ExternalIO/bankers-bonus-client.cpp $(COMMON)
|
||||
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
|
||||
|
||||
bankers-bonus-commsec-client.x: ExternalIO/bankers-bonus-commsec-client.cpp $(COMMON)
|
||||
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
|
||||
|
||||
simple-offline.x: $(FHEOFFLINE)
|
||||
pairwise-offline.x: $(FHEOFFLINE)
|
||||
cnc-offline.x: $(FHEOFFLINE)
|
||||
@@ -195,6 +192,9 @@ OT/BaseOT.o: SimpleOT/Makefile
|
||||
SimpleOT/Makefile:
|
||||
git submodule update --init SimpleOT
|
||||
|
||||
Programs/Circuits:
|
||||
git submodule update --init Programs/Circuits
|
||||
|
||||
.PHONY: mpir-setup mpir-global mpir
|
||||
mpir-setup:
|
||||
git submodule update --init mpir
|
||||
|
||||
10
Math/Z2k.h
10
Math/Z2k.h
@@ -73,6 +73,8 @@ public:
|
||||
static void reqbl(int n);
|
||||
static bool allows(Dtype dtype);
|
||||
|
||||
static void specification(octetStream& os);
|
||||
|
||||
typedef Z2 next;
|
||||
typedef Z2 Scalar;
|
||||
|
||||
@@ -254,12 +256,18 @@ public:
|
||||
}
|
||||
|
||||
template<int L>
|
||||
SignedZ2<K + L> operator*(const Z2<L>& other) const
|
||||
SignedZ2<K + L> operator*(const SignedZ2<L>& other) const
|
||||
{
|
||||
assert((K % 64 == 0) and (L % 64 == 0));
|
||||
return Z2<K>::operator*(other);
|
||||
}
|
||||
|
||||
template<int L>
|
||||
Z2<K + L> operator*(const Z2<L>& other) const
|
||||
{
|
||||
return Z2<K>::operator*(other);
|
||||
}
|
||||
|
||||
SignedZ2<K> operator*(int other) const
|
||||
{
|
||||
return operator*(SignedZ2<64>(other));
|
||||
|
||||
@@ -36,6 +36,12 @@ bool Z2<K>::allows(Dtype dtype)
|
||||
return Integer::allows(dtype);
|
||||
}
|
||||
|
||||
template<int K>
|
||||
void Z2<K>::specification(octetStream& os)
|
||||
{
|
||||
os.store(K);
|
||||
}
|
||||
|
||||
template<int K>
|
||||
Z2<K>::Z2(const bigint& x) : Z2()
|
||||
{
|
||||
|
||||
@@ -205,6 +205,12 @@ bool gfp_<X, L>::allows(Dtype type)
|
||||
}
|
||||
}
|
||||
|
||||
template<int X, int L>
|
||||
void gfp_<X, L>::specification(octetStream& os)
|
||||
{
|
||||
os.store(pr());
|
||||
}
|
||||
|
||||
void to_signed_bigint(bigint& ans, const gfp& x)
|
||||
{
|
||||
to_bigint(ans, x);
|
||||
|
||||
@@ -92,6 +92,8 @@ class gfp_
|
||||
|
||||
static bool allows(Dtype type);
|
||||
|
||||
static void specification(octetStream& os);
|
||||
|
||||
static const bool invertible = true;
|
||||
|
||||
static gfp_ Mul(gfp_ a, gfp_ b) { return a * b; }
|
||||
|
||||
@@ -61,6 +61,9 @@ class modp_
|
||||
void pack(octetStream& o,const Zp_Data& ZpD) const;
|
||||
void unpack(octetStream& o,const Zp_Data& ZpD);
|
||||
|
||||
void pack(octetStream& o) const;
|
||||
void unpack(octetStream& o);
|
||||
|
||||
bool operator==(const modp_& other) const { return 0 == mpn_cmp(x, other.x, L); }
|
||||
bool operator!=(const modp_& other) const { return not (*this == other); }
|
||||
|
||||
|
||||
@@ -20,6 +20,17 @@ void modp_<L>::unpack(octetStream& o,const Zp_Data& ZpD)
|
||||
o.consume((octet*) x,ZpD.t*sizeof(mp_limb_t));
|
||||
}
|
||||
|
||||
template<int L>
|
||||
void modp_<L>::unpack(octetStream& o)
|
||||
{
|
||||
o.consume((octet*) x,L*sizeof(mp_limb_t));
|
||||
}
|
||||
|
||||
template<int L>
|
||||
void modp_<L>::pack(octetStream& o) const
|
||||
{
|
||||
o.append((octet*) x,L*sizeof(mp_limb_t));
|
||||
}
|
||||
|
||||
template<int L>
|
||||
void Negate(modp_<L>& ans,const modp_<L>& x,const Zp_Data& ZpD)
|
||||
|
||||
@@ -14,29 +14,19 @@ void check_ssl_file(string filename)
|
||||
"You can use `Scripts/setup-ssl.sh <nparties>`.");
|
||||
}
|
||||
|
||||
void ssl_error(string side, string pronoun, int other, int server)
|
||||
void ssl_error(string side, string pronoun, string other, string server)
|
||||
{
|
||||
cerr << side << "-side handshake with party " << other
|
||||
cerr << side << "-side handshake with " << other
|
||||
<< " failed. Make sure " << pronoun
|
||||
<< " have the necessary certificate (" << PREP_DIR "P" << server
|
||||
<< " have the necessary certificate (" << PREP_DIR << server
|
||||
<< ".pem in the default configuration),"
|
||||
<< " and run `c_rehash <directory>` on its location." << endl;
|
||||
}
|
||||
|
||||
CryptoPlayer::CryptoPlayer(const Names& Nms, int id_base) :
|
||||
MultiPlayer<ssl_socket*>(Nms, id_base), plaintext_player(Nms, id_base),
|
||||
ctx(boost::asio::ssl::context::tlsv12)
|
||||
ctx("P" + to_string(my_num()))
|
||||
{
|
||||
string prefix = PREP_DIR "P" + to_string(my_num());
|
||||
string cert_file = prefix + ".pem";
|
||||
string key_file = prefix + ".key";
|
||||
check_ssl_file(cert_file);
|
||||
check_ssl_file(key_file);
|
||||
|
||||
ctx.use_certificate_file(cert_file, ctx.pem);
|
||||
ctx.use_private_key_file(key_file, ctx.pem);
|
||||
ctx.add_verify_path("Player-Data");
|
||||
|
||||
sockets.resize(num_players());
|
||||
senders.resize(num_players());
|
||||
|
||||
@@ -49,30 +39,8 @@ CryptoPlayer::CryptoPlayer(const Names& Nms, int id_base) :
|
||||
continue;
|
||||
}
|
||||
|
||||
sockets[i] = new ssl_socket(io_service, ctx);
|
||||
sockets[i]->lowest_layer().assign(boost::asio::ip::tcp::v4(), plaintext_player.socket(i));
|
||||
sockets[i]->set_verify_mode(boost::asio::ssl::verify_peer);
|
||||
sockets[i]->set_verify_callback(boost::asio::ssl::rfc2818_verification("P" + to_string(i)));
|
||||
if (i < my_num())
|
||||
try
|
||||
{
|
||||
sockets[i]->handshake(ssl_socket::client);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
ssl_error("Client", "we", i, i);
|
||||
throw;
|
||||
}
|
||||
if (i > my_num())
|
||||
try
|
||||
{
|
||||
sockets[i]->handshake(ssl_socket::server);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
ssl_error("Server", "they", i, my_num());
|
||||
throw;
|
||||
}
|
||||
sockets[i] = new ssl_socket(io_service, ctx, plaintext_player.socket(i),
|
||||
"P" + to_string(i), "P" + to_string(my_num()), i < my_num());
|
||||
|
||||
senders[i] = new Sender<ssl_socket*>(sockets[i]);
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
class CryptoPlayer : public MultiPlayer<ssl_socket*>
|
||||
{
|
||||
PlainPlayer plaintext_player;
|
||||
boost::asio::ssl::context ctx;
|
||||
ssl_ctx ctx;
|
||||
boost::asio::io_service io_service;
|
||||
|
||||
vector<Sender<ssl_socket*>*> senders;
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
#include "Player.h"
|
||||
#include "ssl_sockets.h"
|
||||
#include "Exceptions/Exceptions.h"
|
||||
#include "Networking/STS.h"
|
||||
#include "Tools/int.h"
|
||||
#include "Tools/NetworkOptions.h"
|
||||
#include "Networking/Server.h"
|
||||
|
||||
@@ -1,228 +0,0 @@
|
||||
#include "Networking/STS.h"
|
||||
#include <sodium.h>
|
||||
#include <string>
|
||||
#include <string.h>
|
||||
#include <unistd.h>
|
||||
#include <stdio.h>
|
||||
#include <iomanip>
|
||||
#include <fcntl.h>
|
||||
|
||||
void STS::kdf_block(unsigned char *block)
|
||||
{
|
||||
crypto_hash_sha512_state state;
|
||||
crypto_hash_sha512_init(&state);
|
||||
unsigned char ctrbytes[sizeof kdf_counter];
|
||||
kdf_counter++;
|
||||
|
||||
// Little endian serialization
|
||||
for(size_t i=0; i<sizeof(kdf_counter); i++) {
|
||||
ctrbytes[i] = (unsigned char)((kdf_counter >> i*8) & 0xFF);
|
||||
}
|
||||
crypto_hash_sha512_update(&state,ctrbytes,sizeof ctrbytes);
|
||||
crypto_hash_sha512_update(&state,raw_secret,crypto_hash_sha512_BYTES);
|
||||
crypto_hash_sha512_final(&state, block);
|
||||
}
|
||||
|
||||
vector<unsigned char> STS::unsafe_derive_secret(size_t sz)
|
||||
{
|
||||
// KDF ~ H(cnt || raw_secret)
|
||||
vector<unsigned char> resultSecret(sz + crypto_hash_sha512_BYTES - (sz % crypto_hash_sha512_BYTES));
|
||||
size_t total=0;
|
||||
while(total < sz) {
|
||||
unsigned char *block = &resultSecret[total];
|
||||
kdf_block(block);
|
||||
total += crypto_hash_sha512_BYTES;
|
||||
}
|
||||
return resultSecret;
|
||||
}
|
||||
|
||||
STS::STS()
|
||||
{
|
||||
phase = UNDEFINED;
|
||||
}
|
||||
|
||||
void STS::init( const unsigned char theirPub[crypto_sign_PUBLICKEYBYTES]
|
||||
, const unsigned char myPub[crypto_sign_PUBLICKEYBYTES]
|
||||
, const unsigned char myPriv[crypto_sign_SECRETKEYBYTES])
|
||||
{
|
||||
phase = UNKNOWN;
|
||||
memcpy(their_public_sign_key, theirPub, crypto_sign_PUBLICKEYBYTES);
|
||||
memcpy(my_public_sign_key, myPub, crypto_sign_PUBLICKEYBYTES);
|
||||
memcpy(my_private_sign_key, myPriv, crypto_sign_SECRETKEYBYTES);
|
||||
memset(their_ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES);
|
||||
memset(ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES);
|
||||
memset(ephemeral_private_key, 0, crypto_box_SECRETKEYBYTES);
|
||||
kdf_counter = 0;
|
||||
}
|
||||
|
||||
STS::STS( const unsigned char theirPub[crypto_sign_PUBLICKEYBYTES]
|
||||
, const unsigned char myPub[crypto_sign_PUBLICKEYBYTES]
|
||||
, const unsigned char myPriv[crypto_sign_SECRETKEYBYTES])
|
||||
{
|
||||
phase = UNKNOWN;
|
||||
memcpy(their_public_sign_key, theirPub, crypto_sign_PUBLICKEYBYTES);
|
||||
memcpy(my_public_sign_key, myPub, crypto_sign_PUBLICKEYBYTES);
|
||||
memcpy(my_private_sign_key, myPriv, crypto_sign_SECRETKEYBYTES);
|
||||
memset(their_ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES);
|
||||
memset(ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES);
|
||||
memset(ephemeral_private_key, 0, crypto_box_SECRETKEYBYTES);
|
||||
kdf_counter = 0;
|
||||
}
|
||||
|
||||
STS::~STS()
|
||||
{
|
||||
memset(their_public_sign_key, 0, crypto_sign_PUBLICKEYBYTES);
|
||||
memset(my_private_sign_key, 0, crypto_sign_SECRETKEYBYTES);
|
||||
memset(ephemeral_private_key, 0, crypto_box_SECRETKEYBYTES);
|
||||
memset(ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES);
|
||||
memset(their_ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES);
|
||||
memset(raw_secret, 0, crypto_hash_sha512_BYTES);
|
||||
kdf_counter = 0;
|
||||
phase = UNKNOWN;
|
||||
}
|
||||
|
||||
sts_msg1_t STS::send_msg1()
|
||||
{
|
||||
sts_msg1_t m;
|
||||
if(UNKNOWN != phase) {
|
||||
throw "STS BAD PHASE";
|
||||
}
|
||||
|
||||
crypto_box_keypair(ephemeral_public_key, ephemeral_private_key);
|
||||
memcpy(m.bytes,ephemeral_public_key,crypto_box_PUBLICKEYBYTES);
|
||||
phase = SENT1;
|
||||
return m;
|
||||
}
|
||||
|
||||
// If the incoming signature is valid, compute:
|
||||
// shared secret = H(DH(pubB,privA) || pubA || pubB)
|
||||
// msg = Sign_{privED-A} (pubA || pubB )
|
||||
//
|
||||
sts_msg3_t STS::recv_msg2(sts_msg2_t msg2)
|
||||
{
|
||||
unsigned char *theirPublicKey = msg2.pubkey;
|
||||
unsigned char *theirSig = msg2.sig;
|
||||
unsigned char theirSigDec[crypto_sign_BYTES];
|
||||
unsigned char scalar_result[crypto_scalarmult_SCALARBYTES];
|
||||
const unsigned char zeroNonce[crypto_stream_NONCEBYTES] = {0};
|
||||
int ret;
|
||||
crypto_hash_sha512_state state;
|
||||
sts_msg3_t msg;
|
||||
|
||||
if(SENT1 != phase) {
|
||||
throw "STS BAD PHASE";
|
||||
}
|
||||
ret = crypto_scalarmult(scalar_result, ephemeral_private_key, theirPublicKey);
|
||||
if(0 != ret) {
|
||||
throw "crypto_scalarmult failed";
|
||||
}
|
||||
|
||||
crypto_hash_sha512_init(&state);
|
||||
crypto_hash_sha512_update(&state,scalar_result,crypto_scalarmult_SCALARBYTES);
|
||||
crypto_hash_sha512_update(&state,ephemeral_public_key,crypto_box_PUBLICKEYBYTES);
|
||||
crypto_hash_sha512_update(&state,theirPublicKey,crypto_box_PUBLICKEYBYTES);
|
||||
crypto_hash_sha512_final(&state,raw_secret);
|
||||
|
||||
vector<unsigned char> keKey = unsafe_derive_secret(crypto_stream_KEYBYTES);
|
||||
vector<unsigned char> expectedMessage;
|
||||
expectedMessage.insert(expectedMessage.end(), theirPublicKey , theirPublicKey + crypto_box_PUBLICKEYBYTES);
|
||||
expectedMessage.insert(expectedMessage.end(), ephemeral_public_key, ephemeral_public_key + crypto_box_PUBLICKEYBYTES);
|
||||
|
||||
crypto_stream_xor(theirSigDec, theirSig, crypto_sign_BYTES, zeroNonce, &keKey[0]);
|
||||
|
||||
int badSig = crypto_sign_verify_detached(theirSigDec, &expectedMessage[0], expectedMessage.size(), their_public_sign_key);
|
||||
|
||||
if(badSig) {
|
||||
throw "Bad signature received in message 2.";
|
||||
} else {
|
||||
unsigned char *mySigEnc = msg.bytes;
|
||||
unsigned char mySig[crypto_sign_BYTES];
|
||||
vector<unsigned char> signMessage;
|
||||
signMessage.insert(signMessage.end(), ephemeral_public_key, ephemeral_public_key + crypto_box_PUBLICKEYBYTES);
|
||||
signMessage.insert(signMessage.end(), theirPublicKey , theirPublicKey + crypto_box_PUBLICKEYBYTES);
|
||||
if(0 != crypto_sign_detached(mySig, NULL, &signMessage[0], signMessage.size(), my_private_sign_key)) {
|
||||
throw "Signing failed.";
|
||||
}
|
||||
vector<unsigned char> keKey2 = unsafe_derive_secret(crypto_stream_KEYBYTES);
|
||||
crypto_stream_xor(mySigEnc, mySig, crypto_sign_BYTES, zeroNonce, &keKey2[0]);
|
||||
|
||||
phase = FINISHED;
|
||||
return msg;
|
||||
}
|
||||
}
|
||||
|
||||
sts_msg2_t STS::recv_msg1(sts_msg1_t msg1)
|
||||
{
|
||||
unsigned char *theirPublicKey = msg1.bytes;
|
||||
unsigned char scalar_result[crypto_scalarmult_SCALARBYTES];
|
||||
crypto_hash_sha512_state state;
|
||||
sts_msg2_t m;
|
||||
int ret;
|
||||
|
||||
if(UNKNOWN != phase) {
|
||||
throw "recv_msg1 called on non-unknown phase";
|
||||
}
|
||||
|
||||
memcpy(their_ephemeral_public_key, theirPublicKey, crypto_box_PUBLICKEYBYTES);
|
||||
|
||||
crypto_box_keypair(ephemeral_public_key, ephemeral_private_key);
|
||||
memcpy(m.pubkey,ephemeral_public_key,crypto_box_PUBLICKEYBYTES);
|
||||
ret = crypto_scalarmult(scalar_result, ephemeral_private_key, theirPublicKey);
|
||||
if(0 != ret) {
|
||||
throw "crypto_scalarmult failed when processing message 1";
|
||||
}
|
||||
|
||||
crypto_hash_sha512_init(&state);
|
||||
crypto_hash_sha512_update(&state,scalar_result,crypto_scalarmult_SCALARBYTES);
|
||||
crypto_hash_sha512_update(&state,theirPublicKey,crypto_box_PUBLICKEYBYTES);
|
||||
crypto_hash_sha512_update(&state,ephemeral_public_key,crypto_box_PUBLICKEYBYTES);
|
||||
crypto_hash_sha512_final(&state,raw_secret);
|
||||
|
||||
vector<unsigned char> livenessProof;
|
||||
livenessProof.insert(livenessProof.end(), ephemeral_public_key, ephemeral_public_key + crypto_box_PUBLICKEYBYTES);
|
||||
livenessProof.insert(livenessProof.end(), theirPublicKey , theirPublicKey + crypto_box_PUBLICKEYBYTES);
|
||||
unsigned char mySig[crypto_sign_BYTES];
|
||||
unsigned char *mySigEnc = m.sig;
|
||||
vector<unsigned char> keKey = unsafe_derive_secret(crypto_stream_KEYBYTES);
|
||||
|
||||
unsigned char zeroNonce[crypto_stream_NONCEBYTES] = {0};
|
||||
if(0 != crypto_sign_detached(mySig, NULL, &livenessProof[0], livenessProof.size(), my_private_sign_key)) {
|
||||
throw "Signing failed.";
|
||||
}
|
||||
crypto_stream_xor(mySigEnc, mySig, crypto_sign_BYTES, zeroNonce, &keKey[0]);
|
||||
|
||||
phase = SENT2;
|
||||
return m;
|
||||
}
|
||||
|
||||
void STS::recv_msg3(sts_msg3_t msg3)
|
||||
{
|
||||
unsigned char *theirSig=msg3.bytes;
|
||||
unsigned char theirSigDec[crypto_sign_BYTES];
|
||||
vector<unsigned char> expectedMessage;
|
||||
if(SENT2 != phase) {
|
||||
throw "recv_msg3 called out of order";
|
||||
}
|
||||
|
||||
expectedMessage.insert(expectedMessage.end(), their_ephemeral_public_key , their_ephemeral_public_key + crypto_box_PUBLICKEYBYTES);
|
||||
expectedMessage.insert(expectedMessage.end(), ephemeral_public_key, ephemeral_public_key + crypto_box_PUBLICKEYBYTES);
|
||||
unsigned char zeroNonce[crypto_stream_NONCEBYTES] = {0};
|
||||
vector<unsigned char> keKey2 = unsafe_derive_secret(crypto_stream_KEYBYTES);
|
||||
|
||||
crypto_stream_xor(theirSigDec, theirSig, crypto_sign_BYTES, zeroNonce, &keKey2[0]);
|
||||
int badSig = crypto_sign_verify_detached(theirSigDec, &expectedMessage[0], expectedMessage.size(), their_public_sign_key);
|
||||
|
||||
if(badSig) {
|
||||
throw "Bad signature received in message 3.";
|
||||
} else {
|
||||
phase = FINISHED;
|
||||
}
|
||||
}
|
||||
|
||||
vector<unsigned char> STS::derive_secret(size_t sz)
|
||||
{
|
||||
if(phase != FINISHED) {
|
||||
throw "Can not derive secrets till the key exchange has completed.";
|
||||
}
|
||||
return unsafe_derive_secret(sz);
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
#ifndef _NETWORK_STS
|
||||
#define _NETWORK_STS
|
||||
|
||||
/* The Station to Station protocol
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
#include <sodium.h>
|
||||
|
||||
using namespace std;
|
||||
|
||||
typedef enum
|
||||
{ UNKNOWN // Have not started the interaction or have cleared the memory
|
||||
, SENT1 // Sent initial message
|
||||
, SENT2 // Received 1, sent 2
|
||||
, FINISHED // Done (received msg 2 & sent 3 or received msg 3)
|
||||
, UNDEFINED // For arrays/vectors/etc of STS classes that are initialized later.
|
||||
} phase_t;
|
||||
|
||||
struct msg1_st {
|
||||
unsigned char bytes[crypto_box_PUBLICKEYBYTES];
|
||||
};
|
||||
typedef struct msg1_st sts_msg1_t;
|
||||
struct msg2_st {
|
||||
unsigned char pubkey[crypto_box_PUBLICKEYBYTES];
|
||||
unsigned char sig[crypto_sign_BYTES];
|
||||
};
|
||||
typedef struct msg2_st sts_msg2_t;
|
||||
struct msg3_st {
|
||||
unsigned char bytes[crypto_sign_BYTES];
|
||||
};
|
||||
typedef struct msg3_st sts_msg3_t;
|
||||
|
||||
class STS
|
||||
{
|
||||
phase_t phase;
|
||||
unsigned char their_public_sign_key[crypto_sign_PUBLICKEYBYTES];
|
||||
unsigned char my_public_sign_key[crypto_sign_PUBLICKEYBYTES];
|
||||
unsigned char my_private_sign_key[crypto_sign_SECRETKEYBYTES];
|
||||
unsigned char ephemeral_private_key[crypto_box_SECRETKEYBYTES];
|
||||
unsigned char ephemeral_public_key[crypto_box_PUBLICKEYBYTES];
|
||||
unsigned char their_ephemeral_public_key[crypto_box_PUBLICKEYBYTES];
|
||||
unsigned char raw_secret[crypto_hash_sha512_BYTES];
|
||||
uint64_t kdf_counter;
|
||||
public:
|
||||
STS();
|
||||
STS( const unsigned char theirPub[crypto_sign_PUBLICKEYBYTES]
|
||||
, const unsigned char myPub[crypto_sign_PUBLICKEYBYTES]
|
||||
, const unsigned char myPriv[crypto_sign_SECRETKEYBYTES]);
|
||||
~STS();
|
||||
|
||||
void init( const unsigned char theirPub[crypto_sign_PUBLICKEYBYTES]
|
||||
, const unsigned char myPub[crypto_sign_PUBLICKEYBYTES]
|
||||
, const unsigned char myPriv[crypto_sign_SECRETKEYBYTES]);
|
||||
|
||||
sts_msg1_t send_msg1();
|
||||
sts_msg3_t recv_msg2(sts_msg2_t msg2);
|
||||
|
||||
sts_msg2_t recv_msg1(sts_msg1_t msg1);
|
||||
void recv_msg3(sts_msg3_t msg3);
|
||||
|
||||
vector<unsigned char> derive_secret(size_t);
|
||||
private:
|
||||
vector<unsigned char> unsafe_derive_secret(size_t);
|
||||
void kdf_block(unsigned char *block);
|
||||
};
|
||||
|
||||
#endif /* _NETWORK_STS */
|
||||
@@ -74,13 +74,61 @@ void ServerSocket::init()
|
||||
pthread_create(&thread, 0, accept_thread, this);
|
||||
}
|
||||
|
||||
class ServerJob
|
||||
{
|
||||
ServerSocket& server;
|
||||
int socket;
|
||||
sockaddr dest;
|
||||
|
||||
public:
|
||||
pthread_t thread;
|
||||
|
||||
ServerJob(ServerSocket& server, int socket, sockaddr dest) :
|
||||
server(server), socket(socket), dest(dest), thread(0)
|
||||
{
|
||||
}
|
||||
|
||||
static void* run(void* job)
|
||||
{
|
||||
auto& server_job = *(ServerJob*)(job);
|
||||
server_job.server.wait_for_client_id(server_job.socket, server_job.dest);
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
ServerSocket::~ServerSocket()
|
||||
{
|
||||
for (auto& job : jobs)
|
||||
{
|
||||
pthread_cancel(job->thread);
|
||||
pthread_join(job->thread, 0);
|
||||
delete job;
|
||||
}
|
||||
|
||||
pthread_cancel(thread);
|
||||
pthread_join(thread, 0);
|
||||
if (close(main_socket)) { error("close(main_socket"); };
|
||||
}
|
||||
|
||||
void ServerSocket::wait_for_client_id(int socket, sockaddr dest)
|
||||
{
|
||||
(void) dest;
|
||||
int client_id;
|
||||
try
|
||||
{
|
||||
receive(socket, (unsigned char*) &client_id, sizeof(client_id));
|
||||
process_connection(socket, client_id);
|
||||
}
|
||||
catch (closed_connection&)
|
||||
{
|
||||
#ifdef DEBUG_NETWORKING
|
||||
auto& conn = *(sockaddr_in*) &dest;
|
||||
fprintf(stderr, "client on %s:%d left without identification\n",
|
||||
inet_ntoa(conn.sin_addr), ntohs(conn.sin_port));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void ServerSocket::accept_clients()
|
||||
{
|
||||
while (true)
|
||||
@@ -92,25 +140,19 @@ void ServerSocket::accept_clients()
|
||||
if (consocket<0) { error("set_up_socket:accept"); }
|
||||
|
||||
int client_id;
|
||||
try
|
||||
{
|
||||
receive(consocket, (unsigned char*)&client_id, sizeof(client_id));
|
||||
}
|
||||
catch (closed_connection&)
|
||||
{
|
||||
if (receive_all_or_nothing(consocket, (unsigned char*)&client_id, sizeof(client_id)))
|
||||
process_connection(consocket, client_id);
|
||||
else
|
||||
{
|
||||
#ifdef DEBUG_NETWORKING
|
||||
auto& conn = *(sockaddr_in*)&dest;
|
||||
cerr << "client on " << inet_ntoa(conn.sin_addr) << ":"
|
||||
<< ntohs(conn.sin_port) << " left without identification"
|
||||
<< endl;
|
||||
auto& conn = *(sockaddr_in*) &dest;
|
||||
fprintf(stderr, "deferring client on %s:%d to thread\n",
|
||||
inet_ntoa(conn.sin_addr), ntohs(conn.sin_port));
|
||||
#endif
|
||||
}
|
||||
|
||||
data_signal.lock();
|
||||
process_client(client_id);
|
||||
clients[client_id] = consocket;
|
||||
data_signal.broadcast();
|
||||
data_signal.unlock();
|
||||
// defer to thread
|
||||
jobs.push_back(new ServerJob(*this, consocket, dest));
|
||||
pthread_create(&jobs.back()->thread, 0, ServerJob::run, jobs.back());
|
||||
}
|
||||
|
||||
#ifdef __APPLE__
|
||||
int flags = fcntl(consocket, F_GETFL, 0);
|
||||
@@ -121,15 +163,19 @@ void ServerSocket::accept_clients()
|
||||
}
|
||||
}
|
||||
|
||||
int ServerSocket::get_connection_count()
|
||||
void ServerSocket::process_connection(int consocket, int client_id)
|
||||
{
|
||||
data_signal.lock();
|
||||
int connection_count = clients.size();
|
||||
#ifdef DEBUG_NETWORKING
|
||||
cerr << "client " << hex << client_id << " is on socket " << dec << consocket
|
||||
<< endl;
|
||||
#endif
|
||||
process_client(client_id);
|
||||
clients[client_id] = consocket;
|
||||
data_signal.broadcast();
|
||||
data_signal.unlock();
|
||||
return connection_count;
|
||||
}
|
||||
|
||||
|
||||
int ServerSocket::get_connection_socket(int id)
|
||||
{
|
||||
data_signal.lock();
|
||||
@@ -163,16 +209,10 @@ void AnonymousServerSocket::init()
|
||||
pthread_create(&thread, 0, anonymous_accept_thread, this);
|
||||
}
|
||||
|
||||
int AnonymousServerSocket::get_connection_count()
|
||||
{
|
||||
return num_accepted_clients;
|
||||
}
|
||||
|
||||
void AnonymousServerSocket::process_client(int client_id)
|
||||
{
|
||||
if (clients.find(client_id) != clients.end())
|
||||
close_client_socket(clients[client_id]);
|
||||
num_accepted_clients++;
|
||||
client_connection_queue.push(client_id);
|
||||
}
|
||||
|
||||
|
||||
@@ -12,10 +12,13 @@
|
||||
using namespace std;
|
||||
|
||||
#include <pthread.h>
|
||||
#include <netinet/tcp.h>
|
||||
|
||||
#include "Tools/WaitQueue.h"
|
||||
#include "Tools/Signal.h"
|
||||
|
||||
class ServerJob;
|
||||
|
||||
class ServerSocket
|
||||
{
|
||||
protected:
|
||||
@@ -25,9 +28,13 @@ protected:
|
||||
Signal data_signal;
|
||||
pthread_t thread;
|
||||
|
||||
vector<ServerJob*> jobs;
|
||||
|
||||
// disable copying
|
||||
ServerSocket(const ServerSocket& other);
|
||||
|
||||
void process_connection(int socket, int client_id);
|
||||
|
||||
virtual void process_client(int) {}
|
||||
|
||||
public:
|
||||
@@ -38,14 +45,11 @@ public:
|
||||
|
||||
virtual void accept_clients();
|
||||
|
||||
void wait_for_client_id(int socket, sockaddr dest);
|
||||
|
||||
// This depends on clients sending their id as int.
|
||||
// Has to be thread-safe.
|
||||
int get_connection_socket(int number);
|
||||
|
||||
// How many client connections have been made.
|
||||
virtual int get_connection_count();
|
||||
|
||||
void close_socket();
|
||||
};
|
||||
|
||||
/*
|
||||
@@ -55,18 +59,15 @@ class AnonymousServerSocket : public ServerSocket
|
||||
{
|
||||
private:
|
||||
// No. of accepted connections in this instance
|
||||
int num_accepted_clients;
|
||||
queue<int> client_connection_queue;
|
||||
|
||||
void process_client(int client_id);
|
||||
|
||||
public:
|
||||
AnonymousServerSocket(int Portnum) :
|
||||
ServerSocket(Portnum), num_accepted_clients(0) { };
|
||||
ServerSocket(Portnum) { };
|
||||
void init();
|
||||
|
||||
virtual int get_connection_count();
|
||||
|
||||
// Get socket and id for the last client who connected
|
||||
int get_connection_socket(int& client_id);
|
||||
};
|
||||
|
||||
@@ -12,7 +12,65 @@
|
||||
#include <boost/asio.hpp>
|
||||
#include <boost/asio/ssl.hpp>
|
||||
|
||||
typedef boost::asio::ssl::stream<boost::asio::ip::tcp::socket> ssl_socket;
|
||||
typedef boost::asio::io_service ssl_service;
|
||||
|
||||
void check_ssl_file(string filename);
|
||||
void ssl_error(string side, string pronoun, string other, string server);
|
||||
|
||||
class ssl_ctx : public boost::asio::ssl::context
|
||||
{
|
||||
public:
|
||||
ssl_ctx(string me) :
|
||||
boost::asio::ssl::context(boost::asio::ssl::context::tlsv12)
|
||||
{
|
||||
string prefix = PREP_DIR + me;
|
||||
string cert_file = prefix + ".pem";
|
||||
string key_file = prefix + ".key";
|
||||
check_ssl_file(cert_file);
|
||||
check_ssl_file(key_file);
|
||||
|
||||
use_certificate_file(cert_file, pem);
|
||||
use_private_key_file(key_file, pem);
|
||||
add_verify_path(PREP_DIR);
|
||||
}
|
||||
};
|
||||
|
||||
class ssl_socket : public boost::asio::ssl::stream<boost::asio::ip::tcp::socket>
|
||||
{
|
||||
typedef boost::asio::ssl::stream<boost::asio::ip::tcp::socket> parent;
|
||||
|
||||
public:
|
||||
ssl_socket(boost::asio::io_service& io_service,
|
||||
boost::asio::ssl::context& ctx, int plaintext_socket, string other,
|
||||
string me, bool client) :
|
||||
parent(io_service, ctx)
|
||||
{
|
||||
lowest_layer().assign(boost::asio::ip::tcp::v4(), plaintext_socket);
|
||||
set_verify_mode(boost::asio::ssl::verify_peer);
|
||||
set_verify_callback(boost::asio::ssl::rfc2818_verification(other));
|
||||
if (client)
|
||||
try
|
||||
{
|
||||
handshake(ssl_socket::client);
|
||||
} catch (...)
|
||||
{
|
||||
ssl_error("Client", "we", other, other);
|
||||
throw;
|
||||
}
|
||||
else
|
||||
{
|
||||
try
|
||||
{
|
||||
handshake(ssl_socket::server);
|
||||
} catch (...)
|
||||
{
|
||||
ssl_error("Server", "they", other, me);
|
||||
throw;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
inline size_t send_non_blocking(ssl_socket* socket, octet* data, size_t length)
|
||||
{
|
||||
|
||||
@@ -5,41 +5,24 @@
|
||||
#include <thread>
|
||||
|
||||
ExternalClients::ExternalClients(int party_num, const string& prep_data_dir):
|
||||
party_num(party_num), prep_data_dir(prep_data_dir), server_connection_count(-1)
|
||||
party_num(party_num), prep_data_dir(prep_data_dir),
|
||||
ctx(0)
|
||||
{
|
||||
}
|
||||
|
||||
ExternalClients::~ExternalClients()
|
||||
{
|
||||
// close client sockets
|
||||
for (map<int,int>::iterator it = external_client_sockets.begin();
|
||||
for (auto it = external_client_sockets.begin();
|
||||
it != external_client_sockets.end(); it++)
|
||||
{
|
||||
if (close(it->second))
|
||||
{
|
||||
error("failed to close external client connection socket)");
|
||||
}
|
||||
delete it->second;
|
||||
}
|
||||
for (map<int,AnonymousServerSocket*>::iterator it = client_connection_servers.begin();
|
||||
it != client_connection_servers.end(); it++)
|
||||
{
|
||||
delete it->second;
|
||||
}
|
||||
for (map<int,octet*>::iterator it = symmetric_client_keys.begin();
|
||||
it != symmetric_client_keys.end(); it++)
|
||||
{
|
||||
delete[] it->second;
|
||||
}
|
||||
for (map<int, pair<vector<octet>,uint64_t> >::iterator it_cs = symmetric_client_commsec_send_keys.begin();
|
||||
it_cs != symmetric_client_commsec_send_keys.end(); it_cs++)
|
||||
{
|
||||
memset(&(it_cs->second.first[0]), 0, it_cs->second.first.size());
|
||||
}
|
||||
for (map<int, pair<vector<octet>,uint64_t> >::iterator it_cs = symmetric_client_commsec_recv_keys.begin();
|
||||
it_cs != symmetric_client_commsec_recv_keys.end(); it_cs++)
|
||||
{
|
||||
memset(&(it_cs->second.first[0]), 0, it_cs->second.first.size());
|
||||
}
|
||||
}
|
||||
|
||||
void ExternalClients::start_listening(int portnum_base)
|
||||
@@ -62,125 +45,21 @@ int ExternalClients::get_client_connection(int portnum_base)
|
||||
cerr << "Thread " << this_thread::get_id() << " found server." << endl;
|
||||
int client_id, socket;
|
||||
socket = client_connection_servers[portnum_base]->get_connection_socket(client_id);
|
||||
external_client_sockets[client_id] = socket;
|
||||
if (symmetric_client_keys.find(client_id) != symmetric_client_keys.end())
|
||||
delete symmetric_client_keys[client_id];
|
||||
symmetric_client_commsec_send_keys.erase(client_id);
|
||||
symmetric_client_commsec_recv_keys.erase(client_id);
|
||||
if (ctx == 0)
|
||||
ctx = new ssl_ctx("P" + to_string(get_party_num()));
|
||||
external_client_sockets[client_id] = new ssl_socket(io_service, *ctx, socket,
|
||||
"C" + to_string(client_id), "P" + to_string(get_party_num()), false);
|
||||
cerr << "Party " << get_party_num() << " received external client connection from client id: " << dec << client_id << endl;
|
||||
return client_id;
|
||||
}
|
||||
|
||||
int ExternalClients::connect_to_server(int portnum_base, int ipv4_address)
|
||||
{
|
||||
struct in_addr addr = { (unsigned int)ipv4_address };
|
||||
int csocket;
|
||||
const char* address_str = inet_ntoa(addr);
|
||||
cerr << "Party " << get_party_num() << " connecting to server at " << address_str << " on port " << portnum_base + get_party_num() << endl;
|
||||
set_up_client_socket(csocket, address_str, portnum_base + get_party_num());
|
||||
cerr << "Party " << get_party_num() << " connected to server at " << address_str << " on port " << portnum_base + get_party_num() << endl;
|
||||
int server_id = server_connection_count;
|
||||
// server identifiers are -1, -2, ... to avoid conflict with client identifiers
|
||||
server_connection_count--;
|
||||
external_client_sockets[server_id] = csocket;
|
||||
return server_id;
|
||||
}
|
||||
|
||||
void ExternalClients::curve25519_ints_to_bytes(unsigned char *bytes, const vector<int>& key_ints)
|
||||
{
|
||||
for(unsigned int j = 0; j < key_ints.size(); j++) {
|
||||
for(unsigned int k = 0; k < 4; k++) {
|
||||
bytes[j*sizeof(int) + k] = (key_ints[j] >> ((3-k)*8)) & 0xFF;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Generate sesssion key for a newly connected client, store in symmetric_client_keys
|
||||
// public_key is expected to be size 8 and contain integer values of public key bytes.
|
||||
// Assumes load_server_keys has been run.
|
||||
void ExternalClients::generate_session_key_for_client(int client_id, const vector<int>& public_key)
|
||||
{
|
||||
assert(public_key.size() * sizeof(int) == crypto_box_PUBLICKEYBYTES);
|
||||
|
||||
load_server_keys_once();
|
||||
|
||||
unsigned char client_publickey[crypto_box_PUBLICKEYBYTES];
|
||||
|
||||
curve25519_ints_to_bytes(client_publickey, public_key);
|
||||
|
||||
cerr << "Recevied client public key for client " << dec << client_id << " :";
|
||||
for (unsigned int j = 0; j < crypto_box_PUBLICKEYBYTES; j++)
|
||||
cerr << hex << (int) client_publickey[j] << " ";
|
||||
cerr << dec << endl;
|
||||
|
||||
unsigned char scalarmult_q_by_server[crypto_scalarmult_BYTES];
|
||||
crypto_generichash_state h;
|
||||
|
||||
symmetric_client_keys[client_id] = new octet[crypto_generichash_BYTES];
|
||||
|
||||
// Derive a shared key from this server's secret key and the client's public key
|
||||
// shared key = h(q || server_secretkey || client_publickey)
|
||||
if (crypto_scalarmult(scalarmult_q_by_server, server_secretkey, client_publickey) != 0) {
|
||||
cerr << "Scalar mult failed\n";
|
||||
exit(1);
|
||||
}
|
||||
crypto_generichash_init(&h, NULL, 0U, crypto_generichash_BYTES);
|
||||
crypto_generichash_update(&h, scalarmult_q_by_server, sizeof scalarmult_q_by_server);
|
||||
crypto_generichash_update(&h, client_publickey, sizeof client_publickey);
|
||||
crypto_generichash_update(&h, server_publickey, sizeof server_publickey);
|
||||
crypto_generichash_final(&h, symmetric_client_keys[client_id], crypto_generichash_BYTES);
|
||||
}
|
||||
|
||||
// Read pre-computed server keys from client-setup for this SPDZ engine.
|
||||
// Only needs to be done once per run, but is only necessary if an external connection
|
||||
// is being requested.
|
||||
void ExternalClients::load_server_keys_once()
|
||||
{
|
||||
if (server_keys_loaded) {
|
||||
return;
|
||||
}
|
||||
|
||||
ifstream keyfile;
|
||||
stringstream filename;
|
||||
filename << prep_data_dir << "Player-SPDZ-Keys-P" << get_party_num();
|
||||
keyfile.open(filename.str().c_str());
|
||||
if (keyfile.fail())
|
||||
throw file_error(filename.str().c_str());
|
||||
|
||||
keyfile.read((char*)server_publickey, sizeof server_publickey);
|
||||
if (keyfile.eof())
|
||||
throw end_of_file(filename.str(), "server public key" );
|
||||
keyfile.read((char*)server_secretkey, sizeof server_secretkey);
|
||||
if (keyfile.eof())
|
||||
throw end_of_file(filename.str(), "server private key" );
|
||||
|
||||
bool loaded_ed25519 = true;
|
||||
|
||||
keyfile.read((char*)server_publickey_ed25519, sizeof server_publickey_ed25519);
|
||||
if (keyfile.eof() || keyfile.bad())
|
||||
loaded_ed25519 = false;
|
||||
keyfile.read((char*)server_secretkey_ed25519, sizeof server_secretkey_ed25519);
|
||||
if (keyfile.eof() || keyfile.bad())
|
||||
loaded_ed25519 = false;
|
||||
|
||||
keyfile.close();
|
||||
|
||||
ed25519_keys_loaded = loaded_ed25519;
|
||||
server_keys_loaded = true;
|
||||
}
|
||||
|
||||
void ExternalClients::require_ed25519_keys()
|
||||
{
|
||||
if (!ed25519_keys_loaded)
|
||||
throw "Ed25519 keys required but not found in player key files";
|
||||
}
|
||||
|
||||
int ExternalClients::get_party_num()
|
||||
{
|
||||
return party_num;
|
||||
}
|
||||
|
||||
int ExternalClients::get_socket(int id)
|
||||
ssl_socket* ExternalClients::get_socket(int id)
|
||||
{
|
||||
if (external_client_sockets.find(id) == external_client_sockets.end())
|
||||
throw runtime_error("external connection not found for id " + to_string(id));
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#define _ExternalClients
|
||||
|
||||
#include "Networking/sockets.h"
|
||||
#include "Networking/ssl_sockets.h"
|
||||
#include "Exceptions/Exceptions.h"
|
||||
#include <vector>
|
||||
#include <map>
|
||||
@@ -23,24 +24,15 @@ class ExternalClients
|
||||
|
||||
int party_num;
|
||||
const string prep_data_dir;
|
||||
int server_connection_count;
|
||||
unsigned char server_publickey[crypto_box_PUBLICKEYBYTES];
|
||||
unsigned char server_secretkey[crypto_box_SECRETKEYBYTES];
|
||||
bool server_keys_loaded = false;
|
||||
bool ed25519_keys_loaded = false;
|
||||
|
||||
// Maps holding per client values (indexed by unique 32-bit id)
|
||||
std::map<int,int> external_client_sockets;
|
||||
std::map<int,ssl_socket*> external_client_sockets;
|
||||
|
||||
ssl_service io_service;
|
||||
ssl_ctx* ctx;
|
||||
|
||||
public:
|
||||
|
||||
unsigned char server_publickey_ed25519[crypto_sign_ed25519_PUBLICKEYBYTES];
|
||||
unsigned char server_secretkey_ed25519[crypto_sign_ed25519_SECRETKEYBYTES];
|
||||
|
||||
std::map<int,octet*> symmetric_client_keys;
|
||||
std::map<int,pair<vector<octet>,uint64_t>> symmetric_client_commsec_send_keys;
|
||||
std::map<int,pair<vector<octet>,uint64_t>> symmetric_client_commsec_recv_keys;
|
||||
|
||||
ExternalClients(int party_num, const string& prep_data_dir);
|
||||
~ExternalClients();
|
||||
|
||||
@@ -48,18 +40,10 @@ class ExternalClients
|
||||
|
||||
int get_client_connection(int portnum_base);
|
||||
|
||||
int connect_to_server(int portnum_base, int ipv4_address);
|
||||
|
||||
// return the socket for a given client or server identifier
|
||||
int get_socket(int socket_id);
|
||||
|
||||
void curve25519_ints_to_bytes(unsigned char bytes[crypto_box_PUBLICKEYBYTES], const vector<int>& key_ints);
|
||||
void generate_session_key_for_client(int client_id, const vector<int>& public_key);
|
||||
|
||||
void load_server_keys_once();
|
||||
ssl_socket* get_socket(int socket_id);
|
||||
|
||||
int get_party_num();
|
||||
void require_ed25519_keys();
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
@@ -5,17 +5,18 @@
|
||||
|
||||
#include "FixInput.h"
|
||||
|
||||
const char* FixInput::NAME = "real number";
|
||||
|
||||
void FixInput::read(std::istream& in, const int* params)
|
||||
template<>
|
||||
void FixInput_<Integer>::read(std::istream& in, const int* params)
|
||||
{
|
||||
#ifdef LOW_PREC_INPUT
|
||||
double x;
|
||||
in >> x;
|
||||
items[0] = x * (1 << *params);
|
||||
#else
|
||||
}
|
||||
|
||||
template<>
|
||||
void FixInput_<bigint>::read(std::istream& in, const int* params)
|
||||
{
|
||||
mpf_class x;
|
||||
in >> x;
|
||||
items[0] = x << *params;
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -11,7 +11,8 @@
|
||||
#include "Math/bigint.h"
|
||||
#include "Math/Integer.h"
|
||||
|
||||
class FixInput
|
||||
template<class T>
|
||||
class FixInput_
|
||||
{
|
||||
public:
|
||||
const static int N_DEST = 1;
|
||||
@@ -20,13 +21,18 @@ public:
|
||||
|
||||
const static int TYPE = 1;
|
||||
|
||||
#ifdef LOW_PREC_INPUT
|
||||
Integer items[N_DEST];
|
||||
#else
|
||||
bigint items[N_DEST];
|
||||
#endif
|
||||
T items[N_DEST];
|
||||
|
||||
void read(std::istream& in, const int* params);
|
||||
};
|
||||
|
||||
template<class T>
|
||||
const char* FixInput_<T>::NAME = "real number";
|
||||
|
||||
#ifdef LOW_PREC_INPUT
|
||||
typedef FixInput_<Integer> FixInput;
|
||||
#else
|
||||
typedef FixInput_<bigint> FixInput;
|
||||
#endif
|
||||
|
||||
#endif /* PROCESSOR_FIXINPUT_H_ */
|
||||
|
||||
@@ -57,6 +57,8 @@ public:
|
||||
virtual T finalize_mine() = 0;
|
||||
virtual void finalize_other(int player, T& target, octetStream& o, int n_bits = -1) = 0;
|
||||
T finalize(int player, int n_bits = -1);
|
||||
|
||||
void raw_input(SubProcessor<T>& proc, const vector<int>& args);
|
||||
};
|
||||
|
||||
template<class T>
|
||||
@@ -88,9 +90,6 @@ public:
|
||||
|
||||
T finalize_mine();
|
||||
void finalize_other(int player, T& target, octetStream& o, int n_bits = -1);
|
||||
|
||||
void start(int player, int n_inputs);
|
||||
void stop(int player, const vector<int>& targets);
|
||||
};
|
||||
|
||||
#endif /* PROCESSOR_INPUT_H_ */
|
||||
|
||||
@@ -100,7 +100,7 @@ template<class T>
|
||||
void Input<T>::add_other(int player)
|
||||
{
|
||||
open_type t;
|
||||
shares[player].push_back({});
|
||||
shares.at(player).push_back({});
|
||||
prep.get_input(shares[player].back(), t, player);
|
||||
}
|
||||
|
||||
@@ -131,12 +131,16 @@ void InputBase<T>::exchange()
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Input<T>::start(int player, int n_inputs)
|
||||
void InputBase<T>::raw_input(SubProcessor<T>& proc, const vector<int>& args)
|
||||
{
|
||||
reset(player);
|
||||
if (player == P.my_num())
|
||||
auto& P = proc.P;
|
||||
reset_all(P);
|
||||
|
||||
for (auto it = args.begin(); it != args.end();)
|
||||
{
|
||||
for (int i = 0; i < n_inputs; i++)
|
||||
int player = *it++;
|
||||
it++;
|
||||
if (player == P.my_num())
|
||||
{
|
||||
clear t;
|
||||
try
|
||||
@@ -149,33 +153,21 @@ void Input<T>::start(int player, int n_inputs)
|
||||
}
|
||||
add_mine(t);
|
||||
}
|
||||
send_mine();
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i = 0; i < n_inputs; i++)
|
||||
add_other(player);
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Input<T>::stop(int player, const vector<int>& targets)
|
||||
{
|
||||
assert(proc != 0);
|
||||
if (P.my_num() == player)
|
||||
for (unsigned int i = 0; i < targets.size(); i++)
|
||||
proc->get_S_ref(targets[i]) = finalize_mine();
|
||||
else
|
||||
{
|
||||
octetStream o;
|
||||
this->timer.start();
|
||||
P.receive_player(player, o, true);
|
||||
this->timer.stop();
|
||||
for (unsigned int i = 0; i < targets.size(); i++)
|
||||
else
|
||||
{
|
||||
finalize_other(player, proc->get_S_ref(targets[i]), o);
|
||||
add_other(player);
|
||||
}
|
||||
}
|
||||
|
||||
timer.start();
|
||||
exchange();
|
||||
timer.stop();
|
||||
|
||||
for (auto it = args.begin(); it != args.end();)
|
||||
{
|
||||
int player = *it++;
|
||||
proc.get_S_ref(*it++) = finalize(player);
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
|
||||
@@ -115,6 +115,7 @@ enum
|
||||
INPUTFLOAT = 0xF1,
|
||||
INPUTMIXED = 0xF2,
|
||||
INPUTMIXEDREG = 0xF3,
|
||||
RAWINPUT = 0xF4,
|
||||
STARTINPUT = 0x61,
|
||||
STOPINPUT = 0x62,
|
||||
READSOCKETC = 0x63,
|
||||
@@ -248,6 +249,7 @@ enum
|
||||
GSTOPINPUT = 0x162,
|
||||
GREADSOCKETS = 0x164,
|
||||
GWRITESOCKETS = 0x166,
|
||||
GRAWINPUT = 0x1F4,
|
||||
// Bitwise logic
|
||||
GANDC = 0x170,
|
||||
GXORC = 0x171,
|
||||
@@ -328,6 +330,8 @@ public:
|
||||
int get_opcode() const { return opcode; }
|
||||
int get_size() const { return size; }
|
||||
|
||||
// Reads a single instruction from the istream
|
||||
void parse(istream& s, int inst_pos);
|
||||
void parse_operands(istream& s, int pos, int file_pos);
|
||||
|
||||
bool is_gf2n_instruction() const { return ((opcode&0x100)!=0); }
|
||||
@@ -347,9 +351,6 @@ class DataPositions;
|
||||
class Instruction : public BaseInstruction
|
||||
{
|
||||
public:
|
||||
// Reads a single instruction from the istream
|
||||
void parse(istream& s, int inst_pos);
|
||||
|
||||
// Return whether usage is known
|
||||
bool get_offline_data_usage(DataPositions& usage);
|
||||
|
||||
@@ -361,6 +362,5 @@ public:
|
||||
void execute(Processor<sint, sgf2n>& Proc) const;
|
||||
};
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@
|
||||
#include "Tools/callgrind.h"
|
||||
|
||||
inline
|
||||
void Instruction::parse(istream& s, int inst_pos)
|
||||
void BaseInstruction::parse(istream& s, int inst_pos)
|
||||
{
|
||||
n=0; start.resize(0);
|
||||
r[0]=0; r[1]=0; r[2]=0; r[3]=0;
|
||||
@@ -224,7 +224,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case STARTPRIVATEOUTPUT:
|
||||
case GSTARTPRIVATEOUTPUT:
|
||||
case DIGESTC:
|
||||
case CONNECTIPV4: // write socket handle, read IPv4 address, portnum
|
||||
r[0]=get_int(s);
|
||||
r[1]=get_int(s);
|
||||
n = get_int(s);
|
||||
@@ -254,8 +253,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case PRINTREGB:
|
||||
case GPRINTREG:
|
||||
case LDINT:
|
||||
case STARTINPUT:
|
||||
case GSTARTINPUT:
|
||||
case STOPPRIVATEOUTPUT:
|
||||
case GSTOPPRIVATEOUTPUT:
|
||||
case INPUTMASK:
|
||||
@@ -310,6 +307,8 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case INPUTFLOAT:
|
||||
case INPUTMIXED:
|
||||
case INPUTMIXEDREG:
|
||||
case RAWINPUT:
|
||||
case GRAWINPUT:
|
||||
case TRUNC_PR:
|
||||
num_var_args = get_int(s);
|
||||
get_vector(num_var_args, start, s);
|
||||
@@ -336,7 +335,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
case READSOCKETC:
|
||||
case READSOCKETS:
|
||||
case READSOCKETINT:
|
||||
case READCLIENTPUBLICKEY:
|
||||
num_var_args = get_int(s) - 1;
|
||||
r[0] = get_int(s);
|
||||
get_vector(num_var_args, start, s);
|
||||
@@ -352,20 +350,18 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
r[1] = get_int(s);
|
||||
get_vector(num_var_args, start, s);
|
||||
break;
|
||||
case CONNECTIPV4:
|
||||
throw runtime_error("parties as clients not supported any more");
|
||||
case READCLIENTPUBLICKEY:
|
||||
case INITSECURESOCKET:
|
||||
case RESPSECURESOCKET:
|
||||
num_var_args = get_int(s) - 1;
|
||||
r[0] = get_int(s);
|
||||
get_vector(num_var_args, start, s);
|
||||
break;
|
||||
throw runtime_error("VM-controlled encryption not supported any more");
|
||||
// raw input
|
||||
case STARTINPUT:
|
||||
case GSTARTINPUT:
|
||||
case STOPINPUT:
|
||||
case GSTOPINPUT:
|
||||
// subtract player number argument
|
||||
num_var_args = get_int(s) - 1;
|
||||
n = get_int(s);
|
||||
get_vector(num_var_args, start, s);
|
||||
break;
|
||||
throw runtime_error("two-stage input not supported any more");
|
||||
case GBITDEC:
|
||||
case GBITCOM:
|
||||
num_var_args = get_int(s) - 2;
|
||||
@@ -621,6 +617,7 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
|
||||
case INPUTB:
|
||||
skip = 4;
|
||||
offset = 3;
|
||||
size_offset = -2;
|
||||
break;
|
||||
case ANDM:
|
||||
size = DIV_CEIL(n, 64);
|
||||
@@ -733,6 +730,7 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
typedef typename sint::bit_type T;
|
||||
auto& processor = Proc.Procb;
|
||||
auto& instruction = *this;
|
||||
auto& Ci = Proc.get_Ci();
|
||||
|
||||
// optimize some instructions
|
||||
switch (opcode)
|
||||
@@ -1292,17 +1290,11 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
case INPUTMIXEDREG:
|
||||
sint::Input::input_mixed(Proc.Procp, start, size, true);
|
||||
return;
|
||||
case STARTINPUT:
|
||||
Proc.Procp.input.start(r[0],n);
|
||||
case RAWINPUT:
|
||||
Proc.Procp.input.raw_input(Proc.Procp, start);
|
||||
break;
|
||||
case GSTARTINPUT:
|
||||
Proc.Proc2.input.start(r[0],n);
|
||||
break;
|
||||
case STOPINPUT:
|
||||
Proc.Procp.input.stop(n,start);
|
||||
break;
|
||||
case GSTOPINPUT:
|
||||
Proc.Proc2.input.stop(n,start);
|
||||
case GRAWINPUT:
|
||||
Proc.Proc2.input.raw_input(Proc.Proc2, start);
|
||||
break;
|
||||
case ANDC:
|
||||
Proc.get_Cp_ref(r[0]).AND(Proc.read_Cp(r[1]),Proc.read_Cp(r[2]));
|
||||
@@ -1670,26 +1662,16 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
ss << "No connection on port " << r[0] << endl;
|
||||
throw Processor_Error(ss.str());
|
||||
}
|
||||
if (Proc.P.my_num() == 0)
|
||||
{
|
||||
octetStream os;
|
||||
os.store(int(sint::open_type::type_char()));
|
||||
sint::open_type::specification(os);
|
||||
os.Send(Proc.external_clients.get_socket(client_handle));
|
||||
}
|
||||
Proc.write_Ci(r[0], client_handle);
|
||||
break;
|
||||
}
|
||||
case CONNECTIPV4:
|
||||
{
|
||||
// connect to server at port n + my_num()
|
||||
int ipv4 = Proc.read_Ci(r[1]);
|
||||
int server_handle = Proc.external_clients.connect_to_server(n, ipv4);
|
||||
Proc.write_Ci(r[0], server_handle);
|
||||
break;
|
||||
}
|
||||
case READCLIENTPUBLICKEY:
|
||||
Proc.read_client_public_key(Proc.read_Ci(r[0]), start);
|
||||
break;
|
||||
case INITSECURESOCKET:
|
||||
Proc.init_secure_socket(Proc.read_Ci(r[i]), start);
|
||||
break;
|
||||
case RESPSECURESOCKET:
|
||||
Proc.resp_secure_socket(Proc.read_Ci(r[i]), start);
|
||||
break;
|
||||
case READSOCKETINT:
|
||||
Proc.read_socket_ints(Proc.read_Ci(r[0]), start);
|
||||
break;
|
||||
|
||||
@@ -35,7 +35,6 @@ Machine<sint, sgf2n>::Machine(int my_number, Names& playerNames,
|
||||
|
||||
// Set up the fields
|
||||
prep_dir_prefix = get_prep_dir(N.num_players(), opts.lgp, lg2);
|
||||
char filename[2048];
|
||||
bool read_mac_keys = false;
|
||||
|
||||
sgf2n::clear::init_field(lg2);
|
||||
@@ -96,16 +95,7 @@ Machine<sint, sgf2n>::Machine(int my_number, Names& playerNames,
|
||||
sint::clear::next::template init<typename sint::clear>(false);
|
||||
|
||||
// Initialize the global memory
|
||||
if (memtype.compare("new")==0)
|
||||
{sprintf(filename, PREP_DIR "Player-Memory-P%d", my_number);
|
||||
ifstream memfile(filename);
|
||||
if (memfile.fail()) { throw file_error(filename); }
|
||||
M2.Load_Memory(memfile);
|
||||
Mp.Load_Memory(memfile);
|
||||
Mi.Load_Memory(memfile);
|
||||
memfile.close();
|
||||
}
|
||||
else if (memtype.compare("old")==0)
|
||||
if (memtype.compare("old")==0)
|
||||
{
|
||||
inpf.open(memory_filename(), ios::in | ios::binary);
|
||||
if (inpf.fail()) { throw file_error(memory_filename()); }
|
||||
|
||||
@@ -77,21 +77,6 @@ class Memory
|
||||
|
||||
friend ostream& operator<< <>(ostream& s,const Memory<T>& M);
|
||||
friend istream& operator>> <>(istream& s,Memory<T>& M);
|
||||
|
||||
/* This function loads a un-shared global memory from disk and
|
||||
* produces the memory
|
||||
*
|
||||
* The global unshared memory is of the form
|
||||
* sz <- Size
|
||||
* n val <- Clear values
|
||||
* n val <- Clear values
|
||||
* -1 -1 <- End of clear values
|
||||
* n val <- Shared values
|
||||
* n val <- Shared values
|
||||
* -1 -1
|
||||
*/
|
||||
void Load_Memory(ifstream& inpf);
|
||||
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
@@ -106,40 +106,3 @@ istream& operator>>(istream& s,Memory<T>& M)
|
||||
|
||||
return s;
|
||||
}
|
||||
|
||||
|
||||
template<class T>
|
||||
void Memory<T>::Load_Memory(ifstream& inpf)
|
||||
{
|
||||
Memory<T>& M = *this;
|
||||
|
||||
int a;
|
||||
typename T::clear val;
|
||||
T S;
|
||||
|
||||
inpf >> a;
|
||||
M.resize_s(a);
|
||||
inpf >> a;
|
||||
M.resize_c(a);
|
||||
|
||||
cerr << "Reading Clear Memory" << endl;
|
||||
|
||||
// Read clear memory
|
||||
inpf >> a;
|
||||
val.input(inpf,true);
|
||||
while (a!=-1)
|
||||
{ M.write_C(a,val);
|
||||
inpf >> a;
|
||||
val.input(inpf,true);
|
||||
}
|
||||
cerr << "Reading Shared Memory" << endl;
|
||||
|
||||
// Read shared memory
|
||||
inpf >> a;
|
||||
S.input(inpf,true);
|
||||
while (a!=-1)
|
||||
{ M.write_S(a,S);
|
||||
inpf >> a;
|
||||
S.input(inpf,true);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -112,6 +112,10 @@ public:
|
||||
|
||||
OnlineOptions opts;
|
||||
|
||||
ArithmeticProcessor() :
|
||||
ArithmeticProcessor(OnlineOptions::singleton, BaseMachine::thread_num)
|
||||
{
|
||||
}
|
||||
ArithmeticProcessor(OnlineOptions opts, int thread_num) : thread_num(thread_num),
|
||||
sent(0), rounds(0), opts(opts) {}
|
||||
|
||||
@@ -217,12 +221,6 @@ class Processor : public ArithmeticProcessor
|
||||
|
||||
// Access to external client sockets for reading clear/shared data
|
||||
void read_socket_ints(int client_id, const vector<int>& registers);
|
||||
// Setup client public key
|
||||
void read_client_public_key(int client_id, const vector<int>& registers);
|
||||
void init_secure_socket(int client_id, const vector<int>& registers);
|
||||
void init_secure_socket_internal(int client_id, const vector<int>& registers);
|
||||
void resp_secure_socket(int client_id, const vector<int>& registers);
|
||||
void resp_secure_socket_internal(int client_id, const vector<int>& registers);
|
||||
|
||||
void write_socket(const RegType reg_type, const SecrecyType secrecy_type, const bool send_macs,
|
||||
int socket_id, int message_type, const vector<int>& registers);
|
||||
@@ -239,8 +237,6 @@ class Processor : public ArithmeticProcessor
|
||||
friend ostream& operator<<(ostream& s,const Processor<T, U>& P);
|
||||
|
||||
private:
|
||||
void maybe_decrypt_sequence(int client_id);
|
||||
void maybe_encrypt_sequence(int client_id);
|
||||
|
||||
template<class T> friend class SPDZ;
|
||||
template<class T> friend class SubProcessor;
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
#include "Processor/Processor.h"
|
||||
#include "Processor/Program.h"
|
||||
#include "Networking/STS.h"
|
||||
#include "Protocols/fake-stuff.h"
|
||||
#include "GC/square64.h"
|
||||
|
||||
@@ -68,7 +67,7 @@ Processor<sint, sgf2n>::Processor(int thread_num,Player& P,
|
||||
Procb(machine.bit_memories),
|
||||
Proc2(*this,MC2,DataF.DataF2,P),Procp(*this,MCp,DataF.DataFp,P),
|
||||
privateOutput2(Proc2),privateOutputp(Procp),
|
||||
external_clients(ExternalClients(P.my_num(), machine.prep_dir_prefix)),
|
||||
external_clients(P.my_num(), machine.prep_dir_prefix),
|
||||
binary_file_io(Binary_File_IO())
|
||||
{
|
||||
reset(program,0);
|
||||
@@ -222,7 +221,6 @@ void Processor<sint, sgf2n>::split(const Instruction& instruction)
|
||||
// RegType and SecrecyType determines how registers are read and the socket stream is packed.
|
||||
// If message_type is > 0, send message_type in bytes 0 - 3, to allow an external client to
|
||||
// determine the data structure being sent in a message.
|
||||
// Encryption is enabled if key material (for DH Auth Encryption and/or STS protocol) has been already setup.
|
||||
template<class sint, class sgf2n>
|
||||
void Processor<sint, sgf2n>::write_socket(const RegType reg_type, const SecrecyType secrecy_type, const bool send_macs,
|
||||
int socket_id, int message_type, const vector<int>& registers)
|
||||
@@ -239,7 +237,11 @@ void Processor<sint, sgf2n>::write_socket(const RegType reg_type, const SecrecyT
|
||||
{
|
||||
if (reg_type == MODP && secrecy_type == SECRET) {
|
||||
// Send vector of secret shares and optionally macs
|
||||
get_Sp_ref(registers[i]).pack(socket_stream, send_macs);
|
||||
if (send_macs)
|
||||
get_Sp_ref(registers[i]).pack(socket_stream);
|
||||
else
|
||||
get_Sp_ref(registers[i]).pack(socket_stream,
|
||||
sint::get_rec_factor(P.my_num(), P.num_players()));
|
||||
}
|
||||
else if (reg_type == MODP && secrecy_type == CLEAR) {
|
||||
// Send vector of clear public field elements
|
||||
@@ -257,15 +259,7 @@ void Processor<sint, sgf2n>::write_socket(const RegType reg_type, const SecrecyT
|
||||
}
|
||||
}
|
||||
|
||||
// Apply DH Auth encryption if session keys have been created.
|
||||
map<int,octet*>::iterator it = external_clients.symmetric_client_keys.find(socket_id);
|
||||
if (it != external_clients.symmetric_client_keys.end()) {
|
||||
socket_stream.encrypt(it->second);
|
||||
}
|
||||
|
||||
// Apply STS commsec encryption if session keys have been created.
|
||||
try {
|
||||
maybe_encrypt_sequence(socket_id);
|
||||
socket_stream.Send(external_clients.get_socket(socket_id));
|
||||
}
|
||||
catch (bad_value& e) {
|
||||
@@ -282,7 +276,6 @@ void Processor<sint, sgf2n>::read_socket_ints(int client_id, const vector<int>&
|
||||
int m = registers.size();
|
||||
socket_stream.reset_write_head();
|
||||
socket_stream.Receive(external_clients.get_socket(client_id));
|
||||
maybe_decrypt_sequence(client_id);
|
||||
for (int i = 0; i < m; i++)
|
||||
{
|
||||
int val;
|
||||
@@ -298,7 +291,6 @@ void Processor<sint, sgf2n>::read_socket_vector(int client_id, const vector<int>
|
||||
int m = registers.size();
|
||||
socket_stream.reset_write_head();
|
||||
socket_stream.Receive(external_clients.get_socket(client_id));
|
||||
maybe_decrypt_sequence(client_id);
|
||||
for (int i = 0; i < m; i++)
|
||||
{
|
||||
get_Cp_ref(registers[i]).unpack(socket_stream);
|
||||
@@ -312,146 +304,13 @@ void Processor<sint, sgf2n>::read_socket_private(int client_id, const vector<int
|
||||
int m = registers.size();
|
||||
socket_stream.reset_write_head();
|
||||
socket_stream.Receive(external_clients.get_socket(client_id));
|
||||
maybe_decrypt_sequence(client_id);
|
||||
|
||||
map<int,octet*>::iterator it = external_clients.symmetric_client_keys.find(client_id);
|
||||
if (it != external_clients.symmetric_client_keys.end())
|
||||
{
|
||||
socket_stream.decrypt(it->second);
|
||||
}
|
||||
for (int i = 0; i < m; i++)
|
||||
{
|
||||
get_Sp_ref(registers[i]).unpack(socket_stream, read_macs);
|
||||
}
|
||||
}
|
||||
|
||||
// Read socket for client public key as 8 ints, calculate session key for client.
|
||||
template<class sint, class sgf2n>
|
||||
void Processor<sint, sgf2n>::read_client_public_key(int client_id, const vector<int>& registers) {
|
||||
|
||||
read_socket_ints(client_id, registers);
|
||||
|
||||
// After read into registers, need to extract values
|
||||
vector<int> client_public_key (registers.size(), 0);
|
||||
for(unsigned int i = 0; i < registers.size(); i++) {
|
||||
client_public_key[i] = (int&)get_Ci_ref(registers[i]);
|
||||
}
|
||||
|
||||
external_clients.generate_session_key_for_client(client_id, client_public_key);
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
void Processor<sint, sgf2n>::init_secure_socket_internal(int client_id, const vector<int>& registers) {
|
||||
external_clients.symmetric_client_commsec_send_keys.erase(client_id);
|
||||
external_clients.symmetric_client_commsec_recv_keys.erase(client_id);
|
||||
unsigned char client_public_bytes[crypto_sign_PUBLICKEYBYTES];
|
||||
sts_msg1_t m1;
|
||||
sts_msg2_t m2;
|
||||
sts_msg3_t m3;
|
||||
|
||||
external_clients.load_server_keys_once();
|
||||
external_clients.require_ed25519_keys();
|
||||
|
||||
// Validate inputs and state
|
||||
if(registers.size() != 8) {
|
||||
throw "Invalid call to init_secure_socket.";
|
||||
}
|
||||
|
||||
// Extract client long term public key into bytes
|
||||
vector<int> client_public_key (registers.size(), 0);
|
||||
for(unsigned int i = 0; i < registers.size(); i++) {
|
||||
client_public_key[i] = (int&)get_Ci_ref(registers[i]);
|
||||
}
|
||||
external_clients.curve25519_ints_to_bytes(client_public_bytes, client_public_key);
|
||||
|
||||
// Start Station to Station Protocol
|
||||
STS ke(client_public_bytes, external_clients.server_publickey_ed25519, external_clients.server_secretkey_ed25519);
|
||||
m1 = ke.send_msg1();
|
||||
socket_stream.reset_write_head();
|
||||
socket_stream.append(m1.bytes, sizeof m1.bytes);
|
||||
socket_stream.Send(external_clients.get_socket(client_id));
|
||||
socket_stream.ReceiveExpected(external_clients.get_socket(client_id),
|
||||
96);
|
||||
socket_stream.consume(m2.pubkey, sizeof m2.pubkey);
|
||||
socket_stream.consume(m2.sig, sizeof m2.sig);
|
||||
m3 = ke.recv_msg2(m2);
|
||||
socket_stream.reset_write_head();
|
||||
socket_stream.append(m3.bytes, sizeof m3.bytes);
|
||||
socket_stream.Send(external_clients.get_socket(client_id));
|
||||
|
||||
// Use results of STS to generate send and receive keys.
|
||||
vector<unsigned char> sendKey = ke.derive_secret(crypto_secretbox_KEYBYTES);
|
||||
vector<unsigned char> recvKey = ke.derive_secret(crypto_secretbox_KEYBYTES);
|
||||
external_clients.symmetric_client_commsec_send_keys[client_id] = make_pair(sendKey,0);
|
||||
external_clients.symmetric_client_commsec_recv_keys[client_id] = make_pair(recvKey,0);
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
void Processor<sint, sgf2n>::init_secure_socket(int client_id, const vector<int>& registers) {
|
||||
|
||||
try {
|
||||
init_secure_socket_internal(client_id, registers);
|
||||
} catch (char const *e) {
|
||||
cerr << "STS initiator role failed with: " << e << endl;
|
||||
throw Processor_Error("STS initiator failed");
|
||||
}
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
void Processor<sint, sgf2n>::resp_secure_socket(int client_id, const vector<int>& registers) {
|
||||
try {
|
||||
resp_secure_socket_internal(client_id, registers);
|
||||
} catch (char const *e) {
|
||||
cerr << "STS responder role failed with: " << e << endl;
|
||||
throw Processor_Error("STS responder failed");
|
||||
}
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
void Processor<sint, sgf2n>::resp_secure_socket_internal(int client_id, const vector<int>& registers) {
|
||||
external_clients.symmetric_client_commsec_send_keys.erase(client_id);
|
||||
external_clients.symmetric_client_commsec_recv_keys.erase(client_id);
|
||||
unsigned char client_public_bytes[crypto_sign_PUBLICKEYBYTES];
|
||||
sts_msg1_t m1;
|
||||
sts_msg2_t m2;
|
||||
sts_msg3_t m3;
|
||||
|
||||
external_clients.load_server_keys_once();
|
||||
external_clients.require_ed25519_keys();
|
||||
|
||||
// Validate inputs and state
|
||||
if(registers.size() != 8) {
|
||||
throw "Invalid call to init_secure_socket.";
|
||||
}
|
||||
vector<int> client_public_key (registers.size(), 0);
|
||||
for(unsigned int i = 0; i < registers.size(); i++) {
|
||||
client_public_key[i] = (int&)get_Ci_ref(registers[i]);
|
||||
}
|
||||
external_clients.curve25519_ints_to_bytes(client_public_bytes, client_public_key);
|
||||
|
||||
// Start Station to Station Protocol for the responder
|
||||
STS ke(client_public_bytes, external_clients.server_publickey_ed25519, external_clients.server_secretkey_ed25519);
|
||||
socket_stream.reset_read_head();
|
||||
socket_stream.ReceiveExpected(external_clients.get_socket(client_id),
|
||||
32);
|
||||
socket_stream.consume(m1.bytes, sizeof m1.bytes);
|
||||
m2 = ke.recv_msg1(m1);
|
||||
socket_stream.reset_write_head();
|
||||
socket_stream.append(m2.pubkey, sizeof m2.pubkey);
|
||||
socket_stream.append(m2.sig, sizeof m2.sig);
|
||||
socket_stream.Send(external_clients.get_socket(client_id));
|
||||
|
||||
socket_stream.ReceiveExpected(external_clients.get_socket(client_id),
|
||||
64);
|
||||
socket_stream.consume(m3.bytes, sizeof m3.bytes);
|
||||
ke.recv_msg3(m3);
|
||||
|
||||
// Use results of STS to generate send and receive keys.
|
||||
vector<unsigned char> recvKey = ke.derive_secret(crypto_secretbox_KEYBYTES);
|
||||
vector<unsigned char> sendKey = ke.derive_secret(crypto_secretbox_KEYBYTES);
|
||||
external_clients.symmetric_client_commsec_recv_keys[client_id] = make_pair(recvKey,0);
|
||||
external_clients.symmetric_client_commsec_send_keys[client_id] = make_pair(sendKey,0);
|
||||
}
|
||||
|
||||
// Read share data from a file starting at file_pos until registers filled.
|
||||
// file_pos_register is written with new file position (-1 is eof).
|
||||
@@ -722,26 +581,4 @@ ostream& operator<<(ostream& s,const Processor<sint, sgf2n>& P)
|
||||
return s;
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
void Processor<sint, sgf2n>::maybe_decrypt_sequence(int client_id)
|
||||
{
|
||||
map<int, pair<vector<octet>,uint64_t> >::iterator it_cs = external_clients.symmetric_client_commsec_recv_keys.find(client_id);
|
||||
if (it_cs != external_clients.symmetric_client_commsec_recv_keys.end())
|
||||
{
|
||||
socket_stream.decrypt_sequence(&it_cs->second.first[0], it_cs->second.second);
|
||||
it_cs->second.second++;
|
||||
}
|
||||
}
|
||||
|
||||
template<class sint, class sgf2n>
|
||||
void Processor<sint, sgf2n>::maybe_encrypt_sequence(int client_id)
|
||||
{
|
||||
map<int, pair<vector<octet>,uint64_t> >::iterator it_cs = external_clients.symmetric_client_commsec_send_keys.find(client_id);
|
||||
if (it_cs != external_clients.symmetric_client_commsec_send_keys.end())
|
||||
{
|
||||
socket_stream.encrypt_sequence(&it_cs->second.first[0], it_cs->second.second);
|
||||
it_cs->second.second++;
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
1
Programs/Circuits
Submodule
1
Programs/Circuits
Submodule
Submodule Programs/Circuits added at 82dfda9d12
8
Programs/Source/aes_circuit.mpc
Normal file
8
Programs/Source/aes_circuit.mpc
Normal file
@@ -0,0 +1,8 @@
|
||||
from circuit import Circuit
|
||||
sb128 = sbits.get_type(128)
|
||||
key = sb128(0x2b7e151628aed2a6abf7158809cf4f3c)
|
||||
plaintext = sb128(0x6bc1bee22e409f96e93d7e117393172a)
|
||||
n = 1000
|
||||
aes128 = Circuit('aes_128')
|
||||
ciphertexts = aes128(sbitvec([key] * n), sbitvec([plaintext] * n))
|
||||
ciphertexts.elements()[n - 1].reveal().print_reg()
|
||||
@@ -4,8 +4,6 @@
|
||||
to deduce the maximum value from a range of integer input.
|
||||
|
||||
Demonstrate clients external to computing parties supplying input and receiving an authenticated result. See bankers-bonus-client.cpp for client (and setup instructions).
|
||||
|
||||
For an implementation with communications security see bankers_bonus_commsec.mpc.
|
||||
|
||||
Wait for MAX_NUM_CLIENTS to join the game or client finish flag to be sent
|
||||
before calculating the maximum.
|
||||
|
||||
@@ -1,144 +0,0 @@
|
||||
# coding=latin1
|
||||
|
||||
"""
|
||||
Solve Bankers bonus, aka Millionaires problem.
|
||||
to deduce the maximum value from a range of integer input.
|
||||
|
||||
Demonstrate clients external to computing parties supplying input and receiving
|
||||
an authenticated result. See bankers-bonus-commsec-client.cpp for client (and setup instructions).
|
||||
|
||||
For an implementation without communications security see bankers_bonus.mpc.
|
||||
|
||||
Wait for MAX_NUM_CLIENTS to join the game or client finish flag to be sent
|
||||
before calculating the maximum.
|
||||
|
||||
Note each client connects in a single thread and so is potentially blocked.
|
||||
|
||||
Each round / game will reset and so this runs indefinitiely.
|
||||
"""
|
||||
|
||||
from Compiler.types import sint, regint, Array, Matrix, MemValue
|
||||
from Compiler.instructions import listen, acceptclientconnection
|
||||
from Compiler.library import print_ln, do_while, if_e, else_, for_range
|
||||
from Compiler.util import if_else
|
||||
|
||||
PORTNUM = 14000
|
||||
MAX_NUM_CLIENTS = 8
|
||||
n_rounds = 0
|
||||
|
||||
if len(program.args) > 1:
|
||||
n_rounds = int(program.args[1])
|
||||
|
||||
def accept_client():
|
||||
client_socket_id = regint()
|
||||
acceptclientconnection(client_socket_id, PORTNUM)
|
||||
last = regint.read_from_socket(client_socket_id)
|
||||
|
||||
# Crypto setup
|
||||
public_signing_key = regint.read_from_socket(client_socket_id, 8)
|
||||
public_key = regint.read_client_public_key(client_socket_id)
|
||||
regint.resp_secure_socket(client_socket_id,*public_signing_key)
|
||||
|
||||
return client_socket_id, last
|
||||
|
||||
def client_input(client_socket_id):
|
||||
"""
|
||||
Send share of random value, receive input and deduce share.
|
||||
"""
|
||||
|
||||
client_inputs = sint.receive_from_client(1, client_socket_id)
|
||||
|
||||
return client_inputs[0]
|
||||
|
||||
|
||||
def determine_winner(number_clients, client_values, client_ids):
|
||||
"""Work out and return client_id which corresponds to max client_value"""
|
||||
max_value = Array(1, sint)
|
||||
max_value[0] = client_values[0]
|
||||
win_client_id = Array(1, sint)
|
||||
win_client_id[0] = client_ids[0]
|
||||
|
||||
@for_range(number_clients-1)
|
||||
def loop_body(i):
|
||||
# Is this client input a new maximum, will be sint(1) if true, else sint(0)
|
||||
is_new_max = max_value[0] < client_values[i+1]
|
||||
# Keep latest max_value
|
||||
max_value[0] = if_else(is_new_max, client_values[i+1], max_value[0])
|
||||
# Keep current winning client id
|
||||
win_client_id[0] = if_else(is_new_max, client_ids[i+1], win_client_id[0])
|
||||
|
||||
return win_client_id[0]
|
||||
|
||||
|
||||
def write_winner_to_clients(sockets, number_clients, winning_client_id):
|
||||
"""Send share of winning client id to all clients who joined game."""
|
||||
|
||||
# Setup authenticate result using share of random.
|
||||
# client can validate ∑ winning_client_id * ∑ rnd_from_triple = ∑ auth_result
|
||||
rnd_from_triple = sint.get_random_triple()[0]
|
||||
auth_result = winning_client_id * rnd_from_triple
|
||||
|
||||
@for_range(number_clients)
|
||||
def loop_body(i):
|
||||
sint.write_shares_to_socket(sockets[i], [winning_client_id, rnd_from_triple, auth_result])
|
||||
|
||||
|
||||
def main():
|
||||
"""Listen in while loop for players to join a game.
|
||||
Once maxiumum reached or have notified that round finished, run comparison and return result."""
|
||||
# Start listening for client socket connections
|
||||
listen(PORTNUM)
|
||||
print_ln('Listening for client connections on base port %s', PORTNUM)
|
||||
|
||||
def game_loop(_=None):
|
||||
print_ln('Starting a new round of the game.')
|
||||
|
||||
# Clients socket id (integer).
|
||||
client_sockets = Array(MAX_NUM_CLIENTS, regint)
|
||||
# Number of clients
|
||||
number_clients = MemValue(regint(0))
|
||||
# Clients secret input.
|
||||
client_values = Array(MAX_NUM_CLIENTS, sint)
|
||||
# Client ids to identity client
|
||||
client_ids = Array(MAX_NUM_CLIENTS, sint)
|
||||
# Keep track of received inputs
|
||||
seen = Array(MAX_NUM_CLIENTS, regint)
|
||||
seen.assign_all(0)
|
||||
|
||||
# Loop round waiting for each client to connect
|
||||
@do_while
|
||||
def client_connections():
|
||||
client_id, last = accept_client()
|
||||
@if_(client_id >= MAX_NUM_CLIENTS)
|
||||
def _():
|
||||
print_ln('client id too high')
|
||||
crash()
|
||||
client_sockets[client_id] = client_id
|
||||
client_ids[client_id] = client_id
|
||||
seen[client_id] = 1
|
||||
@if_(last == 1)
|
||||
def _():
|
||||
number_clients.write(client_id + 1)
|
||||
|
||||
return (sum(seen) < number_clients) + (number_clients == 0)
|
||||
|
||||
@for_range(number_clients)
|
||||
def _(client_id):
|
||||
client_values[client_id] = client_input(client_id)
|
||||
|
||||
winning_client_id = determine_winner(number_clients, client_values, client_ids)
|
||||
|
||||
print_ln('Found winner, index: %s.', winning_client_id.reveal())
|
||||
|
||||
write_winner_to_clients(client_sockets, number_clients, winning_client_id)
|
||||
|
||||
return True
|
||||
|
||||
if n_rounds > 0:
|
||||
print('run %d rounds' % n_rounds)
|
||||
for_range(n_rounds)(game_loop)
|
||||
else:
|
||||
print('run forever')
|
||||
do_while(game_loop)
|
||||
|
||||
main()
|
||||
@@ -1,5 +1,6 @@
|
||||
import ml
|
||||
import random
|
||||
import re
|
||||
|
||||
program.use_trunc_pr = True
|
||||
sfix.round_nearest = True
|
||||
@@ -10,6 +11,13 @@ cfix.set_precision(16, 31)
|
||||
N = int(program.args[1])
|
||||
n_features = int(program.args[2])
|
||||
|
||||
n_threads = None
|
||||
|
||||
for arg in program.args:
|
||||
m = re.match('n_threads=(.*)', arg)
|
||||
if m:
|
||||
n_threads = int(m.group(1))
|
||||
|
||||
program.allocated_mem['s'] = 1 + n_features
|
||||
|
||||
b = sfix.load_mem(0)
|
||||
@@ -24,13 +32,15 @@ dense.W.assign_vector(W)
|
||||
print_ln('b=%s W[-1]=%s', dense.b[0].reveal(),
|
||||
dense.W[n_features - 1][0][0].reveal())
|
||||
|
||||
@for_range_opt(n_features)
|
||||
@for_range_opt_multithread(n_threads, n_features)
|
||||
def _(i):
|
||||
@for_range_opt(N)
|
||||
def _(j):
|
||||
dense.X[j][0][i] = sfix.get_input_from(0)
|
||||
|
||||
dense.forward()
|
||||
batch = regint.Array(N)
|
||||
batch.assign(regint.inc(N))
|
||||
dense.forward(batch)
|
||||
|
||||
print_str('predictions: ')
|
||||
|
||||
|
||||
@@ -1,28 +1,52 @@
|
||||
import ml
|
||||
import random
|
||||
import re
|
||||
|
||||
program.use_trunc_pr = True
|
||||
sfix.round_nearest = True
|
||||
|
||||
sfix.set_precision(16, 31)
|
||||
cfix.set_precision(16, 31)
|
||||
sfloat.vlen = sfix.f
|
||||
|
||||
n_epochs = 200
|
||||
n_epochs = 100
|
||||
|
||||
n_normal = int(program.args[1])
|
||||
n_pos = int(program.args[2])
|
||||
n_features = int(program.args[3])
|
||||
|
||||
if 'approx' in program.args:
|
||||
approx = 3
|
||||
elif 'approx5' in program.args:
|
||||
approx = 5
|
||||
else:
|
||||
approx = False
|
||||
|
||||
if 'split' in program.args:
|
||||
program.use_split(3)
|
||||
|
||||
n_threads = None
|
||||
|
||||
for arg in program.args:
|
||||
m = re.match('n_threads=(.*)', arg)
|
||||
if m:
|
||||
n_threads = int(m.group(1))
|
||||
|
||||
debug = 'debug' in program.args
|
||||
|
||||
ml.set_n_threads(n_threads)
|
||||
|
||||
n_examples = n_normal + n_pos
|
||||
N = max(n_normal, n_pos) * 2
|
||||
|
||||
if 'mini' in program.args:
|
||||
batch_size = 32
|
||||
else:
|
||||
batch_size = N
|
||||
|
||||
X_normal = sfix.Matrix(n_normal, n_features)
|
||||
X_pos = sfix.Matrix(n_pos, n_features)
|
||||
|
||||
@for_range_opt(n_features)
|
||||
@for_range_opt_multithread(n_threads, n_features)
|
||||
def _(i):
|
||||
@for_range_opt(n_normal)
|
||||
def _(j):
|
||||
@@ -32,11 +56,11 @@ def _(i):
|
||||
X_pos[j][i] = sfix.get_input_from(0)
|
||||
|
||||
dense = ml.Dense(N, n_features, 1)
|
||||
layers = [dense, ml.Output(N)]
|
||||
layers = [dense, ml.Output(N, approx=approx)]
|
||||
|
||||
sgd = ml.SGD(layers, n_epochs, report_loss=debug)
|
||||
sgd.reset([X_normal, X_pos])
|
||||
sgd.run()
|
||||
sgd.run(batch_size)
|
||||
|
||||
if debug:
|
||||
@for_range(N)
|
||||
|
||||
@@ -32,8 +32,6 @@ sgd = ml.SGD(layers, batch // 128 * 10 , debug=debug, report_loss=False)
|
||||
sgd.reset([X_normal, X_pos])
|
||||
sgd.run(batch_size=batch)
|
||||
|
||||
ml.approx_sigmoid.special = False
|
||||
|
||||
# @for_range(1000)
|
||||
# def _(i):
|
||||
# sgd.backward()
|
||||
|
||||
@@ -82,7 +82,14 @@ if 'quant' in program.args:
|
||||
else:
|
||||
dense = ml.Dense(N, n_features, 1)
|
||||
|
||||
layers = [dense, ml.Output(N, debug=debug, approx='approx' in program.args)]
|
||||
if 'approx' in program.args:
|
||||
approx = 3
|
||||
elif 'approx5' in program.args:
|
||||
approx = 5
|
||||
else:
|
||||
approx = False
|
||||
|
||||
layers = [dense, ml.Output(N, debug=debug, approx=approx)]
|
||||
|
||||
Y = sfix.Array(n_examples)
|
||||
X = sfix.Matrix(n_examples, n_features)
|
||||
|
||||
@@ -97,8 +97,3 @@ test(c[0], 0)
|
||||
test(c[1], 1)
|
||||
test(c[2], 1)
|
||||
test(c[3], 0)
|
||||
|
||||
k = 41
|
||||
a = int(2.9142 * 2**k)
|
||||
alpha = sbitint.get_type(2 * k)(a)
|
||||
test(sbits.bit_compose((alpha >> 64).bit_decompose()[:64]), 0)
|
||||
|
||||
@@ -127,12 +127,14 @@ public:
|
||||
|
||||
void pack(octetStream& os, bool full = true) const
|
||||
{
|
||||
(void)full;
|
||||
FixedVec<T, 2>::pack(os);
|
||||
if (full)
|
||||
FixedVec<T, 2>::pack(os);
|
||||
else
|
||||
(*this)[0].pack(os);
|
||||
}
|
||||
void unpack(octetStream& os, bool full = true)
|
||||
{
|
||||
(void)full;
|
||||
assert(full);
|
||||
FixedVec<T, 2>::unpack(os);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -14,18 +14,14 @@ template <class T>
|
||||
class PrepLessInput : public InputBase<T>
|
||||
{
|
||||
protected:
|
||||
SubProcessor<T>* processor;
|
||||
vector<T> shares;
|
||||
size_t i_share;
|
||||
|
||||
public:
|
||||
PrepLessInput(SubProcessor<T>* proc) :
|
||||
InputBase<T>(proc ? proc->Proc : 0), processor(proc), i_share(0) {}
|
||||
InputBase<T>(proc ? proc->Proc : 0), i_share(0) {}
|
||||
virtual ~PrepLessInput() {}
|
||||
|
||||
void start(int player, int n_inputs);
|
||||
void stop(int player, vector<int> targets);
|
||||
|
||||
virtual void reset(int player) = 0;
|
||||
virtual void add_mine(const typename T::open_type& input,
|
||||
int n_bits = -1) = 0;
|
||||
|
||||
@@ -57,47 +57,6 @@ void ReplicatedInput<T>::exchange()
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void PrepLessInput<T>::start(int player, int n_inputs)
|
||||
{
|
||||
assert(processor != 0);
|
||||
auto& proc = *processor;
|
||||
reset(player);
|
||||
|
||||
if (player == proc.P.my_num())
|
||||
{
|
||||
for (int i = 0; i < n_inputs; i++)
|
||||
{
|
||||
typename T::clear t;
|
||||
this->buffer.input(t);
|
||||
add_mine(t);
|
||||
}
|
||||
|
||||
send_mine();
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void PrepLessInput<T>::stop(int player, vector<int> targets)
|
||||
{
|
||||
assert(processor != 0);
|
||||
auto& proc = *processor;
|
||||
if (proc.P.my_num() == player)
|
||||
{
|
||||
for (unsigned int i = 0; i < targets.size(); i++)
|
||||
proc.get_S_ref(targets[i]) = finalize_mine();
|
||||
}
|
||||
else
|
||||
{
|
||||
octetStream o;
|
||||
this->timer.start();
|
||||
proc.P.receive_player(player, o, true);
|
||||
this->timer.stop();
|
||||
for (unsigned int i = 0; i < targets.size(); i++)
|
||||
finalize_other(player, proc.get_S_ref(targets[i]), o);
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
inline void ReplicatedInput<T>::finalize_other(int player, T& target,
|
||||
octetStream& o, int n_bits)
|
||||
|
||||
@@ -57,6 +57,11 @@ public:
|
||||
return ShamirMachine::s().threshold;
|
||||
}
|
||||
|
||||
static T get_rec_factor(int i, int n)
|
||||
{
|
||||
return Protocol::get_rec_factor(i, n);
|
||||
}
|
||||
|
||||
static ShamirShare constant(T value, int my_num, const T& alphai = {})
|
||||
{
|
||||
return ShamirShare(value, my_num, alphai);
|
||||
@@ -135,14 +140,17 @@ public:
|
||||
throw runtime_error("never call this");
|
||||
}
|
||||
|
||||
void pack(octetStream& os, bool full = true) const
|
||||
void pack(octetStream& os, const T& rec_factor) const
|
||||
{
|
||||
(*this * rec_factor).pack(os);
|
||||
}
|
||||
void pack(octetStream& os) const
|
||||
{
|
||||
(void)full;
|
||||
T::pack(os);
|
||||
}
|
||||
void unpack(octetStream& os, bool full = true)
|
||||
{
|
||||
(void)full;
|
||||
assert(full);
|
||||
T::unpack(os);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -24,6 +24,8 @@ public:
|
||||
template<class T, class U>
|
||||
static void split(vector<U>, vector<int>, int, T*, int, Player&)
|
||||
{ throw runtime_error("split not implemented"); }
|
||||
|
||||
static bool get_rec_factor(int, int) { return false; }
|
||||
};
|
||||
|
||||
#endif /* PROTOCOLS_SHAREINTERFACE_H_ */
|
||||
|
||||
@@ -25,7 +25,9 @@ public:
|
||||
}
|
||||
|
||||
void buffer_triples();
|
||||
void buffer_squares();
|
||||
void buffer_inverses();
|
||||
void buffer_bits();
|
||||
};
|
||||
|
||||
#endif /* PROTOCOLS_SOHOPREP_H_ */
|
||||
|
||||
@@ -70,9 +70,58 @@ void SohoPrep<T>::buffer_triples()
|
||||
ci.element(i)}});
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void SohoPrep<T>::buffer_squares()
|
||||
{
|
||||
|
||||
auto& proc = this->proc;
|
||||
assert(proc != 0);
|
||||
lock.lock();
|
||||
if (not setup)
|
||||
{
|
||||
PlainPlayer P(proc->P.N, T::clear::type_char());
|
||||
basic_setup(P);
|
||||
}
|
||||
lock.unlock();
|
||||
|
||||
Plaintext_<FD> ai(setup->FieldD);
|
||||
SeededPRNG G;
|
||||
ai.randomize(G);
|
||||
Ciphertext Ca = setup->pk.encrypt(ai);
|
||||
octetStream os;
|
||||
Ca.pack(os);
|
||||
|
||||
for (int i = 1; i < proc->P.num_players(); i++)
|
||||
{
|
||||
proc->P.pass_around(os);
|
||||
Ca.add<0>(os);
|
||||
}
|
||||
|
||||
Ciphertext Cc = Ca.mul(setup->pk, Ca);
|
||||
Plaintext_<FD> ci(setup->FieldD);
|
||||
SimpleDistDecrypt<FD> dd(proc->P, *setup);
|
||||
EncCommitBase_<FD> EC;
|
||||
dd.reshare(ci, Cc, EC);
|
||||
|
||||
for (unsigned i = 0; i < ai.num_slots(); i++)
|
||||
this->squares.push_back({{ai.element(i), ci.element(i)}});
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void SohoPrep<T>::buffer_inverses()
|
||||
{
|
||||
assert(this->proc != 0);
|
||||
::buffer_inverses(this->inverses, *this, this->proc->MC, this->proc->P);
|
||||
}
|
||||
|
||||
template<>
|
||||
void SohoPrep<SohoShare<gfp>>::buffer_bits()
|
||||
{
|
||||
buffer_bits_from_squares(*this);
|
||||
}
|
||||
|
||||
template<>
|
||||
void SohoPrep<SohoShare<gf2n_short>>::buffer_bits()
|
||||
{
|
||||
buffer_bits_without_check();
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@
|
||||
#include "MascotPrep.h"
|
||||
#include "RingOnlyPrep.h"
|
||||
#include "Spdz2kShare.h"
|
||||
#include "GC/TinySecret.h"
|
||||
|
||||
template<class T, class U>
|
||||
void bits_from_square_in_ring(vector<T>& bits, int buffer_size, U* bit_prep);
|
||||
|
||||
25
README.md
25
README.md
@@ -147,7 +147,7 @@ compute the preprocessing time for a particular computation.
|
||||
required. This includes mainstream processors released 2014 or later.
|
||||
For older models you need to deactivate the respective
|
||||
extensions in the `ARCH` variable.
|
||||
- To benchmark online-only protocols or Overdrive, add the following line at the top: `MY_CFLAGS = -DINSECURE`
|
||||
- To benchmark online-only protocols or Overdrive offline phases, add the following line at the top: `MY_CFLAGS = -DINSECURE`
|
||||
- `PREP_DIR` should point to should be a local, unversioned directory to store preprocessing data (default is `Player-Data` in the current directory).
|
||||
- For homomorphic encryption, set `USE_NTL = 1`.
|
||||
|
||||
@@ -240,6 +240,29 @@ al.](https://eprint.iacr.org/2020/338) You can activate them by using
|
||||
`-Y` instead of `-X`. Note that this also activates classic daBits
|
||||
when useful.
|
||||
|
||||
#### Bristol Fashion circuits
|
||||
|
||||
Bristol Fashion is the name of a description format of binary circuits
|
||||
used by
|
||||
[SCALE-MAMBA](https://github.com/KULeuven-COSIC/SCALE-MAMBA). You can
|
||||
access such circuits from the high-level language if they are present
|
||||
in `Programs/Circuits`. To run the AES-128 circuit provided with
|
||||
SCALE-MAMBA, you can run the following:
|
||||
```
|
||||
make Programs/Circuits
|
||||
./compile.py aes_circuit
|
||||
Scripts/semi.sh aes_circuit
|
||||
```
|
||||
This downloads the circuit, compiles it to MP-SPDZ bytecode, and runs
|
||||
it as semi-honest two-party computation 1000 times in parallel. It
|
||||
should then output the AES test vector
|
||||
`0x3ad77bb40d7a3660a89ecaf32466ef97`. You can run it with any other
|
||||
protocol as well.
|
||||
|
||||
See the
|
||||
[documentation](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#module-Compiler.circuit)
|
||||
for further examples.
|
||||
|
||||
#### Compiling and running programs from external directories
|
||||
|
||||
Programs can also be edited, compiled and run from any directory with the above basic structure. So for a source file in `./Programs/Source/`, all SPDZ scripts must be run from `./`. The `setup-online.sh` script must also be run from `./` to create the relevant data. For example:
|
||||
|
||||
13
Scripts/setup-clients.sh
Executable file
13
Scripts/setup-clients.sh
Executable file
@@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
|
||||
n=$1
|
||||
|
||||
test -e Player-Data || mkdir Player-Data
|
||||
|
||||
echo Setting up SSL for $n parties
|
||||
|
||||
for i in `seq 0 $[n-1]`; do
|
||||
openssl req -newkey rsa -nodes -x509 -out Player-Data/C$i.pem -keyout Player-Data/C$i.key -subj "/CN=C$i"
|
||||
done
|
||||
|
||||
c_rehash Player-Data
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
make -j4 ecdsa Fake-ECDSA.x
|
||||
make -j4 ecdsa Fake-ECDSA.x secure.x
|
||||
|
||||
run()
|
||||
{
|
||||
@@ -19,8 +19,11 @@ for i in rep mal-rep shamir mal-shamir; do
|
||||
run $i 2
|
||||
done
|
||||
|
||||
./Fake-ECDSA.x
|
||||
|
||||
for i in semi mascot fake-spdz; do
|
||||
for i in semi mascot; do
|
||||
run $i 1
|
||||
done
|
||||
|
||||
if ! ./secure.x; then
|
||||
./Fake-ECDSA.x
|
||||
run fake-spdz 1
|
||||
fi
|
||||
|
||||
107
Tools/Config.cpp
107
Tools/Config.cpp
@@ -1,107 +0,0 @@
|
||||
// Client key file format:
|
||||
// X25519 Public Key
|
||||
// X25519 Secret Key
|
||||
// Ed25519 Public Key
|
||||
// Ed25519 Secret Key
|
||||
// Server 1 X25519 Public Key
|
||||
// Server 1 Ed25519 Public Key
|
||||
// ...
|
||||
// Server N Public Key
|
||||
// Server N Ed25519 Public Key
|
||||
//
|
||||
// Player key file format:
|
||||
// X25519 Public Key
|
||||
// X25519 Secret Key
|
||||
// Ed25519 Public Key
|
||||
// Ed25519 Secret Key
|
||||
// Number of clients [64 bit little endian]
|
||||
// Client 1 X25519 Public Key
|
||||
// Client 1 Ed25519 Public Key
|
||||
// ...
|
||||
// Client N X25519 Public Key
|
||||
// Client N Ed25519 Public Key
|
||||
// Number of servers [64 bit little endian]
|
||||
// Server 1 X25519 Public Key
|
||||
// Server 1 Ed25519 Public Key
|
||||
// ...
|
||||
// Server N X25519 Public Key
|
||||
// Server N Ed25519 Public Key
|
||||
#include "Tools/octetStream.h"
|
||||
#include "Networking/Player.h"
|
||||
#include "Math/gf2n.h"
|
||||
#include "Config.h"
|
||||
#include <sodium.h>
|
||||
#include <vector>
|
||||
#include <iomanip>
|
||||
|
||||
namespace Config {
|
||||
static void output(const vector<octet> &vec, ofstream &of)
|
||||
{
|
||||
copy(vec.begin(), vec.end(), ostreambuf_iterator<char>(of));
|
||||
}
|
||||
|
||||
void putW64le(ofstream &outf, uint64_t nr)
|
||||
{
|
||||
char buf[8];
|
||||
for(int i=0;i<8;i++) {
|
||||
char byte = (uint8_t)(nr >> (i*8));
|
||||
buf[i] = (char)byte;
|
||||
}
|
||||
outf.write(buf,sizeof buf);
|
||||
}
|
||||
|
||||
void write_player_config_file(string config_dir
|
||||
,int player_number, public_key my_pub, secret_key my_priv
|
||||
, public_signing_key my_signing_pub, secret_signing_key my_signing_priv
|
||||
, vector<public_key> client_pubs, vector<public_signing_key> client_signing_pubs
|
||||
, vector<public_key> player_pubs, vector<public_signing_key> player_signing_pubs)
|
||||
{
|
||||
stringstream filename;
|
||||
filename << config_dir << "Player-SPDZ-Keys-P" << player_number;
|
||||
ofstream outf(filename.str().c_str(), ios::out | ios::binary);
|
||||
if (outf.fail())
|
||||
throw file_error(filename.str().c_str());
|
||||
if(crypto_box_PUBLICKEYBYTES != my_pub.size() ||
|
||||
crypto_box_SECRETKEYBYTES != my_priv.size() ||
|
||||
crypto_sign_PUBLICKEYBYTES != my_signing_pub.size() ||
|
||||
crypto_sign_SECRETKEYBYTES != my_signing_priv.size()) {
|
||||
throw "Invalid key sizes";
|
||||
} else if(client_pubs.size() != client_signing_pubs.size()) {
|
||||
throw "Incorrect number of client keys";
|
||||
} else if(player_pubs.size() != player_signing_pubs.size()) {
|
||||
throw "Incorrect number of player keys";
|
||||
} else {
|
||||
for(size_t i = 0; i < client_pubs.size(); i++) {
|
||||
if(crypto_box_PUBLICKEYBYTES != client_pubs[i].size() ||
|
||||
crypto_sign_PUBLICKEYBYTES != client_signing_pubs[i].size()) {
|
||||
throw "Incorrect size of client key.";
|
||||
}
|
||||
}
|
||||
for(size_t i = 0; i < player_pubs.size(); i++) {
|
||||
if(crypto_box_PUBLICKEYBYTES != player_pubs[i].size() ||
|
||||
crypto_sign_PUBLICKEYBYTES != player_signing_pubs[i].size()) {
|
||||
throw "Incorrect size of player key.";
|
||||
}
|
||||
}
|
||||
}
|
||||
// Write public and secret X25519 keys
|
||||
output(my_pub, outf);
|
||||
output(my_priv, outf);
|
||||
output(my_signing_pub, outf);
|
||||
output(my_signing_priv, outf);
|
||||
|
||||
putW64le(outf, (uint64_t)client_pubs.size());
|
||||
// Write all client public keys
|
||||
for (size_t j = 0; j < client_pubs.size(); j++) {
|
||||
output(client_pubs[j], outf);
|
||||
output(client_signing_pubs[j], outf);
|
||||
}
|
||||
putW64le(outf, (uint64_t)player_pubs.size());
|
||||
for (size_t j = 0; j < player_pubs.size(); j++) {
|
||||
output(player_pubs[j], outf);
|
||||
output(player_signing_pubs[j], outf);
|
||||
}
|
||||
outf.flush();
|
||||
outf.close();
|
||||
}
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
#include "Tools/octetStream.h"
|
||||
#include "Networking/Player.h"
|
||||
#include <sodium.h>
|
||||
namespace Config {
|
||||
typedef vector<octet> public_key;
|
||||
typedef vector<octet> public_signing_key;
|
||||
typedef vector<octet> secret_key;
|
||||
typedef vector<octet> secret_signing_key;
|
||||
void write_player_config_file(string config_dir
|
||||
,int player_number, public_key my_pub, secret_key my_priv
|
||||
, public_signing_key my_signing_pub, secret_signing_key my_signing_priv
|
||||
, vector<public_key> client_pubs, vector<public_signing_key> client_signing_pubs
|
||||
, vector<public_key> player_pubs, vector<public_signing_key> player_signing_pubs);
|
||||
void putW64le(ofstream &outf, uint64_t nr);
|
||||
}
|
||||
@@ -105,9 +105,7 @@ bigint octetStream::check_sum(int req_bytes) const
|
||||
bool octetStream::equals(const octetStream& a) const
|
||||
{
|
||||
if (len!=a.len) { return false; }
|
||||
for (size_t i=0; i<len; i++)
|
||||
{ if (data[i]!=a.data[i]) { return false; } }
|
||||
return true;
|
||||
return memcmp(data, a.data, len) == 0;
|
||||
}
|
||||
|
||||
|
||||
@@ -267,106 +265,6 @@ void octetStream::exchange(T send_socket, T receive_socket, octetStream& receive
|
||||
}
|
||||
|
||||
|
||||
void octetStream::store(const vector<int>& v)
|
||||
{
|
||||
store(v.size());
|
||||
for (int x : v)
|
||||
store(x);
|
||||
}
|
||||
|
||||
|
||||
void octetStream::get(vector<int>& v)
|
||||
{
|
||||
size_t size;
|
||||
get(size);
|
||||
v.resize(size);
|
||||
for (int& x : v)
|
||||
get(x);
|
||||
}
|
||||
|
||||
|
||||
// Construct the ciphertext as `crypto_secretbox(pt, counter||random)`
|
||||
void octetStream::encrypt_sequence(const octet* key, uint64_t counter)
|
||||
{
|
||||
octet nonce[crypto_secretbox_NONCEBYTES];
|
||||
int i;
|
||||
int message_len_bytes = len;
|
||||
randombytes_buf(nonce, sizeof nonce);
|
||||
if(counter == UINT64_MAX) {
|
||||
throw Processor_Error("Encryption would overflow counter. Too many messages.");
|
||||
} else {
|
||||
counter++;
|
||||
}
|
||||
for(i=0; i<8; i++) {
|
||||
nonce[i] = uint8_t ((counter >> (8*i)) & 0xFF);
|
||||
}
|
||||
int ciphertext_len = message_len_bytes + crypto_secretbox_MACBYTES;
|
||||
octet ciphertext[ciphertext_len];
|
||||
|
||||
crypto_secretbox_easy(ciphertext, data, message_len_bytes, nonce, key);
|
||||
// append the ciphertext to an empty octet stream
|
||||
reset_read_head();
|
||||
reset_write_head();
|
||||
append(ciphertext, ciphertext_len*sizeof(octet));
|
||||
// append the nonce
|
||||
append(nonce, crypto_secretbox_NONCEBYTES * sizeof(octet));
|
||||
}
|
||||
|
||||
void octetStream::decrypt_sequence(const octet* key, uint64_t counter)
|
||||
{
|
||||
int ciphertext_len = len - crypto_box_NONCEBYTES;
|
||||
const octet *nonce = data + ciphertext_len;
|
||||
int i;
|
||||
uint64_t recvCounter=0;
|
||||
// Numbers are typically 24U + 16U so cast to int is safe.
|
||||
if (len < (int)(crypto_box_NONCEBYTES + crypto_secretbox_MACBYTES))
|
||||
{
|
||||
throw Processor_Error("Cannot decrypt octetStream: ciphertext too short");
|
||||
}
|
||||
for(i=7; i>=0; i--) {
|
||||
recvCounter |= ((uint64_t) *(nonce + i)) << (i*8);
|
||||
}
|
||||
if(recvCounter != counter + 1) {
|
||||
throw Processor_Error("Incorrect counter on stream. Possible MITM.");
|
||||
}
|
||||
if (crypto_secretbox_open_easy(data, data, ciphertext_len, nonce, key) != 0)
|
||||
{
|
||||
throw Processor_Error("octetStream decryption failed!");
|
||||
}
|
||||
rewind_write_head(crypto_box_NONCEBYTES + crypto_secretbox_MACBYTES);
|
||||
//prepare for unpack after decryption by resetting the read head
|
||||
reset_read_head();
|
||||
}
|
||||
|
||||
void octetStream::encrypt(const octet* key)
|
||||
{
|
||||
octet nonce[crypto_secretbox_NONCEBYTES];
|
||||
randombytes_buf(nonce, sizeof nonce);
|
||||
int message_len_bytes = len;
|
||||
resize(len + crypto_secretbox_MACBYTES + crypto_secretbox_NONCEBYTES);
|
||||
|
||||
// Encrypt data in-place
|
||||
crypto_secretbox_easy(data, data, message_len_bytes, nonce, key);
|
||||
// Adjust length to account for MAC, then append nonce
|
||||
len += crypto_secretbox_MACBYTES;
|
||||
append(nonce, sizeof nonce);
|
||||
}
|
||||
|
||||
void octetStream::decrypt(const octet* key)
|
||||
{
|
||||
int ciphertext_len = len - crypto_box_NONCEBYTES;
|
||||
// Numbers are typically 24U + 16U so cast to int is safe.
|
||||
if (len < (int)(crypto_box_NONCEBYTES + crypto_secretbox_MACBYTES))
|
||||
{
|
||||
throw Processor_Error("Cannot decrypt octetStream: ciphertext too short");
|
||||
}
|
||||
if (crypto_secretbox_open_easy(data, data, ciphertext_len, data + ciphertext_len, key) != 0)
|
||||
{
|
||||
throw Processor_Error("octetStream decryption failed!");
|
||||
}
|
||||
rewind_write_head(crypto_box_NONCEBYTES + crypto_secretbox_MACBYTES);
|
||||
}
|
||||
|
||||
void octetStream::input(istream& s)
|
||||
{
|
||||
size_t size;
|
||||
|
||||
@@ -127,6 +127,8 @@ class octetStream
|
||||
|
||||
template<class T>
|
||||
T get();
|
||||
template<class T>
|
||||
void get(T& ans);
|
||||
|
||||
// works for all statically allocated types
|
||||
template <class T>
|
||||
@@ -134,8 +136,10 @@ class octetStream
|
||||
template <class T>
|
||||
void unserialize(T& x) { consume((octet*)&x, sizeof(x)); }
|
||||
|
||||
void store(const vector<int>& v);
|
||||
void get(vector<int>& v);
|
||||
template <class T>
|
||||
void store(const vector<T>& v);
|
||||
template <class T>
|
||||
void get(vector<T>& v);
|
||||
|
||||
void consume(octetStream& s,size_t l)
|
||||
{ s.resize(l);
|
||||
@@ -147,20 +151,8 @@ class octetStream
|
||||
void Send(T socket_num) const;
|
||||
template<class T>
|
||||
void Receive(T socket_num);
|
||||
void ReceiveExpected(int socket_num, size_t expected);
|
||||
|
||||
// In-place authenticated encryption using sodium; key of length crypto_generichash_BYTES
|
||||
// ciphertext = Enc(message) | MAC | counter
|
||||
//
|
||||
// This is much like 'encrypt' but uses a deterministic counter for the nonce,
|
||||
// allowing enforcement of message order.
|
||||
void encrypt_sequence(const octet* key, uint64_t counter);
|
||||
void decrypt_sequence(const octet* key, uint64_t counter);
|
||||
|
||||
// In-place authenticated encryption using sodium; key of length crypto_secretbox_KEYBYTES
|
||||
// ciphertext = Enc(message) | MAC | nonce
|
||||
void encrypt(const octet* key);
|
||||
void decrypt(const octet* key);
|
||||
template<class T>
|
||||
void ReceiveExpected(T socket_num, size_t expected);
|
||||
|
||||
void input(istream& s);
|
||||
void output(ostream& s);
|
||||
@@ -278,7 +270,8 @@ inline void octetStream::Receive(T socket_num)
|
||||
reset_read_head();
|
||||
}
|
||||
|
||||
inline void octetStream::ReceiveExpected(int socket_num, size_t expected)
|
||||
template<class T>
|
||||
inline void octetStream::ReceiveExpected(T socket_num, size_t expected)
|
||||
{
|
||||
size_t nlen = 0;
|
||||
receive(socket_num, nlen, LENGTH_SIZE);
|
||||
@@ -310,11 +303,35 @@ T octetStream::get()
|
||||
return res;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void octetStream::get(T& res)
|
||||
{
|
||||
res.unpack(*this);
|
||||
}
|
||||
|
||||
template<>
|
||||
inline int octetStream::get()
|
||||
{
|
||||
return get_int(sizeof(int));
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void octetStream::store(const vector<T>& v)
|
||||
{
|
||||
store(v.size());
|
||||
for (auto& x : v)
|
||||
store(x);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void octetStream::get(vector<T>& v)
|
||||
{
|
||||
size_t size;
|
||||
get(size);
|
||||
v.resize(size);
|
||||
for (auto& x : v)
|
||||
get(x);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "Math/bigint.h"
|
||||
#include "Math/fixint.h"
|
||||
#include "Math/Z2k.hpp"
|
||||
#include "Math/gfp.h"
|
||||
#include "Tools/Subroutines.h"
|
||||
#include <stdio.h>
|
||||
#include <sodium.h>
|
||||
|
||||
@@ -1,178 +0,0 @@
|
||||
// Preprocessing stage to:
|
||||
// Create the public/private key pairs for each client
|
||||
// Create the public/private key pairs for each spdz engine
|
||||
// For each client store the client keys + all spdz engine public keys
|
||||
// in a file named Client-Keys-C<client id>
|
||||
// For each spdz engine store the spdz engine keys + all client public keys
|
||||
// in a file named Player-SPDZ-Keys-P<player id>
|
||||
//
|
||||
|
||||
#include <sodium.h>
|
||||
|
||||
#include "Math/gf2n.h"
|
||||
#include "Math/gfp.h"
|
||||
#include "Protocols/Share.h"
|
||||
#include "Math/Setup.h"
|
||||
#include "Protocols/fake-stuff.h"
|
||||
#include "Exceptions/Exceptions.h"
|
||||
|
||||
#include "Math/Setup.h"
|
||||
#include "Processor/Data_Files.h"
|
||||
#include "Tools/mkpath.h"
|
||||
#include "Tools/ezOptionParser.h"
|
||||
#include "Tools/Config.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
using namespace std;
|
||||
|
||||
static void output(const vector<octet> &vec, ofstream &of)
|
||||
{
|
||||
copy(vec.begin(), vec.end(), ostreambuf_iterator<char>(of));
|
||||
}
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
ez::ezOptionParser opt;
|
||||
|
||||
opt.syntax = "./client-setup.x <nplayers> [OPTIONS]\n";
|
||||
|
||||
opt.add(
|
||||
"0", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Number of external clients (default: nplayers)", // Help description.
|
||||
"-nc", // Flag token.
|
||||
"--numclients" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"128", // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Bit length of GF(p) field (default: 128)", // Help description.
|
||||
"-lgp", // Flag token.
|
||||
"--lgp" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
to_string(gf2n::default_degree()).c_str(), // Default.
|
||||
0, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
("Bit length of GF(2^n) field (default: " + to_string(gf2n::default_degree()) + ")").c_str(), // Help description.
|
||||
"-lg2", // Flag token.
|
||||
"--lg2" // Flag token.
|
||||
);
|
||||
opt.parse(argc, argv);
|
||||
|
||||
string prep_data_prefix;
|
||||
|
||||
string usage;
|
||||
|
||||
int nplayers;
|
||||
if (opt.firstArgs.size() == 2)
|
||||
{
|
||||
nplayers = atoi(opt.firstArgs[1]->c_str());
|
||||
}
|
||||
else if (opt.lastArgs.size() == 1)
|
||||
{
|
||||
nplayers = atoi(opt.lastArgs[0]->c_str());
|
||||
}
|
||||
else
|
||||
{
|
||||
cerr << "ERROR: invalid number of arguments\n";
|
||||
opt.getUsage(usage);
|
||||
cout << usage;
|
||||
return 1;
|
||||
}
|
||||
|
||||
int lg2, lgp, nclients;
|
||||
opt.get("--numclients")->getInt(nclients);
|
||||
if (nclients <= 0)
|
||||
nclients = nplayers;
|
||||
opt.get("--lgp")->getInt(lgp);
|
||||
opt.get("--lg2")->getInt(lg2);
|
||||
|
||||
cout << "nplayers = " << nplayers << endl;
|
||||
cout << "nclients = " << nclients << endl;
|
||||
cout << "lgp = " << lgp << endl;
|
||||
cout << "lgp2 = " << lg2 << endl;
|
||||
|
||||
prep_data_prefix = get_prep_dir(nplayers, lgp, lg2);
|
||||
cout << "prep dir = " << prep_data_prefix << endl;
|
||||
|
||||
vector<Config::public_key> client_publickeys;
|
||||
vector<Config::secret_key> client_secretkeys;
|
||||
client_publickeys.resize(nclients);
|
||||
client_secretkeys.resize(nclients);
|
||||
for (int i = 0; i < nclients; i++) {
|
||||
client_secretkeys[i].resize(crypto_box_SECRETKEYBYTES);
|
||||
client_publickeys[i].resize(crypto_box_PUBLICKEYBYTES);
|
||||
randombytes_buf(&client_secretkeys[i][0], client_secretkeys[i].size());
|
||||
crypto_scalarmult_base(&client_publickeys[i][0], &client_secretkeys[i][0]);
|
||||
}
|
||||
|
||||
vector<Config::public_signing_key> client_signing_publickeys;
|
||||
vector<Config::secret_signing_key> client_signing_secretkeys;
|
||||
client_signing_publickeys.resize(nclients);
|
||||
client_signing_secretkeys.resize(nclients);
|
||||
for (int i = 0; i < nclients; i++) {
|
||||
client_signing_publickeys[i].resize(crypto_sign_PUBLICKEYBYTES);
|
||||
client_signing_secretkeys[i].resize(crypto_sign_SECRETKEYBYTES);
|
||||
crypto_sign_keypair(&client_signing_publickeys[i][0], &client_signing_secretkeys[i][0]);
|
||||
}
|
||||
|
||||
vector<Config::public_key> server_publickeys;
|
||||
vector<Config::secret_key> server_secretkeys;
|
||||
server_publickeys.resize(nplayers);
|
||||
server_secretkeys.resize(nplayers);
|
||||
for (int i = 0; i < nplayers; i++) {
|
||||
server_publickeys[i].resize(crypto_box_PUBLICKEYBYTES);
|
||||
server_secretkeys[i].resize(crypto_box_SECRETKEYBYTES);
|
||||
randombytes_buf(&server_secretkeys[i][0], server_secretkeys[i].size());
|
||||
crypto_scalarmult_base(&server_publickeys[i][0], &server_secretkeys[i][0]);
|
||||
}
|
||||
vector<Config::public_signing_key> server_signing_publickeys;
|
||||
vector<Config::secret_signing_key> server_signing_secretkeys;
|
||||
server_signing_publickeys.resize(nplayers);
|
||||
server_signing_secretkeys.resize(nplayers);
|
||||
for (int i = 0; i < nplayers; i++) {
|
||||
server_signing_publickeys[i].resize(crypto_sign_PUBLICKEYBYTES);
|
||||
server_signing_secretkeys[i].resize(crypto_sign_SECRETKEYBYTES);
|
||||
crypto_sign_keypair(&server_signing_publickeys[i][0], &server_signing_secretkeys[i][0]);
|
||||
}
|
||||
|
||||
/* Write client files */
|
||||
for (int i = 0; i < nclients; i++) {
|
||||
stringstream filename;
|
||||
filename << prep_data_prefix << "Client-Keys-C" << i;
|
||||
ofstream outf(filename.str().c_str());
|
||||
if (outf.fail())
|
||||
throw file_error(filename.str().c_str());
|
||||
// Write public key and secret key
|
||||
output(client_publickeys[i],outf);
|
||||
output(client_secretkeys[i],outf);
|
||||
output(client_signing_publickeys[i],outf);
|
||||
output(client_signing_secretkeys[i],outf);
|
||||
int keycount = 2;
|
||||
|
||||
// Write all spdz engine public keys
|
||||
for (int j = 0; j < nplayers; j++) {
|
||||
output(server_publickeys[j], outf);
|
||||
output(server_signing_publickeys[j], outf);
|
||||
keycount++;
|
||||
}
|
||||
outf.close();
|
||||
cout << "Wrote " << keycount << " keys to " << filename.str() << endl;
|
||||
}
|
||||
|
||||
/* Write spdz engine files */
|
||||
for (int i = 0; i < nplayers; i++) {
|
||||
Config::write_player_config_file( prep_data_prefix, i
|
||||
, server_publickeys[i], server_secretkeys[i]
|
||||
, server_signing_publickeys[i], server_signing_secretkeys[i]
|
||||
, client_publickeys, client_signing_publickeys
|
||||
, server_publickeys, server_signing_publickeys);
|
||||
}
|
||||
}
|
||||
24
azure-pipelines.yml
Normal file
24
azure-pipelines.yml
Normal file
@@ -0,0 +1,24 @@
|
||||
# C/C++ with GCC
|
||||
# Build your C/C++ project with GCC using make.
|
||||
# Add steps that publish test results, save build artifacts, deploy, and more:
|
||||
# https://docs.microsoft.com/azure/devops/pipelines/apps/c-cpp/gcc
|
||||
|
||||
trigger:
|
||||
- master
|
||||
|
||||
pool:
|
||||
vmImage: 'ubuntu-latest'
|
||||
|
||||
steps:
|
||||
- script: |
|
||||
bash -c "sudo apt-get install libsodium-dev libntl-dev yasm texinfo libboost-dev libboost-thread-dev python3-gmpy2 libcrypto++-dev python-networkx"
|
||||
- script: |
|
||||
make mpir
|
||||
- script:
|
||||
echo USE_NTL=1 >> CONFIG.mine
|
||||
- script: |
|
||||
make
|
||||
- script:
|
||||
Scripts/setup-ssl.sh
|
||||
- script:
|
||||
Scripts/test_tutorial.sh -C
|
||||
17
compile.py
17
compile.py
@@ -26,24 +26,22 @@ def main():
|
||||
help="specify output file")
|
||||
parser.add_option("-a", "--asm-output", dest="asmoutfile",
|
||||
help="asm output file for debugging")
|
||||
parser.add_option("-l", "--asm-input", action="store_true", dest="assemblymode",
|
||||
help="old-style asm input")
|
||||
parser.add_option("-p", "--primesize", dest="param", default=-1,
|
||||
help="bit length of modulus")
|
||||
parser.add_option("-g", "--galoissize", dest="galois", default=40,
|
||||
help="bit length of Galois field")
|
||||
parser.add_option("-d", "--debug", action="store_true", dest="debug",
|
||||
help="keep track of trace for debugging")
|
||||
parser.add_option("-e", "--emulate", action="store_true", dest="emulate", default=False,
|
||||
help="emulate register values for debugging")
|
||||
parser.add_option("-c", "--comparison", dest="comparison", default="log",
|
||||
help="comparison variant: log|plain|inv|sinv")
|
||||
parser.add_option("-r", "--noreorder", dest="reorder_between_opens",
|
||||
action="store_false", default=True,
|
||||
help="don't attempt to place instructions between start/stop opens")
|
||||
parser.add_option("-O", "--optimize-hard", action="store_false",
|
||||
parser.add_option("-M", "--preserve-mem-order", action="store_true",
|
||||
dest="preserve_mem_order", default=False,
|
||||
help="don't preserve order of memory instructions; possible loss of correctness")
|
||||
help="preserve order of memory instructions; possible efficiency loss")
|
||||
parser.add_option("-O", "--optimize-hard", action="store_true",
|
||||
dest="optimize_hard", help="currently not in use")
|
||||
parser.add_option("-u", "--noreallocate", action="store_true", dest="noreallocate",
|
||||
default=False, help="don't reallocate")
|
||||
parser.add_option("-m", "--max-parallel-open", dest="max_parallel_open",
|
||||
@@ -79,10 +77,13 @@ def main():
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
if options.optimize_hard:
|
||||
print('Note that -O/--optimize-hard currently has no effect')
|
||||
|
||||
def compilation():
|
||||
prog = Compiler.run(args, options, param=int(options.param),
|
||||
merge_opens=options.merge_opens, emulate=options.emulate,
|
||||
assemblymode=options.assemblymode, debug=options.debug)
|
||||
merge_opens=options.merge_opens,
|
||||
debug=options.debug)
|
||||
prog.write_bytes(options.outfile)
|
||||
|
||||
if options.asmoutfile:
|
||||
|
||||
@@ -56,3 +56,9 @@ Compiler.ml module
|
||||
-------------------------
|
||||
|
||||
.. automodule:: Compiler.ml
|
||||
|
||||
Compiler.circuit module
|
||||
-----------------------
|
||||
|
||||
.. automodule:: Compiler.circuit
|
||||
:members:
|
||||
|
||||
Reference in New Issue
Block a user