Bristol Fashion.

This commit is contained in:
Marcel Keller
2020-04-02 18:06:14 +11:00
parent cb8e46d2f3
commit 24926df83b
98 changed files with 1266 additions and 2657 deletions

3
.gitmodules vendored
View File

@@ -4,3 +4,6 @@
[submodule "mpir"]
path = mpir
url = git://github.com/wbhart/mpir.git
[submodule "Programs/Circuits"]
path = Programs/Circuits
url = https://github.com/mkskeller/bristol-fashion

View File

@@ -1,5 +1,12 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
## 0.1.6 (Apr 2, 2020)
- Bristol Fashion circuits
- Semi-honest computation with somewhat homomorphic encryption
- Use SSL for client connections
- Client facilities for all arithmetic protocols
## 0.1.5 (Mar 20, 2020)
- Faster conversion between arithmetic and binary secret sharing using [extended daBits](https://eprint.iacr.org/2020/338)

View File

@@ -67,19 +67,33 @@ class BinaryVectorInstruction(base.Instruction):
def copy(self, size, subs):
return type(self)(*self.get_new_args(size, subs))
class NonVectorInstruction(base.Instruction):
is_vec = lambda self: False
def __init__(self, *args, **kwargs):
assert(args[0].n <= args[0].unit)
super(NonVectorInstruction, self).__init__(*args, **kwargs)
class NonVectorInstruction1(base.Instruction):
is_vec = lambda self: False
def __init__(self, *args, **kwargs):
assert(args[1].n <= args[1].unit)
super(NonVectorInstruction1, self).__init__(*args, **kwargs)
class xors(BinaryVectorInstruction):
code = opcodes['XORS']
arg_format = tools.cycle(['int','sbw','sb','sb'])
class xorm(base.Instruction):
class xorm(NonVectorInstruction):
code = opcodes['XORM']
arg_format = ['int','sbw','sb','cb']
class xorcb(base.Instruction):
class xorcb(NonVectorInstruction):
code = opcodes['XORCB']
arg_format = ['cbw','cb','cb']
class xorcbi(base.Instruction):
class xorcbi(NonVectorInstruction):
code = opcodes['XORCBI']
arg_format = ['cbw','cb','int']
@@ -101,47 +115,48 @@ class andm(BinaryVectorInstruction):
code = opcodes['ANDM']
arg_format = ['int','sbw','sb','cb']
class addcb(base.Instruction):
class addcb(NonVectorInstruction):
code = opcodes['ADDCB']
arg_format = ['cbw','cb','cb']
class addcbi(base.Instruction):
class addcbi(NonVectorInstruction):
code = opcodes['ADDCBI']
arg_format = ['cbw','cb','int']
class mulcbi(base.Instruction):
class mulcbi(NonVectorInstruction):
code = opcodes['MULCBI']
arg_format = ['cbw','cb','int']
class bitdecs(base.VarArgsInstruction):
class bitdecs(NonVectorInstruction, base.VarArgsInstruction):
code = opcodes['BITDECS']
arg_format = tools.chain(['sb'], itertools.repeat('sbw'))
class bitcoms(base.VarArgsInstruction):
class bitcoms(NonVectorInstruction, base.VarArgsInstruction):
code = opcodes['BITCOMS']
arg_format = tools.chain(['sbw'], itertools.repeat('sb'))
class bitdecc(base.VarArgsInstruction):
class bitdecc(NonVectorInstruction, base.VarArgsInstruction):
code = opcodes['BITDECC']
arg_format = tools.chain(['cb'], itertools.repeat('cbw'))
class shrcbi(base.Instruction):
class shrcbi(NonVectorInstruction):
code = opcodes['SHRCBI']
arg_format = ['cbw','cb','int']
class shlcbi(base.Instruction):
class shlcbi(NonVectorInstruction):
code = opcodes['SHLCBI']
arg_format = ['cbw','cb','int']
class ldbits(base.Instruction):
class ldbits(NonVectorInstruction):
code = opcodes['LDBITS']
arg_format = ['sbw','i','i']
class ldmsb(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
class ldmsb(base.DirectMemoryInstruction, base.ReadMemoryInstruction,
base.VectorInstruction):
code = opcodes['LDMSB']
arg_format = ['sbw','int']
class stmsb(base.DirectMemoryWriteInstruction):
class stmsb(base.DirectMemoryWriteInstruction, base.VectorInstruction):
code = opcodes['STMSB']
arg_format = ['sb','int']
# def __init__(self, *args, **kwargs):
@@ -149,19 +164,20 @@ class stmsb(base.DirectMemoryWriteInstruction):
# import inspect
# self.caller = [frame[1:] for frame in inspect.stack()[1:]]
class ldmcb(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
class ldmcb(base.DirectMemoryInstruction, base.ReadMemoryInstruction,
base.VectorInstruction):
code = opcodes['LDMCB']
arg_format = ['cbw','int']
class stmcb(base.DirectMemoryWriteInstruction):
class stmcb(base.DirectMemoryWriteInstruction, base.VectorInstruction):
code = opcodes['STMCB']
arg_format = ['cb','int']
class ldmsbi(base.ReadMemoryInstruction):
class ldmsbi(base.ReadMemoryInstruction, base.VectorInstruction):
code = opcodes['LDMSBI']
arg_format = ['sbw','ci']
class stmsbi(base.WriteMemoryInstruction):
class stmsbi(base.WriteMemoryInstruction, base.VectorInstruction):
code = opcodes['STMSBI']
arg_format = ['sb','ci']
@@ -185,15 +201,15 @@ class stmsdci(base.WriteMemoryInstruction):
code = opcodes['STMSDCI']
arg_format = tools.cycle(['cb','cb'])
class convsint(base.Instruction):
class convsint(NonVectorInstruction1):
code = opcodes['CONVSINT']
arg_format = ['int','sbw','ci']
class convcint(base.Instruction):
class convcint(NonVectorInstruction):
code = opcodes['CONVCINT']
arg_format = ['cbw','ci']
class convcbit(base.Instruction):
class convcbit(NonVectorInstruction1):
code = opcodes['CONVCBIT']
arg_format = ['ciw','cb']
@@ -222,18 +238,19 @@ class split(base.Instruction):
super(split_class, self).__init__(*args, **kwargs)
assert (len(args) - 2) % args[0] == 0
class movsb(base.Instruction):
class movsb(NonVectorInstruction):
code = opcodes['MOVSB']
arg_format = ['sbw','sb']
class trans(base.VarArgsInstruction):
code = opcodes['TRANS']
is_vec = lambda self: True
def __init__(self, *args):
self.arg_format = ['int'] + ['sbw'] * args[0] + \
['sb'] * (len(args) - 1 - args[0])
super(trans, self).__init__(*args)
class bitb(base.Instruction):
class bitb(NonVectorInstruction):
code = opcodes['BITB']
arg_format = ['sbw']
@@ -245,20 +262,22 @@ class inputb(base.DoNotEliminateInstruction, base.VarArgsInstruction):
__slots__ = []
code = opcodes['INPUTB']
arg_format = tools.cycle(['p','int','int','sbw'])
is_vec = lambda self: True
class print_regb(base.IOInstruction):
class print_regb(base.VectorInstruction, base.IOInstruction):
code = opcodes['PRINTREGB']
arg_format = ['cb','i']
def __init__(self, reg, comment=''):
super(print_regb, self).__init__(reg, self.str_to_int(comment))
class print_reg_plainb(base.IOInstruction):
class print_reg_plainb(NonVectorInstruction, base.IOInstruction):
code = opcodes['PRINTREGPLAINB']
arg_format = ['cb']
class print_reg_signed(base.IOInstruction):
code = opcodes['PRINTREGSIGNED']
arg_format = ['int','cb']
is_vec = lambda self: True
class print_float_plainb(base.IOInstruction):
__slots__ = []

View File

@@ -77,6 +77,9 @@ class bits(Tape.Register, _structure, _bit):
def n_elements():
return 1
@classmethod
def mem_size(cls):
return math.ceil(cls.n / cls.unit)
@classmethod
def load_mem(cls, address, mem_type=None, size=None):
if size not in (None, 1):
v = [cls.load_mem(address + i) for i in range(size)]
@@ -101,9 +104,8 @@ class bits(Tape.Register, _structure, _bit):
def copy(self):
return type(self)(n=instructions_base.get_global_vector_size())
def set_length(self, n):
if n > self.max_length:
print(self.max_length)
raise Exception('too long: %d' % n)
if n > self.n:
raise Exception('too long: %d/%d' % (n, self.n))
self.n = n
def set_size(self, size):
pass
@@ -135,7 +137,7 @@ class bits(Tape.Register, _structure, _bit):
if self.n != None:
suffix = '%d' % self.n
if type(self).n != None and type(self).n != self.n:
suffice += '/%d' % type(self).n
suffix += '/%d' % type(self).n
else:
suffix = 'undef'
return '%s(%s)' % (super(bits, self).__repr__(), suffix)
@@ -237,6 +239,7 @@ class sbits(bits):
bitdec = inst.bitdecs
bitcom = inst.bitcoms
conv_regint = inst.convsint
one_cache = {}
@classmethod
def conv_regint_by_bit(cls, n, res, other):
tmp = cbits.get_type(n)()
@@ -285,14 +288,12 @@ class sbits(bits):
% (value, self.n))
if self.n <= 32:
inst.ldbits(self, self.n, value)
elif self.n <= 64:
self.load_other(regint(value, size=1))
elif self.n <= 128:
lower = sbits.get_type(64)(value % 2**64)
upper = sbits.get_type(self.n - 64)(value >> 64)
self.mov(self, lower + (upper << 64))
else:
raise NotImplementedError('more than 128 bits wanted')
size = math.ceil(self.n / self.unit)
tmp = regint(size=size)
for i in range(size):
tmp[i].load_int((value >> (i * 64)) % 2**64)
self.load_other(tmp)
def load_other(self, other):
if isinstance(other, cbits) and self.n == other.n:
inst.convcbit2s(self.n, self, other)
@@ -393,11 +394,10 @@ class sbits(bits):
# res = type(self)(n=self.n)
# inst.nots(res, self)
# return res
if self.n == None or self.n > self.unit:
one = self.get_type(self.n)()
self.conv_regint_by_bit(self.n, one, regint(1, size=self.n))
else:
one = self.new(value=self.long_one(), n=self.n)
key = self.n, library.get_block()
if key not in self.one_cache:
self.one_cache[key] = self.new(value=self.long_one(), n=self.n)
one = self.one_cache[key]
return self + one
def __neg__(self):
return self
@@ -432,12 +432,12 @@ class sbits(bits):
@classmethod
def trans(cls, rows):
rows = list(rows)
if len(rows) == 1:
if len(rows) == 1 and rows[0].n <= rows[0].unit:
return rows[0].bit_decompose()
n_columns = rows[0].n
for row in rows:
assert(row.n == n_columns)
if n_columns == 1:
if n_columns == 1 and len(rows) <= cls.unit:
return [cls.bit_compose(rows)]
else:
res = [cls.new(n=len(rows)) for i in range(n_columns)]
@@ -452,6 +452,10 @@ class sbits(bits):
@staticmethod
def ripple_carry_adder(*args, **kwargs):
return sbitint.ripple_carry_adder(*args, **kwargs)
def to_sint(self, n_bits):
bits = sbitvec.from_vec(sbitvec([self]).v[:n_bits]).elements()[0]
bits = sint(bits, size=n_bits)
return sint.bit_compose(bits)
class sbitvec(_vec):
@classmethod
@@ -524,6 +528,8 @@ class sbitvec(_vec):
return iter(self.v)
def __len__(self):
return len(self.v)
def __getitem__(self, index):
return self.v[index]
@classmethod
def conv(cls, other):
return cls.from_vec(other.v)

View File

@@ -210,85 +210,6 @@ class Merger:
max_depth_of[v] = min(max_depth_of[u], max_depth_of[v])
return max_depth_of
def merge_inputs(self):
merges = defaultdict(list)
remaining_input_nodes = []
def do_merge(nodes):
if len(nodes) > 1000:
print('Merging %d inputs...' % len(nodes))
self.do_merge(iter(nodes))
for n in self.input_nodes:
inst = self.instructions[n]
merge = merges[inst.args[0],inst.__class__]
if len(merge) == 0:
remaining_input_nodes.append(n)
merge.append(n)
if len(merge) >= self.max_parallel_open:
do_merge(merge)
merge[:] = []
for merge in reversed(sorted(merges.values())):
if merge:
do_merge(merge)
self.input_nodes = remaining_input_nodes
def compute_preorder(self, merges, rev_depth_of):
# find flexible nodes that can be on several levels
# and find sources on level 0
G = self.G
merge_nodes_set = self.open_nodes
depth_of = self.depths
instructions = self.instructions
flex_nodes = defaultdict(dict)
starters = []
for n in range(len(G)):
if n not in merge_nodes_set and \
depth_of[n] != rev_depth_of[n] and G[n] and G.get_attr(n,'start') == -1 and not isinstance(instructions[n], AsymmetricCommunicationInstruction):
#print n, depth_of[n], rev_depth_of[n]
flex_nodes[depth_of[n]].setdefault(rev_depth_of[n], set()).add(n)
elif len(G.pred[n]) == 0 and \
not isinstance(self.instructions[n], RawInputInstruction):
starters.append(n)
if n % 10000000 == 0 and n > 0:
print("Processed %d nodes at" % n, time.asctime())
inputs = defaultdict(list)
for node in self.input_nodes:
player = self.instructions[node].args[0]
inputs[player].append(node)
first_inputs = [l[0] for l in inputs.values()]
other_inputs = []
i = 0
while True:
i += 1
found = False
for l in inputs.values():
if i < len(l):
other_inputs.append(l[i])
found = True
if not found:
break
other_inputs.reverse()
preorder = []
# magical preorder for topological search
max_depth = max(merges)
if max_depth > 10000:
print("Computing pre-ordering ...")
for i in range(max_depth, 0, -1):
preorder.append(G.get_attr(merges[i], 'stop'))
for j in flex_nodes[i-1].values():
preorder.extend(j)
preorder.extend(flex_nodes[0].get(i, []))
preorder.append(merges[i])
if i % 100000 == 0 and i > 0:
print("Done level %d at" % i, time.asctime())
preorder.extend(other_inputs)
preorder.extend(starters)
preorder.extend(first_inputs)
if max_depth > 10000:
print("Done at", time.asctime())
return preorder
def longest_paths_merge(self):
""" Attempt to merge instructions of type instruction_type (which are given in
merge_nodes) using longest paths algorithm.
@@ -301,7 +222,7 @@ class Merger:
instructions = self.instructions
merge_nodes = self.open_nodes
depths = self.depths
if not merge_nodes and not self.input_nodes:
if not merge_nodes:
return 0
# merge opens at same depth
@@ -321,8 +242,6 @@ class Merger:
(len(merge), t.__name__, i, len(merges)))
self.do_merge(merge)
self.merge_inputs()
preorder = None
if len(instructions) > 100000:
@@ -340,7 +259,6 @@ class Merger:
options = self.options
open_nodes = set()
self.open_nodes = open_nodes
self.input_nodes = []
colordict = defaultdict(lambda: 'gray', asm_open='red',\
ldi='lightblue', ldm='lightblue', stm='blue',\
mov='yellow', mulm='orange', mulc='orange',\
@@ -507,14 +425,7 @@ class Merger:
elif isinstance(instr, PublicFileIOInstruction):
keep_order(instr, n, instr.__class__)
elif isinstance(instr, RawInputInstruction):
keep_order(instr, n, instr.__class__, 0)
self.input_nodes.append(n)
G.add_node(n, merges=[])
player = instr.args[0]
if isinstance(instr, stopinput):
add_edge(last[startinput_class][player], n)
elif isinstance(instr, gstopinput):
add_edge(last[gstartinput][player], n)
keep_order(instr, n, instr.__class__)
elif isinstance(instr, startprivateoutput_class):
keep_order(instr, n, startprivateoutput_class, 2)
elif isinstance(instr, stopprivateoutput_class):
@@ -559,18 +470,14 @@ class Merger:
unused_result = not G.degree(i) and len(list(inst.get_def())) \
and reduce(operator.and_, (reg.can_eliminate for reg in inst.get_def())) \
and not isinstance(inst, (DoNotEliminateInstruction))
stop_node = G.get_attr(i, 'stop')
unused_startopen = stop_node != -1 and instructions[stop_node] is None
def eliminate(i):
G.remove_node(i)
merge_nodes.discard(i)
stats[type(instructions[i]).__name__] += 1
instructions[i] = None
if unused_result or unused_startopen:
if unused_result:
eliminate(i)
count += 1
if unused_startopen:
open_count += len(inst.args)
# remove unnecessary stack instructions
# left by optimization with budget
if isinstance(inst, popint_class) and \

201
Compiler/circuit.py Normal file
View File

@@ -0,0 +1,201 @@
"""
This module contains functionality using circuits in the so-called
`Bristol Fashion`_ format. You can download a few examples including
the ones used below into ``Programs/Circuits`` as follows::
make Programs/Circuits
.. _`Bristol Fashion`: https://homes.esat.kuleuven.be/~nsmart/MPC
"""
from Compiler.GC.types import sbitvec, sbits
from Compiler.library import function_block
from Compiler import util
import itertools
class Circuit:
"""
Use a Bristol Fashion circuit in a high-level program. The
following example adds signed 64-bit inputs from two different
parties and prints the result::
from circuit import Circuit
sb64 = sbits.get_type(64)
adder = Circuit('adder64')
a, b = [sbitvec(sb64.get_input_from(i)) for i in (0, 1)]
print_ln('%s', adder(a, b).elements()[0].reveal())
Circuits can also be executed in parallel as the following example
shows::
from circuit import Circuit
sb128 = sbits.get_type(128)
key = sb128(0x2b7e151628aed2a6abf7158809cf4f3c)
plaintext = sb128(0x6bc1bee22e409f96e93d7e117393172a)
n = 1000
aes128 = Circuit('aes_128')
ciphertexts = aes128(sbitvec([key] * n), sbitvec([plaintext] * n))
ciphertexts.elements()[n - 1].reveal().print_reg()
This executes AES-128 1000 times in parallel and then outputs the
last result, which should be ``0x3ad77bb40d7a3660a89ecaf32466ef97``,
one of the test vectors for AES-128.
"""
def __init__(self, name):
self.filename = 'Programs/Circuits/%s.txt' % name
f = open(self.filename)
self.functions = {}
def __call__(self, *inputs):
return self.run(*inputs)
def run(self, *inputs):
n = inputs[0][0].n
if n not in self.functions:
self.functions[n] = function_block(lambda *args:
self.compile(*args))
flat_res = self.functions[n](*itertools.chain(*inputs))
res = []
i = 0
for l in self.n_output_wires:
v = []
for i in range(l):
v.append(flat_res[i])
i += 1
res.append(sbitvec.from_vec(v))
return util.untuplify(res)
def compile(self, *all_inputs):
f = open(self.filename)
lines = iter(f)
next_line = lambda: next(lines).split()
n_gates, n_wires = (int(x) for x in next_line())
self.n_wires = n_wires
input_line = [int(x) for x in next_line()]
n_inputs = input_line[0]
n_input_wires = input_line[1:]
assert(n_inputs == len(n_input_wires))
inputs = []
s = 0
for n in n_input_wires:
inputs.append(all_inputs[s:s + n])
s += n
output_line = [int(x) for x in next_line()]
n_outputs = output_line[0]
self.n_output_wires = output_line[1:]
assert(n_outputs == len(self.n_output_wires))
next(lines)
wires = [None] * n_wires
self.wires = wires
i_wire = 0
for input, input_wires in zip(inputs, n_input_wires):
assert(len(input) == input_wires)
for i, reg in enumerate(input):
wires[i_wire] = reg
i_wire += 1
for i in range(n_gates):
line = next_line()
t = line[-1]
if t in ('XOR', 'AND'):
assert line[0] == '2'
assert line[1] == '1'
assert len(line) == 6
ins = [wires[int(line[2 + i])] for i in range(2)]
if t == 'XOR':
wires[int(line[4])] = ins[0] ^ ins[1]
else:
wires[int(line[4])] = ins[0] & ins[1]
elif t == 'INV':
assert line[0] == '1'
assert line[1] == '1'
assert len(line) == 5
wires[int(line[3])] = ~wires[int(line[2])]
return self.wires[-sum(self.n_output_wires):]
Keccak_f = None
def sha3_256(x):
"""
This function implements SHA3-256 for inputs of up to 1080 bits::
from circuit import sha3_256
a = sbitvec.from_vec([])
b = sbitvec(sint(0xcc), 8)
for x in a, b:
sha3_256(x).elements()[0].reveal().print_reg()
This should output the first two test vectors of SHA3-256 in
byte-reversed order::
0x5375f6fb6aa989b0c287a923afe81e79ff875921cacc956666d71ebff8c6ffa7
0x17c7e0d65c285af8406d4f21c071851a312b739a8ecdf25c1270d31c39357067
Note that :py:obj:`sint` to :py:obj:`sbitvec` conversion is only
implemented for computation modulo a power of two.
"""
global Keccak_f
if Keccak_f is None:
# only one instance
Keccak_f = Circuit('Keccak_f')
# whole bytes
assert len(x.v) % 8 == 0
# only one block
r = 1088
assert len(x.v) < 1088
if x.v:
n = x.v[0].n
else:
n = 1
d = sbitvec([sbits.get_type(8)(0x06)] * n)
sbn = sbits.get_type(n)
padding = [sbn(0)] * (r - 8 - len(x.v))
P_flat = x.v + d.v + padding
assert len(P_flat) == r
P_flat[-1] = ~P_flat[-1]
w = 64
P1 = [P_flat[i * w:(i + 1) * w] for i in range(r // w)]
S = [[[sbn(0) for i in range(w)] for i in range(5)] for i in range(5)]
for x in range(5):
for y in range(5):
if x + 5 * y < r // w:
for i in range(w):
S[x][y][i] ^= P1[x + 5 * y][i]
def flatten(S):
res = [None] * 1600
for y in range(5):
for x in range(5):
for i in range(w):
j = (5 * y + x) * w + i // 8 * 8 + 7 - i % 8
res[1600 - 1 - j] = S[x][y][i]
return res
def unflatten(S_flat):
res = [[[None] * w for j in range(5)] for i in range(5)]
for y in range(5):
for x in range(5):
for i in range(w):
j = (5 * y + x) * w + i // 8 * 8 + 7 - i % 8
res[x][y][i] = S_flat[1600 - 1 -j]
return res
S = unflatten(Keccak_f(flatten(S)))
Z = []
while len(Z) <= 256:
for y in range(5):
for x in range(5):
if x + 5 * y < r // w:
Z += S[y][x]
if len(Z) <= 256:
S = unflatten(Keccak_f(flatten(S)))
return sbitvec.from_vec(Z[:256])

View File

@@ -262,6 +262,7 @@ def Mod2mField(a_prime, a, k, m, kappa, signed):
return r_dprime, r_prime, c, c_prime, u, t, c2k1
def MaskingBitsInRing(m, strict=False):
program.curr_tape.require_bit_length(1)
from Compiler.types import sint
if program.use_edabit():
return sint.get_edabit(m, strict)

View File

@@ -9,19 +9,18 @@ import time
import sys
def run(args, options, param=-1, merge_opens=True, emulate=True, \
reallocate=True, assemblymode=False, debug=False):
def run(args, options, param=-1, merge_opens=True,
reallocate=True, debug=False):
""" Compile a file and output a Program object.
If merge_opens is set to True, will attempt to merge any parallelisable open
instructions. """
prog = Program(args, options, param, assemblymode)
prog = Program(args, options, param)
instructions.program = prog
instructions_base.program = prog
types.program = prog
comparison.program = prog
prog.EMULATE = emulate
prog.DEBUG = debug
VARS['program'] = prog
if options.binary:
@@ -31,26 +30,9 @@ def run(args, options, param=-1, merge_opens=True, emulate=True, \
print('Compiling file', prog.infile)
# no longer needed, but may want to support assembly in future (?)
if assemblymode:
prog.restart_main_thread()
for i in range(INIT_REG_MAX):
VARS['c%d'%i] = prog.curr_block.new_reg('c')
VARS['s%d'%i] = prog.curr_block.new_reg('s')
VARS['cg%d'%i] = prog.curr_block.new_reg('cg')
VARS['sg%d'%i] = prog.curr_block.new_reg('sg')
if i % 10000000 == 0 and i > 0:
print("Initialized %d register variables at" % i, time.asctime())
# first pass determines how many assembler registers are used
prog.FIRST_PASS = True
exec(compile(open(prog.infile).read(), prog.infile, 'exec'), VARS)
if instructions_base.Instruction.count != 0:
print('instructions count', instructions_base.Instruction.count)
instructions_base.Instruction.count = 0
prog.FIRST_PASS = False
prog.reset_values()
# make compiler modules directly accessible
sys.path.insert(0, 'Compiler')
# create the tapes
@@ -60,17 +42,14 @@ def run(args, options, param=-1, merge_opens=True, emulate=True, \
for tape in prog.tapes:
tape.optimize(options)
# check program still does the same thing after optimizations
if emulate:
clearmem = list(prog.mem_c)
sharedmem = list(prog.mem_s)
prog.emulate()
if prog.mem_c != clearmem or prog.mem_s != sharedmem:
print('Warning: emulated memory values changed after compiler optimization')
# raise CompilerError('Compiler optimization caused incorrect memory write.')
if prog.main_thread_running:
prog.update_req(prog.curr_tape)
if prog.req_num:
print('Program requires:')
for x in prog.req_num.pretty():
print(x)
if prog.verbose:
print('Program requires:', repr(prog.req_num))
print('Cost:', 0 if prog.req_num is None else prog.req_num.cost())

View File

@@ -1,13 +1,7 @@
from collections import defaultdict
#INIT_REG_MAX = 655360
INIT_REG_MAX = 1310720
REG_MAX = 2 ** 32
USER_MEM = 8192
TMP_MEM = 8192
TMP_MEM_BASE = USER_MEM
TMP_REG = 3
TMP_REG_BASE = REG_MAX - TMP_REG
P_VALUES = { 32: 2147565569, \
64: 9223372036855103489, \

View File

@@ -17,7 +17,7 @@ class SparseDiGraph(object):
""" max_nodes: maximum no of nodes
default_attributes: dict of node attributes and default values """
if default_attributes is None:
default_attributes = { 'merges': None, 'stop': -1, 'start': -1 }
default_attributes = { 'merges': None }
self.default_attributes = default_attributes
self.attribute_pos = dict(list(zip(list(default_attributes.keys()), list(range(len(default_attributes))))))
self.n = max_nodes

View File

@@ -34,9 +34,6 @@ class ldi(base.Instruction):
__slots__ = []
code = base.opcodes['LDI']
arg_format = ['cw','i']
def execute(self):
self.args[0].value = self.args[1]
@base.gf2n
@base.vectorize
@@ -45,9 +42,6 @@ class ldsi(base.Instruction):
__slots__ = []
code = base.opcodes['LDSI']
arg_format = ['sw','i']
def execute(self):
self.args[0].value = self.args[1]
@base.gf2n
@base.vectorize
@@ -57,9 +51,6 @@ class ldmc(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
code = base.opcodes['LDMC']
arg_format = ['cw','int']
def execute(self):
self.args[0].value = program.mem_c[self.args[1]]
@base.gf2n
@base.vectorize
class ldms(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
@@ -68,9 +59,6 @@ class ldms(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
code = base.opcodes['LDMS']
arg_format = ['sw','int']
def execute(self):
self.args[0].value = program.mem_s[self.args[1]]
@base.gf2n
@base.vectorize
class stmc(base.DirectMemoryWriteInstruction):
@@ -79,9 +67,6 @@ class stmc(base.DirectMemoryWriteInstruction):
code = base.opcodes['STMC']
arg_format = ['c','int']
def execute(self):
program.mem_c[self.args[1]] = self.args[0].value
@base.gf2n
@base.vectorize
class stms(base.DirectMemoryWriteInstruction):
@@ -90,9 +75,6 @@ class stms(base.DirectMemoryWriteInstruction):
code = base.opcodes['STMS']
arg_format = ['s','int']
def execute(self):
program.mem_s[self.args[1]] = self.args[0].value
@base.vectorize
class ldmint(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
r""" Assigns register $ci_i$ the value in memory \verb+Ci[n]+. """
@@ -100,9 +82,6 @@ class ldmint(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
code = base.opcodes['LDMINT']
arg_format = ['ciw','int']
def execute(self):
self.args[0].value = program.mem_i[self.args[1]]
@base.vectorize
class stmint(base.DirectMemoryWriteInstruction):
r""" Sets \verb+Ci[n]+ to be the value $ci_i$. """
@@ -110,18 +89,12 @@ class stmint(base.DirectMemoryWriteInstruction):
code = base.opcodes['STMINT']
arg_format = ['ci','int']
def execute(self):
program.mem_i[self.args[1]] = self.args[0].value
# must have seperate instructions because address is always modp
@base.vectorize
class ldmci(base.ReadMemoryInstruction):
r""" Assigns register $c_i$ the value in memory \verb+C[cj]+. """
code = base.opcodes['LDMCI']
arg_format = ['cw','ci']
def execute(self):
self.args[0].value = program.mem_c[self.args[1].value]
@base.vectorize
class ldmsi(base.ReadMemoryInstruction):
@@ -129,53 +102,35 @@ class ldmsi(base.ReadMemoryInstruction):
code = base.opcodes['LDMSI']
arg_format = ['sw','ci']
def execute(self):
self.args[0].value = program.mem_s[self.args[1].value]
@base.vectorize
class stmci(base.WriteMemoryInstruction):
r""" Sets \verb+C[cj]+ to be the value $c_i$. """
code = base.opcodes['STMCI']
arg_format = ['c','ci']
def execute(self):
program.mem_c[self.args[1].value] = self.args[0].value
@base.vectorize
class stmsi(base.WriteMemoryInstruction):
r""" Sets \verb+S[cj]+ to be the value $s_i$. """
code = base.opcodes['STMSI']
arg_format = ['s','ci']
def execute(self):
program.mem_s[self.args[1].value] = self.args[0].value
@base.vectorize
class ldminti(base.ReadMemoryInstruction):
r""" Assigns register $ci_i$ the value in memory \verb+Ci[cj]+. """
code = base.opcodes['LDMINTI']
arg_format = ['ciw','ci']
def execute(self):
self.args[0].value = program.mem_i[self.args[1].value]
@base.vectorize
class stminti(base.WriteMemoryInstruction):
r""" Sets \verb+Ci[cj]+ to be the value $ci_i$. """
code = base.opcodes['STMINTI']
arg_format = ['ci','ci']
def execute(self):
program.mem_i[self.args[1].value] = self.args[0].value
@base.vectorize
class gldmci(base.ReadMemoryInstruction):
r""" Assigns register $c_i$ the value in memory \verb+C[cj]+. """
code = base.opcodes['LDMCI'] + 0x100
arg_format = ['cgw','ci']
def execute(self):
self.args[0].value = program.mem_c[self.args[1].value]
@base.vectorize
class gldmsi(base.ReadMemoryInstruction):
@@ -183,27 +138,18 @@ class gldmsi(base.ReadMemoryInstruction):
code = base.opcodes['LDMSI'] + 0x100
arg_format = ['sgw','ci']
def execute(self):
self.args[0].value = program.mem_s[self.args[1].value]
@base.vectorize
class gstmci(base.WriteMemoryInstruction):
r""" Sets \verb+C[cj]+ to be the value $c_i$. """
code = base.opcodes['STMCI'] + 0x100
arg_format = ['cg','ci']
def execute(self):
program.mem_c[self.args[1].value] = self.args[0].value
@base.vectorize
class gstmsi(base.WriteMemoryInstruction):
r""" Sets \verb+S[cj]+ to be the value $s_i$. """
code = base.opcodes['STMSI'] + 0x100
arg_format = ['sg','ci']
def execute(self):
program.mem_s[self.args[1].value] = self.args[0].value
@base.gf2n
@base.vectorize
class protectmems(base.Instruction):
@@ -233,9 +179,6 @@ class movc(base.Instruction):
code = base.opcodes['MOVC']
arg_format = ['cw','c']
def execute(self):
self.args[0].value = self.args[1].value
@base.gf2n
@base.vectorize
class movs(base.Instruction):
@@ -244,9 +187,6 @@ class movs(base.Instruction):
code = base.opcodes['MOVS']
arg_format = ['sw','s']
def execute(self):
self.args[0].value = self.args[1].value
@base.vectorize
class movint(base.Instruction):
r""" Assigns register $ci_i$ the value in the register $ci_j$. """
@@ -452,9 +392,6 @@ class divc(base.InvertInstruction):
__slots__ = []
code = base.opcodes['DIVC']
arg_format = ['cw','c','c']
def execute(self):
self.args[0].value = self.args[1].value * pow(self.args[2].value, program.P-2, program.P) % program.P
@base.gf2n
@base.vectorize
@@ -464,9 +401,6 @@ class modc(base.Instruction):
code = base.opcodes['MODC']
arg_format = ['cw','c','c']
def execute(self):
self.args[0].value = self.args[1].value % self.args[2].value
@base.vectorize
class inv2m(base.InvertInstruction):
__slots__ = []
@@ -498,9 +432,6 @@ class andc(base.Instruction):
__slots__ = []
code = base.opcodes['ANDC']
arg_format = ['cw','c','c']
def execute(self):
self.args[0].value = (self.args[1].value & self.args[2].value) % program.P
@base.gf2n
@base.vectorize
@@ -509,9 +440,6 @@ class orc(base.Instruction):
__slots__ = []
code = base.opcodes['ORC']
arg_format = ['cw','c','c']
def execute(self):
self.args[0].value = (self.args[1].value | self.args[2].value) % program.P
@base.gf2n
@base.vectorize
@@ -520,9 +448,6 @@ class xorc(base.Instruction):
__slots__ = []
code = base.opcodes['XORC']
arg_format = ['cw','c','c']
def execute(self):
self.args[0].value = (self.args[1].value ^ self.args[2].value) % program.P
@base.vectorize
class notc(base.Instruction):
@@ -530,9 +455,6 @@ class notc(base.Instruction):
__slots__ = []
code = base.opcodes['NOTC']
arg_format = ['cw','c', 'int']
def execute(self):
self.args[0].value = (~self.args[1].value + 2 ** self.args[2]) % program.P
@base.vectorize
class gnotc(base.Instruction):
@@ -544,9 +466,6 @@ class gnotc(base.Instruction):
def is_gf2n(self):
return True
def execute(self):
self.args[0].value = ~self.args[1].value
@base.vectorize
class gbitdec(base.Instruction):
r""" Store every $n$-th bit of $cg_i$ in $cg_j, \dots$. """
@@ -672,8 +591,6 @@ class divci(base.InvertInstruction, base.ClearImmediate):
r""" Clear division by immediate value $c_i=c_j/n$. """
__slots__ = []
code = base.opcodes['DIVCI']
def execute(self):
self.args[0].value = self.args[1].value * pow(self.args[2], program.P-2, program.P) % program.P
@base.gf2n
@base.vectorize
@@ -719,9 +636,6 @@ class shlc(base.Instruction):
__slots__ = []
code = base.opcodes['SHLC']
arg_format = ['cw','c','c']
def execute(self):
self.args[0].value = (self.args[1].value << self.args[2].value) % program.P
@base.gf2n
@base.vectorize
@@ -730,9 +644,6 @@ class shrc(base.Instruction):
__slots__ = []
code = base.opcodes['SHRC']
arg_format = ['cw','c','c']
def execute(self):
self.args[0].value = (self.args[1].value >> self.args[2].value) % program.P
@base.gf2n
@base.vectorize
@@ -764,11 +675,6 @@ class triple(base.DataInstruction):
code = base.opcodes['TRIPLE']
arg_format = ['sw','sw','sw']
data_type = 'triple'
def execute(self):
self.args[0].value = randint(0,program.P)
self.args[1].value = randint(0,program.P)
self.args[2].value = (self.args[0].value * self.args[1].value) % program.P
@base.vectorize
class gbittriple(base.DataInstruction):
@@ -804,9 +710,6 @@ class bit(base.DataInstruction):
code = base.opcodes['BIT']
arg_format = ['sw']
data_type = 'bit'
def execute(self):
self.args[0].value = randint(0,1)
@base.vectorize
class dabit(base.DataInstruction):
@@ -848,10 +751,6 @@ class square(base.DataInstruction):
code = base.opcodes['SQUARE']
arg_format = ['sw','sw']
data_type = 'square'
def execute(self):
self.args[0].value = randint(0,program.P)
self.args[1].value = (self.args[0].value * self.args[0].value) % program.P
@base.gf2n
@base.vectorize
@@ -868,11 +767,6 @@ class inverse(base.DataInstruction):
raise CompilerError('random inverse in ring not implemented')
base.DataInstruction.__init__(self, *args, **kwargs)
def execute(self):
self.args[0].value = randint(0,program.P)
import gmpy
self.args[1].value = int(gmpy.invert(self.args[0].value, program.P))
@base.gf2n
@base.vectorize
class inputmask(base.Instruction):
@@ -920,8 +814,6 @@ class asm_input(base.TextInputInstruction):
for player in self.args[1::2]:
req_node.increment((self.field_type, 'input', player), \
self.get_size())
def execute(self):
self.args[0].value = _python_input("Enter player %d's input:" % self.args[1]) % program.P
@base.vectorize
class inputfix(base.TextInputInstruction):
@@ -1006,46 +898,18 @@ class inputmixedreg(inputmixed_base):
req_node.increment((self.field_type, 'input', 0), float('inf'))
@base.gf2n
class startinput(base.RawInputInstruction):
class rawinput(base.RawInputInstruction, base.Mergeable):
r""" Receive inputs from player $p$. """
__slots__ = []
code = base.opcodes['STARTINPUT']
arg_format = ['p', 'int']
code = base.opcodes['RAWINPUT']
arg_format = tools.cycle(['p','sw'])
field_type = 'modp'
def add_usage(self, req_node):
req_node.increment((self.field_type, 'input', self.args[0]), \
self.args[1])
def merge(self, other):
self.args[1] += other.args[1]
class StopInputInstruction(base.RawInputInstruction):
__slots__ = []
def merge(self, other):
if self.get_size() != other.get_size():
raise NotImplemented()
else:
self.args += other.args[1:]
class stopinput(StopInputInstruction):
r""" Receive inputs from player $p$ and put in registers. """
__slots__ = []
code = base.opcodes['STOPINPUT']
arg_format = tools.chain(['p'], itertools.repeat('sw'))
def has_var_args(self):
return True
class gstopinput(StopInputInstruction):
r""" Receive inputs from player $p$ and put in registers. """
__slots__ = []
code = 0x100 + base.opcodes['STOPINPUT']
arg_format = tools.chain(['p'], itertools.repeat('sgw'))
def has_var_args(self):
return True
for i in range(0, len(self.args), 2):
player = self.args[i]
req_node.increment((self.field_type, 'input', player), \
self.get_size())
@base.gf2n
@base.vectorize
@@ -1054,9 +918,6 @@ class print_mem(base.IOInstruction):
__slots__ = []
code = base.opcodes['PRINTMEM']
arg_format = ['c']
def execute(self):
pass
@base.gf2n
@base.vectorize
@@ -1069,9 +930,6 @@ class print_reg(base.IOInstruction):
def __init__(self, reg, comment=''):
super(print_reg_class, self).__init__(reg, self.str_to_int(comment))
def execute(self):
pass
@base.gf2n
@base.vectorize
class print_reg_plain(base.IOInstruction):
@@ -1238,41 +1096,6 @@ class acceptclientconnection(base.IOInstruction):
code = base.opcodes['ACCEPTCLIENTCONNECTION']
arg_format = ['ciw', 'int']
class connectipv4(base.IOInstruction):
"""Connect to server at IPv4 address in register \verb|cj| at given port. Write socket handle to register \verb|ci|"""
__slots__ = []
code = base.opcodes['CONNECTIPV4']
arg_format = ['ciw', 'ci', 'int']
class readclientpublickey(base.IOInstruction):
"""Read a client public key as 8 32-bit ints for a specified client id"""
__slots__ = []
code = base.opcodes['READCLIENTPUBLICKEY']
arg_format = tools.chain(['ci'], itertools.repeat('ci'))
def has_var_args(self):
return True
class initsecuresocket(base.IOInstruction):
"""Read a client public key as 8 32-bit ints for a specified client id,
negotiate a shared key via STS and use it for replay resistant comms"""
__slots__ = []
code = base.opcodes['INITSECURESOCKET']
arg_format = tools.chain(['ci'], itertools.repeat('ci'))
def has_var_args(self):
return True
class respsecuresocket(base.IOInstruction):
"""Read a client public key as 8 32-bit ints for a specified client id,
negotiate a shared key via STS and use it for replay resistant comms"""
__slots__ = []
code = base.opcodes['RESPSECURESOCKET']
arg_format = tools.chain(['ci'], itertools.repeat('ci'))
def has_var_args(self):
return True
class writesharestofile(base.IOInstruction):
"""Write shares to a file"""
__slots__ = []
@@ -1392,12 +1215,6 @@ class eqzc(base.UnaryComparisonInstruction):
r""" Clear comparison $c_i = (c_j \stackrel{?}{==} 0)$. """
__slots__ = []
code = base.opcodes['EQZC']
def execute(self):
if self.args[1].value == 0:
self.args[0].value = 1
else:
self.args[0].value = 0
@base.vectorize
class ltzc(base.UnaryComparisonInstruction):
@@ -1435,9 +1252,6 @@ class jmp(base.JumpInstruction):
arg_format = ['int']
jump_arg = 0
def execute(self):
pass
class jmpi(base.JumpInstruction):
""" Unconditional relative jump of $c_i+1$ instructions. """
__slots__ = []
@@ -1457,9 +1271,6 @@ class jmpnz(base.JumpInstruction):
code = base.opcodes['JMPNZ']
arg_format = ['ci', 'int']
jump_arg = 1
def execute(self):
pass
class jmpeqz(base.JumpInstruction):
r""" Jump $n+1$ instructions if $c_i == 0$. """
@@ -1467,9 +1278,6 @@ class jmpeqz(base.JumpInstruction):
code = base.opcodes['JMPEQZ']
arg_format = ['ci', 'int']
jump_arg = 1
def execute(self):
pass
###
### Conversions

View File

@@ -115,6 +115,7 @@ opcodes = dict(
INPUTFLOAT = 0xF1,
INPUTMIXED = 0xF2,
INPUTMIXEDREG = 0xF3,
RAWINPUT = 0xF4,
STARTINPUT = 0x61,
STOPINPUT = 0x62,
READSOCKETC = 0x63,
@@ -688,15 +689,12 @@ class Instruction(object):
self.args = list(args)
if not kwargs.get('copying', False):
self.check_args()
if not program.FIRST_PASS:
if kwargs.get('add_to_prog', True):
program.curr_block.instructions.append(self)
if program.DEBUG:
self.caller = [frame[1:] for frame in inspect.stack()[1:]]
else:
self.caller = None
if program.EMULATE:
self.execute()
if kwargs.get('add_to_prog', True):
program.curr_block.instructions.append(self)
if program.DEBUG:
self.caller = [frame[1:] for frame in inspect.stack()[1:]]
else:
self.caller = None
Instruction.count += 1
if Instruction.count % 100000 == 0:
@@ -717,10 +715,6 @@ class Instruction(object):
def get_bytes(self):
return bytearray(self.get_encoding())
def execute(self):
""" Emulate execution of this instruction """
raise NotImplementedError('execute method must be implemented')
def check_args(self):
""" Check the args match up with that specified in arg_format """
try:
@@ -839,21 +833,12 @@ class VectorInstruction(Instruction):
class AddBase(Instruction):
__slots__ = []
def execute(self):
self.args[0].value = (self.args[1].value + self.args[2].value) % program.P
class SubBase(Instruction):
__slots__ = []
def execute(self):
self.args[0].value = (self.args[1].value - self.args[2].value) % program.P
class MulBase(Instruction):
__slots__ = []
def execute(self):
self.args[0].value = (self.args[1].value * self.args[2].value) % program.P
###
### Basic arithmetic with immediate values
###
@@ -861,9 +846,6 @@ class MulBase(Instruction):
class ImmediateBase(Instruction):
__slots__ = ['op']
def execute(self):
exec('self.args[0].value = self.args[1].value.%s(self.args[2]) %% program.P' % self.op)
class SharedImmediate(ImmediateBase):
__slots__ = []
arg_format = ['sw', 's', 'i']
@@ -1023,10 +1005,7 @@ class CISC(Instruction):
def __init__(self, *args):
self.args = args
self.check_args()
#if EMULATE:
# self.expand()
if not program.FIRST_PASS:
self.expand()
self.expand()
def expand(self):
""" Expand this into a sequence of RISC instructions. """

View File

@@ -342,8 +342,10 @@ class Function:
x.reg_type)))
runtime_args = [None] * len(args)
for t in sorted(type_args, key=lambda x: x.reg_type):
for i,i_arg in enumerate(type_args[t]):
i = 0
for i_arg in type_args[t]:
runtime_args[i_arg] = t.load_mem(bases[t] + i)
i += util.mem_size(t)
return self.function(*(list(compile_args) + runtime_args))
self.on_first_call(wrapped_function)
self.type_args[len(args)] = type_args
@@ -354,10 +356,12 @@ class Function:
for i,reg_type in enumerate(sorted(type_args,
key=lambda x: x.reg_type)):
store_in_mem(bases[reg_type], base + i)
for j,i_arg in enumerate(type_args[reg_type]):
j = 0
for i_arg in type_args[reg_type]:
if get_reg_type(args[i_arg]) != reg_type:
raise CompilerError('type mismatch')
store_in_mem(args[i_arg], bases[reg_type] + j)
j += util.mem_size(reg_type)
return self.on_call(base, bases)
class FunctionTape(Function):

View File

@@ -60,27 +60,21 @@ def sigmoid_prime(x):
return sx * (1 - sx)
@vectorize
def approx_sigmoid(x):
if approx_sigmoid.special and \
get_program().options.ring and get_program().use_edabit():
l = int(get_program().options.ring)
r, r_bits = sint.get_edabit(x.k, False)
c = ((x.v - r) << (l - x.k)).reveal() >> (l - x.k)
c_bits = c.bit_decompose(x.k)
lower_overflow = CarryOutRawLE(c_bits[:x.f - 1], r_bits[:x.f - 1])
higher_bits = sbitint.bit_adder(c_bits[x.f - 1:], r_bits[x.f - 1:],
lower_overflow)
sign = higher_bits[-1]
higher_bits.pop(-1)
aa = sign & ~util.tree_reduce(operator.and_, higher_bits)
bb = ~sign & ~util.tree_reduce(operator.and_, [~x for x in higher_bits])
a, b = (sint.conv(x) for x in (aa, bb))
def approx_sigmoid(x, n=3):
if n == 5:
cuts = [-5, -2.5, 2.5, 5]
le = [0] + [x <= cut for cut in cuts] + [1]
select = [le[i + 1] - le[i] for i in range(5)]
outputs = [cfix(10 ** -4),
0.02776 * x + 0.145,
0.17 * x + 0.5,
0.02776 * x + 0.85498,
cfix(1 - 10 ** -4)]
return sum(a * b for a, b in zip(select, outputs))
else:
a = x < -0.5
b = x > 0.5
return a.if_else(0, b.if_else(1, 0.5 + x))
approx_sigmoid.special = False
return a.if_else(0, b.if_else(1, 0.5 + x))
def lse_0_from_e_x(x, e_x):
return sanitize(-x, log_e(1 + e_x), x + 2 ** -x.f, 0)
@@ -144,7 +138,7 @@ class Output(Layer):
def eval(self, size, base=0):
if self.approx:
return approx_sigmoid(self.X.get_vector(base, size))
return approx_sigmoid(self.X.get_vector(base, size), self.approx)
else:
return sigmoid_from_e_x(self.X.get_vector(base, size),
self.e_x.get_vector(base, size))
@@ -531,7 +525,7 @@ class QuantConv2d(QuantConvBase):
_, inputs_h, inputs_w, n_channels_in = self.input_shape
return weights_h * weights_w * n_channels_in
def forward(self, batch):
def forward(self, batch=[None]):
assert len(batch) == 1
assert(self.weight_shape[0] == self.output_shape[-1])

View File

@@ -39,11 +39,11 @@ class Program(object):
These are created by executing a file containing appropriate instructions
and threads. """
def __init__(self, args, options, param=-1, assemblymode=False):
def __init__(self, args, options, param=-1):
self.options = options
self.verbose = options.verbose
self.args = args
self.init_names(args, assemblymode)
self.init_names(args)
self.P = P_VALUES[param]
self.param = param
if (param != -1) + sum(x != 0 for x in(options.ring, options.field,
@@ -65,8 +65,6 @@ class Program(object):
self.tape_counter = 0
self.tapes = []
self._curr_tape = None
self.EMULATE = True # defaults
self.FIRST_PASS = False
self.DEBUG = False
self.main_thread_running = False
self.allocated_mem = RegType.create_dict(lambda: USER_MEM)
@@ -102,9 +100,8 @@ class Program(object):
self.use_dabit = options.mixed
self._edabit = options.edabit
self._split = False
self._square = False
Program.prog = self
self.reset_values()
def get_args(self):
return self.args
@@ -131,7 +128,7 @@ class Program(object):
res = max(res, sum(running.values()))
return res
def init_names(self, args, assemblymode):
def init_names(self, args):
# ignore path to file - source must be in Programs/Source
if 'Programs' in os.listdir(os.getcwd()):
# compile prog in ./Programs/Source directory
@@ -153,8 +150,6 @@ class Program(object):
if os.path.exists(args[0]):
self.infile = args[0]
elif assemblymode:
self.infile = self.programs_dir + '/Source/' + progname + '.asm'
else:
self.infile = self.programs_dir + '/Source/' + progname + '.mpc'
"""
@@ -234,40 +229,6 @@ class Program(object):
else:
self.req_num += tape.req_num
def read_memory(self, filename):
""" Read the clear and shared memory from a file """
f = open(filename)
n = int(next(f))
self.mem_c = [0]*n
self.mem_s = [0]*n
mem = self.mem_c
done_c = False
for line in f:
line = line.split(' ')
a = int(line[0])
b = int(line[1])
if a != -1:
mem[a] = b
elif done_c:
break
else:
mem = self.mem_s
done_c = True
def get_memory(self, mem_type, i):
if mem_type == 'c':
return self.mem_c[i]
elif mem_type == 's':
return self.mem_s[i]
raise CompilerError('Invalid memory type')
def reset_values(self):
""" Reset register and memory values. """
for tape in self.tapes:
tape.reset_registers()
self.mem_c = list(range(USER_MEM + TMP_MEM))
self.mem_s = list(range(USER_MEM + TMP_MEM))
def write_bytes(self, outfile=None):
""" Write all non-empty threads and schedule to files. """
# runtime doesn't support 'new-style' parallelism yet
@@ -329,17 +290,6 @@ class Program(object):
tape.write_str(self.options.asmoutfile + '-' + tape.name)
tape.purge()
def emulate(self):
""" Emulate execution of entire program. """
self.reset_values()
for sch in self.schedule:
if sch[0] == 'start':
for tape in sch[1]:
self._curr_tape = tape
for block in tape.basicblocks:
for line in block.instructions:
line.execute()
def restart_main_thread(self):
if self.main_thread_running:
# wait for main thread to finish
@@ -375,6 +325,10 @@ class Program(object):
if size == 0:
return
if isinstance(mem_type, type):
try:
size *= math.ceil(mem_type.n / mem_type.unit)
except AttributeError:
pass
self.types[mem_type.reg_type] = mem_type
mem_type = mem_type.reg_type
elif reg_type is not None:
@@ -447,6 +401,12 @@ class Program(object):
assert change in (2, 3)
self._split = change
def use_square(self, change=None):
if change is None:
return self._square
else:
self._square = change
class Tape:
""" A tape contains a list of basic blocks, onto which instructions are added. """
def __init__(self, name, program):
@@ -588,7 +548,6 @@ class Tape:
#print 'Compiling basic block', sub.name
def init_registers(self):
self.reset_registers()
self.reg_counter = RegType.create_dict(lambda: 0)
def init_names(self, name):
@@ -605,7 +564,6 @@ class Tape:
for block in self.basicblocks:
block.purge()
self._is_empty = (len(self.basicblocks) == 0)
del self.reg_values
del self.basicblocks
del self.active_basicblock
self.purged = True
@@ -840,13 +798,6 @@ class Tape:
else:
return self.reg_counter[reg_type]
def reset_registers(self):
""" Reset register values to zero. """
self.reg_values = RegType.create_dict(lambda: [])
def get_value(self, reg_type, i):
return self.reg_values[reg_type][i]
def __str__(self):
return self.name
@@ -883,12 +834,28 @@ class Tape:
def cost(self):
return sum(num * COST[req[0]][req[1]] for req,num in list(self.items()) \
if req[1] != 'input' and req[0] != 'edabit')
def pretty(self):
t = lambda x: 'integer' if x == 'modp' else x
res = []
for req, num in self.items():
domain = t(req[0])
n = '%12.0f' % num
if req[1] == 'input':
res += ['%s %s inputs from player %d' \
% (n, domain, req[2])]
elif domain.endswith('edabit'):
if domain == 'sedabit':
eda = 'strict edabits'
else:
eda = 'loose edabits'
res += ['%s %s of length %d' % (n, eda, req[1])]
elif req[0] != 'all':
res += ['%s %s %ss' % (n, domain, req[1])]
if self['all','round']:
res += ['% 12.0f virtual machine rounds' % self['all','round']]
return res
def __str__(self):
return ", ".join('%s inputs in %s from player %d' \
% (num, req[0], req[2]) \
if req[1] == 'input' \
else '%s %ss in %s' % (num, req[1], req[0]) \
for req,num in list(self.items()))
return ', '.join(self.pretty())
def __repr__(self):
return repr(dict(self))
@@ -959,14 +926,12 @@ class Tape:
"""
Class for creating new registers. The register's index is automatically assigned
based on the block's reg_counter dictionary.
The 'value' property is for emulation.
"""
__slots__ = ["reg_type", "program", "absolute_i", "relative_i", \
"size", "vector", "vectorbase", "caller", \
"can_eliminate"]
def __init__(self, reg_type, program, value=None, size=None, i=None):
def __init__(self, reg_type, program, size=None, i=None):
""" Creates a new register.
reg_type must be one of those defined in RegType. """
if Compiler.instructions_base.get_global_instruction_type() == 'gf2n':
@@ -989,8 +954,6 @@ class Tape:
else:
self.i = float('inf')
self.vector = []
if value is not None:
self.value = value
self.can_eliminate = True
if Program.prog.DEBUG:
self.caller = [frame[1:] for frame in inspect.stack()[1:]]
@@ -1010,22 +973,9 @@ class Tape:
def set_size(self, size):
if self.size == size:
return
elif not self.program.program.options.assemblymode:
else:
raise CompilerError('Mismatch of instruction and register size:'
' %s != %s' % (self.size, size))
elif self.size == 1 and self.vectorbase is self:
if '%s%d' % (self.reg_type, self.i) in compilerLib.VARS:
# create vector register in assembly mode
self.size = size
self.vector = [self]
for i in range(1,size):
reg = compilerLib.VARS['%s%d' % (self.reg_type, self.i + i)]
reg.set_vectorbase(self)
self.vector.append(reg)
else:
raise CompilerError('Cannot find %s in VARS' % str(self))
else:
raise CompilerError('Cannot reset size of vector register')
def set_vectorbase(self, vectorbase):
if self.vectorbase is not self:
@@ -1074,16 +1024,6 @@ class Tape:
def copy(self):
return Tape.Register(self.reg_type, Program.prog.curr_tape)
@property
def value(self):
return self.program.reg_values[self.reg_type][self.i]
@value.setter
def value(self, val):
while (len(self.program.reg_values[self.reg_type]) <= self.i):
self.program.reg_values[self.reg_type] += [0] * INIT_REG_MAX
self.program.reg_values[self.reg_type][self.i] = val
@property
def is_gf2n(self):
return self.reg_type == RegType.ClearGF2N or \

View File

@@ -232,7 +232,7 @@ def inputmixed(*args):
if isinstance(args[-1], int):
instructions.inputmixed(*args)
else:
instructions.inputmixedreg(*args)
instructions.inputmixedreg(*(args[:-1] + (regint.conv(args[-1]),)))
class _number(object):
""" Number functionality. """
@@ -762,7 +762,11 @@ class cint(_clear, _int):
def load_int(self, val):
if val:
# +1 for sign
program.curr_tape.require_bit_length(1 + int(math.ceil(math.log(abs(val)))))
bit_length = 1 + int(math.ceil(math.log(abs(val))))
if program.options.ring:
assert(bit_length <= int(program.options.ring))
elif program.param != -1 or program.options.field:
program.curr_tape.require_bit_length(bit_length)
if self.in_immediate_range(val):
ldi(self, val)
else:
@@ -783,7 +787,7 @@ class cint(_clear, _int):
sum += sign * chunk
@vectorize
def to_regint(self, n_bits=None, dest=None):
def to_regint(self, n_bits=64, dest=None):
""" Convert to regint.
:param n_bits: bit length (int)
@@ -1146,23 +1150,6 @@ class regint(_register, _int):
else:
return res
@vectorized_classmethod
def read_client_public_key(cls, client_id):
""" Receive 8 register values from socket containing client public key."""
res = [cls() for i in range(8)]
readclientpublickey(client_id, *res)
return res
@vectorized_classmethod
def init_secure_socket(cls, client_id, w1, w2, w3, w4, w5, w6, w7, w8):
""" Use 8 register values containing client public key."""
initsecuresocket(client_id, w1, w2, w3, w4, w5, w6, w7, w8)
@vectorized_classmethod
def resp_secure_socket(cls, client_id, w1, w2, w3, w4, w5, w6, w7, w8):
""" Receive 8 register values from socket containing client public key."""
respsecuresocket(client_id, w1, w2, w3, w4, w5, w6, w7, w8)
@vectorize
def write_to_socket(self, client_id, message_type=ClientMessageType.NoType):
writesocketint(client_id, message_type, self)
@@ -1439,7 +1426,7 @@ class _secret(_register):
def get_input_from(cls, player):
""" Secret input from player.
:param: player (compile-time int) """
:param player: public (regint/cint/int) """
res = cls()
asm_input(res, player)
return res
@@ -1648,9 +1635,12 @@ class _secret(_register):
@vectorize
def square(self):
""" Secret square. """
res = self.__class__()
sqrs(res, self)
return res
if program.use_square():
res = self.__class__()
sqrs(res, self)
return res
else:
return self * self
@set_instruction_type
@vectorize
@@ -1712,7 +1702,7 @@ class sint(_secret, _int):
def get_input_from(cls, player):
""" Secret input.
:param player: compile-time integer (int) """
:param player: public (regint/cint/int) """
res = cls()
inputmixed('int', res, player)
return res
@@ -1757,8 +1747,7 @@ class sint(_secret, _int):
@classmethod
def get_raw_input_from(cls, player):
res = cls()
startinput(player, 1)
stopinput(player, res)
rawinput(player, res)
return res
@classmethod
@@ -2056,8 +2045,7 @@ class sgf2n(_secret, _gf2n):
@classmethod
def get_raw_input_from(cls, player):
res = cls()
gstartinput(player, 1)
gstopinput(player, res)
grawinput(player, res)
return res
def add(self, other):
@@ -3293,7 +3281,7 @@ class sfix(_fix):
def get_input_from(cls, player):
""" Secret fixed-point input.
:param player: int """
:param player: public (regint/cint/int) """
v = cls.int_type()
inputmixed('fix', v, cls.f, player)
return cls._new(v)
@@ -3674,7 +3662,7 @@ class sfloat(_number, _structure):
def get_input_from(cls, player):
""" Secret floating-point input.
:param player: int """
:param player: public (regint/cint/int) """
v = sint()
p = sint()
z = sint()
@@ -4195,7 +4183,7 @@ class Array(object):
def input_from(self, player, budget=None):
""" Fill with inputs from player if supported by type.
:param player: compile-time (int) """
:param player: public (regint/cint/int) """
self.assign(self.value_type.get_input_from(player, size=len(self)))
def __add__(self, other):
@@ -4351,7 +4339,7 @@ class SubMultiArray(object):
def input_from(self, player, budget=None):
""" Fill with inputs from player if supported by type.
:param player: compile-time (int) """
:param player: public (regint/cint/int) """
@library.for_range_opt(self.sizes[0], budget=budget)
def _(i):
self[i].input_from(player, budget=budget)
@@ -4597,6 +4585,12 @@ class Matrix(MultiArray):
MultiArray.__init__(self, [rows, columns], value_type, debug=debug, \
address=address)
def set_column(self, index, vector):
assert self.value_type.n_elements() == 1
addresses = regint.inc(self.sizes[0], self.address + index,
self.sizes[1])
vector.store_in_mem(addresses)
class VectorArray(object):
def __init__(self, length, value_type, vector_size, address=None):
self.array = Array(length * vector_size, value_type, address)

View File

@@ -182,6 +182,12 @@ def expand(x, size):
except AttributeError:
return x
def mem_size(x):
try:
return x.mem_size()
except AttributeError:
return 1
class set_by_id(object):
def __init__(self, init=[]):
self.content = {}

View File

@@ -1,4 +1,32 @@
The ExternalIO directory contains examples of managing I/O between external client processes and SPDZ parties running SPDZ engines. These instructions assume that SPDZ has been built as per the [project readme](../README.md).
The ExternalIO directory contains an example of managing I/O between external client processes and SPDZ parties running SPDZ engines. These instructions assume that SPDZ has been built as per the [project readme](../README.md).
## Working Examples
[bankers-bonus-client.cpp](./bankers-bonus-client.cpp) acts as a
client to [bankers_bonus.mpc](../Programs/Source/bankers_bonus.mpc)
and demonstrates sending input and receiving output as described by
[Damgård et al.](https://eprint.iacr.org/2015/1006) The computation
allows up to eight clients to input a number and computes the client
with the largest input. You can run it as follows from the main
directory:
```
make bankers-bonus-client.x
./compile.py bankers_bonus 1
Scripts/setup-ssl.sh <nparties>
Scripts/setup-clients.sh 3
Scripts/<protocol>.sh &
./bankers-bonus-client.x 0 <nparties 100 0 &
./bankers-bonus-client.x 1 <nparties> 200 0 &
./bankers-bonus-client.x 2 <nparties> 50 1
```
This should output that the winning id is 1. Note that the ids have to
be incremental, and the client with the highest id has to input 1 as
the last argument while the others have to input 0 there. Furthermore,
`<nparties>` refers to the number of parties running the computation
not the number of clients, and `<protocol>` can be the name of
protocol script. The setup scripts generate the necessary SSL
certificates and keys. Therefore, if you run the computation on
different hosts, you will have to distribute the `*.pem` files.
## I/O MPC Instructions
@@ -55,49 +83,3 @@ Receive shares of private inputs from a client, blocking on client send. This is
*message_type* - optional integer which will be sent in first 4 bytes of message, to indicate message type to client.
*[inputs]* - returned list of shares of private input.
## Securing communications
Two cryptographic protocols have been implemented for use in particular applications and are included here for completeness:
1. Communication security using a Station to Station key agreement and libsodium Secret Box using a nonce counter for message ordering.
2. Authenticated Diffie-Hellman without message ordering.
Please note these are **NOT** required to allow external client I/O. Your mileage may vary, for example in a web setting TLS may be sufficient to secure communications between processes.
[client-setup.cpp](../client-setup.cpp) is a utility which is run to generate the key material for both the external clients and SPDZ parties for both protocols.
#### MPC instructions
**regint.init_secure_socket**(*regint client_socket_id*, *[regint] public_signing_key*)
STS protocol initiator. Read a client public key for a specified client connection and negotiate a shared key via STS. All subsequent write_socket / read_socket instructions are encrypted / decrypted for replay resistant commsec.
*client_socket_id* - an identifier used to refer to the client socket.
*public_signing_key* - client public key supplied as list of 8 32-bit ints.
**regint.resp_secure_socket**(*regint client_socket_id*, *[regint] public_signing_key*)
STS protocol responder. Read a client public key for a specified client connection and negotiate a shared key via STS. All subsequent write_socket / read_socket instructions are encrypted / decrypted for replay resistant commsec.
*client_socket_id* - an identifier used to refer to the client socket.
*public_signing_key* - client public key supplied as list of 8 32-bit ints.
*[regint public_key]* **regint.read_client_public_key**(*regint client_socket_id*)
Instruction to read the client public key and run setup for the authenticated Diffie-Hellman encryption. All subsequent write_socket instructions are encrypted. Only the sint.read_from_socket instruction is encrypted.
*client_socket_id* - an identifier used to refer to the client socket.
*public_key* - client public key made available to mpc programs as list of 8 32-bit ints.
## Working Examples
See [bankers-bonus-client.cpp](./bankers-bonus-client.cpp) which acts as a client to [bankers_bonus.mpc](../Programs/Source/bankers_bonus.mpc) and demonstrates sending input and receiving output with no communications security.
See [bankers-bonus-commsec-client.cpp](./bankers-bonus-commsec-client.cpp) which acts as a client to [bankers_bonus_commsec.mpc](../Programs/Source/bankers_bonus_commsec.mpc) which runs the same algorithm but includes both the available crypto protocols.
More instructions on how to run these are provided in the *-client files.

View File

@@ -16,11 +16,10 @@
* - share of random value [r]
* - share of winning unique id * random value [w]
* winning unique id is valid if ∑ [y] * ∑ [r] = ∑ [w]
*
* No communications security is used.
*
* To run with 2 parties / SPDZ engines:
* ./Scripts/setup-online.sh to create triple shares for each party (spdz engine).
* ./Scripts/setup-clients.sh to create SSL keys and certificates for clients
* ./compile.py bankers_bonus
* ./Scripts/run-online.sh bankers_bonus to run the engines.
*
@@ -34,6 +33,7 @@
#include "Math/gfp.h"
#include "Math/gf2n.h"
#include "Networking/sockets.h"
#include "Networking/ssl_sockets.h"
#include "Tools/int.h"
#include "Math/Setup.h"
#include "Protocols/fake-stuff.h"
@@ -46,12 +46,13 @@
// Send the private inputs masked with a random value.
// Receive shares of a preprocessed triple from each SPDZ engine, combine and check the triples are valid.
// Add the private input value to triple[0] and send to each spdz engine.
void send_private_inputs(const vector<gfp>& values, vector<int>& sockets, int nparties)
template<class T>
void send_private_inputs(const vector<T>& values, vector<ssl_socket*>& sockets, int nparties)
{
int num_inputs = values.size();
octetStream os;
vector< vector<gfp> > triples(num_inputs, vector<gfp>(3));
vector<gfp> triple_shares(3);
vector< vector<T> > triples(num_inputs, vector<T>(3));
vector<T> triple_shares(3);
// Receive num_inputs triples from SPDZ
for (int j = 0; j < nparties; j++)
@@ -59,6 +60,10 @@ void send_private_inputs(const vector<gfp>& values, vector<int>& sockets, int np
os.reset_write_head();
os.Receive(sockets[j]);
#ifdef VERBOSE_COMM
cerr << "received " << os.get_length() << " from " << j << endl;
#endif
for (int j = 0; j < num_inputs; j++)
{
for (int k = 0; k < 3; k++)
@@ -72,49 +77,30 @@ void send_private_inputs(const vector<gfp>& values, vector<int>& sockets, int np
// Check triple relations (is a party cheating?)
for (int i = 0; i < num_inputs; i++)
{
if (triples[i][0] * triples[i][1] != triples[i][2])
if (T(triples[i][0] * triples[i][1]) != triples[i][2])
{
cerr << triples[i][2] << " != " << triples[i][0] << " * " << triples[i][1] << endl;
cerr << "Incorrect triple at " << i << ", aborting\n";
exit(1);
throw mac_fail();
}
}
// Send inputs + triple[0], so SPDZ can compute shares of each value
os.reset_write_head();
for (int i = 0; i < num_inputs; i++)
{
gfp y = values[i] + triples[i][0];
T y = values[i] + triples[i][0];
y.pack(os);
}
for (int j = 0; j < nparties; j++)
os.Send(sockets[j]);
}
// Assumes that Scripts/setup-online.sh has been run to compute prime
void initialise_fields(const string& dir_prefix)
{
int lg2;
bigint p;
string filename = dir_prefix + "Params-Data";
cout << "loading params from: " << filename << endl;
ifstream inpf(filename.c_str());
if (inpf.fail()) { throw file_error(filename.c_str()); }
inpf >> p;
inpf >> lg2;
inpf.close();
gfp::init_field(p);
gf2n::init_field(lg2);
}
// Receive shares of the result and sum together.
// Also receive authenticating values.
gfp receive_result(vector<int>& sockets, int nparties)
template<class T>
T receive_result(vector<ssl_socket*>& sockets, int nparties)
{
vector<gfp> output_values(3);
vector<T> output_values(3);
octetStream os;
for (int i = 0; i < nparties; i++)
{
@@ -122,20 +108,32 @@ gfp receive_result(vector<int>& sockets, int nparties)
os.Receive(sockets[i]);
for (unsigned int j = 0; j < 3; j++)
{
gfp value;
T value;
value.unpack(os);
output_values[j] += value;
}
}
if (output_values[0] * output_values[1] != output_values[2])
if (T(output_values[0] * output_values[1]) != output_values[2])
{
cerr << "Unable to authenticate output value as correct, aborting." << endl;
exit(1);
throw mac_fail();
}
return output_values[0];
}
template<class T>
void run(int salary_value, vector<ssl_socket*>& sockets, int nparties)
{
// Run the computation
send_private_inputs<T>({salary_value}, sockets, nparties);
cout << "Sent private inputs to each SPDZ engine, waiting for result..." << endl;
// Get the result back (client_id of winning client)
T result = receive_result<T>(sockets, nparties);
cout << "Winning client id is : " << result << endl;
}
int main(int argc, char** argv)
{
@@ -162,34 +160,65 @@ int main(int argc, char** argv)
if (argc > 6)
port_base = atoi(argv[6]);
// init static gfp
string prep_data_prefix = get_prep_dir(nparties, 128, gf2n::default_degree());
initialise_fields(prep_data_prefix);
bigint::init_thread();
// Setup connections from this client to each party socket
vector<int> sockets(nparties);
vector<int> plain_sockets(nparties);
vector<ssl_socket*> sockets(nparties);
ssl_ctx ctx("C" + to_string(my_client_id));
ssl_service io_service;
octetStream specification;
for (int i = 0; i < nparties; i++)
{
set_up_client_socket(sockets[i], host_name.c_str(), port_base + i);
send(sockets[i], (octet*) &my_client_id, sizeof(int));
set_up_client_socket(plain_sockets[i], host_name.c_str(), port_base + i);
send(plain_sockets[i], (octet*) &my_client_id, sizeof(int));
sockets[i] = new ssl_socket(io_service, ctx, plain_sockets[i],
"P" + to_string(i), "C" + to_string(my_client_id), true);
if (i == 0)
specification.Receive(sockets[0]);
octetStream os;
os.store(finish);
os.Send(sockets[i]);
}
cout << "Finish setup socket connections to SPDZ engines." << endl;
// Run the commputation
send_private_inputs({salary_value}, sockets, nparties);
cout << "Sent private inputs to each SPDZ engine, waiting for result..." << endl;
int type = specification.get<int>();
switch (type)
{
case 'p':
{
gfp::init_field(specification.get<bigint>());
cerr << "using prime " << gfp::pr() << endl;
run<gfp>(salary_value, sockets, nparties);
break;
}
case 'R':
{
int R = specification.get<int>();
switch (R)
{
case 64:
run<Z2<64>>(salary_value, sockets, nparties);
break;
case 104:
run<Z2<104>>(salary_value, sockets, nparties);
break;
case 128:
run<Z2<128>>(salary_value, sockets, nparties);
break;
default:
cerr << R << "-bit ring not implemented";
exit(1);
}
break;
}
default:
cerr << "Type " << type << " not implemented";
exit(1);
}
// Get the result back (client_id of winning client)
gfp result = receive_result(sockets, nparties);
cout << "Winning client id is : " << result << endl;
for (unsigned int i = 0; i < sockets.size(); i++)
close_client_socket(sockets[i]);
for (int i = 0; i < nparties; i++)
delete sockets[i];
return 0;
}

View File

@@ -1,407 +0,0 @@
/*
* Demonstrate external client inputing and receiving outputs from a SPDZ process,
* following the protocol described in https://eprint.iacr.org/2015/1006.pdf.
* Uses SPDZ implemented encryption for external client communication, see bankers-bonus-client.cpp
* for a simpler client with no crypto.
*
* Provides a client to bankers_bonus_commsec.mpc program to calculate which banker pays for lunch based on
* the private value annual bonus. Up to 8 clients can connect to the SPDZ engines running
* the bankers_bonus.mpc program.
*
* Each connecting client:
* - sends an increasing id to identify the client, starting with 0
* - sends an integer (0 meaining more players will join this round or 1 meaning stop the round and calc the result).
* - runs crypto setup to demonstrate both DH Auth Encryption and STS protocol for comms security.
* - sends an integer input (bonus value to compare)
*
* The result is returned authenticated with a share of a random value:
* - share of winning unique id [y]
* - share of random value [r]
* - share of winning unique id * random value [w]
* winning unique id is valid if ∑ [y] * ∑ [r] = ∑ [w]
*
* To run with 2 parties (SPDZ engines) and 3 external clients:
* ./Scripts/setup-online.sh to create triple shares for each party (spdz engine).
* ./client-setup.x 2 -nc 3 to create the crypto key material for both parties and clients.
* ./compile.py bankers_bonus_commsec
* ./Scripts/run-online.sh bankers_bonus_commsec to run the engines.
*
* ./bankers-bonus-commsec-client.x 0 2 100 0
* ./bankers-bonus-commsec-client.x 1 2 200 0
* ./bankers-bonus-commsec-client.x 2 2 50 1
*
* Expect winner to be second client with id 1.
* Note here client id must match id used in generating client key material, Client-Keys-C<client id>.
*/
#include "Math/gfp.h"
#include "Math/gf2n.h"
#include "Networking/sockets.h"
#include "Networking/STS.h"
#include "Tools/int.h"
#include "Math/Setup.h"
#include "Protocols/fake-stuff.h"
#include <sodium.h>
#include <iostream>
#include <iomanip>
#include <sstream>
#include <fstream>
typedef pair< vector<octet>, vector<octet> > keypair_t; // A pair of send/recv keys for talking to SPDZ
typedef vector< keypair_t > commsec_t; // A database of send/recv keys indexed by server
typedef struct {
unsigned char client_secretkey[crypto_sign_SECRETKEYBYTES];
unsigned char client_publickey[crypto_sign_PUBLICKEYBYTES];
vector<int> client_publickey_ints;
vector< vector<unsigned char> >server_publickey;
} sign_key_container_t;
keypair_t sts_response_role_exceptions(sign_key_container_t keys, vector<int>& sockets, int server_id)
{
STS ke(&keys.server_publickey[server_id][0], keys.client_publickey, keys.client_secretkey);
sts_msg1_t m1;
sts_msg2_t m2;
sts_msg3_t m3;
octetStream os;
os.Receive(sockets[server_id]);
os.consume(m1.bytes, sizeof m1.bytes);
m2 = ke.recv_msg1(m1);
os.reset_write_head();
os.append(m2.pubkey, sizeof m2.pubkey);
os.append(m2.sig, sizeof m2.sig);
os.Send(sockets[server_id]);
os.Receive(sockets[server_id]);
os.consume(m3.bytes, sizeof m3.bytes);
ke.recv_msg3(m3);
vector<unsigned char> recvKey = ke.derive_secret(crypto_secretbox_KEYBYTES);
vector<unsigned char> sendKey = ke.derive_secret(crypto_secretbox_KEYBYTES);
return make_pair(sendKey,recvKey);
}
keypair_t sts_initiator_role_exceptions(sign_key_container_t keys, vector<int>& sockets, int server_id)
{
STS ke(&keys.server_publickey[server_id][0], keys.client_publickey, keys.client_secretkey);
sts_msg1_t m1;
sts_msg2_t m2;
sts_msg3_t m3;
octetStream os;
m1 = ke.send_msg1();
cout << "m1: ";
for (unsigned int j = 0; j < 32; j++)
cout << setfill('0') << setw(2) << hex << (int) m1.bytes[j];
cout << dec << endl;
os.reset_write_head();
os.append(m1.bytes, sizeof m1.bytes);
os.Send(sockets[server_id]);
os.reset_write_head();
os.Receive(sockets[server_id]);
os.consume(m2.pubkey, sizeof m2.pubkey);
os.consume(m2.sig, sizeof m2.sig);
m3 = ke.recv_msg2(m2);
os.reset_write_head();
os.append(m3.bytes, sizeof m3.bytes);
os.Send(sockets[server_id]);
vector<unsigned char> sendKey = ke.derive_secret(crypto_secretbox_KEYBYTES);
vector<unsigned char> recvKey = ke.derive_secret(crypto_secretbox_KEYBYTES);
return make_pair(sendKey,recvKey);
}
pair< vector<octet>, vector<octet> > sts_response_role(sign_key_container_t keys, vector<int>& sockets, int server_id)
{
pair< vector<octet>, vector<octet> > res;
try {
res = sts_response_role_exceptions(keys, sockets, server_id);
} catch(char const *e) {
cerr << "Error in STS: " << e << endl;
exit(1);
}
return res;
}
pair< vector<octet>, vector<octet> > sts_initiator_role(sign_key_container_t keys, vector<int>& sockets, int server_id)
{
pair< vector<octet>, vector<octet> > res;
try {
res = sts_initiator_role_exceptions(keys, sockets, server_id);
} catch(char const *e) {
cerr << "Error in STS: " << e << endl;
exit(1);
}
return res;
}
// Send the private inputs masked with a random value.
// Receive shares of a preprocessed triple from each SPDZ engine, combine and check the triples are valid.
// Add the private input value to triple[0] and send to each spdz engine.
void send_private_inputs(const vector<gfp>& values, vector<int>& sockets, int nparties,
commsec_t commsec, vector<octet*>& keys)
{
int num_inputs = values.size();
octetStream os;
vector< vector<gfp> > triples(num_inputs, vector<gfp>(3));
vector<gfp> triple_shares(3);
// Receive num_inputs triples from SPDZ
for (int j = 0; j < nparties; j++)
{
os.reset_write_head();
os.Receive(sockets[j]);
os.decrypt_sequence(&commsec[j].second[0],0);
os.decrypt(keys[j]);
for (int j = 0; j < num_inputs; j++)
{
for (int k = 0; k < 3; k++)
{
triple_shares[k].unpack(os);
triples[j][k] += triple_shares[k];
}
}
}
// Check triple relations
for (int i = 0; i < num_inputs; i++)
{
if (triples[i][0] * triples[i][1] != triples[i][2])
{
cerr << "Incorrect triple at " << i << ", aborting\n";
exit(1);
}
}
// Send inputs + triple[0], so SPDZ can compute shares of each value
os.reset_write_head();
for (int i = 0; i < num_inputs; i++)
{
gfp y = values[i] + triples[i][0];
y.pack(os);
}
for (int j = 0; j < nparties; j++)
{
octetStream temp = os;
temp.encrypt_sequence(&commsec[j].first[0], 0);
temp.Send(sockets[j]);
}
}
// Send public key in clear to each SPDZ engine.
void send_public_key(vector<int>& pubkey, int socket)
{
octetStream os;
os.reset_write_head();
for (unsigned int i = 0; i < pubkey.size(); i++)
{
os.store(pubkey[i]);
}
os.Send(socket);
}
// Assumes that Scripts/setup-online.sh has been run to compute prime
void initialise_fields(const string& dir_prefix)
{
int lg2;
bigint p;
string filename = dir_prefix + "Params-Data";
cout << "loading params from: " << filename << endl;
ifstream inpf(filename.c_str());
if (inpf.fail()) { throw file_error(filename.c_str()); }
inpf >> p;
inpf >> lg2;
inpf.close();
gfp::init_field(p);
gf2n::init_field(lg2);
}
// Assumes that client-setup has been run to create key pairs for clients and parties
void generate_symmetric_keys(vector<octet*>& keys, vector<int>& client_public_key_ints,
sign_key_container_t *sts_key, const string& dir_prefix, int client_no)
{
unsigned char client_publickey[crypto_box_PUBLICKEYBYTES];
unsigned char client_secretkey[crypto_box_SECRETKEYBYTES];
unsigned char server_publickey[crypto_box_PUBLICKEYBYTES];
unsigned char scalarmult_q[crypto_scalarmult_BYTES];
crypto_generichash_state h;
// read client public/secret keys + SPDZ server public keys
ifstream keyfile;
stringstream client_filename;
client_filename << dir_prefix << "Client-Keys-C" << client_no;
keyfile.open(client_filename.str().c_str());
if (keyfile.fail())
throw file_error(client_filename.str());
keyfile.read((char*)client_publickey, sizeof client_publickey);
if (keyfile.eof())
throw end_of_file(client_filename.str(), "client public key" );
// Convert client public key unsigned char to int, reverse endianness
for(unsigned int j = 0; j < client_public_key_ints.size(); j++) {
int keybyte = 0;
for(unsigned int k = 0; k < 4; k++) {
keybyte = keybyte + (((int)client_publickey[j*4+k]) << ((3-k) * 8));
}
client_public_key_ints[j] = keybyte;
}
keyfile.read((char*)client_secretkey, sizeof client_secretkey);
if (keyfile.eof()) {
throw end_of_file(client_filename.str(), "client private key" );
}
keyfile.read((char*)sts_key->client_publickey, crypto_sign_PUBLICKEYBYTES);
keyfile.read((char*)sts_key->client_secretkey, crypto_sign_SECRETKEYBYTES);
// Convert client public key unsigned char to int, reverse endianness
sts_key->client_publickey_ints.resize(8);
for(unsigned int j = 0; j < sts_key->client_publickey_ints.size(); j++) {
int keybyte = 0;
for(unsigned int k = 0; k < 4; k++) {
keybyte = keybyte + (((int)sts_key->client_publickey[j*4+k]) << ((3-k) * 8));
}
sts_key->client_publickey_ints[j] = keybyte;
}
for (unsigned int i = 0; i < keys.size(); i++)
{
keys[i] = new octet[crypto_generichash_BYTES];
keyfile.read((char*)server_publickey, crypto_box_PUBLICKEYBYTES);
if (keyfile.eof())
throw end_of_file(client_filename.str(), "server public key for party " + to_string(i));
keyfile.read((char*)(&sts_key->server_publickey[i][0]), crypto_sign_PUBLICKEYBYTES);
if (keyfile.eof())
throw end_of_file(client_filename.str(), "server public signing key for party " + to_string(i));
// Derive a shared key from this server's secret key and the client's public key
// shared key = h(q || client_secretkey || server_publickey)
if (crypto_scalarmult(scalarmult_q, client_secretkey, server_publickey) != 0) {
cerr << "Scalar mult failed\n";
exit(1);
}
crypto_generichash_init(&h, NULL, 0U, crypto_generichash_BYTES);
crypto_generichash_update(&h, scalarmult_q, sizeof scalarmult_q);
crypto_generichash_update(&h, client_publickey, sizeof client_publickey);
crypto_generichash_update(&h, server_publickey, sizeof server_publickey);
crypto_generichash_final(&h, keys[i], crypto_generichash_BYTES);
}
keyfile.close();
cout << "My public key is: ";
for (unsigned int j = 0; j < 32; j++)
cout << setfill('0') << setw(2) << hex << (int) client_publickey[j];
cout << dec << endl;
}
// Receive shares of the result and sum together.
// Also receive authenticating values.
gfp receive_result(vector<int>& sockets, int nparties, commsec_t commsec, vector<octet*>& keys)
{
vector<gfp> output_values(3);
octetStream os;
for (int i = 0; i < nparties; i++)
{
os.reset_write_head();
os.Receive(sockets[i]);
os.decrypt_sequence(&commsec[i].second[0],1);
os.decrypt(keys[i]);
for (unsigned int j = 0; j < 3; j++)
{
gfp value;
value.unpack(os);
output_values[j] += value;
}
}
if (output_values[0] * output_values[1] != output_values[2])
{
cerr << "Unable to authenticate output value as correct, aborting." << endl;
exit(1);
}
return output_values[0];
}
int main(int argc, char** argv)
{
int my_client_id;
int nparties;
int salary_value;
int finish;
int port_base = 14000;
sign_key_container_t sts_key;
string host_name = "localhost";
if (argc < 5) {
cout << "Usage is external-client <client identifier> <number of spdz parties> "
<< "<salary to compare> <finish (0 false, 1 true)> <optional host name, default localhost> "
<< "<optional spdz party port base number, default 14000>" << endl;
exit(0);
}
my_client_id = atoi(argv[1]);
nparties = atoi(argv[2]);
salary_value = atoi(argv[3]);
finish = atoi(argv[4]);
if (argc > 5)
host_name = argv[5];
if (argc > 6)
port_base = atoi(argv[6]);
sts_key.server_publickey.resize(nparties);
for(int i = 0 ; i < nparties; i++) {
sts_key.server_publickey[i].resize(crypto_sign_PUBLICKEYBYTES);
}
// init static gfp
string prep_data_prefix = get_prep_dir(nparties, 128, gf2n::default_degree());
initialise_fields(prep_data_prefix);
bigint::init_thread();
// Generate session keys to decrypt data sent from each spdz engine (party)
vector<octet*> session_keys(nparties);
vector<int> client_public_key_ints(8);
generate_symmetric_keys(session_keys, client_public_key_ints, &sts_key, prep_data_prefix, my_client_id);
// Setup connections from this client to each party socket and send the client public keys
vector<int> sockets(nparties);
// vector< pair <vector<octet>, vector <octet> > > commseckey(nparties);
commsec_t commseckey(nparties);
for (int i = 0; i < nparties; i++)
{
set_up_client_socket(sockets[i], host_name.c_str(), port_base + i);
send(sockets[i], (octet*) &my_client_id, sizeof(int));
octetStream os;
os.store(finish);
os.Send(sockets[i]);
send_public_key(sts_key.client_publickey_ints, sockets[i]);
send_public_key(client_public_key_ints, sockets[i]);
commseckey[i] = sts_initiator_role(sts_key, sockets, i);
}
cout << "Finish setup socket connections to SPDZ engines." << endl;
// Send the inputs to the SPDZ Engines
send_private_inputs({salary_value}, sockets, nparties, commseckey, session_keys);
cout << "Sent private inputs to each SPDZ engine, waiting for result..." << endl;
// Get the result back
gfp result = receive_result(sockets, nparties, commseckey, session_keys);
cout << "Winning client id is : " << result << endl;
for (unsigned int i = 0; i < sockets.size(); i++)
close_client_socket(sockets[i]);
return 0;
}

View File

@@ -180,6 +180,13 @@ void FFT_Data::pack(octetStream& o) const
{
R.pack(o);
prData.pack(o);
o.store(root);
o.store(twop);
o.store(two_root);
o.store(b);
iphi.pack(o);
o.store(powers);
o.store(powers_i);
}
@@ -187,7 +194,13 @@ void FFT_Data::unpack(octetStream& o)
{
R.unpack(o);
prData.unpack(o);
init(R, prData);
o.get(root);
o.get(twop);
o.get(two_root);
o.get(b);
iphi.unpack(o);
o.get(powers);
o.get(powers_i);
}

View File

@@ -345,19 +345,38 @@ void init(Ring& Rg,int m)
Rg.pi.resize(Rg.phim); Rg.pi_inv.resize(Rg.mm);
for (int i=0; i<Rg.mm; i++) { Rg.pi_inv[i]=-1; }
int k=0;
for (int i=1; i<Rg.mm; i++)
{ if (gcd(i,Rg.mm)==1)
{ Rg.pi[k]=i;
Rg.pi_inv[i]=k;
k++;
if (((m - 1) & m) == 0)
{
// m is power of two
// no need to generate poly
int k = 0;
for (int i = 1; i < Rg.mm; i++)
{
// easy GCD
if (i % 2 == 1)
{
Rg.pi[k] = i;
Rg.pi_inv[i] = k;
k++;
}
}
}
else
{
int k=0;
for (int i=1; i<Rg.mm; i++)
{ if (gcd(i,Rg.mm)==1)
{ Rg.pi[k]=i;
Rg.pi_inv[i]=k;
k++;
}
}
ZZX P=Cyclotomic(Rg.mm);
Rg.poly.resize(Rg.phim+1);
for (int i=0; i<Rg.phim+1; i++)
{ Rg.poly[i]=to_int(coeff(P,i)); }
ZZX P=Cyclotomic(Rg.mm);
Rg.poly.resize(Rg.phim+1);
for (int i=0; i<Rg.phim+1; i++)
{ Rg.poly[i]=to_int(coeff(P,i)); }
}
}

View File

@@ -16,7 +16,9 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p,
{
if (sigma <= 0)
this->sigma = sigma = FHE_Params().get_R();
#ifdef VERBOSE
cerr << "Standard deviation: " << this->sigma << endl;
#endif
h += extra_h * sec;
produce_epsilon_constants();

View File

@@ -29,19 +29,27 @@ istream& operator>>(istream& s,Ring& R)
void Ring::pack(octetStream& o) const
{
o.store(mm);
o.store(phim);
o.store(pi);
o.store(pi_inv);
o.store(poly);
if (((mm - 1) & mm) != 0)
{
o.store(phim);
o.store(pi);
o.store(pi_inv);
o.store(poly);
}
}
void Ring::unpack(octetStream& o)
{
o.get(mm);
o.get(phim);
o.get(pi);
o.get(pi_inv);
o.get(poly);
if (((mm - 1) & mm) != 0)
{
o.get(phim);
o.get(pi);
o.get(pi_inv);
o.get(poly);
}
else
init(*this, mm);
}
bool Ring::operator !=(const Ring& other) const

View File

@@ -7,6 +7,7 @@
#include <vector>
#include <iostream>
#include <assert.h>
using namespace std;
#include "Tools/octetStream.h"
@@ -31,7 +32,7 @@ class Ring
int p(int i) const { return pi.at(i); }
int p_inv(int i) const { return pi_inv.at(i); }
const vector<int>& Phi() const { return poly; }
const vector<int>& Phi() const { assert(poly.size()); return poly; }
friend ostream& operator<<(ostream& s,const Ring& R);
friend istream& operator>>(istream& s,Ring& R);

View File

@@ -96,6 +96,8 @@ public:
res += x.from == from;
return res;
}
int n_interactive_inputs_from_me(int my_num);
};
#endif /* GC_ARGTUPLES_H_ */

View File

@@ -76,14 +76,7 @@ unsigned Instruction::get_mem(RegType reg_type) const
inline
void Instruction::parse(istream& s, int pos)
{
n = 0;
start.resize(0);
::memset(r, 0, sizeof(r));
int file_pos = s.tellg();
opcode = ::get_int(s);
parse_operands(s, pos, file_pos);
BaseInstruction::parse(s, pos);
switch(opcode)
{

View File

@@ -42,6 +42,7 @@ MAYBE_INLINE bool Instruction::execute(Processor<T>& processor,
cout << endl;
#endif
const Instruction& instruction = *this;
auto& Ci = processor.I;
switch (opcode)
{
#define X(NAME, CODE) case NAME: CODE; return true;

View File

@@ -65,6 +65,8 @@ public:
void reset(const U& program);
long long get_input(const int* params, bool interactive = false);
bigint get_long_input(const int* params, ProcessorBase& input_proc,
bool interactive = false);
void bitcoms(T& x, const vector<int>& regs) { x.bitcom(S, regs); }
void bitdecs(const vector<int>& regs, const T& x) { x.bitdec(S, regs); }
@@ -84,6 +86,10 @@ public:
template<class U>
void store_clear_in_dynamic(const vector<int>& args, U& dynamic_memory);
template<class U>
void mem_op(int n, Memory<U>& dest, const Memory<U>& source,
Integer dest_address, Integer source_address);
void xors(const vector<int>& args);
void andm(const ::BaseInstruction& instruction);
void and_(const vector<int>& args, bool repeat);
@@ -95,9 +101,9 @@ public:
void reveal(const ::BaseInstruction& instruction);
void print_reg(int reg, int n);
void print_reg(int reg, int n, int size);
void print_reg_plain(Clear& value);
void print_reg_signed(unsigned n_bits, Clear& value);
void print_reg_signed(unsigned n_bits, Integer value);
void print_chr(int n);
void print_str(int n);
void print_float(const vector<int>& args);

View File

@@ -67,11 +67,19 @@ void Processor<T>::reset(const U& program)
template<class T>
inline long long GC::Processor<T>::get_input(const int* params, bool interactive)
{
bigint res = ProcessorBase::get_input<FixInput>(interactive, &params[1]).items[0];
assert(params[0] <= 64);
return get_long_input(params, *this, interactive).get_si();
}
template<class T>
bigint GC::Processor<T>::get_long_input(const int* params,
ProcessorBase& input_proc, bool interactive)
{
bigint res = input_proc.get_input<FixInput_<bigint>>(interactive,
&params[1]).items[0];
int n_bits = *params;
check_input(res, n_bits);
assert(n_bits <= 64);
return res.get_si();
return res;
}
template<class T>
@@ -171,6 +179,17 @@ void GC::Processor<T>::store_clear_in_dynamic(const vector<int>& args,
T::store_clear_in_dynamic(dynamic_memory, accesses);
}
template<class T>
template<class U>
void Processor<T>::mem_op(int n, Memory<U>& dest, const Memory<U>& source,
Integer dest_address, Integer source_address)
{
for (int i = 0; i < n; i++)
{
dest[dest_address + i] = source[source_address + i];
}
}
template <class T>
void Processor<T>::xors(const vector<int>& args)
{
@@ -234,12 +253,15 @@ void Processor<T>::reveal(const vector<int>& args)
}
template <class T>
void Processor<T>::print_reg(int reg, int n)
void Processor<T>::print_reg(int reg, int n, int size)
{
#ifdef DEBUG_VALUES
cout << "print_reg " << typeid(T).name() << " " << reg << " " << &C[reg] << endl;
#endif
T::out << "Reg[" << reg << "] = " << hex << showbase << C[reg] << dec << " # ";
bigint output;
for (int i = 0; i < size; i++)
output += bigint((unsigned long)C[reg + i].get()) << (T::default_length * i);
T::out << "Reg[" << reg << "] = " << hex << showbase << output << dec << " # ";
print_str(n);
T::out << endl << flush;
}
@@ -251,14 +273,29 @@ void Processor<T>::print_reg_plain(Clear& value)
}
template <class T>
void Processor<T>::print_reg_signed(unsigned n_bits, Clear& value)
void Processor<T>::print_reg_signed(unsigned n_bits, Integer reg)
{
unsigned n_shift = 0;
if (n_bits > 1)
n_shift = sizeof(value.get()) * 8 - n_bits;
if (n_shift > 63)
n_shift = 0;
T::out << dec << (value.get() << n_shift >> n_shift) << flush;
if (n_bits <= Clear::N_BITS)
{
auto value = C[reg];
unsigned n_shift = 0;
if (n_bits > 1)
n_shift = sizeof(value.get()) * 8 - n_bits;
if (n_shift > 63)
n_shift = 0;
T::out << dec << (value.get() << n_shift >> n_shift) << flush;
}
else
{
bigint tmp = 0;
for (int i = 0; i < DIV_CEIL(n_bits, Clear::N_BITS); i++)
{
tmp += bigint((unsigned long)C[reg + i].get()) << (i * Clear::N_BITS);
}
if (tmp >= bigint(1) << (n_bits - 1))
tmp -= bigint(1) << n_bits;
T::out << dec << tmp << flush;
}
}
template <class T>

View File

@@ -21,6 +21,8 @@
#include <fstream>
class ProcessorBase;
namespace GC
{
@@ -116,6 +118,10 @@ public:
static void inputb(Processor<U>& processor, const vector<int>& args)
{ T::inputb(processor, args); }
template<class U>
static void inputb(Processor<U>& processor, ProcessorBase& input_proc,
const vector<int>& args)
{ T::inputb(processor, input_proc, args); }
template<class U>
static void reveal_inst(Processor<U>& processor, const vector<int>& args)
{ processor.reveal(args); }

View File

@@ -290,12 +290,14 @@ void Secret<T>::trans(Processor<U>& processor, int n_outputs,
const vector<int>& args)
{
int n_inputs = args.size() - n_outputs;
int dl = U::default_length;
for (int i = 0; i < n_outputs; i++)
{
processor.S[args[i]].resize_regs(n_inputs);
for (int j = 0; j < DIV_CEIL(n_inputs, dl); j++)
processor.S[args[i] + j].resize_regs(min(dl, n_inputs - j * dl));
for (int j = 0; j < n_inputs; j++)
processor.S[args[i]].registers[j] =
processor.S[args[n_outputs + j]].registers[i];
processor.S[args[i] + j / dl].registers[j % dl] =
processor.S[args[n_outputs + j] + i / dl].registers[i % dl];
}
}

View File

@@ -25,12 +25,25 @@ SemiSecret::MC* SemiSecret::new_mc(mac_key_type)
void SemiSecret::trans(Processor<SemiSecret>& processor, int n_outputs,
const vector<int>& args)
{
square64 square;
for (size_t i = n_outputs; i < args.size(); i++)
square.rows[i - n_outputs] = processor.S[args[i]].get();
square.transpose(args.size() - n_outputs, n_outputs);
for (int i = 0; i < n_outputs; i++)
processor.S[args[i]] = square.rows[i];
int N_BITS = default_length;
for (int j = 0; j < DIV_CEIL(n_outputs, N_BITS); j++)
for (int l = 0; l < DIV_CEIL(args.size() - n_outputs, N_BITS); l++)
{
square64 square;
size_t input_base = n_outputs + l * N_BITS;
for (size_t i = input_base;
i < min(input_base + N_BITS, args.size()); i++)
square.rows[i - input_base] = processor.S[args[i] + j].get();
square.transpose(
min(size_t(N_BITS), args.size() - n_outputs - l * N_BITS),
min(N_BITS, n_outputs - j * N_BITS));
int output_base = j * N_BITS;
for (int i = output_base; i < min(n_outputs, output_base + N_BITS);
i++)
{
processor.S[args[i] + l] = square.rows[i - output_base];
}
}
}
void SemiSecret::load_clear(int n, const Integer& x)

View File

@@ -20,6 +20,7 @@ using namespace std;
#include "Protocols/Replicated.h"
#include "Protocols/ReplicatedMC.h"
#include "Processor/DummyProtocol.h"
#include "Processor/ProcessorBase.h"
namespace GC
{
@@ -58,7 +59,10 @@ public:
{ and_(processor, args, false); }
static void and_(Processor<U>& processor, const vector<int>& args, bool repeat);
static void xors(Processor<U>& processor, const vector<int>& args);
static void inputb(Processor<U>& processor, const vector<int>& args);
static void inputb(Processor<U>& processor, const vector<int>& args)
{ inputb(processor, processor, args); }
static void inputb(Processor<U>& processor, ProcessorBase& input_processor,
const vector<int>& args);
static void reveal_inst(Processor<U>& processor, const vector<int>& args);
static void convcbit(Integer& dest, const Clear& source) { dest = source; }

View File

@@ -26,7 +26,7 @@ namespace GC
{
template<class U>
const int VectorSecret<U>::default_length;
const int ReplicatedSecret<U>::N_BITS;
template<class U>
const int ReplicatedSecret<U>::default_length;
@@ -92,6 +92,7 @@ void ShareSecret<U>::store_clear_in_dynamic(Memory<U>& mem,
template<class U>
void ShareSecret<U>::inputb(Processor<U>& processor,
ProcessorBase& input_processor,
const vector<int>& args)
{
auto& party = ShareThread<U>::s();
@@ -99,16 +100,22 @@ void ShareSecret<U>::inputb(Processor<U>& processor,
input.reset_all(*party.P);
InputArgList a(args);
bool interactive = Thread<U>::s().n_interactive_inputs_from_me(a) > 0;
bool interactive = a.n_interactive_inputs_from_me(party.P->my_num()) > 0;
int dl = U::default_length;
for (auto x : a)
{
if (x.from == party.P->my_num())
{
input.add_mine(processor.get_input(x.params, interactive), x.n_bits);
bigint whole_input = processor.get_long_input(x.params,
input_processor, interactive);
for (int i = 0; i < DIV_CEIL(x.n_bits, dl); i++)
input.add_mine(bigint(whole_input >> (i * dl)).get_si(),
min(dl, x.n_bits - i * dl));
}
else
input.add_other(x.from);
for (int i = 0; i < DIV_CEIL(x.n_bits, dl); i++)
input.add_other(x.from);
}
if (interactive)
@@ -120,8 +127,12 @@ void ShareSecret<U>::inputb(Processor<U>& processor,
{
int from = x.from;
int n_bits = x.n_bits;
auto& res = processor.S[x.dest];
res = input.finalize(from, n_bits).mask(n_bits);
for (int i = 0; i < DIV_CEIL(x.n_bits, dl); i++)
{
auto& res = processor.S[x.dest + i];
int n = min(dl, n_bits - i * dl);
res = input.finalize(from, n).mask(n);
}
}
}
@@ -139,7 +150,11 @@ void ShareSecret<U>::reveal_inst(Processor<U>& processor,
if (n > max(U::default_length, Clear::N_BITS))
assert(U::default_length == Clear::N_BITS);
for (int j = 0; j < DIV_CEIL(n, U::default_length); j++)
shares.push_back(processor.S[r1 + j].mask(n));
{
shares.push_back(
processor.S[r1 + j].mask(
min(U::default_length, n - j * U::default_length)));
}
}
assert(party.MC);
PointerVector<typename U::open_type> opened;
@@ -149,7 +164,10 @@ void ShareSecret<U>::reveal_inst(Processor<U>& processor,
int n = args[i];
int r0 = args[i + 1];
for (int j = 0; j < DIV_CEIL(n, U::default_length); j++)
processor.C[r0 + j] = opened.next().mask(n);
{
processor.C[r0 + j] = opened.next().mask(
min(U::default_length, n - j * U::default_length));
}
}
}
@@ -180,12 +198,22 @@ void ReplicatedSecret<U>::trans(Processor<U>& processor,
assert(length == 2);
for (int k = 0; k < 2; k++)
{
square64 square;
for (size_t i = n_outputs; i < args.size(); i++)
square.rows[i - n_outputs] = processor.S[args[i]][k].get();
square.transpose(args.size() - n_outputs, n_outputs);
for (int i = 0; i < n_outputs; i++)
processor.S[args[i]][k] = square.rows[i];
for (int j = 0; j < DIV_CEIL(n_outputs, N_BITS); j++)
for (int l = 0; l < DIV_CEIL(args.size() - n_outputs, N_BITS); l++)
{
square64 square;
size_t input_base = n_outputs + l * N_BITS;
for (size_t i = input_base; i < min(input_base + N_BITS, args.size()); i++)
square.rows[i - input_base] = processor.S[args[i] + j][k].get();
square.transpose(
min(size_t(N_BITS), args.size() - n_outputs - l * N_BITS),
min(N_BITS, n_outputs - j * N_BITS));
int output_base = j * N_BITS;
for (int i = output_base; i < min(n_outputs, output_base + N_BITS); i++)
{
processor.S[args[i] + l][k] = square.rows[i - output_base];
}
}
}
}

View File

@@ -75,7 +75,9 @@ void ShareThread<T>::post_run()
{
MC->Check(*this->P);
#ifndef INSECURE
#ifdef VERBOSE
cerr << "Removing used pre-processed data" << endl;
#endif
DataF.prune();
#endif
}

View File

@@ -90,18 +90,23 @@ void Thread<T>::finish()
pthread_join(thread, 0);
}
template<class T>
int GC::Thread<T>::n_interactive_inputs_from_me(InputArgList& args)
int Thread<T>::n_interactive_inputs_from_me(InputArgList& args)
{
return args.n_interactive_inputs_from_me(P->my_num());
}
} /* namespace GC */
inline int InputArgList::n_interactive_inputs_from_me(int my_num)
{
int res = 0;
if (thread_num == 0 and master.opts.interactive)
res = args.n_inputs_from(P->my_num());
if (ArithmeticProcessor().use_stdin())
res = n_inputs_from(my_num);
if (res > 0)
cout << "Please enter " << res << " numbers:" << endl;
return res;
}
} /* namespace GC */
#endif

View File

@@ -218,6 +218,9 @@ public:
}
};
template<class U>
const int VectorSecret<U>::default_length;
template<class T>
inline VectorSecret<T> operator*(const BitVec& clear, const VectorSecret<T>& share)
{

View File

@@ -30,12 +30,11 @@
#define IMM instruction.get_n()
#define EXTRA instruction.get_start()
#define SIZE instruction.get_size()
#define MSD processor.memories.MS[IMM]
#define MMC processor.memories.MC[IMM]
#define MMS processor.memories.MS
#define MMC processor.memories.MC
#define MID MACH->MI[IMM]
#define MSI processor.memories.MS[PI1.get()]
#define MII MACH->MI[PI1.get()]
#define BIT_INSTRUCTIONS \
@@ -44,7 +43,6 @@
X(XORCBI, C0.xor_(PC1, IMM)) \
X(ANDRS, T::andrs(PROC, EXTRA)) \
X(ANDS, T::ands(PROC, EXTRA)) \
X(INPUTB, T::inputb(PROC, EXTRA)) \
X(ADDCB, C0 = PC1 + PC2) \
X(ADDCBI, C0 = PC1 + IMM) \
X(MULCBI, C0 = PC1 * IMM) \
@@ -54,24 +52,25 @@
X(SHRCBI, C0 = PC1 >> IMM) \
X(SHLCBI, C0 = PC1 << IMM) \
X(LDBITS, S0.load_clear(REG1, IMM)) \
X(LDMSB, S0 = MSD) \
X(STMSB, MSD = S0) \
X(LDMCB, C0 = MMC) \
X(STMCB, MMC = C0) \
X(LDMSB, PROC.mem_op(SIZE, PROC.S, MMS, R0, IMM)) \
X(STMSB, PROC.mem_op(SIZE, MMS, PROC.S, IMM, R0)) \
X(LDMCB, PROC.mem_op(SIZE, PROC.C, MMC, R0, IMM)) \
X(STMCB, PROC.mem_op(SIZE, MMC, PROC.C, IMM, R0)) \
X(LDMSBI, PROC.mem_op(SIZE, PROC.S, MMS, R0, Ci[REG1])) \
X(STMSBI, PROC.mem_op(SIZE, MMS, PROC.S, Ci[REG1], R0)) \
X(MOVSB, S0 = PS1) \
X(TRANS, T::trans(PROC, IMM, EXTRA)) \
X(BITB, PROC.random_bit(S0)) \
X(REVEAL, T::reveal_inst(PROC, EXTRA)) \
X(PRINTREGSIGNED, PROC.print_reg_signed(IMM, C0)) \
X(PRINTREGB, PROC.print_reg(R0, IMM)) \
X(PRINTREGSIGNED, PROC.print_reg_signed(IMM, R0)) \
X(PRINTREGB, PROC.print_reg(R0, IMM, SIZE)) \
X(PRINTREGPLAINB, PROC.print_reg_plain(C0)) \
X(PRINTFLOATPLAINB, PROC.print_float(EXTRA)) \
X(CONDPRINTSTRB, if(C0.get()) PROC.print_str(IMM)) \
#define COMBI_INSTRUCTIONS BIT_INSTRUCTIONS \
X(INPUTB, T::inputb(PROC, Proc, EXTRA)) \
X(ANDM, processor.andm(instruction)) \
X(LDMSBI, S0 = processor.memories.MS[Proc.read_Ci(REG1)]) \
X(STMSBI, processor.memories.MS[Proc.read_Ci(REG1)] = S0) \
X(CONVSINT, S0.load_clear(IMM, Proc.read_Ci(REG1))) \
X(CONVCINT, C0 = Proc.read_Ci(REG1)) \
X(CONVCBIT, Proc.write_Ci(R0, PC1.get())) \
@@ -84,8 +83,7 @@
X(SPLIT, Proc.split(INST)) \
#define GC_INSTRUCTIONS \
X(LDMSBI, S0 = MSI) \
X(STMSBI, MSI = S0) \
X(INPUTB, T::inputb(PROC, EXTRA)) \
X(LDMSD, PROC.load_dynamic_direct(EXTRA, MD)) \
X(STMSD, PROC.store_dynamic_direct(EXTRA, MD)) \
X(LDMSDI, PROC.load_dynamic_indirect(EXTRA, MD)) \

View File

@@ -7,6 +7,7 @@
#include "Tools/cpu_support.h"
#include <stdexcept>
#include <iostream>
#include <assert.h>
using namespace std;
union matrix32x8
@@ -98,6 +99,9 @@ void square64::transpose(int n_rows, int n_cols)
print();
#endif
assert(n_rows <= 64);
assert(n_cols <= 64);
square64 tmp = *this;
*this = {};

View File

@@ -3,9 +3,7 @@
#include "Math/Setup.h"
#include "Protocols/Share.h"
#include "Tools/ezOptionParser.h"
#include "Tools/Config.h"
#include "Networking/Server.h"
#include "GC/TinierSecret.h"
#include <iostream>
#include <map>

View File

@@ -59,7 +59,7 @@ offline: $(OT_EXE) Check-Offline.x
gen_input: gen_input_f2n.x gen_input_fp.x
externalIO: client-setup.x bankers-bonus-client.x bankers-bonus-commsec-client.x
externalIO: bankers-bonus-client.x
bmr: bmr-program-party.x bmr-program-tparty.x
@@ -134,9 +134,6 @@ bmr-clean:
bankers-bonus-client.x: ExternalIO/bankers-bonus-client.cpp $(COMMON)
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
bankers-bonus-commsec-client.x: ExternalIO/bankers-bonus-commsec-client.cpp $(COMMON)
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
simple-offline.x: $(FHEOFFLINE)
pairwise-offline.x: $(FHEOFFLINE)
cnc-offline.x: $(FHEOFFLINE)
@@ -195,6 +192,9 @@ OT/BaseOT.o: SimpleOT/Makefile
SimpleOT/Makefile:
git submodule update --init SimpleOT
Programs/Circuits:
git submodule update --init Programs/Circuits
.PHONY: mpir-setup mpir-global mpir
mpir-setup:
git submodule update --init mpir

View File

@@ -73,6 +73,8 @@ public:
static void reqbl(int n);
static bool allows(Dtype dtype);
static void specification(octetStream& os);
typedef Z2 next;
typedef Z2 Scalar;
@@ -254,12 +256,18 @@ public:
}
template<int L>
SignedZ2<K + L> operator*(const Z2<L>& other) const
SignedZ2<K + L> operator*(const SignedZ2<L>& other) const
{
assert((K % 64 == 0) and (L % 64 == 0));
return Z2<K>::operator*(other);
}
template<int L>
Z2<K + L> operator*(const Z2<L>& other) const
{
return Z2<K>::operator*(other);
}
SignedZ2<K> operator*(int other) const
{
return operator*(SignedZ2<64>(other));

View File

@@ -36,6 +36,12 @@ bool Z2<K>::allows(Dtype dtype)
return Integer::allows(dtype);
}
template<int K>
void Z2<K>::specification(octetStream& os)
{
os.store(K);
}
template<int K>
Z2<K>::Z2(const bigint& x) : Z2()
{

View File

@@ -205,6 +205,12 @@ bool gfp_<X, L>::allows(Dtype type)
}
}
template<int X, int L>
void gfp_<X, L>::specification(octetStream& os)
{
os.store(pr());
}
void to_signed_bigint(bigint& ans, const gfp& x)
{
to_bigint(ans, x);

View File

@@ -92,6 +92,8 @@ class gfp_
static bool allows(Dtype type);
static void specification(octetStream& os);
static const bool invertible = true;
static gfp_ Mul(gfp_ a, gfp_ b) { return a * b; }

View File

@@ -61,6 +61,9 @@ class modp_
void pack(octetStream& o,const Zp_Data& ZpD) const;
void unpack(octetStream& o,const Zp_Data& ZpD);
void pack(octetStream& o) const;
void unpack(octetStream& o);
bool operator==(const modp_& other) const { return 0 == mpn_cmp(x, other.x, L); }
bool operator!=(const modp_& other) const { return not (*this == other); }

View File

@@ -20,6 +20,17 @@ void modp_<L>::unpack(octetStream& o,const Zp_Data& ZpD)
o.consume((octet*) x,ZpD.t*sizeof(mp_limb_t));
}
template<int L>
void modp_<L>::unpack(octetStream& o)
{
o.consume((octet*) x,L*sizeof(mp_limb_t));
}
template<int L>
void modp_<L>::pack(octetStream& o) const
{
o.append((octet*) x,L*sizeof(mp_limb_t));
}
template<int L>
void Negate(modp_<L>& ans,const modp_<L>& x,const Zp_Data& ZpD)

View File

@@ -14,29 +14,19 @@ void check_ssl_file(string filename)
"You can use `Scripts/setup-ssl.sh <nparties>`.");
}
void ssl_error(string side, string pronoun, int other, int server)
void ssl_error(string side, string pronoun, string other, string server)
{
cerr << side << "-side handshake with party " << other
cerr << side << "-side handshake with " << other
<< " failed. Make sure " << pronoun
<< " have the necessary certificate (" << PREP_DIR "P" << server
<< " have the necessary certificate (" << PREP_DIR << server
<< ".pem in the default configuration),"
<< " and run `c_rehash <directory>` on its location." << endl;
}
CryptoPlayer::CryptoPlayer(const Names& Nms, int id_base) :
MultiPlayer<ssl_socket*>(Nms, id_base), plaintext_player(Nms, id_base),
ctx(boost::asio::ssl::context::tlsv12)
ctx("P" + to_string(my_num()))
{
string prefix = PREP_DIR "P" + to_string(my_num());
string cert_file = prefix + ".pem";
string key_file = prefix + ".key";
check_ssl_file(cert_file);
check_ssl_file(key_file);
ctx.use_certificate_file(cert_file, ctx.pem);
ctx.use_private_key_file(key_file, ctx.pem);
ctx.add_verify_path("Player-Data");
sockets.resize(num_players());
senders.resize(num_players());
@@ -49,30 +39,8 @@ CryptoPlayer::CryptoPlayer(const Names& Nms, int id_base) :
continue;
}
sockets[i] = new ssl_socket(io_service, ctx);
sockets[i]->lowest_layer().assign(boost::asio::ip::tcp::v4(), plaintext_player.socket(i));
sockets[i]->set_verify_mode(boost::asio::ssl::verify_peer);
sockets[i]->set_verify_callback(boost::asio::ssl::rfc2818_verification("P" + to_string(i)));
if (i < my_num())
try
{
sockets[i]->handshake(ssl_socket::client);
}
catch (...)
{
ssl_error("Client", "we", i, i);
throw;
}
if (i > my_num())
try
{
sockets[i]->handshake(ssl_socket::server);
}
catch (...)
{
ssl_error("Server", "they", i, my_num());
throw;
}
sockets[i] = new ssl_socket(io_service, ctx, plaintext_player.socket(i),
"P" + to_string(i), "P" + to_string(my_num()), i < my_num());
senders[i] = new Sender<ssl_socket*>(sockets[i]);
}

View File

@@ -15,7 +15,7 @@
class CryptoPlayer : public MultiPlayer<ssl_socket*>
{
PlainPlayer plaintext_player;
boost::asio::ssl::context ctx;
ssl_ctx ctx;
boost::asio::io_service io_service;
vector<Sender<ssl_socket*>*> senders;

View File

@@ -2,7 +2,6 @@
#include "Player.h"
#include "ssl_sockets.h"
#include "Exceptions/Exceptions.h"
#include "Networking/STS.h"
#include "Tools/int.h"
#include "Tools/NetworkOptions.h"
#include "Networking/Server.h"

View File

@@ -1,228 +0,0 @@
#include "Networking/STS.h"
#include <sodium.h>
#include <string>
#include <string.h>
#include <unistd.h>
#include <stdio.h>
#include <iomanip>
#include <fcntl.h>
void STS::kdf_block(unsigned char *block)
{
crypto_hash_sha512_state state;
crypto_hash_sha512_init(&state);
unsigned char ctrbytes[sizeof kdf_counter];
kdf_counter++;
// Little endian serialization
for(size_t i=0; i<sizeof(kdf_counter); i++) {
ctrbytes[i] = (unsigned char)((kdf_counter >> i*8) & 0xFF);
}
crypto_hash_sha512_update(&state,ctrbytes,sizeof ctrbytes);
crypto_hash_sha512_update(&state,raw_secret,crypto_hash_sha512_BYTES);
crypto_hash_sha512_final(&state, block);
}
vector<unsigned char> STS::unsafe_derive_secret(size_t sz)
{
// KDF ~ H(cnt || raw_secret)
vector<unsigned char> resultSecret(sz + crypto_hash_sha512_BYTES - (sz % crypto_hash_sha512_BYTES));
size_t total=0;
while(total < sz) {
unsigned char *block = &resultSecret[total];
kdf_block(block);
total += crypto_hash_sha512_BYTES;
}
return resultSecret;
}
STS::STS()
{
phase = UNDEFINED;
}
void STS::init( const unsigned char theirPub[crypto_sign_PUBLICKEYBYTES]
, const unsigned char myPub[crypto_sign_PUBLICKEYBYTES]
, const unsigned char myPriv[crypto_sign_SECRETKEYBYTES])
{
phase = UNKNOWN;
memcpy(their_public_sign_key, theirPub, crypto_sign_PUBLICKEYBYTES);
memcpy(my_public_sign_key, myPub, crypto_sign_PUBLICKEYBYTES);
memcpy(my_private_sign_key, myPriv, crypto_sign_SECRETKEYBYTES);
memset(their_ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES);
memset(ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES);
memset(ephemeral_private_key, 0, crypto_box_SECRETKEYBYTES);
kdf_counter = 0;
}
STS::STS( const unsigned char theirPub[crypto_sign_PUBLICKEYBYTES]
, const unsigned char myPub[crypto_sign_PUBLICKEYBYTES]
, const unsigned char myPriv[crypto_sign_SECRETKEYBYTES])
{
phase = UNKNOWN;
memcpy(their_public_sign_key, theirPub, crypto_sign_PUBLICKEYBYTES);
memcpy(my_public_sign_key, myPub, crypto_sign_PUBLICKEYBYTES);
memcpy(my_private_sign_key, myPriv, crypto_sign_SECRETKEYBYTES);
memset(their_ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES);
memset(ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES);
memset(ephemeral_private_key, 0, crypto_box_SECRETKEYBYTES);
kdf_counter = 0;
}
STS::~STS()
{
memset(their_public_sign_key, 0, crypto_sign_PUBLICKEYBYTES);
memset(my_private_sign_key, 0, crypto_sign_SECRETKEYBYTES);
memset(ephemeral_private_key, 0, crypto_box_SECRETKEYBYTES);
memset(ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES);
memset(their_ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES);
memset(raw_secret, 0, crypto_hash_sha512_BYTES);
kdf_counter = 0;
phase = UNKNOWN;
}
sts_msg1_t STS::send_msg1()
{
sts_msg1_t m;
if(UNKNOWN != phase) {
throw "STS BAD PHASE";
}
crypto_box_keypair(ephemeral_public_key, ephemeral_private_key);
memcpy(m.bytes,ephemeral_public_key,crypto_box_PUBLICKEYBYTES);
phase = SENT1;
return m;
}
// If the incoming signature is valid, compute:
// shared secret = H(DH(pubB,privA) || pubA || pubB)
// msg = Sign_{privED-A} (pubA || pubB )
//
sts_msg3_t STS::recv_msg2(sts_msg2_t msg2)
{
unsigned char *theirPublicKey = msg2.pubkey;
unsigned char *theirSig = msg2.sig;
unsigned char theirSigDec[crypto_sign_BYTES];
unsigned char scalar_result[crypto_scalarmult_SCALARBYTES];
const unsigned char zeroNonce[crypto_stream_NONCEBYTES] = {0};
int ret;
crypto_hash_sha512_state state;
sts_msg3_t msg;
if(SENT1 != phase) {
throw "STS BAD PHASE";
}
ret = crypto_scalarmult(scalar_result, ephemeral_private_key, theirPublicKey);
if(0 != ret) {
throw "crypto_scalarmult failed";
}
crypto_hash_sha512_init(&state);
crypto_hash_sha512_update(&state,scalar_result,crypto_scalarmult_SCALARBYTES);
crypto_hash_sha512_update(&state,ephemeral_public_key,crypto_box_PUBLICKEYBYTES);
crypto_hash_sha512_update(&state,theirPublicKey,crypto_box_PUBLICKEYBYTES);
crypto_hash_sha512_final(&state,raw_secret);
vector<unsigned char> keKey = unsafe_derive_secret(crypto_stream_KEYBYTES);
vector<unsigned char> expectedMessage;
expectedMessage.insert(expectedMessage.end(), theirPublicKey , theirPublicKey + crypto_box_PUBLICKEYBYTES);
expectedMessage.insert(expectedMessage.end(), ephemeral_public_key, ephemeral_public_key + crypto_box_PUBLICKEYBYTES);
crypto_stream_xor(theirSigDec, theirSig, crypto_sign_BYTES, zeroNonce, &keKey[0]);
int badSig = crypto_sign_verify_detached(theirSigDec, &expectedMessage[0], expectedMessage.size(), their_public_sign_key);
if(badSig) {
throw "Bad signature received in message 2.";
} else {
unsigned char *mySigEnc = msg.bytes;
unsigned char mySig[crypto_sign_BYTES];
vector<unsigned char> signMessage;
signMessage.insert(signMessage.end(), ephemeral_public_key, ephemeral_public_key + crypto_box_PUBLICKEYBYTES);
signMessage.insert(signMessage.end(), theirPublicKey , theirPublicKey + crypto_box_PUBLICKEYBYTES);
if(0 != crypto_sign_detached(mySig, NULL, &signMessage[0], signMessage.size(), my_private_sign_key)) {
throw "Signing failed.";
}
vector<unsigned char> keKey2 = unsafe_derive_secret(crypto_stream_KEYBYTES);
crypto_stream_xor(mySigEnc, mySig, crypto_sign_BYTES, zeroNonce, &keKey2[0]);
phase = FINISHED;
return msg;
}
}
sts_msg2_t STS::recv_msg1(sts_msg1_t msg1)
{
unsigned char *theirPublicKey = msg1.bytes;
unsigned char scalar_result[crypto_scalarmult_SCALARBYTES];
crypto_hash_sha512_state state;
sts_msg2_t m;
int ret;
if(UNKNOWN != phase) {
throw "recv_msg1 called on non-unknown phase";
}
memcpy(their_ephemeral_public_key, theirPublicKey, crypto_box_PUBLICKEYBYTES);
crypto_box_keypair(ephemeral_public_key, ephemeral_private_key);
memcpy(m.pubkey,ephemeral_public_key,crypto_box_PUBLICKEYBYTES);
ret = crypto_scalarmult(scalar_result, ephemeral_private_key, theirPublicKey);
if(0 != ret) {
throw "crypto_scalarmult failed when processing message 1";
}
crypto_hash_sha512_init(&state);
crypto_hash_sha512_update(&state,scalar_result,crypto_scalarmult_SCALARBYTES);
crypto_hash_sha512_update(&state,theirPublicKey,crypto_box_PUBLICKEYBYTES);
crypto_hash_sha512_update(&state,ephemeral_public_key,crypto_box_PUBLICKEYBYTES);
crypto_hash_sha512_final(&state,raw_secret);
vector<unsigned char> livenessProof;
livenessProof.insert(livenessProof.end(), ephemeral_public_key, ephemeral_public_key + crypto_box_PUBLICKEYBYTES);
livenessProof.insert(livenessProof.end(), theirPublicKey , theirPublicKey + crypto_box_PUBLICKEYBYTES);
unsigned char mySig[crypto_sign_BYTES];
unsigned char *mySigEnc = m.sig;
vector<unsigned char> keKey = unsafe_derive_secret(crypto_stream_KEYBYTES);
unsigned char zeroNonce[crypto_stream_NONCEBYTES] = {0};
if(0 != crypto_sign_detached(mySig, NULL, &livenessProof[0], livenessProof.size(), my_private_sign_key)) {
throw "Signing failed.";
}
crypto_stream_xor(mySigEnc, mySig, crypto_sign_BYTES, zeroNonce, &keKey[0]);
phase = SENT2;
return m;
}
void STS::recv_msg3(sts_msg3_t msg3)
{
unsigned char *theirSig=msg3.bytes;
unsigned char theirSigDec[crypto_sign_BYTES];
vector<unsigned char> expectedMessage;
if(SENT2 != phase) {
throw "recv_msg3 called out of order";
}
expectedMessage.insert(expectedMessage.end(), their_ephemeral_public_key , their_ephemeral_public_key + crypto_box_PUBLICKEYBYTES);
expectedMessage.insert(expectedMessage.end(), ephemeral_public_key, ephemeral_public_key + crypto_box_PUBLICKEYBYTES);
unsigned char zeroNonce[crypto_stream_NONCEBYTES] = {0};
vector<unsigned char> keKey2 = unsafe_derive_secret(crypto_stream_KEYBYTES);
crypto_stream_xor(theirSigDec, theirSig, crypto_sign_BYTES, zeroNonce, &keKey2[0]);
int badSig = crypto_sign_verify_detached(theirSigDec, &expectedMessage[0], expectedMessage.size(), their_public_sign_key);
if(badSig) {
throw "Bad signature received in message 3.";
} else {
phase = FINISHED;
}
}
vector<unsigned char> STS::derive_secret(size_t sz)
{
if(phase != FINISHED) {
throw "Can not derive secrets till the key exchange has completed.";
}
return unsafe_derive_secret(sz);
}

View File

@@ -1,70 +0,0 @@
#ifndef _NETWORK_STS
#define _NETWORK_STS
/* The Station to Station protocol
*/
#include <iostream>
#include <fstream>
#include <vector>
#include <sodium.h>
using namespace std;
typedef enum
{ UNKNOWN // Have not started the interaction or have cleared the memory
, SENT1 // Sent initial message
, SENT2 // Received 1, sent 2
, FINISHED // Done (received msg 2 & sent 3 or received msg 3)
, UNDEFINED // For arrays/vectors/etc of STS classes that are initialized later.
} phase_t;
struct msg1_st {
unsigned char bytes[crypto_box_PUBLICKEYBYTES];
};
typedef struct msg1_st sts_msg1_t;
struct msg2_st {
unsigned char pubkey[crypto_box_PUBLICKEYBYTES];
unsigned char sig[crypto_sign_BYTES];
};
typedef struct msg2_st sts_msg2_t;
struct msg3_st {
unsigned char bytes[crypto_sign_BYTES];
};
typedef struct msg3_st sts_msg3_t;
class STS
{
phase_t phase;
unsigned char their_public_sign_key[crypto_sign_PUBLICKEYBYTES];
unsigned char my_public_sign_key[crypto_sign_PUBLICKEYBYTES];
unsigned char my_private_sign_key[crypto_sign_SECRETKEYBYTES];
unsigned char ephemeral_private_key[crypto_box_SECRETKEYBYTES];
unsigned char ephemeral_public_key[crypto_box_PUBLICKEYBYTES];
unsigned char their_ephemeral_public_key[crypto_box_PUBLICKEYBYTES];
unsigned char raw_secret[crypto_hash_sha512_BYTES];
uint64_t kdf_counter;
public:
STS();
STS( const unsigned char theirPub[crypto_sign_PUBLICKEYBYTES]
, const unsigned char myPub[crypto_sign_PUBLICKEYBYTES]
, const unsigned char myPriv[crypto_sign_SECRETKEYBYTES]);
~STS();
void init( const unsigned char theirPub[crypto_sign_PUBLICKEYBYTES]
, const unsigned char myPub[crypto_sign_PUBLICKEYBYTES]
, const unsigned char myPriv[crypto_sign_SECRETKEYBYTES]);
sts_msg1_t send_msg1();
sts_msg3_t recv_msg2(sts_msg2_t msg2);
sts_msg2_t recv_msg1(sts_msg1_t msg1);
void recv_msg3(sts_msg3_t msg3);
vector<unsigned char> derive_secret(size_t);
private:
vector<unsigned char> unsafe_derive_secret(size_t);
void kdf_block(unsigned char *block);
};
#endif /* _NETWORK_STS */

View File

@@ -74,13 +74,61 @@ void ServerSocket::init()
pthread_create(&thread, 0, accept_thread, this);
}
class ServerJob
{
ServerSocket& server;
int socket;
sockaddr dest;
public:
pthread_t thread;
ServerJob(ServerSocket& server, int socket, sockaddr dest) :
server(server), socket(socket), dest(dest), thread(0)
{
}
static void* run(void* job)
{
auto& server_job = *(ServerJob*)(job);
server_job.server.wait_for_client_id(server_job.socket, server_job.dest);
return 0;
}
};
ServerSocket::~ServerSocket()
{
for (auto& job : jobs)
{
pthread_cancel(job->thread);
pthread_join(job->thread, 0);
delete job;
}
pthread_cancel(thread);
pthread_join(thread, 0);
if (close(main_socket)) { error("close(main_socket"); };
}
void ServerSocket::wait_for_client_id(int socket, sockaddr dest)
{
(void) dest;
int client_id;
try
{
receive(socket, (unsigned char*) &client_id, sizeof(client_id));
process_connection(socket, client_id);
}
catch (closed_connection&)
{
#ifdef DEBUG_NETWORKING
auto& conn = *(sockaddr_in*) &dest;
fprintf(stderr, "client on %s:%d left without identification\n",
inet_ntoa(conn.sin_addr), ntohs(conn.sin_port));
#endif
}
}
void ServerSocket::accept_clients()
{
while (true)
@@ -92,25 +140,19 @@ void ServerSocket::accept_clients()
if (consocket<0) { error("set_up_socket:accept"); }
int client_id;
try
{
receive(consocket, (unsigned char*)&client_id, sizeof(client_id));
}
catch (closed_connection&)
{
if (receive_all_or_nothing(consocket, (unsigned char*)&client_id, sizeof(client_id)))
process_connection(consocket, client_id);
else
{
#ifdef DEBUG_NETWORKING
auto& conn = *(sockaddr_in*)&dest;
cerr << "client on " << inet_ntoa(conn.sin_addr) << ":"
<< ntohs(conn.sin_port) << " left without identification"
<< endl;
auto& conn = *(sockaddr_in*) &dest;
fprintf(stderr, "deferring client on %s:%d to thread\n",
inet_ntoa(conn.sin_addr), ntohs(conn.sin_port));
#endif
}
data_signal.lock();
process_client(client_id);
clients[client_id] = consocket;
data_signal.broadcast();
data_signal.unlock();
// defer to thread
jobs.push_back(new ServerJob(*this, consocket, dest));
pthread_create(&jobs.back()->thread, 0, ServerJob::run, jobs.back());
}
#ifdef __APPLE__
int flags = fcntl(consocket, F_GETFL, 0);
@@ -121,15 +163,19 @@ void ServerSocket::accept_clients()
}
}
int ServerSocket::get_connection_count()
void ServerSocket::process_connection(int consocket, int client_id)
{
data_signal.lock();
int connection_count = clients.size();
#ifdef DEBUG_NETWORKING
cerr << "client " << hex << client_id << " is on socket " << dec << consocket
<< endl;
#endif
process_client(client_id);
clients[client_id] = consocket;
data_signal.broadcast();
data_signal.unlock();
return connection_count;
}
int ServerSocket::get_connection_socket(int id)
{
data_signal.lock();
@@ -163,16 +209,10 @@ void AnonymousServerSocket::init()
pthread_create(&thread, 0, anonymous_accept_thread, this);
}
int AnonymousServerSocket::get_connection_count()
{
return num_accepted_clients;
}
void AnonymousServerSocket::process_client(int client_id)
{
if (clients.find(client_id) != clients.end())
close_client_socket(clients[client_id]);
num_accepted_clients++;
client_connection_queue.push(client_id);
}

View File

@@ -12,10 +12,13 @@
using namespace std;
#include <pthread.h>
#include <netinet/tcp.h>
#include "Tools/WaitQueue.h"
#include "Tools/Signal.h"
class ServerJob;
class ServerSocket
{
protected:
@@ -25,9 +28,13 @@ protected:
Signal data_signal;
pthread_t thread;
vector<ServerJob*> jobs;
// disable copying
ServerSocket(const ServerSocket& other);
void process_connection(int socket, int client_id);
virtual void process_client(int) {}
public:
@@ -38,14 +45,11 @@ public:
virtual void accept_clients();
void wait_for_client_id(int socket, sockaddr dest);
// This depends on clients sending their id as int.
// Has to be thread-safe.
int get_connection_socket(int number);
// How many client connections have been made.
virtual int get_connection_count();
void close_socket();
};
/*
@@ -55,18 +59,15 @@ class AnonymousServerSocket : public ServerSocket
{
private:
// No. of accepted connections in this instance
int num_accepted_clients;
queue<int> client_connection_queue;
void process_client(int client_id);
public:
AnonymousServerSocket(int Portnum) :
ServerSocket(Portnum), num_accepted_clients(0) { };
ServerSocket(Portnum) { };
void init();
virtual int get_connection_count();
// Get socket and id for the last client who connected
int get_connection_socket(int& client_id);
};

View File

@@ -12,7 +12,65 @@
#include <boost/asio.hpp>
#include <boost/asio/ssl.hpp>
typedef boost::asio::ssl::stream<boost::asio::ip::tcp::socket> ssl_socket;
typedef boost::asio::io_service ssl_service;
void check_ssl_file(string filename);
void ssl_error(string side, string pronoun, string other, string server);
class ssl_ctx : public boost::asio::ssl::context
{
public:
ssl_ctx(string me) :
boost::asio::ssl::context(boost::asio::ssl::context::tlsv12)
{
string prefix = PREP_DIR + me;
string cert_file = prefix + ".pem";
string key_file = prefix + ".key";
check_ssl_file(cert_file);
check_ssl_file(key_file);
use_certificate_file(cert_file, pem);
use_private_key_file(key_file, pem);
add_verify_path(PREP_DIR);
}
};
class ssl_socket : public boost::asio::ssl::stream<boost::asio::ip::tcp::socket>
{
typedef boost::asio::ssl::stream<boost::asio::ip::tcp::socket> parent;
public:
ssl_socket(boost::asio::io_service& io_service,
boost::asio::ssl::context& ctx, int plaintext_socket, string other,
string me, bool client) :
parent(io_service, ctx)
{
lowest_layer().assign(boost::asio::ip::tcp::v4(), plaintext_socket);
set_verify_mode(boost::asio::ssl::verify_peer);
set_verify_callback(boost::asio::ssl::rfc2818_verification(other));
if (client)
try
{
handshake(ssl_socket::client);
} catch (...)
{
ssl_error("Client", "we", other, other);
throw;
}
else
{
try
{
handshake(ssl_socket::server);
} catch (...)
{
ssl_error("Server", "they", other, me);
throw;
}
}
}
};
inline size_t send_non_blocking(ssl_socket* socket, octet* data, size_t length)
{

View File

@@ -5,41 +5,24 @@
#include <thread>
ExternalClients::ExternalClients(int party_num, const string& prep_data_dir):
party_num(party_num), prep_data_dir(prep_data_dir), server_connection_count(-1)
party_num(party_num), prep_data_dir(prep_data_dir),
ctx(0)
{
}
ExternalClients::~ExternalClients()
{
// close client sockets
for (map<int,int>::iterator it = external_client_sockets.begin();
for (auto it = external_client_sockets.begin();
it != external_client_sockets.end(); it++)
{
if (close(it->second))
{
error("failed to close external client connection socket)");
}
delete it->second;
}
for (map<int,AnonymousServerSocket*>::iterator it = client_connection_servers.begin();
it != client_connection_servers.end(); it++)
{
delete it->second;
}
for (map<int,octet*>::iterator it = symmetric_client_keys.begin();
it != symmetric_client_keys.end(); it++)
{
delete[] it->second;
}
for (map<int, pair<vector<octet>,uint64_t> >::iterator it_cs = symmetric_client_commsec_send_keys.begin();
it_cs != symmetric_client_commsec_send_keys.end(); it_cs++)
{
memset(&(it_cs->second.first[0]), 0, it_cs->second.first.size());
}
for (map<int, pair<vector<octet>,uint64_t> >::iterator it_cs = symmetric_client_commsec_recv_keys.begin();
it_cs != symmetric_client_commsec_recv_keys.end(); it_cs++)
{
memset(&(it_cs->second.first[0]), 0, it_cs->second.first.size());
}
}
void ExternalClients::start_listening(int portnum_base)
@@ -62,125 +45,21 @@ int ExternalClients::get_client_connection(int portnum_base)
cerr << "Thread " << this_thread::get_id() << " found server." << endl;
int client_id, socket;
socket = client_connection_servers[portnum_base]->get_connection_socket(client_id);
external_client_sockets[client_id] = socket;
if (symmetric_client_keys.find(client_id) != symmetric_client_keys.end())
delete symmetric_client_keys[client_id];
symmetric_client_commsec_send_keys.erase(client_id);
symmetric_client_commsec_recv_keys.erase(client_id);
if (ctx == 0)
ctx = new ssl_ctx("P" + to_string(get_party_num()));
external_client_sockets[client_id] = new ssl_socket(io_service, *ctx, socket,
"C" + to_string(client_id), "P" + to_string(get_party_num()), false);
cerr << "Party " << get_party_num() << " received external client connection from client id: " << dec << client_id << endl;
return client_id;
}
int ExternalClients::connect_to_server(int portnum_base, int ipv4_address)
{
struct in_addr addr = { (unsigned int)ipv4_address };
int csocket;
const char* address_str = inet_ntoa(addr);
cerr << "Party " << get_party_num() << " connecting to server at " << address_str << " on port " << portnum_base + get_party_num() << endl;
set_up_client_socket(csocket, address_str, portnum_base + get_party_num());
cerr << "Party " << get_party_num() << " connected to server at " << address_str << " on port " << portnum_base + get_party_num() << endl;
int server_id = server_connection_count;
// server identifiers are -1, -2, ... to avoid conflict with client identifiers
server_connection_count--;
external_client_sockets[server_id] = csocket;
return server_id;
}
void ExternalClients::curve25519_ints_to_bytes(unsigned char *bytes, const vector<int>& key_ints)
{
for(unsigned int j = 0; j < key_ints.size(); j++) {
for(unsigned int k = 0; k < 4; k++) {
bytes[j*sizeof(int) + k] = (key_ints[j] >> ((3-k)*8)) & 0xFF;
}
}
}
// Generate sesssion key for a newly connected client, store in symmetric_client_keys
// public_key is expected to be size 8 and contain integer values of public key bytes.
// Assumes load_server_keys has been run.
void ExternalClients::generate_session_key_for_client(int client_id, const vector<int>& public_key)
{
assert(public_key.size() * sizeof(int) == crypto_box_PUBLICKEYBYTES);
load_server_keys_once();
unsigned char client_publickey[crypto_box_PUBLICKEYBYTES];
curve25519_ints_to_bytes(client_publickey, public_key);
cerr << "Recevied client public key for client " << dec << client_id << " :";
for (unsigned int j = 0; j < crypto_box_PUBLICKEYBYTES; j++)
cerr << hex << (int) client_publickey[j] << " ";
cerr << dec << endl;
unsigned char scalarmult_q_by_server[crypto_scalarmult_BYTES];
crypto_generichash_state h;
symmetric_client_keys[client_id] = new octet[crypto_generichash_BYTES];
// Derive a shared key from this server's secret key and the client's public key
// shared key = h(q || server_secretkey || client_publickey)
if (crypto_scalarmult(scalarmult_q_by_server, server_secretkey, client_publickey) != 0) {
cerr << "Scalar mult failed\n";
exit(1);
}
crypto_generichash_init(&h, NULL, 0U, crypto_generichash_BYTES);
crypto_generichash_update(&h, scalarmult_q_by_server, sizeof scalarmult_q_by_server);
crypto_generichash_update(&h, client_publickey, sizeof client_publickey);
crypto_generichash_update(&h, server_publickey, sizeof server_publickey);
crypto_generichash_final(&h, symmetric_client_keys[client_id], crypto_generichash_BYTES);
}
// Read pre-computed server keys from client-setup for this SPDZ engine.
// Only needs to be done once per run, but is only necessary if an external connection
// is being requested.
void ExternalClients::load_server_keys_once()
{
if (server_keys_loaded) {
return;
}
ifstream keyfile;
stringstream filename;
filename << prep_data_dir << "Player-SPDZ-Keys-P" << get_party_num();
keyfile.open(filename.str().c_str());
if (keyfile.fail())
throw file_error(filename.str().c_str());
keyfile.read((char*)server_publickey, sizeof server_publickey);
if (keyfile.eof())
throw end_of_file(filename.str(), "server public key" );
keyfile.read((char*)server_secretkey, sizeof server_secretkey);
if (keyfile.eof())
throw end_of_file(filename.str(), "server private key" );
bool loaded_ed25519 = true;
keyfile.read((char*)server_publickey_ed25519, sizeof server_publickey_ed25519);
if (keyfile.eof() || keyfile.bad())
loaded_ed25519 = false;
keyfile.read((char*)server_secretkey_ed25519, sizeof server_secretkey_ed25519);
if (keyfile.eof() || keyfile.bad())
loaded_ed25519 = false;
keyfile.close();
ed25519_keys_loaded = loaded_ed25519;
server_keys_loaded = true;
}
void ExternalClients::require_ed25519_keys()
{
if (!ed25519_keys_loaded)
throw "Ed25519 keys required but not found in player key files";
}
int ExternalClients::get_party_num()
{
return party_num;
}
int ExternalClients::get_socket(int id)
ssl_socket* ExternalClients::get_socket(int id)
{
if (external_client_sockets.find(id) == external_client_sockets.end())
throw runtime_error("external connection not found for id " + to_string(id));

View File

@@ -2,6 +2,7 @@
#define _ExternalClients
#include "Networking/sockets.h"
#include "Networking/ssl_sockets.h"
#include "Exceptions/Exceptions.h"
#include <vector>
#include <map>
@@ -23,24 +24,15 @@ class ExternalClients
int party_num;
const string prep_data_dir;
int server_connection_count;
unsigned char server_publickey[crypto_box_PUBLICKEYBYTES];
unsigned char server_secretkey[crypto_box_SECRETKEYBYTES];
bool server_keys_loaded = false;
bool ed25519_keys_loaded = false;
// Maps holding per client values (indexed by unique 32-bit id)
std::map<int,int> external_client_sockets;
std::map<int,ssl_socket*> external_client_sockets;
ssl_service io_service;
ssl_ctx* ctx;
public:
unsigned char server_publickey_ed25519[crypto_sign_ed25519_PUBLICKEYBYTES];
unsigned char server_secretkey_ed25519[crypto_sign_ed25519_SECRETKEYBYTES];
std::map<int,octet*> symmetric_client_keys;
std::map<int,pair<vector<octet>,uint64_t>> symmetric_client_commsec_send_keys;
std::map<int,pair<vector<octet>,uint64_t>> symmetric_client_commsec_recv_keys;
ExternalClients(int party_num, const string& prep_data_dir);
~ExternalClients();
@@ -48,18 +40,10 @@ class ExternalClients
int get_client_connection(int portnum_base);
int connect_to_server(int portnum_base, int ipv4_address);
// return the socket for a given client or server identifier
int get_socket(int socket_id);
void curve25519_ints_to_bytes(unsigned char bytes[crypto_box_PUBLICKEYBYTES], const vector<int>& key_ints);
void generate_session_key_for_client(int client_id, const vector<int>& public_key);
void load_server_keys_once();
ssl_socket* get_socket(int socket_id);
int get_party_num();
void require_ed25519_keys();
};
#endif

View File

@@ -5,17 +5,18 @@
#include "FixInput.h"
const char* FixInput::NAME = "real number";
void FixInput::read(std::istream& in, const int* params)
template<>
void FixInput_<Integer>::read(std::istream& in, const int* params)
{
#ifdef LOW_PREC_INPUT
double x;
in >> x;
items[0] = x * (1 << *params);
#else
}
template<>
void FixInput_<bigint>::read(std::istream& in, const int* params)
{
mpf_class x;
in >> x;
items[0] = x << *params;
#endif
}

View File

@@ -11,7 +11,8 @@
#include "Math/bigint.h"
#include "Math/Integer.h"
class FixInput
template<class T>
class FixInput_
{
public:
const static int N_DEST = 1;
@@ -20,13 +21,18 @@ public:
const static int TYPE = 1;
#ifdef LOW_PREC_INPUT
Integer items[N_DEST];
#else
bigint items[N_DEST];
#endif
T items[N_DEST];
void read(std::istream& in, const int* params);
};
template<class T>
const char* FixInput_<T>::NAME = "real number";
#ifdef LOW_PREC_INPUT
typedef FixInput_<Integer> FixInput;
#else
typedef FixInput_<bigint> FixInput;
#endif
#endif /* PROCESSOR_FIXINPUT_H_ */

View File

@@ -57,6 +57,8 @@ public:
virtual T finalize_mine() = 0;
virtual void finalize_other(int player, T& target, octetStream& o, int n_bits = -1) = 0;
T finalize(int player, int n_bits = -1);
void raw_input(SubProcessor<T>& proc, const vector<int>& args);
};
template<class T>
@@ -88,9 +90,6 @@ public:
T finalize_mine();
void finalize_other(int player, T& target, octetStream& o, int n_bits = -1);
void start(int player, int n_inputs);
void stop(int player, const vector<int>& targets);
};
#endif /* PROCESSOR_INPUT_H_ */

View File

@@ -100,7 +100,7 @@ template<class T>
void Input<T>::add_other(int player)
{
open_type t;
shares[player].push_back({});
shares.at(player).push_back({});
prep.get_input(shares[player].back(), t, player);
}
@@ -131,12 +131,16 @@ void InputBase<T>::exchange()
}
template<class T>
void Input<T>::start(int player, int n_inputs)
void InputBase<T>::raw_input(SubProcessor<T>& proc, const vector<int>& args)
{
reset(player);
if (player == P.my_num())
auto& P = proc.P;
reset_all(P);
for (auto it = args.begin(); it != args.end();)
{
for (int i = 0; i < n_inputs; i++)
int player = *it++;
it++;
if (player == P.my_num())
{
clear t;
try
@@ -149,33 +153,21 @@ void Input<T>::start(int player, int n_inputs)
}
add_mine(t);
}
send_mine();
}
else
{
for (int i = 0; i < n_inputs; i++)
add_other(player);
}
}
template<class T>
void Input<T>::stop(int player, const vector<int>& targets)
{
assert(proc != 0);
if (P.my_num() == player)
for (unsigned int i = 0; i < targets.size(); i++)
proc->get_S_ref(targets[i]) = finalize_mine();
else
{
octetStream o;
this->timer.start();
P.receive_player(player, o, true);
this->timer.stop();
for (unsigned int i = 0; i < targets.size(); i++)
else
{
finalize_other(player, proc->get_S_ref(targets[i]), o);
add_other(player);
}
}
timer.start();
exchange();
timer.stop();
for (auto it = args.begin(); it != args.end();)
{
int player = *it++;
proc.get_S_ref(*it++) = finalize(player);
}
}
template<class T>

View File

@@ -115,6 +115,7 @@ enum
INPUTFLOAT = 0xF1,
INPUTMIXED = 0xF2,
INPUTMIXEDREG = 0xF3,
RAWINPUT = 0xF4,
STARTINPUT = 0x61,
STOPINPUT = 0x62,
READSOCKETC = 0x63,
@@ -248,6 +249,7 @@ enum
GSTOPINPUT = 0x162,
GREADSOCKETS = 0x164,
GWRITESOCKETS = 0x166,
GRAWINPUT = 0x1F4,
// Bitwise logic
GANDC = 0x170,
GXORC = 0x171,
@@ -328,6 +330,8 @@ public:
int get_opcode() const { return opcode; }
int get_size() const { return size; }
// Reads a single instruction from the istream
void parse(istream& s, int inst_pos);
void parse_operands(istream& s, int pos, int file_pos);
bool is_gf2n_instruction() const { return ((opcode&0x100)!=0); }
@@ -347,9 +351,6 @@ class DataPositions;
class Instruction : public BaseInstruction
{
public:
// Reads a single instruction from the istream
void parse(istream& s, int inst_pos);
// Return whether usage is known
bool get_offline_data_usage(DataPositions& usage);
@@ -361,6 +362,5 @@ public:
void execute(Processor<sint, sgf2n>& Proc) const;
};
#endif

View File

@@ -34,7 +34,7 @@
#include "Tools/callgrind.h"
inline
void Instruction::parse(istream& s, int inst_pos)
void BaseInstruction::parse(istream& s, int inst_pos)
{
n=0; start.resize(0);
r[0]=0; r[1]=0; r[2]=0; r[3]=0;
@@ -224,7 +224,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case STARTPRIVATEOUTPUT:
case GSTARTPRIVATEOUTPUT:
case DIGESTC:
case CONNECTIPV4: // write socket handle, read IPv4 address, portnum
r[0]=get_int(s);
r[1]=get_int(s);
n = get_int(s);
@@ -254,8 +253,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case PRINTREGB:
case GPRINTREG:
case LDINT:
case STARTINPUT:
case GSTARTINPUT:
case STOPPRIVATEOUTPUT:
case GSTOPPRIVATEOUTPUT:
case INPUTMASK:
@@ -310,6 +307,8 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case INPUTFLOAT:
case INPUTMIXED:
case INPUTMIXEDREG:
case RAWINPUT:
case GRAWINPUT:
case TRUNC_PR:
num_var_args = get_int(s);
get_vector(num_var_args, start, s);
@@ -336,7 +335,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case READSOCKETC:
case READSOCKETS:
case READSOCKETINT:
case READCLIENTPUBLICKEY:
num_var_args = get_int(s) - 1;
r[0] = get_int(s);
get_vector(num_var_args, start, s);
@@ -352,20 +350,18 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
r[1] = get_int(s);
get_vector(num_var_args, start, s);
break;
case CONNECTIPV4:
throw runtime_error("parties as clients not supported any more");
case READCLIENTPUBLICKEY:
case INITSECURESOCKET:
case RESPSECURESOCKET:
num_var_args = get_int(s) - 1;
r[0] = get_int(s);
get_vector(num_var_args, start, s);
break;
throw runtime_error("VM-controlled encryption not supported any more");
// raw input
case STARTINPUT:
case GSTARTINPUT:
case STOPINPUT:
case GSTOPINPUT:
// subtract player number argument
num_var_args = get_int(s) - 1;
n = get_int(s);
get_vector(num_var_args, start, s);
break;
throw runtime_error("two-stage input not supported any more");
case GBITDEC:
case GBITCOM:
num_var_args = get_int(s) - 2;
@@ -621,6 +617,7 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
case INPUTB:
skip = 4;
offset = 3;
size_offset = -2;
break;
case ANDM:
size = DIV_CEIL(n, 64);
@@ -733,6 +730,7 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
typedef typename sint::bit_type T;
auto& processor = Proc.Procb;
auto& instruction = *this;
auto& Ci = Proc.get_Ci();
// optimize some instructions
switch (opcode)
@@ -1292,17 +1290,11 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
case INPUTMIXEDREG:
sint::Input::input_mixed(Proc.Procp, start, size, true);
return;
case STARTINPUT:
Proc.Procp.input.start(r[0],n);
case RAWINPUT:
Proc.Procp.input.raw_input(Proc.Procp, start);
break;
case GSTARTINPUT:
Proc.Proc2.input.start(r[0],n);
break;
case STOPINPUT:
Proc.Procp.input.stop(n,start);
break;
case GSTOPINPUT:
Proc.Proc2.input.stop(n,start);
case GRAWINPUT:
Proc.Proc2.input.raw_input(Proc.Proc2, start);
break;
case ANDC:
Proc.get_Cp_ref(r[0]).AND(Proc.read_Cp(r[1]),Proc.read_Cp(r[2]));
@@ -1670,26 +1662,16 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
ss << "No connection on port " << r[0] << endl;
throw Processor_Error(ss.str());
}
if (Proc.P.my_num() == 0)
{
octetStream os;
os.store(int(sint::open_type::type_char()));
sint::open_type::specification(os);
os.Send(Proc.external_clients.get_socket(client_handle));
}
Proc.write_Ci(r[0], client_handle);
break;
}
case CONNECTIPV4:
{
// connect to server at port n + my_num()
int ipv4 = Proc.read_Ci(r[1]);
int server_handle = Proc.external_clients.connect_to_server(n, ipv4);
Proc.write_Ci(r[0], server_handle);
break;
}
case READCLIENTPUBLICKEY:
Proc.read_client_public_key(Proc.read_Ci(r[0]), start);
break;
case INITSECURESOCKET:
Proc.init_secure_socket(Proc.read_Ci(r[i]), start);
break;
case RESPSECURESOCKET:
Proc.resp_secure_socket(Proc.read_Ci(r[i]), start);
break;
case READSOCKETINT:
Proc.read_socket_ints(Proc.read_Ci(r[0]), start);
break;

View File

@@ -35,7 +35,6 @@ Machine<sint, sgf2n>::Machine(int my_number, Names& playerNames,
// Set up the fields
prep_dir_prefix = get_prep_dir(N.num_players(), opts.lgp, lg2);
char filename[2048];
bool read_mac_keys = false;
sgf2n::clear::init_field(lg2);
@@ -96,16 +95,7 @@ Machine<sint, sgf2n>::Machine(int my_number, Names& playerNames,
sint::clear::next::template init<typename sint::clear>(false);
// Initialize the global memory
if (memtype.compare("new")==0)
{sprintf(filename, PREP_DIR "Player-Memory-P%d", my_number);
ifstream memfile(filename);
if (memfile.fail()) { throw file_error(filename); }
M2.Load_Memory(memfile);
Mp.Load_Memory(memfile);
Mi.Load_Memory(memfile);
memfile.close();
}
else if (memtype.compare("old")==0)
if (memtype.compare("old")==0)
{
inpf.open(memory_filename(), ios::in | ios::binary);
if (inpf.fail()) { throw file_error(memory_filename()); }

View File

@@ -77,21 +77,6 @@ class Memory
friend ostream& operator<< <>(ostream& s,const Memory<T>& M);
friend istream& operator>> <>(istream& s,Memory<T>& M);
/* This function loads a un-shared global memory from disk and
* produces the memory
*
* The global unshared memory is of the form
* sz <- Size
* n val <- Clear values
* n val <- Clear values
* -1 -1 <- End of clear values
* n val <- Shared values
* n val <- Shared values
* -1 -1
*/
void Load_Memory(ifstream& inpf);
};
#endif

View File

@@ -106,40 +106,3 @@ istream& operator>>(istream& s,Memory<T>& M)
return s;
}
template<class T>
void Memory<T>::Load_Memory(ifstream& inpf)
{
Memory<T>& M = *this;
int a;
typename T::clear val;
T S;
inpf >> a;
M.resize_s(a);
inpf >> a;
M.resize_c(a);
cerr << "Reading Clear Memory" << endl;
// Read clear memory
inpf >> a;
val.input(inpf,true);
while (a!=-1)
{ M.write_C(a,val);
inpf >> a;
val.input(inpf,true);
}
cerr << "Reading Shared Memory" << endl;
// Read shared memory
inpf >> a;
S.input(inpf,true);
while (a!=-1)
{ M.write_S(a,S);
inpf >> a;
S.input(inpf,true);
}
}

View File

@@ -112,6 +112,10 @@ public:
OnlineOptions opts;
ArithmeticProcessor() :
ArithmeticProcessor(OnlineOptions::singleton, BaseMachine::thread_num)
{
}
ArithmeticProcessor(OnlineOptions opts, int thread_num) : thread_num(thread_num),
sent(0), rounds(0), opts(opts) {}
@@ -217,12 +221,6 @@ class Processor : public ArithmeticProcessor
// Access to external client sockets for reading clear/shared data
void read_socket_ints(int client_id, const vector<int>& registers);
// Setup client public key
void read_client_public_key(int client_id, const vector<int>& registers);
void init_secure_socket(int client_id, const vector<int>& registers);
void init_secure_socket_internal(int client_id, const vector<int>& registers);
void resp_secure_socket(int client_id, const vector<int>& registers);
void resp_secure_socket_internal(int client_id, const vector<int>& registers);
void write_socket(const RegType reg_type, const SecrecyType secrecy_type, const bool send_macs,
int socket_id, int message_type, const vector<int>& registers);
@@ -239,8 +237,6 @@ class Processor : public ArithmeticProcessor
friend ostream& operator<<(ostream& s,const Processor<T, U>& P);
private:
void maybe_decrypt_sequence(int client_id);
void maybe_encrypt_sequence(int client_id);
template<class T> friend class SPDZ;
template<class T> friend class SubProcessor;

View File

@@ -3,7 +3,6 @@
#include "Processor/Processor.h"
#include "Processor/Program.h"
#include "Networking/STS.h"
#include "Protocols/fake-stuff.h"
#include "GC/square64.h"
@@ -68,7 +67,7 @@ Processor<sint, sgf2n>::Processor(int thread_num,Player& P,
Procb(machine.bit_memories),
Proc2(*this,MC2,DataF.DataF2,P),Procp(*this,MCp,DataF.DataFp,P),
privateOutput2(Proc2),privateOutputp(Procp),
external_clients(ExternalClients(P.my_num(), machine.prep_dir_prefix)),
external_clients(P.my_num(), machine.prep_dir_prefix),
binary_file_io(Binary_File_IO())
{
reset(program,0);
@@ -222,7 +221,6 @@ void Processor<sint, sgf2n>::split(const Instruction& instruction)
// RegType and SecrecyType determines how registers are read and the socket stream is packed.
// If message_type is > 0, send message_type in bytes 0 - 3, to allow an external client to
// determine the data structure being sent in a message.
// Encryption is enabled if key material (for DH Auth Encryption and/or STS protocol) has been already setup.
template<class sint, class sgf2n>
void Processor<sint, sgf2n>::write_socket(const RegType reg_type, const SecrecyType secrecy_type, const bool send_macs,
int socket_id, int message_type, const vector<int>& registers)
@@ -239,7 +237,11 @@ void Processor<sint, sgf2n>::write_socket(const RegType reg_type, const SecrecyT
{
if (reg_type == MODP && secrecy_type == SECRET) {
// Send vector of secret shares and optionally macs
get_Sp_ref(registers[i]).pack(socket_stream, send_macs);
if (send_macs)
get_Sp_ref(registers[i]).pack(socket_stream);
else
get_Sp_ref(registers[i]).pack(socket_stream,
sint::get_rec_factor(P.my_num(), P.num_players()));
}
else if (reg_type == MODP && secrecy_type == CLEAR) {
// Send vector of clear public field elements
@@ -257,15 +259,7 @@ void Processor<sint, sgf2n>::write_socket(const RegType reg_type, const SecrecyT
}
}
// Apply DH Auth encryption if session keys have been created.
map<int,octet*>::iterator it = external_clients.symmetric_client_keys.find(socket_id);
if (it != external_clients.symmetric_client_keys.end()) {
socket_stream.encrypt(it->second);
}
// Apply STS commsec encryption if session keys have been created.
try {
maybe_encrypt_sequence(socket_id);
socket_stream.Send(external_clients.get_socket(socket_id));
}
catch (bad_value& e) {
@@ -282,7 +276,6 @@ void Processor<sint, sgf2n>::read_socket_ints(int client_id, const vector<int>&
int m = registers.size();
socket_stream.reset_write_head();
socket_stream.Receive(external_clients.get_socket(client_id));
maybe_decrypt_sequence(client_id);
for (int i = 0; i < m; i++)
{
int val;
@@ -298,7 +291,6 @@ void Processor<sint, sgf2n>::read_socket_vector(int client_id, const vector<int>
int m = registers.size();
socket_stream.reset_write_head();
socket_stream.Receive(external_clients.get_socket(client_id));
maybe_decrypt_sequence(client_id);
for (int i = 0; i < m; i++)
{
get_Cp_ref(registers[i]).unpack(socket_stream);
@@ -312,146 +304,13 @@ void Processor<sint, sgf2n>::read_socket_private(int client_id, const vector<int
int m = registers.size();
socket_stream.reset_write_head();
socket_stream.Receive(external_clients.get_socket(client_id));
maybe_decrypt_sequence(client_id);
map<int,octet*>::iterator it = external_clients.symmetric_client_keys.find(client_id);
if (it != external_clients.symmetric_client_keys.end())
{
socket_stream.decrypt(it->second);
}
for (int i = 0; i < m; i++)
{
get_Sp_ref(registers[i]).unpack(socket_stream, read_macs);
}
}
// Read socket for client public key as 8 ints, calculate session key for client.
template<class sint, class sgf2n>
void Processor<sint, sgf2n>::read_client_public_key(int client_id, const vector<int>& registers) {
read_socket_ints(client_id, registers);
// After read into registers, need to extract values
vector<int> client_public_key (registers.size(), 0);
for(unsigned int i = 0; i < registers.size(); i++) {
client_public_key[i] = (int&)get_Ci_ref(registers[i]);
}
external_clients.generate_session_key_for_client(client_id, client_public_key);
}
template<class sint, class sgf2n>
void Processor<sint, sgf2n>::init_secure_socket_internal(int client_id, const vector<int>& registers) {
external_clients.symmetric_client_commsec_send_keys.erase(client_id);
external_clients.symmetric_client_commsec_recv_keys.erase(client_id);
unsigned char client_public_bytes[crypto_sign_PUBLICKEYBYTES];
sts_msg1_t m1;
sts_msg2_t m2;
sts_msg3_t m3;
external_clients.load_server_keys_once();
external_clients.require_ed25519_keys();
// Validate inputs and state
if(registers.size() != 8) {
throw "Invalid call to init_secure_socket.";
}
// Extract client long term public key into bytes
vector<int> client_public_key (registers.size(), 0);
for(unsigned int i = 0; i < registers.size(); i++) {
client_public_key[i] = (int&)get_Ci_ref(registers[i]);
}
external_clients.curve25519_ints_to_bytes(client_public_bytes, client_public_key);
// Start Station to Station Protocol
STS ke(client_public_bytes, external_clients.server_publickey_ed25519, external_clients.server_secretkey_ed25519);
m1 = ke.send_msg1();
socket_stream.reset_write_head();
socket_stream.append(m1.bytes, sizeof m1.bytes);
socket_stream.Send(external_clients.get_socket(client_id));
socket_stream.ReceiveExpected(external_clients.get_socket(client_id),
96);
socket_stream.consume(m2.pubkey, sizeof m2.pubkey);
socket_stream.consume(m2.sig, sizeof m2.sig);
m3 = ke.recv_msg2(m2);
socket_stream.reset_write_head();
socket_stream.append(m3.bytes, sizeof m3.bytes);
socket_stream.Send(external_clients.get_socket(client_id));
// Use results of STS to generate send and receive keys.
vector<unsigned char> sendKey = ke.derive_secret(crypto_secretbox_KEYBYTES);
vector<unsigned char> recvKey = ke.derive_secret(crypto_secretbox_KEYBYTES);
external_clients.symmetric_client_commsec_send_keys[client_id] = make_pair(sendKey,0);
external_clients.symmetric_client_commsec_recv_keys[client_id] = make_pair(recvKey,0);
}
template<class sint, class sgf2n>
void Processor<sint, sgf2n>::init_secure_socket(int client_id, const vector<int>& registers) {
try {
init_secure_socket_internal(client_id, registers);
} catch (char const *e) {
cerr << "STS initiator role failed with: " << e << endl;
throw Processor_Error("STS initiator failed");
}
}
template<class sint, class sgf2n>
void Processor<sint, sgf2n>::resp_secure_socket(int client_id, const vector<int>& registers) {
try {
resp_secure_socket_internal(client_id, registers);
} catch (char const *e) {
cerr << "STS responder role failed with: " << e << endl;
throw Processor_Error("STS responder failed");
}
}
template<class sint, class sgf2n>
void Processor<sint, sgf2n>::resp_secure_socket_internal(int client_id, const vector<int>& registers) {
external_clients.symmetric_client_commsec_send_keys.erase(client_id);
external_clients.symmetric_client_commsec_recv_keys.erase(client_id);
unsigned char client_public_bytes[crypto_sign_PUBLICKEYBYTES];
sts_msg1_t m1;
sts_msg2_t m2;
sts_msg3_t m3;
external_clients.load_server_keys_once();
external_clients.require_ed25519_keys();
// Validate inputs and state
if(registers.size() != 8) {
throw "Invalid call to init_secure_socket.";
}
vector<int> client_public_key (registers.size(), 0);
for(unsigned int i = 0; i < registers.size(); i++) {
client_public_key[i] = (int&)get_Ci_ref(registers[i]);
}
external_clients.curve25519_ints_to_bytes(client_public_bytes, client_public_key);
// Start Station to Station Protocol for the responder
STS ke(client_public_bytes, external_clients.server_publickey_ed25519, external_clients.server_secretkey_ed25519);
socket_stream.reset_read_head();
socket_stream.ReceiveExpected(external_clients.get_socket(client_id),
32);
socket_stream.consume(m1.bytes, sizeof m1.bytes);
m2 = ke.recv_msg1(m1);
socket_stream.reset_write_head();
socket_stream.append(m2.pubkey, sizeof m2.pubkey);
socket_stream.append(m2.sig, sizeof m2.sig);
socket_stream.Send(external_clients.get_socket(client_id));
socket_stream.ReceiveExpected(external_clients.get_socket(client_id),
64);
socket_stream.consume(m3.bytes, sizeof m3.bytes);
ke.recv_msg3(m3);
// Use results of STS to generate send and receive keys.
vector<unsigned char> recvKey = ke.derive_secret(crypto_secretbox_KEYBYTES);
vector<unsigned char> sendKey = ke.derive_secret(crypto_secretbox_KEYBYTES);
external_clients.symmetric_client_commsec_recv_keys[client_id] = make_pair(recvKey,0);
external_clients.symmetric_client_commsec_send_keys[client_id] = make_pair(sendKey,0);
}
// Read share data from a file starting at file_pos until registers filled.
// file_pos_register is written with new file position (-1 is eof).
@@ -722,26 +581,4 @@ ostream& operator<<(ostream& s,const Processor<sint, sgf2n>& P)
return s;
}
template<class sint, class sgf2n>
void Processor<sint, sgf2n>::maybe_decrypt_sequence(int client_id)
{
map<int, pair<vector<octet>,uint64_t> >::iterator it_cs = external_clients.symmetric_client_commsec_recv_keys.find(client_id);
if (it_cs != external_clients.symmetric_client_commsec_recv_keys.end())
{
socket_stream.decrypt_sequence(&it_cs->second.first[0], it_cs->second.second);
it_cs->second.second++;
}
}
template<class sint, class sgf2n>
void Processor<sint, sgf2n>::maybe_encrypt_sequence(int client_id)
{
map<int, pair<vector<octet>,uint64_t> >::iterator it_cs = external_clients.symmetric_client_commsec_send_keys.find(client_id);
if (it_cs != external_clients.symmetric_client_commsec_send_keys.end())
{
socket_stream.encrypt_sequence(&it_cs->second.first[0], it_cs->second.second);
it_cs->second.second++;
}
}
#endif

1
Programs/Circuits Submodule

Submodule Programs/Circuits added at 82dfda9d12

View File

@@ -0,0 +1,8 @@
from circuit import Circuit
sb128 = sbits.get_type(128)
key = sb128(0x2b7e151628aed2a6abf7158809cf4f3c)
plaintext = sb128(0x6bc1bee22e409f96e93d7e117393172a)
n = 1000
aes128 = Circuit('aes_128')
ciphertexts = aes128(sbitvec([key] * n), sbitvec([plaintext] * n))
ciphertexts.elements()[n - 1].reveal().print_reg()

View File

@@ -4,8 +4,6 @@
to deduce the maximum value from a range of integer input.
Demonstrate clients external to computing parties supplying input and receiving an authenticated result. See bankers-bonus-client.cpp for client (and setup instructions).
For an implementation with communications security see bankers_bonus_commsec.mpc.
Wait for MAX_NUM_CLIENTS to join the game or client finish flag to be sent
before calculating the maximum.

View File

@@ -1,144 +0,0 @@
# coding=latin1
"""
Solve Bankers bonus, aka Millionaires problem.
to deduce the maximum value from a range of integer input.
Demonstrate clients external to computing parties supplying input and receiving
an authenticated result. See bankers-bonus-commsec-client.cpp for client (and setup instructions).
For an implementation without communications security see bankers_bonus.mpc.
Wait for MAX_NUM_CLIENTS to join the game or client finish flag to be sent
before calculating the maximum.
Note each client connects in a single thread and so is potentially blocked.
Each round / game will reset and so this runs indefinitiely.
"""
from Compiler.types import sint, regint, Array, Matrix, MemValue
from Compiler.instructions import listen, acceptclientconnection
from Compiler.library import print_ln, do_while, if_e, else_, for_range
from Compiler.util import if_else
PORTNUM = 14000
MAX_NUM_CLIENTS = 8
n_rounds = 0
if len(program.args) > 1:
n_rounds = int(program.args[1])
def accept_client():
client_socket_id = regint()
acceptclientconnection(client_socket_id, PORTNUM)
last = regint.read_from_socket(client_socket_id)
# Crypto setup
public_signing_key = regint.read_from_socket(client_socket_id, 8)
public_key = regint.read_client_public_key(client_socket_id)
regint.resp_secure_socket(client_socket_id,*public_signing_key)
return client_socket_id, last
def client_input(client_socket_id):
"""
Send share of random value, receive input and deduce share.
"""
client_inputs = sint.receive_from_client(1, client_socket_id)
return client_inputs[0]
def determine_winner(number_clients, client_values, client_ids):
"""Work out and return client_id which corresponds to max client_value"""
max_value = Array(1, sint)
max_value[0] = client_values[0]
win_client_id = Array(1, sint)
win_client_id[0] = client_ids[0]
@for_range(number_clients-1)
def loop_body(i):
# Is this client input a new maximum, will be sint(1) if true, else sint(0)
is_new_max = max_value[0] < client_values[i+1]
# Keep latest max_value
max_value[0] = if_else(is_new_max, client_values[i+1], max_value[0])
# Keep current winning client id
win_client_id[0] = if_else(is_new_max, client_ids[i+1], win_client_id[0])
return win_client_id[0]
def write_winner_to_clients(sockets, number_clients, winning_client_id):
"""Send share of winning client id to all clients who joined game."""
# Setup authenticate result using share of random.
# client can validate ∑ winning_client_id * ∑ rnd_from_triple = ∑ auth_result
rnd_from_triple = sint.get_random_triple()[0]
auth_result = winning_client_id * rnd_from_triple
@for_range(number_clients)
def loop_body(i):
sint.write_shares_to_socket(sockets[i], [winning_client_id, rnd_from_triple, auth_result])
def main():
"""Listen in while loop for players to join a game.
Once maxiumum reached or have notified that round finished, run comparison and return result."""
# Start listening for client socket connections
listen(PORTNUM)
print_ln('Listening for client connections on base port %s', PORTNUM)
def game_loop(_=None):
print_ln('Starting a new round of the game.')
# Clients socket id (integer).
client_sockets = Array(MAX_NUM_CLIENTS, regint)
# Number of clients
number_clients = MemValue(regint(0))
# Clients secret input.
client_values = Array(MAX_NUM_CLIENTS, sint)
# Client ids to identity client
client_ids = Array(MAX_NUM_CLIENTS, sint)
# Keep track of received inputs
seen = Array(MAX_NUM_CLIENTS, regint)
seen.assign_all(0)
# Loop round waiting for each client to connect
@do_while
def client_connections():
client_id, last = accept_client()
@if_(client_id >= MAX_NUM_CLIENTS)
def _():
print_ln('client id too high')
crash()
client_sockets[client_id] = client_id
client_ids[client_id] = client_id
seen[client_id] = 1
@if_(last == 1)
def _():
number_clients.write(client_id + 1)
return (sum(seen) < number_clients) + (number_clients == 0)
@for_range(number_clients)
def _(client_id):
client_values[client_id] = client_input(client_id)
winning_client_id = determine_winner(number_clients, client_values, client_ids)
print_ln('Found winner, index: %s.', winning_client_id.reveal())
write_winner_to_clients(client_sockets, number_clients, winning_client_id)
return True
if n_rounds > 0:
print('run %d rounds' % n_rounds)
for_range(n_rounds)(game_loop)
else:
print('run forever')
do_while(game_loop)
main()

View File

@@ -1,5 +1,6 @@
import ml
import random
import re
program.use_trunc_pr = True
sfix.round_nearest = True
@@ -10,6 +11,13 @@ cfix.set_precision(16, 31)
N = int(program.args[1])
n_features = int(program.args[2])
n_threads = None
for arg in program.args:
m = re.match('n_threads=(.*)', arg)
if m:
n_threads = int(m.group(1))
program.allocated_mem['s'] = 1 + n_features
b = sfix.load_mem(0)
@@ -24,13 +32,15 @@ dense.W.assign_vector(W)
print_ln('b=%s W[-1]=%s', dense.b[0].reveal(),
dense.W[n_features - 1][0][0].reveal())
@for_range_opt(n_features)
@for_range_opt_multithread(n_threads, n_features)
def _(i):
@for_range_opt(N)
def _(j):
dense.X[j][0][i] = sfix.get_input_from(0)
dense.forward()
batch = regint.Array(N)
batch.assign(regint.inc(N))
dense.forward(batch)
print_str('predictions: ')

View File

@@ -1,28 +1,52 @@
import ml
import random
import re
program.use_trunc_pr = True
sfix.round_nearest = True
sfix.set_precision(16, 31)
cfix.set_precision(16, 31)
sfloat.vlen = sfix.f
n_epochs = 200
n_epochs = 100
n_normal = int(program.args[1])
n_pos = int(program.args[2])
n_features = int(program.args[3])
if 'approx' in program.args:
approx = 3
elif 'approx5' in program.args:
approx = 5
else:
approx = False
if 'split' in program.args:
program.use_split(3)
n_threads = None
for arg in program.args:
m = re.match('n_threads=(.*)', arg)
if m:
n_threads = int(m.group(1))
debug = 'debug' in program.args
ml.set_n_threads(n_threads)
n_examples = n_normal + n_pos
N = max(n_normal, n_pos) * 2
if 'mini' in program.args:
batch_size = 32
else:
batch_size = N
X_normal = sfix.Matrix(n_normal, n_features)
X_pos = sfix.Matrix(n_pos, n_features)
@for_range_opt(n_features)
@for_range_opt_multithread(n_threads, n_features)
def _(i):
@for_range_opt(n_normal)
def _(j):
@@ -32,11 +56,11 @@ def _(i):
X_pos[j][i] = sfix.get_input_from(0)
dense = ml.Dense(N, n_features, 1)
layers = [dense, ml.Output(N)]
layers = [dense, ml.Output(N, approx=approx)]
sgd = ml.SGD(layers, n_epochs, report_loss=debug)
sgd.reset([X_normal, X_pos])
sgd.run()
sgd.run(batch_size)
if debug:
@for_range(N)

View File

@@ -32,8 +32,6 @@ sgd = ml.SGD(layers, batch // 128 * 10 , debug=debug, report_loss=False)
sgd.reset([X_normal, X_pos])
sgd.run(batch_size=batch)
ml.approx_sigmoid.special = False
# @for_range(1000)
# def _(i):
# sgd.backward()

View File

@@ -82,7 +82,14 @@ if 'quant' in program.args:
else:
dense = ml.Dense(N, n_features, 1)
layers = [dense, ml.Output(N, debug=debug, approx='approx' in program.args)]
if 'approx' in program.args:
approx = 3
elif 'approx5' in program.args:
approx = 5
else:
approx = False
layers = [dense, ml.Output(N, debug=debug, approx=approx)]
Y = sfix.Array(n_examples)
X = sfix.Matrix(n_examples, n_features)

View File

@@ -97,8 +97,3 @@ test(c[0], 0)
test(c[1], 1)
test(c[2], 1)
test(c[3], 0)
k = 41
a = int(2.9142 * 2**k)
alpha = sbitint.get_type(2 * k)(a)
test(sbits.bit_compose((alpha >> 64).bit_decompose()[:64]), 0)

View File

@@ -127,12 +127,14 @@ public:
void pack(octetStream& os, bool full = true) const
{
(void)full;
FixedVec<T, 2>::pack(os);
if (full)
FixedVec<T, 2>::pack(os);
else
(*this)[0].pack(os);
}
void unpack(octetStream& os, bool full = true)
{
(void)full;
assert(full);
FixedVec<T, 2>::unpack(os);
}
};

View File

@@ -14,18 +14,14 @@ template <class T>
class PrepLessInput : public InputBase<T>
{
protected:
SubProcessor<T>* processor;
vector<T> shares;
size_t i_share;
public:
PrepLessInput(SubProcessor<T>* proc) :
InputBase<T>(proc ? proc->Proc : 0), processor(proc), i_share(0) {}
InputBase<T>(proc ? proc->Proc : 0), i_share(0) {}
virtual ~PrepLessInput() {}
void start(int player, int n_inputs);
void stop(int player, vector<int> targets);
virtual void reset(int player) = 0;
virtual void add_mine(const typename T::open_type& input,
int n_bits = -1) = 0;

View File

@@ -57,47 +57,6 @@ void ReplicatedInput<T>::exchange()
}
}
template<class T>
void PrepLessInput<T>::start(int player, int n_inputs)
{
assert(processor != 0);
auto& proc = *processor;
reset(player);
if (player == proc.P.my_num())
{
for (int i = 0; i < n_inputs; i++)
{
typename T::clear t;
this->buffer.input(t);
add_mine(t);
}
send_mine();
}
}
template<class T>
void PrepLessInput<T>::stop(int player, vector<int> targets)
{
assert(processor != 0);
auto& proc = *processor;
if (proc.P.my_num() == player)
{
for (unsigned int i = 0; i < targets.size(); i++)
proc.get_S_ref(targets[i]) = finalize_mine();
}
else
{
octetStream o;
this->timer.start();
proc.P.receive_player(player, o, true);
this->timer.stop();
for (unsigned int i = 0; i < targets.size(); i++)
finalize_other(player, proc.get_S_ref(targets[i]), o);
}
}
template<class T>
inline void ReplicatedInput<T>::finalize_other(int player, T& target,
octetStream& o, int n_bits)

View File

@@ -57,6 +57,11 @@ public:
return ShamirMachine::s().threshold;
}
static T get_rec_factor(int i, int n)
{
return Protocol::get_rec_factor(i, n);
}
static ShamirShare constant(T value, int my_num, const T& alphai = {})
{
return ShamirShare(value, my_num, alphai);
@@ -135,14 +140,17 @@ public:
throw runtime_error("never call this");
}
void pack(octetStream& os, bool full = true) const
void pack(octetStream& os, const T& rec_factor) const
{
(*this * rec_factor).pack(os);
}
void pack(octetStream& os) const
{
(void)full;
T::pack(os);
}
void unpack(octetStream& os, bool full = true)
{
(void)full;
assert(full);
T::unpack(os);
}
};

View File

@@ -24,6 +24,8 @@ public:
template<class T, class U>
static void split(vector<U>, vector<int>, int, T*, int, Player&)
{ throw runtime_error("split not implemented"); }
static bool get_rec_factor(int, int) { return false; }
};
#endif /* PROTOCOLS_SHAREINTERFACE_H_ */

View File

@@ -25,7 +25,9 @@ public:
}
void buffer_triples();
void buffer_squares();
void buffer_inverses();
void buffer_bits();
};
#endif /* PROTOCOLS_SOHOPREP_H_ */

View File

@@ -70,9 +70,58 @@ void SohoPrep<T>::buffer_triples()
ci.element(i)}});
}
template<class T>
void SohoPrep<T>::buffer_squares()
{
auto& proc = this->proc;
assert(proc != 0);
lock.lock();
if (not setup)
{
PlainPlayer P(proc->P.N, T::clear::type_char());
basic_setup(P);
}
lock.unlock();
Plaintext_<FD> ai(setup->FieldD);
SeededPRNG G;
ai.randomize(G);
Ciphertext Ca = setup->pk.encrypt(ai);
octetStream os;
Ca.pack(os);
for (int i = 1; i < proc->P.num_players(); i++)
{
proc->P.pass_around(os);
Ca.add<0>(os);
}
Ciphertext Cc = Ca.mul(setup->pk, Ca);
Plaintext_<FD> ci(setup->FieldD);
SimpleDistDecrypt<FD> dd(proc->P, *setup);
EncCommitBase_<FD> EC;
dd.reshare(ci, Cc, EC);
for (unsigned i = 0; i < ai.num_slots(); i++)
this->squares.push_back({{ai.element(i), ci.element(i)}});
}
template<class T>
void SohoPrep<T>::buffer_inverses()
{
assert(this->proc != 0);
::buffer_inverses(this->inverses, *this, this->proc->MC, this->proc->P);
}
template<>
void SohoPrep<SohoShare<gfp>>::buffer_bits()
{
buffer_bits_from_squares(*this);
}
template<>
void SohoPrep<SohoShare<gf2n_short>>::buffer_bits()
{
buffer_bits_without_check();
}

View File

@@ -9,7 +9,6 @@
#include "MascotPrep.h"
#include "RingOnlyPrep.h"
#include "Spdz2kShare.h"
#include "GC/TinySecret.h"
template<class T, class U>
void bits_from_square_in_ring(vector<T>& bits, int buffer_size, U* bit_prep);

View File

@@ -147,7 +147,7 @@ compute the preprocessing time for a particular computation.
required. This includes mainstream processors released 2014 or later.
For older models you need to deactivate the respective
extensions in the `ARCH` variable.
- To benchmark online-only protocols or Overdrive, add the following line at the top: `MY_CFLAGS = -DINSECURE`
- To benchmark online-only protocols or Overdrive offline phases, add the following line at the top: `MY_CFLAGS = -DINSECURE`
- `PREP_DIR` should point to should be a local, unversioned directory to store preprocessing data (default is `Player-Data` in the current directory).
- For homomorphic encryption, set `USE_NTL = 1`.
@@ -240,6 +240,29 @@ al.](https://eprint.iacr.org/2020/338) You can activate them by using
`-Y` instead of `-X`. Note that this also activates classic daBits
when useful.
#### Bristol Fashion circuits
Bristol Fashion is the name of a description format of binary circuits
used by
[SCALE-MAMBA](https://github.com/KULeuven-COSIC/SCALE-MAMBA). You can
access such circuits from the high-level language if they are present
in `Programs/Circuits`. To run the AES-128 circuit provided with
SCALE-MAMBA, you can run the following:
```
make Programs/Circuits
./compile.py aes_circuit
Scripts/semi.sh aes_circuit
```
This downloads the circuit, compiles it to MP-SPDZ bytecode, and runs
it as semi-honest two-party computation 1000 times in parallel. It
should then output the AES test vector
`0x3ad77bb40d7a3660a89ecaf32466ef97`. You can run it with any other
protocol as well.
See the
[documentation](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#module-Compiler.circuit)
for further examples.
#### Compiling and running programs from external directories
Programs can also be edited, compiled and run from any directory with the above basic structure. So for a source file in `./Programs/Source/`, all SPDZ scripts must be run from `./`. The `setup-online.sh` script must also be run from `./` to create the relevant data. For example:

13
Scripts/setup-clients.sh Executable file
View File

@@ -0,0 +1,13 @@
#!/bin/bash
n=$1
test -e Player-Data || mkdir Player-Data
echo Setting up SSL for $n parties
for i in `seq 0 $[n-1]`; do
openssl req -newkey rsa -nodes -x509 -out Player-Data/C$i.pem -keyout Player-Data/C$i.key -subj "/CN=C$i"
done
c_rehash Player-Data

View File

@@ -1,6 +1,6 @@
#!/bin/bash
make -j4 ecdsa Fake-ECDSA.x
make -j4 ecdsa Fake-ECDSA.x secure.x
run()
{
@@ -19,8 +19,11 @@ for i in rep mal-rep shamir mal-shamir; do
run $i 2
done
./Fake-ECDSA.x
for i in semi mascot fake-spdz; do
for i in semi mascot; do
run $i 1
done
if ! ./secure.x; then
./Fake-ECDSA.x
run fake-spdz 1
fi

View File

@@ -1,107 +0,0 @@
// Client key file format:
// X25519 Public Key
// X25519 Secret Key
// Ed25519 Public Key
// Ed25519 Secret Key
// Server 1 X25519 Public Key
// Server 1 Ed25519 Public Key
// ...
// Server N Public Key
// Server N Ed25519 Public Key
//
// Player key file format:
// X25519 Public Key
// X25519 Secret Key
// Ed25519 Public Key
// Ed25519 Secret Key
// Number of clients [64 bit little endian]
// Client 1 X25519 Public Key
// Client 1 Ed25519 Public Key
// ...
// Client N X25519 Public Key
// Client N Ed25519 Public Key
// Number of servers [64 bit little endian]
// Server 1 X25519 Public Key
// Server 1 Ed25519 Public Key
// ...
// Server N X25519 Public Key
// Server N Ed25519 Public Key
#include "Tools/octetStream.h"
#include "Networking/Player.h"
#include "Math/gf2n.h"
#include "Config.h"
#include <sodium.h>
#include <vector>
#include <iomanip>
namespace Config {
static void output(const vector<octet> &vec, ofstream &of)
{
copy(vec.begin(), vec.end(), ostreambuf_iterator<char>(of));
}
void putW64le(ofstream &outf, uint64_t nr)
{
char buf[8];
for(int i=0;i<8;i++) {
char byte = (uint8_t)(nr >> (i*8));
buf[i] = (char)byte;
}
outf.write(buf,sizeof buf);
}
void write_player_config_file(string config_dir
,int player_number, public_key my_pub, secret_key my_priv
, public_signing_key my_signing_pub, secret_signing_key my_signing_priv
, vector<public_key> client_pubs, vector<public_signing_key> client_signing_pubs
, vector<public_key> player_pubs, vector<public_signing_key> player_signing_pubs)
{
stringstream filename;
filename << config_dir << "Player-SPDZ-Keys-P" << player_number;
ofstream outf(filename.str().c_str(), ios::out | ios::binary);
if (outf.fail())
throw file_error(filename.str().c_str());
if(crypto_box_PUBLICKEYBYTES != my_pub.size() ||
crypto_box_SECRETKEYBYTES != my_priv.size() ||
crypto_sign_PUBLICKEYBYTES != my_signing_pub.size() ||
crypto_sign_SECRETKEYBYTES != my_signing_priv.size()) {
throw "Invalid key sizes";
} else if(client_pubs.size() != client_signing_pubs.size()) {
throw "Incorrect number of client keys";
} else if(player_pubs.size() != player_signing_pubs.size()) {
throw "Incorrect number of player keys";
} else {
for(size_t i = 0; i < client_pubs.size(); i++) {
if(crypto_box_PUBLICKEYBYTES != client_pubs[i].size() ||
crypto_sign_PUBLICKEYBYTES != client_signing_pubs[i].size()) {
throw "Incorrect size of client key.";
}
}
for(size_t i = 0; i < player_pubs.size(); i++) {
if(crypto_box_PUBLICKEYBYTES != player_pubs[i].size() ||
crypto_sign_PUBLICKEYBYTES != player_signing_pubs[i].size()) {
throw "Incorrect size of player key.";
}
}
}
// Write public and secret X25519 keys
output(my_pub, outf);
output(my_priv, outf);
output(my_signing_pub, outf);
output(my_signing_priv, outf);
putW64le(outf, (uint64_t)client_pubs.size());
// Write all client public keys
for (size_t j = 0; j < client_pubs.size(); j++) {
output(client_pubs[j], outf);
output(client_signing_pubs[j], outf);
}
putW64le(outf, (uint64_t)player_pubs.size());
for (size_t j = 0; j < player_pubs.size(); j++) {
output(player_pubs[j], outf);
output(player_signing_pubs[j], outf);
}
outf.flush();
outf.close();
}
}

View File

@@ -1,15 +0,0 @@
#include "Tools/octetStream.h"
#include "Networking/Player.h"
#include <sodium.h>
namespace Config {
typedef vector<octet> public_key;
typedef vector<octet> public_signing_key;
typedef vector<octet> secret_key;
typedef vector<octet> secret_signing_key;
void write_player_config_file(string config_dir
,int player_number, public_key my_pub, secret_key my_priv
, public_signing_key my_signing_pub, secret_signing_key my_signing_priv
, vector<public_key> client_pubs, vector<public_signing_key> client_signing_pubs
, vector<public_key> player_pubs, vector<public_signing_key> player_signing_pubs);
void putW64le(ofstream &outf, uint64_t nr);
}

View File

@@ -105,9 +105,7 @@ bigint octetStream::check_sum(int req_bytes) const
bool octetStream::equals(const octetStream& a) const
{
if (len!=a.len) { return false; }
for (size_t i=0; i<len; i++)
{ if (data[i]!=a.data[i]) { return false; } }
return true;
return memcmp(data, a.data, len) == 0;
}
@@ -267,106 +265,6 @@ void octetStream::exchange(T send_socket, T receive_socket, octetStream& receive
}
void octetStream::store(const vector<int>& v)
{
store(v.size());
for (int x : v)
store(x);
}
void octetStream::get(vector<int>& v)
{
size_t size;
get(size);
v.resize(size);
for (int& x : v)
get(x);
}
// Construct the ciphertext as `crypto_secretbox(pt, counter||random)`
void octetStream::encrypt_sequence(const octet* key, uint64_t counter)
{
octet nonce[crypto_secretbox_NONCEBYTES];
int i;
int message_len_bytes = len;
randombytes_buf(nonce, sizeof nonce);
if(counter == UINT64_MAX) {
throw Processor_Error("Encryption would overflow counter. Too many messages.");
} else {
counter++;
}
for(i=0; i<8; i++) {
nonce[i] = uint8_t ((counter >> (8*i)) & 0xFF);
}
int ciphertext_len = message_len_bytes + crypto_secretbox_MACBYTES;
octet ciphertext[ciphertext_len];
crypto_secretbox_easy(ciphertext, data, message_len_bytes, nonce, key);
// append the ciphertext to an empty octet stream
reset_read_head();
reset_write_head();
append(ciphertext, ciphertext_len*sizeof(octet));
// append the nonce
append(nonce, crypto_secretbox_NONCEBYTES * sizeof(octet));
}
void octetStream::decrypt_sequence(const octet* key, uint64_t counter)
{
int ciphertext_len = len - crypto_box_NONCEBYTES;
const octet *nonce = data + ciphertext_len;
int i;
uint64_t recvCounter=0;
// Numbers are typically 24U + 16U so cast to int is safe.
if (len < (int)(crypto_box_NONCEBYTES + crypto_secretbox_MACBYTES))
{
throw Processor_Error("Cannot decrypt octetStream: ciphertext too short");
}
for(i=7; i>=0; i--) {
recvCounter |= ((uint64_t) *(nonce + i)) << (i*8);
}
if(recvCounter != counter + 1) {
throw Processor_Error("Incorrect counter on stream. Possible MITM.");
}
if (crypto_secretbox_open_easy(data, data, ciphertext_len, nonce, key) != 0)
{
throw Processor_Error("octetStream decryption failed!");
}
rewind_write_head(crypto_box_NONCEBYTES + crypto_secretbox_MACBYTES);
//prepare for unpack after decryption by resetting the read head
reset_read_head();
}
void octetStream::encrypt(const octet* key)
{
octet nonce[crypto_secretbox_NONCEBYTES];
randombytes_buf(nonce, sizeof nonce);
int message_len_bytes = len;
resize(len + crypto_secretbox_MACBYTES + crypto_secretbox_NONCEBYTES);
// Encrypt data in-place
crypto_secretbox_easy(data, data, message_len_bytes, nonce, key);
// Adjust length to account for MAC, then append nonce
len += crypto_secretbox_MACBYTES;
append(nonce, sizeof nonce);
}
void octetStream::decrypt(const octet* key)
{
int ciphertext_len = len - crypto_box_NONCEBYTES;
// Numbers are typically 24U + 16U so cast to int is safe.
if (len < (int)(crypto_box_NONCEBYTES + crypto_secretbox_MACBYTES))
{
throw Processor_Error("Cannot decrypt octetStream: ciphertext too short");
}
if (crypto_secretbox_open_easy(data, data, ciphertext_len, data + ciphertext_len, key) != 0)
{
throw Processor_Error("octetStream decryption failed!");
}
rewind_write_head(crypto_box_NONCEBYTES + crypto_secretbox_MACBYTES);
}
void octetStream::input(istream& s)
{
size_t size;

View File

@@ -127,6 +127,8 @@ class octetStream
template<class T>
T get();
template<class T>
void get(T& ans);
// works for all statically allocated types
template <class T>
@@ -134,8 +136,10 @@ class octetStream
template <class T>
void unserialize(T& x) { consume((octet*)&x, sizeof(x)); }
void store(const vector<int>& v);
void get(vector<int>& v);
template <class T>
void store(const vector<T>& v);
template <class T>
void get(vector<T>& v);
void consume(octetStream& s,size_t l)
{ s.resize(l);
@@ -147,20 +151,8 @@ class octetStream
void Send(T socket_num) const;
template<class T>
void Receive(T socket_num);
void ReceiveExpected(int socket_num, size_t expected);
// In-place authenticated encryption using sodium; key of length crypto_generichash_BYTES
// ciphertext = Enc(message) | MAC | counter
//
// This is much like 'encrypt' but uses a deterministic counter for the nonce,
// allowing enforcement of message order.
void encrypt_sequence(const octet* key, uint64_t counter);
void decrypt_sequence(const octet* key, uint64_t counter);
// In-place authenticated encryption using sodium; key of length crypto_secretbox_KEYBYTES
// ciphertext = Enc(message) | MAC | nonce
void encrypt(const octet* key);
void decrypt(const octet* key);
template<class T>
void ReceiveExpected(T socket_num, size_t expected);
void input(istream& s);
void output(ostream& s);
@@ -278,7 +270,8 @@ inline void octetStream::Receive(T socket_num)
reset_read_head();
}
inline void octetStream::ReceiveExpected(int socket_num, size_t expected)
template<class T>
inline void octetStream::ReceiveExpected(T socket_num, size_t expected)
{
size_t nlen = 0;
receive(socket_num, nlen, LENGTH_SIZE);
@@ -310,11 +303,35 @@ T octetStream::get()
return res;
}
template<class T>
void octetStream::get(T& res)
{
res.unpack(*this);
}
template<>
inline int octetStream::get()
{
return get_int(sizeof(int));
}
template<class T>
void octetStream::store(const vector<T>& v)
{
store(v.size());
for (auto& x : v)
store(x);
}
template<class T>
void octetStream::get(vector<T>& v)
{
size_t size;
get(size);
v.resize(size);
for (auto& x : v)
get(x);
}
#endif

View File

@@ -3,6 +3,7 @@
#include "Math/bigint.h"
#include "Math/fixint.h"
#include "Math/Z2k.hpp"
#include "Math/gfp.h"
#include "Tools/Subroutines.h"
#include <stdio.h>
#include <sodium.h>

View File

@@ -1,178 +0,0 @@
// Preprocessing stage to:
// Create the public/private key pairs for each client
// Create the public/private key pairs for each spdz engine
// For each client store the client keys + all spdz engine public keys
// in a file named Client-Keys-C<client id>
// For each spdz engine store the spdz engine keys + all client public keys
// in a file named Player-SPDZ-Keys-P<player id>
//
#include <sodium.h>
#include "Math/gf2n.h"
#include "Math/gfp.h"
#include "Protocols/Share.h"
#include "Math/Setup.h"
#include "Protocols/fake-stuff.h"
#include "Exceptions/Exceptions.h"
#include "Math/Setup.h"
#include "Processor/Data_Files.h"
#include "Tools/mkpath.h"
#include "Tools/ezOptionParser.h"
#include "Tools/Config.h"
#include <sstream>
#include <fstream>
using namespace std;
static void output(const vector<octet> &vec, ofstream &of)
{
copy(vec.begin(), vec.end(), ostreambuf_iterator<char>(of));
}
int main(int argc, const char** argv)
{
ez::ezOptionParser opt;
opt.syntax = "./client-setup.x <nplayers> [OPTIONS]\n";
opt.add(
"0", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Number of external clients (default: nplayers)", // Help description.
"-nc", // Flag token.
"--numclients" // Flag token.
);
opt.add(
"128", // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Bit length of GF(p) field (default: 128)", // Help description.
"-lgp", // Flag token.
"--lgp" // Flag token.
);
opt.add(
to_string(gf2n::default_degree()).c_str(), // Default.
0, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
("Bit length of GF(2^n) field (default: " + to_string(gf2n::default_degree()) + ")").c_str(), // Help description.
"-lg2", // Flag token.
"--lg2" // Flag token.
);
opt.parse(argc, argv);
string prep_data_prefix;
string usage;
int nplayers;
if (opt.firstArgs.size() == 2)
{
nplayers = atoi(opt.firstArgs[1]->c_str());
}
else if (opt.lastArgs.size() == 1)
{
nplayers = atoi(opt.lastArgs[0]->c_str());
}
else
{
cerr << "ERROR: invalid number of arguments\n";
opt.getUsage(usage);
cout << usage;
return 1;
}
int lg2, lgp, nclients;
opt.get("--numclients")->getInt(nclients);
if (nclients <= 0)
nclients = nplayers;
opt.get("--lgp")->getInt(lgp);
opt.get("--lg2")->getInt(lg2);
cout << "nplayers = " << nplayers << endl;
cout << "nclients = " << nclients << endl;
cout << "lgp = " << lgp << endl;
cout << "lgp2 = " << lg2 << endl;
prep_data_prefix = get_prep_dir(nplayers, lgp, lg2);
cout << "prep dir = " << prep_data_prefix << endl;
vector<Config::public_key> client_publickeys;
vector<Config::secret_key> client_secretkeys;
client_publickeys.resize(nclients);
client_secretkeys.resize(nclients);
for (int i = 0; i < nclients; i++) {
client_secretkeys[i].resize(crypto_box_SECRETKEYBYTES);
client_publickeys[i].resize(crypto_box_PUBLICKEYBYTES);
randombytes_buf(&client_secretkeys[i][0], client_secretkeys[i].size());
crypto_scalarmult_base(&client_publickeys[i][0], &client_secretkeys[i][0]);
}
vector<Config::public_signing_key> client_signing_publickeys;
vector<Config::secret_signing_key> client_signing_secretkeys;
client_signing_publickeys.resize(nclients);
client_signing_secretkeys.resize(nclients);
for (int i = 0; i < nclients; i++) {
client_signing_publickeys[i].resize(crypto_sign_PUBLICKEYBYTES);
client_signing_secretkeys[i].resize(crypto_sign_SECRETKEYBYTES);
crypto_sign_keypair(&client_signing_publickeys[i][0], &client_signing_secretkeys[i][0]);
}
vector<Config::public_key> server_publickeys;
vector<Config::secret_key> server_secretkeys;
server_publickeys.resize(nplayers);
server_secretkeys.resize(nplayers);
for (int i = 0; i < nplayers; i++) {
server_publickeys[i].resize(crypto_box_PUBLICKEYBYTES);
server_secretkeys[i].resize(crypto_box_SECRETKEYBYTES);
randombytes_buf(&server_secretkeys[i][0], server_secretkeys[i].size());
crypto_scalarmult_base(&server_publickeys[i][0], &server_secretkeys[i][0]);
}
vector<Config::public_signing_key> server_signing_publickeys;
vector<Config::secret_signing_key> server_signing_secretkeys;
server_signing_publickeys.resize(nplayers);
server_signing_secretkeys.resize(nplayers);
for (int i = 0; i < nplayers; i++) {
server_signing_publickeys[i].resize(crypto_sign_PUBLICKEYBYTES);
server_signing_secretkeys[i].resize(crypto_sign_SECRETKEYBYTES);
crypto_sign_keypair(&server_signing_publickeys[i][0], &server_signing_secretkeys[i][0]);
}
/* Write client files */
for (int i = 0; i < nclients; i++) {
stringstream filename;
filename << prep_data_prefix << "Client-Keys-C" << i;
ofstream outf(filename.str().c_str());
if (outf.fail())
throw file_error(filename.str().c_str());
// Write public key and secret key
output(client_publickeys[i],outf);
output(client_secretkeys[i],outf);
output(client_signing_publickeys[i],outf);
output(client_signing_secretkeys[i],outf);
int keycount = 2;
// Write all spdz engine public keys
for (int j = 0; j < nplayers; j++) {
output(server_publickeys[j], outf);
output(server_signing_publickeys[j], outf);
keycount++;
}
outf.close();
cout << "Wrote " << keycount << " keys to " << filename.str() << endl;
}
/* Write spdz engine files */
for (int i = 0; i < nplayers; i++) {
Config::write_player_config_file( prep_data_prefix, i
, server_publickeys[i], server_secretkeys[i]
, server_signing_publickeys[i], server_signing_secretkeys[i]
, client_publickeys, client_signing_publickeys
, server_publickeys, server_signing_publickeys);
}
}

24
azure-pipelines.yml Normal file
View File

@@ -0,0 +1,24 @@
# C/C++ with GCC
# Build your C/C++ project with GCC using make.
# Add steps that publish test results, save build artifacts, deploy, and more:
# https://docs.microsoft.com/azure/devops/pipelines/apps/c-cpp/gcc
trigger:
- master
pool:
vmImage: 'ubuntu-latest'
steps:
- script: |
bash -c "sudo apt-get install libsodium-dev libntl-dev yasm texinfo libboost-dev libboost-thread-dev python3-gmpy2 libcrypto++-dev python-networkx"
- script: |
make mpir
- script:
echo USE_NTL=1 >> CONFIG.mine
- script: |
make
- script:
Scripts/setup-ssl.sh
- script:
Scripts/test_tutorial.sh -C

View File

@@ -26,24 +26,22 @@ def main():
help="specify output file")
parser.add_option("-a", "--asm-output", dest="asmoutfile",
help="asm output file for debugging")
parser.add_option("-l", "--asm-input", action="store_true", dest="assemblymode",
help="old-style asm input")
parser.add_option("-p", "--primesize", dest="param", default=-1,
help="bit length of modulus")
parser.add_option("-g", "--galoissize", dest="galois", default=40,
help="bit length of Galois field")
parser.add_option("-d", "--debug", action="store_true", dest="debug",
help="keep track of trace for debugging")
parser.add_option("-e", "--emulate", action="store_true", dest="emulate", default=False,
help="emulate register values for debugging")
parser.add_option("-c", "--comparison", dest="comparison", default="log",
help="comparison variant: log|plain|inv|sinv")
parser.add_option("-r", "--noreorder", dest="reorder_between_opens",
action="store_false", default=True,
help="don't attempt to place instructions between start/stop opens")
parser.add_option("-O", "--optimize-hard", action="store_false",
parser.add_option("-M", "--preserve-mem-order", action="store_true",
dest="preserve_mem_order", default=False,
help="don't preserve order of memory instructions; possible loss of correctness")
help="preserve order of memory instructions; possible efficiency loss")
parser.add_option("-O", "--optimize-hard", action="store_true",
dest="optimize_hard", help="currently not in use")
parser.add_option("-u", "--noreallocate", action="store_true", dest="noreallocate",
default=False, help="don't reallocate")
parser.add_option("-m", "--max-parallel-open", dest="max_parallel_open",
@@ -79,10 +77,13 @@ def main():
parser.print_help()
return
if options.optimize_hard:
print('Note that -O/--optimize-hard currently has no effect')
def compilation():
prog = Compiler.run(args, options, param=int(options.param),
merge_opens=options.merge_opens, emulate=options.emulate,
assemblymode=options.assemblymode, debug=options.debug)
merge_opens=options.merge_opens,
debug=options.debug)
prog.write_bytes(options.outfile)
if options.asmoutfile:

View File

@@ -56,3 +56,9 @@ Compiler.ml module
-------------------------
.. automodule:: Compiler.ml
Compiler.circuit module
-----------------------
.. automodule:: Compiler.circuit
:members: