Release to add compiler instructions for external client I/O.

This commit is contained in:
Jonathan Evans
2017-09-14 10:35:01 +01:00
parent 3ee4e9e18c
commit 987a78286f
172 changed files with 4325 additions and 859 deletions

39
.gitignore vendored
View File

@@ -4,10 +4,42 @@ Player-Data/*
Prep-Data/*
logs/*
Language-Definition/main.pdf
keys/*
# Personal CONFIG file #
##############################
CONFIG.mine
config_mine.py
# Temporary files #
###################
*.bak
*.orig
*.rej
*.tmp
callgrind.out.*
# Vim
.*.swp
tags
# Eclipse #
###########
.project
.cproject
.settings
# VS Code IDE #
###############
.vscode/**
# Temporary files #
###################
*.bak
*.orig
*.rej
*.tmp
callgrind.out.*
# Compiled source #
###################
@@ -25,6 +57,8 @@ Programs/Public-Input/*
*.bc
*.sch
*.a
*.static
*.d
# Packages #
############
@@ -59,6 +93,8 @@ Programs/Public-Input/*
*.log
*.sql
*.sqlite
*.data
Persistence/*
# OS generated files #
######################
@@ -69,4 +105,5 @@ Programs/Public-Input/*
.Spotlight-V100
.Trashes
ehthumbs.db
Thumbs.db
Thumbs.db
**/*.x.dSYM/**

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#include "Auth/MAC_Check.h"

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#ifndef _MAC_Check
#define _MAC_Check

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#include "Auth/Subroutines.h"

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#ifndef _Subroutines
#define _Subroutines

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
/*
* Summer.cpp

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
/*
* Summer.h

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#include "Math/gf2n.h"

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#ifndef _fake_stuff

56
CHANGELOG.md Normal file
View File

@@ -0,0 +1,56 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enchancements are committed between releases and not documented here.
## 0.0.2 (Sep 13, 2017)
### Support sockets based external client input and output to a SPDZ MPC program.
See the [ExternalIO directory](./ExternalIO/README.md) for more details and examples.
Note that [libsodium](https://download.libsodium.org/doc/) is now a dependency on the SPDZ build.
Added compiler instructions:
* LISTEN
* ACCEPTCLIENTCONNECTION
* CONNECTIPV4
* WRITESOCKETSHARE
* WRITESOCKETINT
Removed instructions:
* OPENSOCKET
* CLOSESOCKET
Modified instructions:
* READSOCKETC
* READSOCKETS
* READSOCKETINT
* WRITESOCKETC
* WRITESOCKETS
Support secure external client input and output with new instructions:
* READCLIENTPUBLICKEY
* INITSECURESOCKET
* RESPSECURESOCKET
### Read/Write secret shares to disk to support persistence in a SPDZ MPC program.
Added compiler instructions:
* READFILESHARE
* WRITEFILESHARE
### Other instructions
Added compiler instructions:
* DIGESTC - Clear truncated hash computation
* PRINTINT - Print register value
## 0.0.1 (Sep 2, 2016)
### Initial Release
* See `README.md` and `tutorial.md`.

6
CONFIG
View File

@@ -1,4 +1,4 @@
# (C) 2016 University of Bristol. See License.txt
# (C) 2017 University of Bristol. See License.txt
ROOT = .
@@ -28,7 +28,7 @@ endif
# Default is 3, which suffices for 128-bit p
# MOD = -DMAX_MOD_SZ=3
LDLIBS = -lmpirxx -lmpir $(MY_LDLIBS) -lm -lpthread
LDLIBS = -lmpirxx -lmpir -lsodium $(MY_LDLIBS) -lm -lpthread
ifeq ($(USE_NTL),1)
LDLIBS := -lntl $(LDLIBS)
@@ -40,7 +40,7 @@ LDLIBS += -lrt
endif
CXX = g++
CFLAGS = $(MY_CFLAGS) -g -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(MEMPROTECT) $(GF2N_LONG) $(PREP_DIR) -maes -mpclmul -msse4.1 $(ARCH)
CFLAGS = $(MY_CFLAGS) -g -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(MEMPROTECT) $(GF2N_LONG) $(PREP_DIR) -maes -mpclmul -msse4.1 $(ARCH) --std=c++11 -Werror
CPPFLAGS = $(CFLAGS)
LD = g++

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
/*
* Check-Offline.cpp
@@ -62,21 +62,27 @@ void check_bits(const T& key,int N,vector<Data_Files*>& dataF,DataFieldType fiel
vector<Share<T> > Sa(N),Sb(N),Sc(N);
int n = 0;
while (!dataF[0]->eof<T>(DATA_BIT))
{
for (int i = 0; i < N; i++)
dataF[i]->get_one(field_type, DATA_BIT, Sa[i]);
check_share(Sa, a, mac, N, key);
if (!(a.is_zero() || a.is_one()))
try {
while (!dataF[0]->eof<T>(DATA_BIT))
{
cout << n << ": " << a << " neither 0 or 1" << endl;
throw bad_value();
}
n++;
}
for (int i = 0; i < N; i++)
dataF[i]->get_one(field_type, DATA_BIT, Sa[i]);
check_share(Sa, a, mac, N, key);
cout << n << " bits of type " << T::type_string() << endl;
if (!(a.is_zero() || a.is_one()))
{
cout << n << ": " << a << " neither 0 or 1" << endl;
throw bad_value();
}
n++;
}
cout << n << " bits of type " << T::type_string() << endl;
}
catch (exception& e)
{
cout << "Error with bits of type " << T::type_string() << endl;
}
}
template<class T>
@@ -85,20 +91,26 @@ void check_inputs(const T& key,int N,vector<Data_Files*>& dataF)
T a, mac, x;
vector< Share<T> > Sa(N);
for (int player = 0; player < N; player++)
{
int n = 0;
while (!dataF[0]->input_eof<T>(player))
{
for (int i = 0; i < N; i++)
dataF[i]->get_input(Sa[i], x, player);
check_share(Sa, a, mac, N, key);
if (!a.equal(x))
throw bad_value();
n++;
}
cout << n << " input masks for player " << player << " of type " << T::type_string() << endl;
}
try {
for (int player = 0; player < N; player++)
{
int n = 0;
while (!dataF[0]->input_eof<T>(player))
{
for (int i = 0; i < N; i++)
dataF[i]->get_input(Sa[i], x, player);
check_share(Sa, a, mac, N, key);
if (!a.equal(x))
throw bad_value();
n++;
}
cout << n << " input masks for player " << player << " of type " << T::type_string() << endl;
}
}
catch (exception& e)
{
cout << "Error with inputs of type " << T::type_string() << endl;
}
}
int main(int argc, const char** argv)

View File

@@ -1,4 +1,4 @@
# (C) 2016 University of Bristol. See License.txt
# (C) 2017 University of Bristol. See License.txt
import compilerLib, program, instructions, types, library, floatingpoint
import inspect

View File

@@ -1,4 +1,4 @@
# (C) 2016 University of Bristol. See License.txt
# (C) 2017 University of Bristol. See License.txt
import itertools, time
from collections import defaultdict, deque
@@ -11,20 +11,20 @@ import Compiler.graph
import Compiler.program
import heapq, itertools
import operator
import sys
class StraightlineAllocator:
"""Allocate variables in a straightline program using n registers.
It is based on the precondition that every register is only defined once."""
def __init__(self, n):
self.free = defaultdict(set)
self.alloc = {}
self.usage = Compiler.program.RegType.create_dict(lambda: 0)
self.defined = {}
self.dealloc = set()
self.n = n
def alloc_reg(self, reg, persistent_allocation):
def alloc_reg(self, reg, free):
base = reg.vectorbase
if base in self.alloc:
# already allocated
@@ -32,8 +32,8 @@ class StraightlineAllocator:
reg_type = reg.reg_type
size = base.size
if not persistent_allocation and self.free[reg_type, size]:
res = self.free[reg_type, size].pop()
if free[reg_type, size]:
res = free[reg_type, size].pop()
else:
if self.usage[reg_type] < self.n:
res = self.usage[reg_type]
@@ -48,7 +48,7 @@ class StraightlineAllocator:
else:
base.i = self.alloc[base]
def dealloc_reg(self, reg, inst):
def dealloc_reg(self, reg, inst, free):
self.dealloc.add(reg)
base = reg.vectorbase
@@ -57,14 +57,14 @@ class StraightlineAllocator:
if i not in self.dealloc:
# not all vector elements ready for deallocation
return
self.free[reg.reg_type, base.size].add(self.alloc[base])
free[reg.reg_type, base.size].add(self.alloc[base])
if inst.is_vec() and base.vector:
for i in base.vector:
self.defined[i] = inst
else:
self.defined[reg] = inst
def process(self, program, persistent_allocation=False):
def process(self, program, alloc_pool):
for k,i in enumerate(reversed(program)):
unused_regs = []
for j in i.get_def():
@@ -75,7 +75,7 @@ class StraightlineAllocator:
(j,i,format_trace(i.caller)))
else:
# unused register
self.alloc_reg(j, persistent_allocation)
self.alloc_reg(j, alloc_pool)
unused_regs.append(j)
if unused_regs and len(unused_regs) == len(i.get_def()):
# only report if all assigned registers are unused
@@ -83,9 +83,9 @@ class StraightlineAllocator:
(unused_regs,i,format_trace(i.caller))
for j in i.get_used():
self.alloc_reg(j, persistent_allocation)
self.alloc_reg(j, alloc_pool)
for j in i.get_def():
self.dealloc_reg(j, i)
self.dealloc_reg(j, i, alloc_pool)
if k % 1000000 == 0 and k > 0:
print "Allocated registers for %d instructions at" % k, time.asctime()
@@ -98,7 +98,7 @@ class StraightlineAllocator:
return self.usage
def determine_scope(block):
def determine_scope(block, options):
last_def = defaultdict(lambda: -1)
used_from_scope = set()
@@ -120,12 +120,16 @@ def determine_scope(block):
print '\tline %d: %s' % (n, instr)
print '\tinstruction trace: %s' % format_trace(instr.caller, '\t\t')
print '\tregister trace: %s' % format_trace(reg.caller, '\t\t')
if options.stop:
sys.exit(1)
def write(reg, n):
if last_def[reg] != -1:
print 'Warning: double write at register', reg
print '\tline %d: %s' % (n, instr)
print '\ttrace: %s' % format_trace(instr.caller, '\t\t')
if options.stop:
sys.exit(1)
last_def[reg] = n
for n,instr in enumerate(block.instructions):

View File

@@ -1,4 +1,4 @@
# (C) 2016 University of Bristol. See License.txt
# (C) 2017 University of Bristol. See License.txt
"""
Functions for secure comparison of GF(p) types.
@@ -68,10 +68,10 @@ def divide_by_two(res, x):
""" Faster clear division by two using a cached value of 2^-1 mod p """
from program import Program
import types
tape = Program.prog.curr_block
if tape not in inverse_of_two:
inverse_of_two[tape] = types.cint(1) / 2
mulc(res, x, inverse_of_two[tape])
block = Program.prog.curr_block
if len(inverse_of_two) == 0 or block not in inverse_of_two:
inverse_of_two[block] = types.cint(1) / 2
mulc(res, x, inverse_of_two[block])
def LTZ(s, a, k, kappa):
"""

View File

@@ -1,4 +1,4 @@
# (C) 2016 University of Bristol. See License.txt
# (C) 2017 University of Bristol. See License.txt
from Compiler.program import Program
from Compiler.config import *

View File

@@ -1,5 +1,5 @@
# (C) 2016 University of Bristol. See License.txt
# (C) 2017 University of Bristol. See License.txt
from collections import defaultdict
#INIT_REG_MAX = 655360

View File

@@ -1,4 +1,4 @@
# (C) 2016 University of Bristol. See License.txt
# (C) 2017 University of Bristol. See License.txt
from Compiler.oram import *

View File

@@ -1,5 +1,5 @@
# (C) 2016 University of Bristol. See License.txt
# (C) 2017 University of Bristol. See License.txt
class CompilerError(Exception):
"""Base class for compiler exceptions."""
pass

View File

@@ -1,4 +1,4 @@
# (C) 2016 University of Bristol. See License.txt
# (C) 2017 University of Bristol. See License.txt
from math import log, floor, ceil
from Compiler.instructions import *
@@ -404,8 +404,8 @@ def TruncPr(a, k, m, kappa=None):
return shift_two(a, m)
if kappa is None:
kappa = 40
kappa = 40
b = two_power(k-1) + a
r_prime, r_dprime = types.sint(), types.sint()
comparison.PRandM(r_dprime, r_prime, [types.sint() for i in range(m)],

View File

@@ -1,4 +1,4 @@
# (C) 2016 University of Bristol. See License.txt
# (C) 2017 University of Bristol. See License.txt
import heapq
from Compiler.exceptions import *

View File

@@ -1,4 +1,4 @@
# (C) 2016 University of Bristol. See License.txt
# (C) 2017 University of Bristol. See License.txt
import sys
import math

View File

@@ -1,4 +1,4 @@
# (C) 2016 University of Bristol. See License.txt
# (C) 2017 University of Bristol. See License.txt
""" This module is for classes of actual assembly instructions.
@@ -446,6 +446,13 @@ class legendrec(base.Instruction):
code = base.opcodes['LEGENDREC']
arg_format = ['cw','c']
@base.vectorize
class digestc(base.Instruction):
r""" Clear truncated hash computation, $c_i = H(c_j)[bytes]$. """
__slots__ = []
code = base.opcodes['DIGESTC']
arg_format = ['cw','c','int']
###
### Bitwise operations
###
@@ -915,6 +922,11 @@ class print_float_plain(base.IOInstruction):
code = base.opcodes['PRINTFLOATPLAIN']
arg_format = ['c', 'c', 'c', 'c']
class print_int(base.IOInstruction):
r""" Print only the value of register \verb|ci| to stdout. """
__slots__ = []
code = base.opcodes['PRINTINT']
arg_format = ['ci']
class print_char(base.IOInstruction):
r""" Print a single character to stdout. """
@@ -952,43 +964,156 @@ class pubinput(base.PublicFileIOInstruction):
@base.vectorize
class readsocketc(base.IOInstruction):
"""Read an int from socket and store in register"""
"""Read a variable number of clear GF(p) values from socket for a specified client id and store in registers"""
__slots__ = []
code = base.opcodes['READSOCKETC']
arg_format = ['ciw', 'int']
arg_format = tools.chain(['ci'], itertools.repeat('cw'))
def has_var_args(self):
return True
@base.vectorize
class readsockets(base.IOInstruction):
"""Read a secret share + MAC from socket and store in register"""
"""Read a variable number of secret shares + MACs from socket for a client id and store in registers"""
__slots__ = []
code = base.opcodes['READSOCKETS']
arg_format = ['sw', 'int']
arg_format = tools.chain(['ci'], itertools.repeat('sw'))
def has_var_args(self):
return True
@base.vectorize
class readsocketint(base.IOInstruction):
"""Read variable number of 32-bit int from socket for a client id and store in registers"""
__slots__ = []
code = base.opcodes['READSOCKETINT']
arg_format = tools.chain(['ci'], itertools.repeat('ciw'))
def has_var_args(self):
return True
@base.vectorize
class writesocketc(base.IOInstruction):
"""Write int from register into socket"""
"""
Write a variable number of clear GF(p) values from registers into socket
for a specified client id, message_type
"""
__slots__ = []
code = base.opcodes['WRITESOCKETC']
arg_format = ['ci', 'int']
arg_format = tools.chain(['ci', 'int'], itertools.repeat('c'))
def has_var_args(self):
return True
@base.vectorize
class writesockets(base.IOInstruction):
"""Write secret share + MAC from register into socket"""
"""
Write a variable number of secret shares + MACs from registers into a socket
for a specified client id, message_type
"""
__slots__ = []
code = base.opcodes['WRITESOCKETS']
arg_format = ['s', 'int']
arg_format = tools.chain(['ci', 'int'], itertools.repeat('s'))
class opensocket(base.IOInstruction):
"""Open a server socket connection at the given port number"""
def has_var_args(self):
return True
@base.vectorize
class writesocketshare(base.IOInstruction):
"""
Write a variable number of secret shares (without MACs) from registers into socket
for a specified client id, message_type
"""
__slots__ = []
code = base.opcodes['OPENSOCKET']
code = base.opcodes['WRITESOCKETSHARE']
arg_format = tools.chain(['ci', 'int'], itertools.repeat('s'))
def has_var_args(self):
return True
@base.vectorize
class writesocketint(base.IOInstruction):
"""
Write a variable number of 32-bit ints from registers into socket
for a specified client id, message_type
"""
__slots__ = []
code = base.opcodes['WRITESOCKETINT']
arg_format = tools.chain(['ci', 'int'], itertools.repeat('ci'))
def has_var_args(self):
return True
class listen(base.IOInstruction):
"""Open a server socket on a party specific port number and listen for client connections (non-blocking)"""
__slots__ = []
code = base.opcodes['LISTEN']
arg_format = ['int']
class closesocket(base.IOInstruction):
"""Close a server socket connection"""
class acceptclientconnection(base.IOInstruction):
"""Wait for a connection at the given port and write socket handle to register """
__slots__ = []
code = base.opcodes['CLOSESOCKET']
arg_format = []
code = base.opcodes['ACCEPTCLIENTCONNECTION']
arg_format = ['ciw', 'int']
class connectipv4(base.IOInstruction):
"""Connect to server at IPv4 address in register \verb|cj| at given port. Write socket handle to register \verb|ci|"""
__slots__ = []
code = base.opcodes['CONNECTIPV4']
arg_format = ['ciw', 'ci', 'int']
class readclientpublickey(base.IOInstruction):
"""Read a client public key as 8 32-bit ints for a specified client id"""
__slots__ = []
code = base.opcodes['READCLIENTPUBLICKEY']
arg_format = tools.chain(['ci'], itertools.repeat('ci'))
def has_var_args(self):
return True
class initsecuresocket(base.IOInstruction):
"""Read a client public key as 8 32-bit ints for a specified client id,
negotiate a shared key via STS and use it for replay resistant comms"""
__slots__ = []
code = base.opcodes['INITSECURESOCKET']
arg_format = tools.chain(['ci'], itertools.repeat('ci'))
def has_var_args(self):
return True
class respsecuresocket(base.IOInstruction):
"""Read a client public key as 8 32-bit ints for a specified client id,
negotiate a shared key via STS and use it for replay resistant comms"""
__slots__ = []
code = base.opcodes['RESPSECURESOCKET']
arg_format = tools.chain(['ci'], itertools.repeat('ci'))
def has_var_args(self):
return True
class writesharestofile(base.IOInstruction):
"""Write shares to a file"""
__slots__ = []
code = base.opcodes['WRITEFILESHARE']
arg_format = itertools.repeat('s')
def has_var_args(self):
return True
class readsharesfromfile(base.IOInstruction):
"""
Read shares from a file. Pass in start posn, return finish posn, shares.
Finish posn will return:
-2 file not found
-1 eof reached
position in file after read finished
"""
__slots__ = []
code = base.opcodes['READFILESHARE']
arg_format = tools.chain(['ci', 'ciw'], itertools.repeat('sw'))
def has_var_args(self):
return True
@base.gf2n
@base.vectorize
@@ -1173,7 +1298,7 @@ class gconvgf2n(base.Instruction):
@base.gf2n
@base.vectorize
class startopen(base.Instruction):
class startopen(base.VarArgsInstruction):
""" Start opening secret register $s_i$. """
__slots__ = []
code = base.opcodes['STARTOPEN']
@@ -1183,12 +1308,9 @@ class startopen(base.Instruction):
for arg in self.args[::-1]:
program.curr_block.open_queue.append(arg.value)
def has_var_args(self):
return True
@base.gf2n
@base.vectorize
class stopopen(base.Instruction):
class stopopen(base.VarArgsInstruction):
""" Store previous opened value in $c_i$. """
__slots__ = []
code = base.opcodes['STOPOPEN']
@@ -1198,9 +1320,6 @@ class stopopen(base.Instruction):
for arg in self.args:
arg.value = program.curr_block.open_queue.pop()
def has_var_args(self):
return True
###
### CISC-style instructions
###

View File

@@ -1,4 +1,4 @@
# (C) 2016 University of Bristol. See License.txt
# (C) 2017 University of Bristol. See License.txt
import itertools
from random import randint
@@ -78,6 +78,7 @@ opcodes = dict(
MODC = 0x36,
MODCI = 0x37,
LEGENDREC = 0x38,
DIGESTC = 0x39,
GMULBITC = 0x136,
GMULBITM = 0x137,
# Open
@@ -95,13 +96,18 @@ opcodes = dict(
# Input
INPUT = 0x60,
STARTINPUT = 0x61,
STOPINPUT = 0x62,
STOPINPUT = 0x62,
READSOCKETC = 0x63,
READSOCKETS = 0x64,
WRITESOCKETC = 0x65,
WRITESOCKETS = 0x66,
OPENSOCKET = 0x67,
CLOSESOCKET = 0x68,
READSOCKETINT = 0x69,
WRITESOCKETINT = 0x6a,
WRITESOCKETSHARE = 0x6b,
LISTEN = 0x6c,
ACCEPTCLIENTCONNECTION = 0x6d,
CONNECTIPV4 = 0x6e,
READCLIENTPUBLICKEY = 0x6f,
# Bitwise logic
ANDC = 0x70,
XORC = 0x71,
@@ -131,6 +137,7 @@ opcodes = dict(
SUBINT = 0x9C,
MULINT = 0x9D,
DIVINT = 0x9E,
PRINTINT = 0x9F,
# Conversion
CONVINT = 0xC0,
CONVMODP = 0xC1,
@@ -149,8 +156,13 @@ opcodes = dict(
PRINTCHRINT = 0xBA,
PRINTSTRINT = 0xBB,
PRINTFLOATPLAIN = 0xBC,
WRITEFILESHARE = 0xBD,
READFILESHARE = 0xBE,
GBITDEC = 0x184,
GBITCOM = 0x185,
# Secure socket
INITSECURESOCKET = 0x1BA,
RESPSECURESOCKET = 0x1BB
)
@@ -329,13 +341,11 @@ class RegType(object):
@staticmethod
def create_dict(init_value_fn):
""" Create a dictionary with all the RegTypes as keys """
return {
RegType.ClearModp : init_value_fn(),
RegType.SecretModp : init_value_fn(),
RegType.ClearGF2N : init_value_fn(),
RegType.SecretGF2N : init_value_fn(),
RegType.ClearInt : init_value_fn(),
}
res = defaultdict(init_value_fn)
# initialization for legacy
for t in RegType.Types:
res[t]
return res
class ArgFormat(object):
@classmethod
@@ -481,7 +491,7 @@ class Instruction(object):
def get_encoding(self):
enc = int_to_bytes(self.get_code())
# add the number of registers to a start/stop open instruction
# add the number of registers if instruction flagged as has var args
if self.has_var_args():
enc += int_to_bytes(len(self.args))
for arg,format in zip(self.args, self.arg_format):
@@ -508,6 +518,8 @@ class Instruction(object):
except ArgumentError as e:
raise CompilerError('Invalid argument "%s" to instruction: %s'
% (e.arg, self) + '\n' + e.msg)
except KeyError as e:
raise CompilerError('Incorrect number of arguments for instruction %s' % (self))
def get_used(self):
""" Return the set of registers that are read in this instruction. """
@@ -537,8 +549,15 @@ class Instruction(object):
def add_usage(self, req_node):
pass
# String version of instruction attempting to replicate encoded version
def __str__(self):
return self.__class__.__name__ + ' ' + self.get_pre_arg() + ', '.join(str(a) for a in self.args)
if self.has_var_args():
varargCount = str(len(self.args)) + ', '
else:
varargCount = ''
return self.__class__.__name__ + ' ' + self.get_pre_arg() + varargCount + ', '.join(str(a) for a in self.args)
def __repr__(self):
return self.__class__.__name__ + '(' + self.get_pre_arg() + ','.join(str(a) for a in self.args) + ')'
@@ -725,6 +744,11 @@ class JumpInstruction(Instruction):
return self.args[self.jump_arg]
class VarArgsInstruction(Instruction):
def has_var_args(self):
return True
class CISC(Instruction):
"""
Base class for a CISC instruction.

View File

@@ -1,4 +1,4 @@
# (C) 2016 University of Bristol. See License.txt
# (C) 2017 University of Bristol. See License.txt
from Compiler.types import cint,sint,cfix,sfix,sfloat,MPCThread,Array,MemValue,cgf2n,sgf2n,_number,_mem,_register,regint,Matrix,_types, cfloat
from Compiler.instructions import *
@@ -72,9 +72,7 @@ def print_str(s, *args):
else:
val = args[i]
if isinstance(val, program.Tape.Register):
if val.reg_type == 'ci':
cint(val).print_reg_plain()
elif val.is_clear:
if val.is_clear:
val.print_reg_plain()
else:
raise CompilerError('Cannot print secret value:', args[i])
@@ -355,7 +353,7 @@ class FunctionBlock(Function):
parent_node = get_tape().req_node
get_tape().open_scope(lambda x: x[0], None, 'begin-' + self.name)
block = get_tape().active_basicblock
block.persistent_allocation = True
block.alloc_pool = defaultdict(set)
del parent_node.children[-1]
self.node = get_tape().req_node
print 'Compiling function', self.name

View File

@@ -1,4 +1,4 @@
# (C) 2016 University of Bristol. See License.txt
# (C) 2017 University of Bristol. See License.txt
import random
import math
@@ -15,8 +15,6 @@ from Compiler import floatingpoint,comparison,permutation
from Compiler.util import *
sys.setrecursionlimit(1000000)
print_access = False
sint_bit_length = 6
max_demux_bits = 3
@@ -40,12 +38,6 @@ def maybe_stop_timer(n):
if detailed_timing:
stop_timer(n)
def reveal(a):
try:
return a.reveal()
except AttributeError:
return a
class Block(object):
def __init__(self, value, lengths):
self.value = self.value_type.hard_conv(value)
@@ -53,8 +45,7 @@ class Block(object):
def get_slice(self):
res = []
for length,start in zip(self.lengths, series(self.lengths)):
res.append(sum(b << i for i,b in \
enumerate(self.bits[start:start+length])))
res.append(util.bit_compose((self.bits[start:start+length])))
return res
def __repr__(self):
return '<' + str(self.value) + '>'
@@ -150,11 +141,17 @@ class gf2nBlock(Block):
self.value = self.lower + value * self.adjust + upper
return self
block_types = { sint: intBlock,
sgf2n: gf2nBlock,
}
def get_block(x, y, *args):
if isinstance(x, sgf2n) or isinstance(y, sgf2n):
return gf2nBlock(x, y, *args)
else:
return intBlock(x, y, *args)
for t in block_types:
if isinstance(x, t):
return block_types[t](x, y, *args)
elif isinstance(y, t):
return block_types[t](x, y, *args)
raise CompilerError('appropiate block type not found')
def get_bit(x, index, bit_length):
if isinstance(x, sgf2n):
@@ -242,14 +239,14 @@ class Value(object):
return (1 - self.empty) * (other == self.value)
return (1 - self.empty) * self.value.equal(other, length)
def reveal(self):
return Value(self.value.reveal(), self.empty.reveal())
return Value(reveal(self.value), reveal(self.empty))
def output(self):
@if_e(self.empty)
def f():
print_str('<>')
@else_
def f():
print_str('<%s>', self.value)
# @if_e(self.empty)
# def f():
# print_str('<>')
# @else_
# def f():
print_str('<%s:%s>', self.empty, self.value)
def __index__(self):
return int(self.value)
def __repr__(self):
@@ -344,12 +341,13 @@ class Entry(object):
def reveal(self):
return Entry(x.reveal() for x in self)
def output(self):
@if_e(self.is_empty)
def f():
print_str('{empty=%s}', self.is_empty)
@else_
def f():
print_str('{%s: %s}', self.v, self.x)
# @if_e(self.is_empty)
# def f():
# print_str('{empty=%s}', self.is_empty)
# @else_
# def f():
# print_str('{%s: %s}', self.v, self.x)\
print_str('{%s: %s,empty=%s}', self.v, self.x, self.is_empty)
class RefRAM(object):
""" RAM reference. """
@@ -362,8 +360,8 @@ class RefRAM(object):
crash()
self.size = oram.bucket_size
self.entry_type = oram.entry_type
self.l = [Array(self.size, t, array.address + \
index * oram.bucket_size) \
self.l = [t.dynamic_array(self.size, t, array.address + \
index * oram.bucket_size) \
for t,array in zip(self.entry_type,oram.ram.l)]
self.index = index
def init_mem(self, empty_entry):
@@ -410,7 +408,7 @@ class RefRAM(object):
Program.prog.curr_tape.start_new_basicblock()
return res
def output(self):
self.reveal().print_reg()
print_ln('%s', [x.reveal() for x in self])
def print_reg(self):
print_ln('listing of RAM at index %s', self.index)
Program.prog.curr_tape.start_new_basicblock()
@@ -428,7 +426,7 @@ class RAM(RefRAM):
#print_reg(cint(0), 'r in')
self.size = size
self.entry_type = entry_type
self.l = [Array(self.size, t) for t in entry_type]
self.l = [t.dynamic_array(self.size, t) for t in entry_type]
self.index = index
class AbstractORAM(object):
@@ -902,7 +900,7 @@ class List(object):
def __init__(self, size, value_type, value_length=1, init_rounds=None):
self.value_type = value_type
self.value_length = value_length
self.l = [Array(size, value_type) \
self.l = [value_type.dynamic_array(size, value_type) \
for i in range(value_length)]
for l in self.l:
l.assign_all(0)
@@ -1322,8 +1320,10 @@ def get_value_size(value_type):
""" Return element size. """
if value_type == sgf2n:
return Program.prog.galois_length
else:
elif value_type == sint:
return 127 - Program.prog.security
else:
return value_type.max_length
def get_parallel(index_size, value_type, value_length):
""" Returning the number of parallel readings feasible, based on
@@ -1410,7 +1410,7 @@ class PackedIndexStructure(object):
else:
self.l[i] = [0] * self.elements_per_block
time()
print_ln('packed ORAM init %s/%s', cint(i), real_init_rounds)
print_ln('packed ORAM init %s/%s', i, real_init_rounds)
print 'index initialized, size', size
def translate_index(self, index):
""" Bit slicing *index* according parameters. Output is tuple
@@ -1425,18 +1425,17 @@ class PackedIndexStructure(object):
return 0, b, c
else:
return (index - rem) / self.entries_per_block, b, c
elif self.value_type == sgf2n:
else:
index_bits = bit_decompose(index, log2(self.size))
l1 = self.log_entries_per_element
l2 = self.log_entries_per_block
c = sum(bit << i for i,bit in enumerate(index_bits[:l1]))
b = sum(bit << i for i,bit in enumerate(index_bits[l1:l2]))
c = bit_compose(index_bits[:l1])
b = bit_compose(index_bits[l1:l2])
if self.small:
return 0, b, c
else:
a = sum(bit << i for i,bit in enumerate(index_bits[l2:]))
a = bit_compose(index_bits[l2:])
return a, b, c
else:
raise CompilerError('Cannot process indices of type', self.value_type)
class Slicer(object):
def __init__(self, pack, index):
@@ -1624,11 +1623,11 @@ class OptimalPackedORAMWithEmpty(PackedORAMWithEmpty):
def test_oram(oram_type, N, value_type=sint, iterations=100):
oram = oram_type(N, value_type=value_type, entry_size=32, init_rounds=0)
print 'initialized'
print_reg(cint(0), 'init')
print_ln('initialized')
stop_timer()
# synchronize
Program.prog.curr_tape.start_new_basicblock(name='sync')
sint(0).reveal()
value_type(0).reveal()
Program.prog.curr_tape.start_new_basicblock(name='sync')
start_timer()
#oram[value_type(0)] = -1

View File

@@ -1,4 +1,4 @@
# (C) 2016 University of Bristol. See License.txt
# (C) 2017 University of Bristol. See License.txt
if '_Array' not in dir():
from oram import *
@@ -76,7 +76,10 @@ def XOR(a, b):
elif isinstance(a, sgf2n) or isinstance(b, sgf2n):
return a + b
else:
return a + b - 2*a*b
try:
return a ^ b
except TypeError:
return a + b - 2*a*b
def pow2_eq(a, i, bit_length=40):
""" Test for equality with 2**i, when a is a power of 2 (gf2n only)"""

View File

@@ -1,4 +1,4 @@
# (C) 2016 University of Bristol. See License.txt
# (C) 2017 University of Bristol. See License.txt
from random import randint
import math

View File

@@ -1,4 +1,4 @@
# (C) 2016 University of Bristol. See License.txt
# (C) 2017 University of Bristol. See License.txt
from Compiler.config import *
from Compiler.exceptions import *
@@ -65,6 +65,7 @@ class Program(object):
self.n_threads = 1
self.free_threads = set()
self.public_input_file = open(self.programs_dir + '/Public-Input/%s' % self.name, 'w')
self.types = {}
Program.prog = self
self.reset_values()
@@ -230,7 +231,7 @@ class Program(object):
# runtime doesn't support 'new-style' parallelism yet
old_style = True
nonempty_tapes = [t for t in self.tapes if not t.is_empty()]
nonempty_tapes = [t for t in self.tapes]
sch_filename = self.programs_dir + '/Schedules/%s.sch' % self.name
sch_file = open(sch_filename, 'w')
@@ -327,12 +328,15 @@ class Program(object):
""" The basic block that is currently being created. """
return self.curr_tape.active_basicblock
def malloc(self, size, mem_type):
def malloc(self, size, mem_type, reg_type=None):
""" Allocate memory from the top """
if size == 0:
return
if isinstance(mem_type, type):
self.types[mem_type.reg_type] = mem_type
mem_type = mem_type.reg_type
elif reg_type is not None:
self.types[mem_type] = reg_type
key = size, mem_type
if self.free_mem_blocks[key]:
addr = self.free_mem_blocks[key].pop()
@@ -346,7 +350,8 @@ class Program(object):
def free(self, addr, mem_type):
""" Free memory """
if self.curr_block.persistent_allocation:
if self.curr_block.alloc_pool \
is not self.curr_tape.basicblocks[0].alloc_pool:
raise CompilerError('Cannot free memory within function block')
size = self.allocated_mem_blocks.pop((addr,mem_type))
self.free_mem_blocks[size,mem_type].add(addr)
@@ -354,10 +359,15 @@ class Program(object):
def finalize_memory(self):
import library
self.curr_tape.start_new_basicblock(None, 'memory-usage')
# reset register counter to 0
self.curr_tape.init_registers()
for mem_type,size in self.allocated_mem.items():
if size:
#print "Memory of type '%s' of size %d" % (mem_type, size)
library.load_mem(size - 1, mem_type)
if mem_type in self.types:
self.types[mem_type].load_mem(size - 1, mem_type)
else:
library.load_mem(size - 1, mem_type)
def public_input(self, x):
self.public_input_file.write('%s\n' % str(x))
@@ -407,9 +417,9 @@ class Tape:
self.children = []
if scope is not None:
scope.children.append(self)
self.persistent_allocation = scope.persistent_allocation
self.alloc_pool = scope.alloc_pool
else:
self.persistent_allocation = False
self.alloc_pool = defaultdict(set)
def new_reg(self, reg_type, size=None):
return self.parent.new_reg(reg_type, size=size)
@@ -511,7 +521,7 @@ class Tape:
print 'Processing tape', self.name, 'with %d blocks' % len(self.basicblocks)
for block in self.basicblocks:
al.determine_scope(block)
al.determine_scope(block, options)
# merge open instructions
# need to do this if there are several blocks
@@ -563,15 +573,15 @@ class Tape:
# allocate registers
reg_counts = self.count_regs()
if filter(lambda n: n > REG_MAX, reg_counts) and not options.noreallocate:
print 'Tape register usage:'
if not options.noreallocate:
print 'Tape register usage:', reg_counts
print 'modp: %d clear, %d secret' % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp])
print 'GF2N: %d clear, %d secret' % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N])
print 'Re-allocating...'
allocator = al.StraightlineAllocator(REG_MAX)
def alloc_loop(block):
for reg in block.used_from_scope:
allocator.alloc_reg(reg, block.persistent_allocation)
allocator.alloc_reg(reg, block.alloc_pool)
for child in block.children:
if child.instructions:
alloc_loop(child)
@@ -584,7 +594,7 @@ class Tape:
if isinstance(jump, (int,long)) and jump < 0 and \
block.exit_block.scope is not None:
alloc_loop(block.exit_block.scope)
allocator.process(block.instructions, block.persistent_allocation)
allocator.process(block.instructions, block.alloc_pool)
# offline data requirements
print 'Compile offline data requirements...'
@@ -614,10 +624,11 @@ class Tape:
if not self.is_empty():
# bit length requirement
self.basicblocks[-1].instructions.append(
Compiler.instructions.reqbl(self.req_bit_length['p'], add_to_prog=False))
self.basicblocks[-1].instructions.append(
Compiler.instructions.greqbl(self.req_bit_length['2'], add_to_prog=False))
for x in ('p', '2'):
if self.req_bit_length['p']:
self.basicblocks[-1].instructions.append(
Compiler.instructions.reqbl(self.req_bit_length['p'],
add_to_prog=False))
print 'Tape requires prime bit length', self.req_bit_length['p']
print 'Tape requires galois bit length', self.req_bit_length['2']

View File

@@ -1,4 +1,4 @@
# (C) 2016 University of Bristol. See License.txt
# (C) 2017 University of Bristol. See License.txt
import itertools

View File

@@ -1,4 +1,4 @@
# (C) 2016 University of Bristol. See License.txt
# (C) 2017 University of Bristol. See License.txt
from Compiler.program import Tape
from Compiler.exceptions import *
@@ -11,6 +11,20 @@ import util
import operator
class ClientMessageType:
""" Enum to define type of message sent to external client. Each may be array of length n."""
# No client message type to be sent, for backwards compatibility - virtual machine relies on this value
NoType = 0
# 3 x sint x n
TripleShares = 1
# 1 x cint x n
ClearModpInt = 2
# 1 x regint x n
Int32 = 3
# 1 x cint (fixed point left shifted by precision) x n
ClearModpFix = 4
class MPCThread(object):
def __init__(self, target, name, args = [], runtime_arg = None):
""" Create a thread from a callable object. """
@@ -97,6 +111,10 @@ def read_mem_value(operation):
class _number(object):
@staticmethod
def bit_compose(bits):
return sum(b << i for i,b in enumerate(bits))
def square(self):
return self * self
@@ -152,7 +170,6 @@ class _gf2n(object):
else:
return tuple(t.conv(r) for r in res)
class _register(Tape.Register, _number):
MemValue = staticmethod(lambda value: MemValue(value))
@@ -340,6 +357,9 @@ class _clear(_register):
__rxor__ = __xor__
__ror__ = __or__
def reveal(self):
return self
class cint(_clear, _int):
" Clear mod p integer type. """
@@ -348,7 +368,25 @@ class cint(_clear, _int):
reg_type = 'c'
@vectorized_classmethod
def load_mem(cls, address):
def read_from_socket(cls, client_id, n=1):
res = [cls() for i in range(n)]
readsocketc(client_id, *res)
if n == 1:
return res[0]
else:
return res
@vectorize
def write_to_socket(self, client_id, message_type=ClientMessageType.NoType):
writesocketc(client_id, message_type, self)
@vectorized_classmethod
def write_to_socket(self, client_id, values, message_type=ClientMessageType.NoType):
""" Send a list of modp integers to socket """
writesocketc(client_id, message_type, *values)
@vectorized_classmethod
def load_mem(cls, address, mem_type=None):
return cls._load_mem(address, ldmc, ldmci)
def store_in_mem(self, address):
@@ -464,6 +502,13 @@ class cint(_clear, _int):
legendrec(res, self)
return res
def digest(self, num_bytes):
res = cint()
digestc(res, self, num_bytes)
return res
class cgf2n(_clear, _gf2n):
__slots__ = []
@@ -478,7 +523,7 @@ class cgf2n(_clear, _gf2n):
return res
@vectorized_classmethod
def load_mem(cls, address):
def load_mem(cls, address, mem_type=None):
return cls._load_mem(address, gldmc, gldmci)
def store_in_mem(self, address):
@@ -560,7 +605,7 @@ class regint(_register, _int):
protectmemint(regint(start), regint(end))
@vectorized_classmethod
def load_mem(cls, address):
def load_mem(cls, address, mem_type=None):
return cls._load_mem(address, ldmint, ldminti)
def store_in_mem(self, address):
@@ -581,14 +626,40 @@ class regint(_register, _int):
return res
@vectorized_classmethod
def read_from_socket(cls):
res = cls()
readsocketc(res,0)
def read_from_socket(cls, client_id, n=1):
""" Receive n register values from socket """
res = [cls() for i in range(n)]
readsocketint(client_id, *res)
if n == 1:
return res[0]
else:
return res
@vectorized_classmethod
def read_client_public_key(cls, client_id):
""" Receive 8 register values from socket containing client public key."""
res = [cls() for i in range(8)]
readclientpublickey(client_id, *res)
return res
@vectorized_classmethod
def init_secure_socket(cls, client_id, w1, w2, w3, w4, w5, w6, w7, w8):
""" Use 8 register values containing client public key."""
initsecuresocket(client_id, w1, w2, w3, w4, w5, w6, w7, w8)
@vectorized_classmethod
def resp_secure_socket(cls, client_id, w1, w2, w3, w4, w5, w6, w7, w8):
""" Receive 8 register values from socket containing client public key."""
respsecuresocket(client_id, w1, w2, w3, w4, w5, w6, w7, w8)
@vectorize
def write_to_socket(self):
writesocketc(self,0)
def write_to_socket(self, client_id, message_type=ClientMessageType.NoType):
writesocketint(client_id, message_type, self)
@vectorized_classmethod
def write_to_socket(self, client_id, values, message_type=ClientMessageType.NoType):
""" Send a list of integers to socket """
writesocketint(client_id, message_type, *values)
@vectorize_init
def __init__(self, val=None, size=None):
@@ -614,7 +685,11 @@ class regint(_register, _int):
elif isinstance(val, regint):
addint(self, val, regint(0))
else:
raise CompilerError("Cannot convert '%s' to integer" % type(val))
try:
val.to_regint(self)
except AttributeError:
raise CompilerError("Cannot convert '%s' to integer" % \
type(val))
@vectorize
@read_mem_value
@@ -652,10 +727,10 @@ class regint(_register, _int):
return self.int_op(other, divint, True)
def __mod__(self, other):
return cint(self) % other
return self - (self / other) * other
def __rmod__(self, other):
return other % cint(self)
return regint(other) % self
def __rpow__(self, other):
return other**cint(self)
@@ -679,10 +754,16 @@ class regint(_register, _int):
return 1 - (self < other)
def __lshift__(self, other):
return regint(cint(self) << other)
if isinstance(other, (int, long)):
return self * 2**other
else:
return regint(cint(self) << other)
def __rshift__(self, other):
return regint(cint(self) >> other)
if isinstance(other, (int, long)):
return self / 2**other
else:
return regint(cint(self) >> other)
def __rlshift__(self, other):
return regint(other << cint(self))
@@ -706,6 +787,31 @@ class regint(_register, _int):
def mod2m(self, *args, **kwargs):
return cint(self).mod2m(*args, **kwargs)
def bit_decompose(self, bit_length=None):
res = []
x = self
two = regint(2)
for i in range(bit_length or program.bit_length):
y = x / two
res.append(x - two * y)
x = y
return res
@staticmethod
def bit_compose(bits):
two = regint(2)
res = 0
for bit in reversed(bits):
res *= two
res += bit
return res
def reveal(self):
return self
def print_reg_plain(self):
print_int(self)
class _secret(_register):
__slots__ = []
@@ -875,18 +981,54 @@ class sint(_secret, _int):
stopinput(player, res)
return res
@classmethod
def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType):
""" Securely obtain shares of n values input by a client """
# send shares of a triple to client
triples = list(itertools.chain(*(sint.get_random_triple() for i in range(n))))
sint.write_shares_to_socket(client_id, triples, message_type)
received = cint.read_from_socket(client_id, n)
y = [0] * n
for i in range(n):
y[i] = received[i] - triples[i * 3]
return y
@vectorized_classmethod
def read_from_socket(cls):
res = cls()
readsockets(res,0)
return res
def read_from_socket(cls, client_id, n=1):
""" Receive n shares and MAC shares from socket """
res = [cls() for i in range(n)]
readsockets(client_id, *res)
if n == 1:
return res[0]
else:
return res
@vectorize
def write_to_socket(self):
writesockets(self,0)
def write_to_socket(self, client_id, message_type=ClientMessageType.NoType):
""" Send share and MAC share to socket """
writesockets(client_id, message_type, self)
@vectorized_classmethod
def load_mem(cls, address):
def write_to_socket(self, client_id, values, message_type=ClientMessageType.NoType):
""" Send a list of shares and MAC shares to socket """
writesockets(client_id, message_type, *values)
@vectorize
def write_share_to_socket(self, client_id, message_type=ClientMessageType.NoType):
""" Send only share to socket """
writesocketshare(client_id, message_type, self)
@vectorized_classmethod
def write_shares_to_socket(cls, client_id, values, message_type=ClientMessageType.NoType, include_macs=False):
""" Send shares of a list of values to a specified client socket """
if include_macs:
writesockets(client_id, message_type, *values)
else:
writesocketshare(client_id, message_type, *values)
@vectorized_classmethod
def load_mem(cls, address, mem_type=None):
return cls._load_mem(address, ldms, ldmsi)
def store_in_mem(self, address):
@@ -1035,7 +1177,7 @@ class sgf2n(_secret, _gf2n):
return super(sgf2n, self).mul(other)
@vectorized_classmethod
def load_mem(cls, address):
def load_mem(cls, address, mem_type=None):
return cls._load_mem(address, gldms, gldmsi)
def store_in_mem(self, address):
@@ -1100,9 +1242,10 @@ class sgf2n(_secret, _gf2n):
bit_length = bit_length or program.galois_length
random_bits = [self.get_random_bit() \
for i in range(0, bit_length, step)]
one = cgf2n(1)
masked = sum([b * (one << (i * step)) for i,b in enumerate(random_bits)], self).reveal()
masked_bits = masked.bit_decompose(bit_length)
masked_bits = masked.bit_decompose(bit_length,step=step)
return [m + r for m,r in zip(masked_bits, random_bits)]
@vectorize
@@ -1456,6 +1599,29 @@ class cfix(_number):
res.append(cint.load_mem(address))
return cfix(*res)
@vectorized_classmethod
def read_from_socket(cls, client_id, n=1):
""" Read one or more cfix values from a socket.
Sender will have already bit shifted and sent as cints."""
cint_input = cint.read_from_socket(client_id, n)
if n == 1:
return cfix(cint_inputs)
else:
return map(cfix, cint_inputs)
@vectorize
def write_to_socket(self, client_id, message_type=ClientMessageType.NoType):
""" Send cfix to socket. Value is sent as bit shifted cint. """
writesocketc(client_id, message_type, cint(self.v))
@vectorized_classmethod
def write_to_socket(self, client_id, values, message_type=ClientMessageType.NoType):
""" Send a list of cfix values to socket. Values are sent as bit shifted cints. """
def cfix_to_cint(fix_val):
return cint(fix_val.v)
cint_values = map(cfix_to_cint, values)
writesocketc(client_id, message_type, *cint_values)
@vectorize_init
def __init__(self, v=None, size=None):
f = self.f
@@ -1613,6 +1779,13 @@ class sfix(_number):
else:
cls.k = k
@classmethod
def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType):
""" Securely obtain shares of n values input by a client.
Assumes client has already run bit shift to convert fixed point to integer."""
sint_inputs = sint.receive_from_client(n, client_id, ClientMessageType.TripleShares)
return map(sfix, sint_inputs)
@vectorized_classmethod
def load_mem(cls, address, mem_type=None):
res = []
@@ -1787,7 +1960,7 @@ class sfloat(_number):
error = 0
@vectorized_classmethod
def load_mem(cls, address):
def load_mem(cls, address, mem_type=None):
res = []
for i in range(4):
res.append(sint.load_mem(address + i * get_global_vector_size()))
@@ -2075,10 +2248,13 @@ class Array(object):
if value_type in _types:
value_type = _types[value_type]
self.address = address
if address is None:
self.address = program.malloc(length, value_type.reg_type)
self.length = length
self.value_type = value_type
if address is None:
self.address = self._malloc()
def _malloc(self):
return program.malloc(self.length, self.value_type)
def delete(self):
if program:
@@ -2106,7 +2282,7 @@ class Array(object):
def f(i):
res[i] = self[start+i*step]
return res
return self.value_type.load_mem(self.get_address(index))
return self._load(self.get_address(index))
def __setitem__(self, index, value):
if isinstance(index, slice):
@@ -2117,7 +2293,13 @@ class Array(object):
self[i] = value[source_index]
source_index.iadd(1)
return
self.value_type.conv(value).store_in_mem(self.get_address(index))
self._store(self.value_type.conv(value), self.get_address(index))
def _load(self, address):
return self.value_type.load_mem(address)
def _store(self, value, address):
value.store_in_mem(address)
def __len__(self):
return self.length
@@ -2149,6 +2331,8 @@ class Array(object):
self[i] = mem_value
return self
sint.dynamic_array = Array
sgf2n.dynamic_array = Array
class Matrix(object):
def __init__(self, rows, columns, value_type, address=None):
@@ -2309,7 +2493,7 @@ class MemValue(_mem):
else:
self.value_type = type(value)
self.reg_type = self.value_type.reg_type
self.address = program.malloc(1, self.reg_type)
self.address = program.malloc(1, self.value_type)
self.deleted = False
self.write(value)
@@ -2339,7 +2523,7 @@ class MemValue(_mem):
if not isinstance(self.register, self.value_type):
raise CompilerError('Mismatch in register type, cannot write \
%s to %s' % (type(self.register), self.value_type))
library.store_in_mem(self.register, self.address)
self.register.store_in_mem(self.address)
self.last_write_block = program.curr_block
return self

View File

@@ -1,4 +1,4 @@
# (C) 2016 University of Bristol. See License.txt
# (C) 2017 University of Bristol. See License.txt
import math
import operator
@@ -54,7 +54,14 @@ def bit_decompose(a, bits):
return a.bit_decompose(bits)
def bit_compose(bits):
return sum(b << i for i,b in enumerate(bits))
bits = list(bits)
try:
if bits:
return bits[0].bit_compose(bits)
else:
return 0
except AttributeError:
return sum(b << i for i,b in enumerate(bits))
def series(a):
sum = 0
@@ -103,3 +110,25 @@ OR = or_op
def pow2(bits):
powers = [b.if_else(2**2**i, 1) for i,b in enumerate(bits)]
return tree_reduce(operator.mul, powers)
def irepeat(l, n):
return reduce(operator.add, ([i] * n for i in l))
def int_len(x):
return len(bin(x)) - 2
def reveal(x):
if isinstance(x, str):
return x
try:
return x.reveal()
except AttributeError:
pass
try:
return [reveal(y) for y in x]
except TypeError:
pass
return x
def is_constant(x):
return isinstance(x, (int, long, bool))

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#ifndef _Exceptions
#define _Exceptions
@@ -121,9 +121,39 @@ class file_error: public exception
}
};
class end_of_file: public exception
{ virtual const char* what() const throw()
{ return "End of file reached"; }
{ string filename, context, ans;
public:
end_of_file(string pfilename="no filename", string pcontext="") :
filename(pfilename), context(pcontext)
{
ans="End of file when reading ";
ans+=filename;
ans+=" ";
ans+=context;
}
~end_of_file()throw() { }
virtual const char* what() const throw()
{
return ans.c_str();
}
};
class file_missing: public exception
{ string filename, context, ans;
public:
file_missing(string pfilename="no filename", string pcontext="") :
filename(pfilename), context(pcontext)
{
ans="File missing : ";
ans+=filename;
ans+=" ";
ans+=context;
}
~file_missing()throw() { }
virtual const char* what() const throw()
{
return ans.c_str();
}
};
class Processor_Error: public exception
{ string msg;
public:
@@ -137,6 +167,11 @@ class Processor_Error: public exception
return msg.c_str();
}
};
class Invalid_Instruction : public Processor_Error
{
public:
Invalid_Instruction(string m) : Processor_Error(m) {}
};
class max_mod_sz_too_small : public exception
{ int len;
public:

105
ExternalIO/README.md Normal file
View File

@@ -0,0 +1,105 @@
(C) 2017 University of Bristol. See License.txt.
The ExternalIO directory contains examples of managing I/O between external client processes and SPDZ parties running SPDZ engines. These instructions assume that SPDZ has been built as per the [project readme](../README.md).
## I/O MPC Instructions
### Connection Setup
**listen**(*int port_num*)
Setup a socket server to listen for client connections. Runs in own thread so once created clients will be able to connect in the background.
*port_num* - the port number to listen on.
**acceptclientconnection**(*regint client_socket_id*, *int port_num*)
Picks the first available client socket connection. Blocks if none available.
*client_socket_id* - an identifier used to refer to the client socket.
*port_num* - the port number identifies the socket server to accept connections on.
### Data Exchange
Only the sint methods are documented here, equivalent methods are available for the other data types **cfix**, **cint** and **regint**. See implementation details in [types.py](../Compiler/types.py).
*[sint inputs]* **sint.read_from_socket**(*regint client_socket_id*, *int number_of_inputs*)
Read a share of an input from a client, blocking on the client send.
*client_socket_id* - an identifier used to refer to the client socket.
*number_of_inputs* - the number of inputs expected
*[inputs]* - returned list of shares of private input.
**sint.write_to_socket**(*regint client_socket_id*, *[sint values]*, *int message_type*)
Write shares of values including macs to an external client.
*client_socket_id* - an identifier used to refer to the client socket.
*[values]* - list of shares of values to send to client.
*message_type* - optional integer which will be sent in first 4 bytes of message, to indicate message type to client.
See also sint.write_shares_to_socket where macs can be explicitly included or excluded from the message.
*[sint inputs]* **sint.receive_from_client**(*int number_of_inputs*, *regint client_socket_id*, *int message_type*)
Receive shares of private inputs from a client, blocking on client send. This is an abstraction which first sends shares of random values to the client and then receives masked input from the client, using the input protocol introduced in [Confidential Benchmarking based on Multiparty Computation. Damgard et al.](http://eprint.iacr.org/2015/1006.pdf)
*number_of_inputs* - the number of inputs expected
*client_socket_id* - an identifier used to refer to the client socket.
*message_type* - optional integer which will be sent in first 4 bytes of message, to indicate message type to client.
*[inputs]* - returned list of shares of private input.
## Securing communications
Two cryptographic protocols have been implemented for use in particular applications and are included here for completeness:
1. Communication security using a Station to Station key agreement and libsodium Secret Box using a nonce counter for message ordering.
2. Authenticated Diffie-Hellman without message ordering.
Please note these are **NOT** required to allow external client I/O. Your mileage may vary, for example in a web setting TLS may be sufficient to secure communications between processes.
[client-setup.cpp](../client-setup.cpp) is a utility which is run to generate the key material for both the external clients and SPDZ parties for both protocols.
#### MPC instructions
**regint.init_secure_socket**(*regint client_socket_id*, *[regint] public_signing_key*)
STS protocol initiator. Read a client public key for a specified client connection and negotiate a shared key via STS. All subsequent write_socket / read_socket instructions are encrypted / decrypted for replay resistant commsec.
*client_socket_id* - an identifier used to refer to the client socket.
*public_signing_key* - client public key supplied as list of 8 32-bit ints.
**regint.resp_secure_socket**(*regint client_socket_id*, *[regint] public_signing_key*)
STS protocol responder. Read a client public key for a specified client connection and negotiate a shared key via STS. All subsequent write_socket / read_socket instructions are encrypted / decrypted for replay resistant commsec.
*client_socket_id* - an identifier used to refer to the client socket.
*public_signing_key* - client public key supplied as list of 8 32-bit ints.
*[regint public_key]* **regint.read_client_public_key**(*regint client_socket_id*)
Instruction to read the client public key and run setup for the authenticated Diffie-Hellman encryption. All subsequent write_socket instructions are encrypted. Only the sint.read_from_socket instruction is encrypted.
*client_socket_id* - an identifier used to refer to the client socket.
*public_key* - client public key made available to mpc programs as list of 8 32-bit ints.
## Working Examples
See [bankers-bonus-client.cpp](./bankers-bonus-client.cpp) which acts as a client to [bankers_bonus.mpc](../Programs/Source/bankers_bonus.mpc) and demonstrates sending input and receiving output with no communications security.
See [bankers-bonus-commsec-client.cpp](./bankers-bonus-commsec-client.cpp) which acts as a client to [bankers_bonus_commsec.mpc](../Programs/Source/bankers_bonus_commsec.mpc) which runs the same algorithm but includes both the available crypto protocols.
More instructions on how to run these are provided in the *-client files.

View File

@@ -0,0 +1,198 @@
/*
* (C) 2017 University of Bristol. See License.txt
*
* Demonstrate external client inputing and receiving outputs from a SPDZ process,
* following the protocol described in https://eprint.iacr.org/2015/1006.pdf.
*
* Provides a client to bankers_bonus.mpc program to calculate which banker pays for lunch based on
* the private value annual bonus. Up to 8 clients can connect to the SPDZ engines running
* the bankers_bonus.mpc program.
*
* Each connecting client:
* - sends a unique id to identify the client
* - sends an integer input (bonus value to compare)
* - sends an integer (0 meaining more players will join this round or 1 meaning stop the round and calc the result).
*
* The result is returned authenticated with a share of a random value:
* - share of winning unique id [y]
* - share of random value [r]
* - share of winning unique id * random value [w]
* winning unique id is valid if ∑ [y] * ∑ [r] = ∑ [w]
*
* No communications security is used.
*
* To run with 2 parties / SPDZ engines:
* ./Scripts/setup-online.sh to create triple shares for each party (spdz engine).
* ./compile.py bankers_bonus
* ./Scripts/run-online bankers_bonus to run the engines.
*
* ./bankers-bonus-client.x 123 2 100 0
* ./bankers-bonus-client.x 456 2 200 0
* ./bankers-bonus-client.x 789 2 50 1
*
* Expect winner to be second client with id 456.
*/
#include "Math/gfp.h"
#include "Math/gf2n.h"
#include "Networking/sockets.h"
#include "Tools/int.h"
#include "Math/Setup.h"
#include "Auth/fake-stuff.h"
#include <sodium.h>
#include <iostream>
#include <sstream>
#include <fstream>
// Send the private inputs masked with a random value.
// Receive shares of a preprocessed triple from each SPDZ engine, combine and check the triples are valid.
// Add the private input value to triple[0] and send to each spdz engine.
void send_private_inputs(vector<gfp>& values, vector<int>& sockets, int nparties)
{
int num_inputs = values.size();
octetStream os;
vector< vector<gfp> > triples(num_inputs, vector<gfp>(3));
vector<gfp> triple_shares(3);
// Receive num_inputs triples from SPDZ
for (int j = 0; j < nparties; j++)
{
os.reset_write_head();
os.Receive(sockets[j]);
for (int j = 0; j < num_inputs; j++)
{
for (int k = 0; k < 3; k++)
{
triple_shares[k].unpack(os);
triples[j][k] += triple_shares[k];
}
}
}
// Check triple relations (is a party cheating?)
for (int i = 0; i < num_inputs; i++)
{
if (triples[i][0] * triples[i][1] != triples[i][2])
{
cerr << "Incorrect triple at " << i << ", aborting\n";
exit(1);
}
}
// Send inputs + triple[0], so SPDZ can compute shares of each value
os.reset_write_head();
for (int i = 0; i < num_inputs; i++)
{
gfp y = values[i] + triples[i][0];
y.pack(os);
}
for (int j = 0; j < nparties; j++)
os.Send(sockets[j]);
}
// Assumes that Scripts/setup-online.sh has been run to compute prime
void initialise_fields(const string& dir_prefix)
{
int lg2;
bigint p;
string filename = dir_prefix + "Params-Data";
cout << "loading params from: " << filename << endl;
ifstream inpf(filename.c_str());
if (inpf.fail()) { throw file_error(filename.c_str()); }
inpf >> p;
inpf >> lg2;
inpf.close();
gfp::init_field(p);
gf2n::init_field(lg2);
}
// Receive shares of the result and sum together.
// Also receive authenticating values.
gfp receive_result(vector<int>& sockets, int nparties)
{
vector<gfp> output_values(3);
octetStream os;
for (int i = 0; i < nparties; i++)
{
os.reset_write_head();
os.Receive(sockets[i]);
for (unsigned int j = 0; j < 3; j++)
{
gfp value;
value.unpack(os);
output_values[j] += value;
}
}
if (output_values[0] * output_values[1] != output_values[2])
{
cerr << "Unable to authenticate output value as correct, aborting." << endl;
exit(1);
}
return output_values[0];
}
int main(int argc, char** argv)
{
int my_client_id;
int nparties;
int salary_value;
int finish;
int port_base = 14000;
string host_name = "localhost";
if (argc < 5) {
cout << "Usage is bankers-bonus-client <client identifier> <number of spdz parties> "
<< "<salary to compare> <finish (0 false, 1 true)> <optional host name, default localhost> "
<< "<optional spdz party port base number, default 14000>" << endl;
exit(0);
}
my_client_id = atoi(argv[1]);
nparties = atoi(argv[2]);
salary_value = atoi(argv[3]);
finish = atoi(argv[4]);
if (argc > 5)
host_name = argv[5];
if (argc > 6)
port_base = atoi(argv[6]);
// init static gfp
string prep_data_prefix = get_prep_dir(nparties, 128, 40);
initialise_fields(prep_data_prefix);
// Setup connections from this client to each party socket
vector<int> sockets(nparties);
for (int i = 0; i < nparties; i++)
{
set_up_client_socket(sockets[i], host_name.c_str(), port_base + i);
}
cout << "Finish setup socket connections to SPDZ engines." << endl;
// Map inputs into gfp
vector<gfp> input_values_gfp(3);
input_values_gfp[0].assign(my_client_id);
input_values_gfp[1].assign(salary_value);
input_values_gfp[2].assign(finish);
// Run the commputation
send_private_inputs(input_values_gfp, sockets, nparties);
cout << "Sent private inputs to each SPDZ engine, waiting for result..." << endl;
// Get the result back (client_id of winning client)
gfp result = receive_result(sockets, nparties);
cout << "Winning client id is : " << result << endl;
for (unsigned int i = 0; i < sockets.size(); i++)
close_client_socket(sockets[i]);
return 0;
}

View File

@@ -0,0 +1,407 @@
/*
* (C) 2017 University of Bristol. See License.txt
*
* Demonstrate external client inputing and receiving outputs from a SPDZ process,
* following the protocol described in https://eprint.iacr.org/2015/1006.pdf.
* Uses SPDZ implemented encryption for external client communication, see bankers-bonus-client.cpp
* for a simpler client with no crypto.
*
* Provides a client to bankers_bonus_commsec.mpc program to calculate which banker pays for lunch based on
* the private value annual bonus. Up to 8 clients can connect to the SPDZ engines running
* the bankers_bonus.mpc program.
*
* Each connecting client:
* - runs crypto setup to demonstrate both DH Auth Encryption and STS protocol for comms security.
* - sends a unique id to identify the client
* - sends an integer input (bonus value to compare)
* - sends an integer (0 meaining more players will join this round or 1 meaning stop the round and calc the result).
*
* The result is returned authenticated with a share of a random value:
* - share of winning unique id [y]
* - share of random value [r]
* - share of winning unique id * random value [w]
* winning unique id is valid if ∑ [y] * ∑ [r] = ∑ [w]
*
* To run with 2 parties (SPDZ engines) and 3 external clients:
* ./Scripts/setup-online.sh to create triple shares for each party (spdz engine).
* ./client-setup.x 2 -nc 3 to create the crypto key material for both parties and clients.
* ./compile.py bankers_bonus_commsec
* ./Scripts/run-online bankers_bonus_commsec to run the engines.
*
* ./bankers-bonus-commsec-client.x 0 2 100 0
* ./bankers-bonus-commsec-client.x 1 2 200 0
* ./bankers-bonus-commsec-client.x 2 2 50 1
*
* Expect winner to be second client with id 1.
* Note here client id must match id used in generating client key material, Client-Keys-C<client id>.
*/
#include "Math/gfp.h"
#include "Math/gf2n.h"
#include "Networking/sockets.h"
#include "Networking/STS.h"
#include "Tools/int.h"
#include "Math/Setup.h"
#include "Auth/fake-stuff.h"
#include <sodium.h>
#include <iostream>
#include <iomanip>
#include <sstream>
#include <fstream>
typedef pair< vector<octet>, vector<octet> > keypair_t; // A pair of send/recv keys for talking to SPDZ
typedef vector< keypair_t > commsec_t; // A database of send/recv keys indexed by server
typedef struct {
unsigned char client_secretkey[crypto_sign_SECRETKEYBYTES];
unsigned char client_publickey[crypto_sign_PUBLICKEYBYTES];
vector<int> client_publickey_ints;
vector< vector<unsigned char> >server_publickey;
} sign_key_container_t;
keypair_t sts_response_role_exceptions(sign_key_container_t keys, vector<int>& sockets, int server_id)
{
STS ke(&keys.server_publickey[server_id][0], keys.client_publickey, keys.client_secretkey);
sts_msg1_t m1;
sts_msg2_t m2;
sts_msg3_t m3;
octetStream os;
os.Receive(sockets[server_id]);
os.consume(m1.bytes, sizeof m1.bytes);
m2 = ke.recv_msg1(m1);
os.reset_write_head();
os.append(m2.pubkey, sizeof m2.pubkey);
os.append(m2.sig, sizeof m2.sig);
os.Send(sockets[server_id]);
os.Receive(sockets[server_id]);
os.consume(m3.bytes, sizeof m3.bytes);
ke.recv_msg3(m3);
vector<unsigned char> recvKey = ke.derive_secret(crypto_secretbox_KEYBYTES);
vector<unsigned char> sendKey = ke.derive_secret(crypto_secretbox_KEYBYTES);
return make_pair(sendKey,recvKey);
}
keypair_t sts_initiator_role_exceptions(sign_key_container_t keys, vector<int>& sockets, int server_id)
{
STS ke(&keys.server_publickey[server_id][0], keys.client_publickey, keys.client_secretkey);
sts_msg1_t m1;
sts_msg2_t m2;
sts_msg3_t m3;
octetStream os;
m1 = ke.send_msg1();
cout << "m1: ";
for (unsigned int j = 0; j < 32; j++)
cout << setfill('0') << setw(2) << hex << (int) m1.bytes[j];
cout << dec << endl;
os.reset_write_head();
os.append(m1.bytes, sizeof m1.bytes);
os.Send(sockets[server_id]);
os.reset_write_head();
os.Receive(sockets[server_id]);
os.consume(m2.pubkey, sizeof m2.pubkey);
os.consume(m2.sig, sizeof m2.sig);
m3 = ke.recv_msg2(m2);
os.reset_write_head();
os.append(m3.bytes, sizeof m3.bytes);
os.Send(sockets[server_id]);
vector<unsigned char> sendKey = ke.derive_secret(crypto_secretbox_KEYBYTES);
vector<unsigned char> recvKey = ke.derive_secret(crypto_secretbox_KEYBYTES);
return make_pair(sendKey,recvKey);
}
pair< vector<octet>, vector<octet> > sts_response_role(sign_key_container_t keys, vector<int>& sockets, int server_id)
{
pair< vector<octet>, vector<octet> > res;
try {
res = sts_response_role_exceptions(keys, sockets, server_id);
} catch(char const *e) {
cerr << "Error in STS: " << e << endl;
exit(1);
}
return res;
}
pair< vector<octet>, vector<octet> > sts_initiator_role(sign_key_container_t keys, vector<int>& sockets, int server_id)
{
pair< vector<octet>, vector<octet> > res;
try {
res = sts_initiator_role_exceptions(keys, sockets, server_id);
} catch(char const *e) {
cerr << "Error in STS: " << e << endl;
exit(1);
}
return res;
}
// Send the private inputs masked with a random value.
// Receive shares of a preprocessed triple from each SPDZ engine, combine and check the triples are valid.
// Add the private input value to triple[0] and send to each spdz engine.
void send_private_inputs(vector<gfp>& values, vector<int>& sockets, int nparties,
commsec_t commsec, vector<octet*>& keys)
{
int num_inputs = values.size();
octetStream os;
vector< vector<gfp> > triples(num_inputs, vector<gfp>(3));
vector<gfp> triple_shares(3);
// Receive num_inputs triples from SPDZ
for (int j = 0; j < nparties; j++)
{
os.reset_write_head();
os.Receive(sockets[j]);
os.decrypt_sequence(&commsec[j].second[0],0);
os.decrypt(keys[j]);
for (int j = 0; j < num_inputs; j++)
{
for (int k = 0; k < 3; k++)
{
triple_shares[k].unpack(os);
triples[j][k] += triple_shares[k];
}
}
}
// Check triple relations
for (int i = 0; i < num_inputs; i++)
{
if (triples[i][0] * triples[i][1] != triples[i][2])
{
cerr << "Incorrect triple at " << i << ", aborting\n";
exit(1);
}
}
// Send inputs + triple[0], so SPDZ can compute shares of each value
os.reset_write_head();
for (int i = 0; i < num_inputs; i++)
{
gfp y = values[i] + triples[i][0];
y.pack(os);
}
for (int j = 0; j < nparties; j++) {
os.encrypt_sequence(&commsec[j].first[0],0);
os.Send(sockets[j]);
}
}
// Send public key in clear to each SPDZ engine.
void send_public_key(vector<int>& pubkey, int socket)
{
octetStream os;
os.reset_write_head();
for (unsigned int i = 0; i < pubkey.size(); i++)
{
os.store(pubkey[i]);
}
os.Send(socket);
}
// Assumes that Scripts/setup-online.sh has been run to compute prime
void initialise_fields(const string& dir_prefix)
{
int lg2;
bigint p;
string filename = dir_prefix + "Params-Data";
cout << "loading params from: " << filename << endl;
ifstream inpf(filename.c_str());
if (inpf.fail()) { throw file_error(filename.c_str()); }
inpf >> p;
inpf >> lg2;
inpf.close();
gfp::init_field(p);
gf2n::init_field(lg2);
}
// Assumes that client-setup has been run to create key pairs for clients and parties
void generate_symmetric_keys(vector<octet*>& keys, vector<int>& client_public_key_ints,
sign_key_container_t *sts_key, const string& dir_prefix, int client_no)
{
unsigned char client_publickey[crypto_box_PUBLICKEYBYTES];
unsigned char client_secretkey[crypto_box_SECRETKEYBYTES];
unsigned char server_publickey[crypto_box_PUBLICKEYBYTES];
unsigned char scalarmult_q[crypto_scalarmult_BYTES];
crypto_generichash_state h;
// read client public/secret keys + SPDZ server public keys
ifstream keyfile;
stringstream client_filename;
client_filename << dir_prefix << "Client-Keys-C" << client_no;
keyfile.open(client_filename.str().c_str());
if (keyfile.fail())
throw file_error(client_filename.str());
keyfile.read((char*)client_publickey, sizeof client_publickey);
if (keyfile.eof())
throw end_of_file(client_filename.str(), "client public key" );
// Convert client public key unsigned char to int, reverse endianness
for(unsigned int j = 0; j < client_public_key_ints.size(); j++) {
int keybyte = 0;
for(unsigned int k = 0; k < 4; k++) {
keybyte = keybyte + (((int)client_publickey[j*4+k]) << ((3-k) * 8));
}
client_public_key_ints[j] = keybyte;
}
keyfile.read((char*)client_secretkey, sizeof client_secretkey);
if (keyfile.eof()) {
throw end_of_file(client_filename.str(), "client private key" );
}
keyfile.read((char*)sts_key->client_publickey, crypto_sign_PUBLICKEYBYTES);
keyfile.read((char*)sts_key->client_secretkey, crypto_sign_SECRETKEYBYTES);
// Convert client public key unsigned char to int, reverse endianness
sts_key->client_publickey_ints.resize(8);
for(unsigned int j = 0; j < sts_key->client_publickey_ints.size(); j++) {
int keybyte = 0;
for(unsigned int k = 0; k < 4; k++) {
keybyte = keybyte + (((int)sts_key->client_publickey[j*4+k]) << ((3-k) * 8));
}
sts_key->client_publickey_ints[j] = keybyte;
}
for (unsigned int i = 0; i < keys.size(); i++)
{
keys[i] = new octet[crypto_generichash_BYTES];
keyfile.read((char*)server_publickey, crypto_box_PUBLICKEYBYTES);
if (keyfile.eof())
throw end_of_file(client_filename.str(), "server public key for party " + i);
keyfile.read((char*)(&sts_key->server_publickey[i][0]), crypto_sign_PUBLICKEYBYTES);
if (keyfile.eof())
throw end_of_file(client_filename.str(), "server public signing key for party " + i);
// Derive a shared key from this server's secret key and the client's public key
// shared key = h(q || client_secretkey || server_publickey)
if (crypto_scalarmult(scalarmult_q, client_secretkey, server_publickey) != 0) {
cerr << "Scalar mult failed\n";
exit(1);
}
crypto_generichash_init(&h, NULL, 0U, crypto_generichash_BYTES);
crypto_generichash_update(&h, scalarmult_q, sizeof scalarmult_q);
crypto_generichash_update(&h, client_publickey, sizeof client_publickey);
crypto_generichash_update(&h, server_publickey, sizeof server_publickey);
crypto_generichash_final(&h, keys[i], crypto_generichash_BYTES);
}
keyfile.close();
cout << "My public key is: ";
for (unsigned int j = 0; j < 32; j++)
cout << setfill('0') << setw(2) << hex << (int) client_publickey[j];
cout << dec << endl;
}
// Receive shares of the result and sum together.
// Also receive authenticating values.
gfp receive_result(vector<int>& sockets, int nparties, commsec_t commsec, vector<octet*>& keys)
{
vector<gfp> output_values(3);
octetStream os;
for (int i = 0; i < nparties; i++)
{
os.reset_write_head();
os.Receive(sockets[i]);
os.decrypt_sequence(&commsec[i].second[0],1);
os.decrypt(keys[i]);
for (unsigned int j = 0; j < 3; j++)
{
gfp value;
value.unpack(os);
output_values[j] += value;
}
}
if (output_values[0] * output_values[1] != output_values[2])
{
cerr << "Unable to authenticate output value as correct, aborting." << endl;
exit(1);
}
return output_values[0];
}
int main(int argc, char** argv)
{
int my_client_id;
int nparties;
int salary_value;
int finish;
int port_base = 14000;
sign_key_container_t sts_key;
string host_name = "localhost";
if (argc < 5) {
cout << "Usage is external-client <client identifier> <number of spdz parties> "
<< "<salary to compare> <finish (0 false, 1 true)> <optional host name, default localhost> "
<< "<optional spdz party port base number, default 14000>" << endl;
exit(0);
}
my_client_id = atoi(argv[1]);
nparties = atoi(argv[2]);
salary_value = atoi(argv[3]);
finish = atoi(argv[4]);
if (argc > 5)
host_name = argv[5];
if (argc > 6)
port_base = atoi(argv[6]);
sts_key.server_publickey.resize(nparties);
for(int i = 0 ; i < nparties; i++) {
sts_key.server_publickey[i].resize(crypto_sign_PUBLICKEYBYTES);
}
// init static gfp
string prep_data_prefix = get_prep_dir(nparties, 128, 40);
initialise_fields(prep_data_prefix);
// Generate session keys to decrypt data sent from each spdz engine (party)
vector<octet*> session_keys(nparties);
vector<int> client_public_key_ints(8);
generate_symmetric_keys(session_keys, client_public_key_ints, &sts_key, prep_data_prefix, my_client_id);
// Setup connections from this client to each party socket and send the client public keys
vector<int> sockets(nparties);
// vector< pair <vector<octet>, vector <octet> > > commseckey(nparties);
commsec_t commseckey(nparties);
for (int i = 0; i < nparties; i++)
{
set_up_client_socket(sockets[i], host_name.c_str(), port_base + i);
send_public_key(sts_key.client_publickey_ints, sockets[i]);
send_public_key(client_public_key_ints, sockets[i]);
commseckey[i] = sts_initiator_role(sts_key, sockets, i);
}
cout << "Finish setup socket connections to SPDZ engines." << endl;
// Map inputs into gfp
vector<gfp> input_values_gfp(3);
input_values_gfp[0].assign(my_client_id);
input_values_gfp[1].assign(salary_value);
input_values_gfp[2].assign(finish);
// Send the inputs to the SPDZ Engines
send_private_inputs(input_values_gfp, sockets, nparties, commseckey, session_keys);
cout << "Sent private inputs to each SPDZ engine, waiting for result..." << endl;
// Get the result back
gfp result = receive_result(sockets, nparties, commseckey, session_keys);
cout << "Winning client id is : " << result << endl;
for (unsigned int i = 0; i < sockets.size(); i++)
close_client_socket(sockets[i]);
return 0;
}

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#include "Math/gf2n.h"
@@ -490,8 +490,6 @@ int main(int argc, const char** argv)
bigint p;
generate_online_setup(outf, prep_data_prefix, p, lgp, lg2);
generate_keys(prep_data_prefix, nplayers);
/* Find number players and MAC keys etc*/
gfp keyp,pp; keyp.assign_zero();
gf2n key2,p2; key2.assign_zero();

View File

@@ -1,6 +1,6 @@
University of Bristol : Open Access Software Licence
Copyright (c) 2016, The University of Bristol, a chartered corporation having Royal Charter number RC000648 and a charity (number X1121) and its place of administration being at Senate House, Tyndall Avenue, Bristol, BS8 1TH, United Kingdom.
Copyright (c) 2017, The University of Bristol, a chartered corporation having Royal Charter number RC000648 and a charity (number X1121) and its place of administration being at Senate House, Tyndall Avenue, Bristol, BS8 1TH, United Kingdom.
All rights reserved

View File

@@ -1,4 +1,4 @@
# (C) 2016 University of Bristol. See License.txt
# (C) 2017 University of Bristol. See License.txt
include CONFIG
@@ -26,7 +26,7 @@ LIB = libSPDZ.a
LIBSIMPLEOT = SimpleOT/libsimpleot.a
all: gen_input online offline
all: gen_input online offline externalIO
online: Fake-Offline.x Server.x Player-Online.x Check-Offline.x
@@ -34,6 +34,8 @@ offline: $(OT_EXE) Check-Offline.x
gen_input: gen_input_f2n.x gen_input_fp.x
externalIO: client-setup.x bankers-bonus-client.x bankers-bonus-commsec-client.x
Fake-Offline.x: Fake-Offline.cpp $(COMMON) $(PROCESSOR)
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
@@ -69,7 +71,14 @@ gen_input_f2n.x: Scripts/gen_input_f2n.cpp $(COMMON)
gen_input_fp.x: Scripts/gen_input_fp.cpp $(COMMON)
$(CXX) $(CFLAGS) Scripts/gen_input_fp.cpp -o gen_input_fp.x $(COMMON) $(LDLIBS)
client-setup.x: client-setup.cpp $(COMMON) $(PROCESSOR)
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
bankers-bonus-client.x: ExternalIO/bankers-bonus-client.cpp $(COMMON) $(PROCESSOR)
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
bankers-bonus-commsec-client.x: ExternalIO/bankers-bonus-commsec-client.cpp $(COMMON) $(PROCESSOR)
$(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS)
clean:
-rm */*.o *.o *.x core.* *.a gmon.out
-rm */*.o *.o */*.d *.d *.x core.* *.a gmon.out

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
/*
* Integer.cpp

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
/*
* Integer.h
@@ -15,10 +15,13 @@ using namespace std;
class Integer
{
protected:
long a;
public:
static string type_string() { return "integer"; }
Integer() { a = 0; }
Integer(long a) : a(a) {}
@@ -26,6 +29,31 @@ class Integer
void assign_zero() { a = 0; }
long operator+(const Integer& other) const { return a + other.a; }
long operator-(const Integer& other) const { return a - other.a; }
long operator*(const Integer& other) const { return a * other.a; }
long operator/(const Integer& other) const { return a / other.a; }
long operator>>(const Integer& other) const { return a >> other.a; }
long operator<<(const Integer& other) const { return a << other.a; }
long operator^(const Integer& other) const { return a ^ other.a; }
long operator&(const Integer& other) const { return a ^ other.a; }
long operator|(const Integer& other) const { return a ^ other.a; }
bool operator==(const Integer& other) const { return a == other.a; }
bool operator!=(const Integer& other) const { return a != other.a; }
bool operator<(const Integer& other) const { return a < other.a; }
bool operator<=(const Integer& other) const { return a <= other.a; }
bool operator>(const Integer& other) const { return a > other.a; }
bool operator>=(const Integer& other) const { return a >= other.a; }
long operator^=(const Integer& other) { return a ^= other.a; }
friend unsigned int& operator+=(unsigned int& x, const Integer& other) { return x += other.a; }
friend ostream& operator<<(ostream& s, const Integer& x) { x.output(s, true); return s; }
void output(ostream& s,bool human) const;
void input(istream& s,bool human);

View File

@@ -1,5 +1,5 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#include "Math/Setup.h"
#include "Math/gfp.h"
@@ -111,8 +111,8 @@ void generate_online_setup(ofstream& outf, string dirname, bigint& p, int lgp, i
}
string get_prep_dir(int nparties, int lg2p, int gf2ndegree)
{
if (gf2ndegree == 0)
{
if (gf2ndegree == 0)
gf2ndegree = gf2n::default_length();
stringstream ss;
ss << PREP_DIR << nparties << "-" << lg2p << "-" << gf2ndegree << "/";

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
/*
* Setup.h

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#include "Share.h"
@@ -99,6 +99,23 @@ T combine(const vector< Share<T> >& S)
}
template<class T>
inline void Share<T>::pack(octetStream& os) const
{
a.pack(os);
mac.pack(os);
}
template<class T>
inline void Share<T>::unpack(octetStream& os)
{
a.unpack(os);
mac.unpack(os);
}
template<class T>
bool check_macs(const vector< Share<T> >& S,const T& key)
{

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#ifndef _Share
@@ -69,6 +69,19 @@ class Share
void sub(const Share<T>& S1,const Share<T>& S2);
void add(const Share<T>& S1) { add(*this,S1); }
Share<T> operator+(const Share<T>& x) const
{ Share<T> res; res.add(*this, x); return res; }
template <class U>
Share<T> operator*(const U& x) const
{ Share<T> res; res.mul(*this, x); return res; }
Share<T>& operator+=(const Share<T>& x) { add(x); return *this; }
template <class U>
Share<T>& operator*=(const U& x) { mul(*this, x); return *this; }
Share<T> operator<<(int i) { return this->operator*(T(1) << i); }
Share<T>& operator<<=(int i) { return *this = *this << i; }
// Input and output from a stream
// - Can do in human or machine only format (later should be faster)
void output(ostream& s,bool human) const
@@ -80,6 +93,11 @@ class Share
mac.input(s,human);
}
friend ostream& operator<<(ostream& s, const Share<T>& x) { x.output(s, true); return s; }
void pack(octetStream& os) const;
void unpack(octetStream& os);
/* Takes a vector of shares, one from each player and
* determines the shared value
* - i.e. Partially open the shares

View File

@@ -1,5 +1,5 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#include "Zp_Data.h"

View File

@@ -1,5 +1,5 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#ifndef _Zp_Data
#define _Zp_Data

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#include "bigint.h"

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#ifndef _bigint
#define _bigint

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
/*
* types.h

View File

@@ -1,5 +1,5 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#include "Math/gf2n.h"
@@ -58,12 +58,12 @@ void gf2n_short::init_tables()
void gf2n_short::init_field(int nn)
{
if (nn == 0)
{
nn = default_length();
cerr << "Using GF(2^" << nn << ")" << endl;
}
if (nn == 0)
{
nn = default_length();
cerr << "Using GF(2^" << nn << ")" << endl;
}
gf2n_short::init_tables();
int i,j=-1;
for (i=0; i<num_2_fields && j==-1; i++)
@@ -322,7 +322,7 @@ void gf2n_short::randomize(PRNG& G)
void gf2n_short::output(ostream& s,bool human) const
{
if (human)
{ s << hex << a << dec << " "; }
{ s << hex << showbase << a << dec << " "; }
else
{ s.write((char*) &a,sizeof(word)); }
}

View File

@@ -1,5 +1,5 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#ifndef _gf2n
#define _gf2n
@@ -67,8 +67,8 @@ class gf2n_short
static string type_string() { return "gf2n"; }
static int size() { return sizeof(a); }
static int t() { return 0; }
static int t() { return 0; }
static int default_length() { return 40; }
word get() const { return a; }
@@ -80,6 +80,7 @@ class gf2n_short
void assign_one() { a=1; }
void assign_x() { a=2; }
void assign(word aa) { a=aa&mask; }
void assign(long aa) { assign(word(aa)); }
void assign(int aa) { a=static_cast<unsigned int>(aa)&mask; }
void assign(const char* buffer) { a = *(word*)buffer; }
@@ -93,8 +94,10 @@ class gf2n_short
}
gf2n_short() { a=0; }
gf2n_short(const gf2n_short& g) { assign(g); }
gf2n_short(int g) { assign(g); }
gf2n_short(word a) { assign(a); }
gf2n_short(long a) { assign(a); }
gf2n_short(int a) { assign(a); }
gf2n_short(const char* a) { assign(a); }
~gf2n_short() { ; }
gf2n_short& operator=(const gf2n_short& g)
@@ -167,7 +170,7 @@ class gf2n_short
void input(istream& s,bool human);
friend ostream& operator<<(ostream& s,const gf2n_short& x)
{ s << hex << "0x" << x.a << dec;
{ s << hex << showbase << x.a << dec;
return s;
}
friend istream& operator>>(istream& s,gf2n_short& x)

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
/*
* gf2n_longlong.cpp

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
/*
* gf2nlong.h

View File

@@ -1,5 +1,5 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#include "Math/gfp.h"
@@ -71,10 +71,15 @@ void gfp::SHL(const gfp& x,int n)
{
if (!x.is_zero())
{
bigint bi;
to_bigint(bi,x,false);
mpn_lshift(bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_size,n);
to_gfp(*this, bi);
if (n != 0)
{
bigint bi;
to_bigint(bi,x,false);
mpn_lshift(bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_size,n);
to_gfp(*this, bi);
}
else
assign(x);
}
else
{
@@ -87,10 +92,15 @@ void gfp::SHR(const gfp& x,int n)
{
if (!x.is_zero())
{
bigint bi;
to_bigint(bi,x);
mpn_rshift(bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_size,n);
to_gfp(*this, bi);
if (n != 0)
{
bigint bi;
to_bigint(bi,x);
mpn_rshift(bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_size,n);
to_gfp(*this, bi);
}
else
assign(x);
}
else
{

View File

@@ -1,5 +1,5 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#ifndef _gfp
#define _gfp
@@ -93,7 +93,7 @@ class gfp
bool is_zero() const { return isZero(a,ZpD); }
bool is_one() const { return isOne(a,ZpD); }
bool is_one() const { return isOne(a,ZpD); }
bool is_bit() const { return is_zero() or is_one(); }
bool equal(const gfp& y) const { return areEqual(a,y.a,ZpD); }
bool operator==(const gfp& y) const { return equal(y); }

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#include "Zp_Data.h"
#include "modp.h"

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#ifndef _Modp
#define _Modp

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
/*
* operations.h
@@ -17,15 +17,15 @@ T& operator*=(const T& y, const bool& x) { y = x ? y : T(); return y; }
template <class T, class U>
T operator+(const T& x, const U& y) { T res; res.add(x, y); return res; }
template <class T, class U>
T operator*(const T& x, const U& y) { T res; res.mul(x, y); return res; }
template <class T>
T operator*(const T& x, const T& y) { T res; res.mul(x, y); return res; }
template <class T, class U>
T operator-(const T& x, const U& y) { T res; res.sub(x, y); return res; }
template <class T, class U>
T& operator+=(T& x, const U& y) { x.add(y); return x; }
template <class T, class U>
T& operator*=(T& x, const U& y) { x.mul(y); return x; }
template <class T>
T& operator*=(T& x, const T& y) { x.mul(y); return x; }
template <class T, class U>
T& operator-=(T& x, const U& y) { x.sub(y); return x; }

View File

@@ -1,21 +1,37 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#include "Player.h"
#include "Exceptions/Exceptions.h"
#include "Networking/STS.h"
#include <sys/select.h>
#include <utility>
// Use printf rather than cout so valgrind can detect thread issues
using namespace std;
CommsecKeysPackage::CommsecKeysPackage(vector<public_signing_key> playerpubs,
secret_signing_key mypriv,
public_signing_key mypub)
{
player_public_keys = playerpubs;
my_secret_key = mypriv;
my_public_key = mypub;
}
void Names::init(int player,int pnb,const char* servername)
{
{
player_no=player;
portnum_base=pnb;
setup_names(servername);
keys = NULL;
setup_server();
}
void Names::init(int player,int pnb,vector<string> Nms)
{
init(player, pnb, Nms);
}
void Names::init(int player,int pnb,vector<octet*> Nms)
{
@@ -23,18 +39,10 @@ void Names::init(int player,int pnb,vector<octet*> Nms)
portnum_base=pnb;
nplayers=Nms.size();
names.resize(nplayers);
for (int i=0; i<nplayers; i++)
{ names[i]=(char*)Nms[i]; }
setup_server();
}
void Names::init(int player,int pnb,vector<string> Nms)
{
player_no=player;
portnum_base=pnb;
nplayers=Nms.size();
names=Nms;
for (int i=0; i<nplayers; i++) {
names[i]=(char*)Nms[i];
}
keys = NULL;
setup_server();
}
@@ -51,6 +59,7 @@ void Names::init(int player, int _nplayers, int pnb, const string& filename)
player_no = player;
nplayers = _nplayers;
portnum_base = pnb;
keys = NULL;
string line;
while (getline(hostsfile, line))
{
@@ -65,6 +74,11 @@ void Names::init(int player, int _nplayers, int pnb, const string& filename)
setup_server();
}
void Names::set_keys( CommsecKeysPackage *keys )
{
this->keys = keys;
}
void Names::setup_names(const char *servername)
{
int socket_num;
@@ -79,11 +93,16 @@ void Names::setup_names(const char *servername)
// Send my name
octet my_name[512];
memset(my_name,0,512*sizeof(octet));
gethostname((char*)my_name,512);
sockaddr_in address;
socklen_t size = sizeof address;
getsockname(socket_num, (sockaddr*)&address, &size);
char* name = inet_ntoa(address.sin_addr);
// max length of IP address with ending 0
strncpy((char*)my_name, name, 16);
fprintf(stderr, "My Name = %s\n",my_name);
send(socket_num,my_name,512);
cerr << "My number = " << player_no << endl;
// Now get the set of names
int i;
receive(socket_num,nplayers);
@@ -102,6 +121,7 @@ void Names::setup_names(const char *servername)
void Names::setup_server()
{
server = new ServerSocket(portnum_base + player_no);
server->init();
}
@@ -113,6 +133,7 @@ Names::Names(const Names& other)
nplayers = other.nplayers;
portnum_base = other.portnum_base;
names = other.names;
keys = NULL;
server = 0;
}
@@ -135,7 +156,7 @@ Player::Player(const Names& Nms, int id) : PlayerBase(Nms), send_to_self_socket(
Player::~Player()
{
{
/* Close down the sockets */
for (int i=0; i<nplayers; i++)
close_client_socket(sockets[i]);
@@ -148,37 +169,42 @@ Player::~Player()
// Can also communicate with myself, but only with send_to and receive_from
void Player::setup_sockets(const vector<string>& names,int portnum_base,int id_base,ServerSocket& server)
{
sockets.resize(nplayers);
// Set up the client side
for (int i=player_no; i<nplayers; i++)
{ int pn=id_base+i*nplayers+player_no;
fprintf(stderr, "Setting up client to %s:%d with id 0x%x\n",names[i].c_str(),portnum_base+i,pn);
set_up_client_socket(sockets[i],names[i].c_str(),portnum_base+i);
send(sockets[i], (unsigned char*)&pn, sizeof(pn));
sockets.resize(nplayers);
// Set up the client side
for (int i=player_no; i<nplayers; i++) {
int pn=id_base+i*nplayers+player_no;
if (i==player_no) {
const char* localhost = "127.0.0.1";
fprintf(stderr, "Setting up send to self socket to %s:%d with id 0x%x\n",localhost,portnum_base+i,pn);
set_up_client_socket(sockets[i],localhost,portnum_base+i);
} else {
fprintf(stderr, "Setting up client to %s:%d with id 0x%x\n",names[i].c_str(),portnum_base+i,pn);
set_up_client_socket(sockets[i],names[i].c_str(),portnum_base+i);
}
send(sockets[i], (unsigned char*)&pn, sizeof(pn));
}
send_to_self_socket = sockets[player_no];
// Setting up the server side
for (int i=0; i<=player_no; i++)
{ int id=id_base+player_no*nplayers+i;
fprintf(stderr, "Setting up server with id 0x%x\n",id);
sockets[i] = server.get_connection_socket(id);
send_to_self_socket = sockets[player_no];
// Setting up the server side
for (int i=0; i<=player_no; i++) {
int id=id_base+player_no*nplayers+i;
fprintf(stderr, "Setting up server with id 0x%x\n",id);
sockets[i] = server.get_connection_socket(id);
}
for (int i = 0; i < nplayers; i++)
{
// timeout of 5 minutes
struct timeval tv;
tv.tv_sec = 300;
tv.tv_usec = 0;
int fl = setsockopt(sockets[i], SOL_SOCKET, SO_RCVTIMEO, (char*)&tv, sizeof(struct timeval));
if (fl<0) { error("set_up_socket:setsockopt"); }
socket_players[sockets[i]] = i;
for (int i = 0; i < nplayers; i++) {
// timeout of 5 minutes
struct timeval tv;
tv.tv_sec = 300;
tv.tv_usec = 0;
int fl = setsockopt(sockets[i], SOL_SOCKET, SO_RCVTIMEO, (char*)&tv, sizeof(struct timeval));
if (fl<0) { error("set_up_socket:setsockopt"); }
socket_players[sockets[i]] = i;
}
}
void Player::send_to(int player,const octetStream& o,bool donthash) const
{
{
int socket = socket_to_send(player);
o.Send(socket);
if (!donthash)
@@ -187,12 +213,15 @@ void Player::send_to(int player,const octetStream& o,bool donthash) const
void Player::send_all(const octetStream& o,bool donthash) const
{ for (int i=0; i<nplayers; i++)
{ if (i!=player_no)
{ o.Send(sockets[i]); }
}
if (!donthash)
{ blk_SHA1_Update(&ctx,o.get_data(),o.get_length()); }
{
for (int i=0; i<nplayers; i++) {
if (i!=player_no) {
o.Send(sockets[i]);
}
}
if (!donthash) {
blk_SHA1_Update(&ctx,o.get_data(),o.get_length());
}
}
@@ -211,17 +240,17 @@ void Player::Broadcast_Receive(vector<octetStream>& o,bool donthash) const
{ for (int i=0; i<nplayers; i++)
{ if (i>player_no)
{ o[player_no].Send(sockets[i]); }
else if (i<player_no)
else if (i<player_no)
{ o[i].reset_write_head();
o[i].Receive(sockets[i]);
o[i].Receive(sockets[i]);
}
}
for (int i=0; i<nplayers; i++)
{ if (i<player_no)
{ o[player_no].Send(sockets[i]); }
else if (i>player_no)
else if (i>player_no)
{ o[i].reset_write_head();
o[i].Receive(sockets[i]);
o[i].Receive(sockets[i]);
}
}
if (!donthash)
@@ -240,7 +269,7 @@ void Player::Check_Broadcast() const
Broadcast_Receive(h,true);
for (int i=0; i<nplayers; i++)
{ if (i!=player_no)
{ if (i!=player_no)
{ if (!h[i].equals(h[player_no]))
{ throw broadcast_invalid(); }
}
@@ -353,26 +382,97 @@ void ThreadPlayer::send_all(const octetStream& o,bool donthash) const
TwoPartyPlayer::TwoPartyPlayer(const Names& Nms, int other_player, int id) : PlayerBase(Nms), other_player(other_player)
{
is_server = Nms.my_num() > other_player;
setup_sockets(Nms.names[other_player].c_str(), *Nms.server, Nms.portnum_base + other_player, id);
setup_sockets(other_player, Nms, Nms.portnum_base + other_player, id);
}
TwoPartyPlayer::~TwoPartyPlayer()
{
{
for(size_t i=0; i < my_secret_key.size(); i++) {
my_secret_key[i] = 0;
}
close_client_socket(socket);
}
void TwoPartyPlayer::setup_sockets(const char* hostname, ServerSocket& server, int pn, int id)
static pair<keyinfo,keyinfo> sts_initiator(int socket, CommsecKeysPackage *keys, int other_player)
{
if (is_server)
{
fprintf(stderr, "Setting up server with id %d\n",id);
socket = server.get_connection_socket(id);
sts_msg1_t m1;
sts_msg2_t m2;
sts_msg3_t m3;
octetStream socket_stream;
// Start Station to Station Protocol
STS ke(&keys->player_public_keys[other_player][0], &keys->my_public_key[0], &keys->my_secret_key[0]);
m1 = ke.send_msg1();
socket_stream.reset_write_head();
socket_stream.append(m1.bytes, sizeof m1.bytes);
socket_stream.Send(socket);
socket_stream.Receive(socket);
socket_stream.consume(m2.pubkey, sizeof m2.pubkey);
socket_stream.consume(m2.sig, sizeof m2.sig);
m3 = ke.recv_msg2(m2);
socket_stream.reset_write_head();
socket_stream.append(m3.bytes, sizeof m3.bytes);
socket_stream.Send(socket);
// Use results of STS to generate send and receive keys.
vector<unsigned char> sendKey = ke.derive_secret(crypto_secretbox_KEYBYTES);
vector<unsigned char> recvKey = ke.derive_secret(crypto_secretbox_KEYBYTES);
keyinfo sendkeyinfo = make_pair(sendKey,0);
keyinfo recvkeyinfo = make_pair(recvKey,0);
return make_pair(sendkeyinfo,recvkeyinfo);
}
static pair<keyinfo,keyinfo> sts_responder(int socket, CommsecKeysPackage *keys, int other_player)
// secret_signing_key mykey, public_signing_key mypubkey, public_signing_key theirkey)
{
sts_msg1_t m1;
sts_msg2_t m2;
sts_msg3_t m3;
octetStream socket_stream;
// Start Station to Station Protocol for the responder
STS ke(&keys->player_public_keys[other_player][0], &keys->my_public_key[0], &keys->my_secret_key[0]);
socket_stream.Receive(socket);
socket_stream.consume(m1.bytes, sizeof m1.bytes);
m2 = ke.recv_msg1(m1);
socket_stream.reset_write_head();
socket_stream.append(m2.pubkey, sizeof m2.pubkey);
socket_stream.append(m2.sig, sizeof m2.sig);
socket_stream.Send(socket);
socket_stream.Receive(socket);
socket_stream.consume(m3.bytes, sizeof m3.bytes);
ke.recv_msg3(m3);
// Use results of STS to generate send and receive keys.
vector<unsigned char> recvKey = ke.derive_secret(crypto_secretbox_KEYBYTES);
vector<unsigned char> sendKey = ke.derive_secret(crypto_secretbox_KEYBYTES);
keyinfo sendkeyinfo = make_pair(sendKey,0);
keyinfo recvkeyinfo = make_pair(recvKey,0);
return make_pair(sendkeyinfo,recvkeyinfo);
}
void TwoPartyPlayer::setup_sockets(int other_player, const Names &nms, int portNum, int id)
{
const char *hostname = nms.names[other_player].c_str();
ServerSocket *server = nms.server;
if (is_server) {
fprintf(stderr, "Setting up server with id %d\n",id);
socket = server->get_connection_socket(id);
if(NULL != nms.keys) {
pair<keyinfo,keyinfo> send_recv_pair = sts_responder(socket, nms.keys, other_player);
player_send_key = send_recv_pair.first;
player_recv_key = send_recv_pair.second;
}
}
else
{
fprintf(stderr, "Setting up client to %s:%d with id %d\n", hostname, pn, id);
set_up_client_socket(socket, hostname, pn);
::send(socket, (unsigned char*)&id, sizeof(id));
else {
fprintf(stderr, "Setting up client to %s:%d with id %d\n", hostname, portNum, id);
set_up_client_socket(socket, hostname, portNum);
::send(socket, (unsigned char*)&id, sizeof(id));
if(NULL != nms.keys) {
pair<keyinfo,keyinfo> send_recv_pair = sts_initiator(socket, nms.keys, other_player);
player_send_key = send_recv_pair.first;
player_recv_key = send_recv_pair.second;
}
}
}
@@ -381,31 +481,37 @@ int TwoPartyPlayer::other_player_num() const
return other_player;
}
void TwoPartyPlayer::send(octetStream& o) const
void TwoPartyPlayer::send(octetStream& o)
{
if(p2pcommsec) {
o.encrypt_sequence(&player_send_key.first[0], player_send_key.second);
player_send_key.second++;
}
o.Send(socket);
}
void TwoPartyPlayer::receive(octetStream& o) const
void TwoPartyPlayer::receive(octetStream& o)
{
o.reset_write_head();
o.Receive(socket);
if(p2pcommsec) {
o.decrypt_sequence(&player_recv_key.first[0], player_recv_key.second);
player_recv_key.second++;
}
}
void TwoPartyPlayer::send_receive_player(vector<octetStream>& o) const
void TwoPartyPlayer::send_receive_player(vector<octetStream>& o)
{
{
if (is_server)
{
o[0].Send(socket);
o[1].reset_write_head();
o[1].Receive(socket);
send(o[0]);
receive(o[1]);
}
else
{
o[1].reset_write_head();
o[1].Receive(socket);
o[0].Send(socket);
receive(o[1]);
send(o[0]);
}
}
}

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#ifndef _Player
#define _Player
@@ -23,6 +23,23 @@ using namespace std;
#include "Networking/Receiver.h"
#include "Networking/Sender.h"
typedef vector<octet> public_signing_key;
typedef vector<octet> secret_signing_key;
typedef vector<octet> chachakey;
typedef pair< chachakey, uint64_t > keyinfo;
class CommsecKeysPackage {
public:
vector<public_signing_key> player_public_keys;
secret_signing_key my_secret_key;
public_signing_key my_public_key;
CommsecKeysPackage(vector<public_signing_key> playerpubs,
secret_signing_key mypriv,
public_signing_key mypub);
~CommsecKeysPackage();
};
/* Class to get the names off the server */
class Names
{
@@ -31,6 +48,8 @@ class Names
int portnum_base;
int player_no;
CommsecKeysPackage *keys;
void setup_names(const char *servername);
void setup_server();
@@ -39,7 +58,6 @@ class Names
mutable ServerSocket* server;
// Usual setup names
void init(int player,int pnb,const char* servername);
Names(int player,int pnb,const char* servername)
{ init(player,pnb,servername); }
@@ -50,11 +68,10 @@ class Names
void init(int player,int pnb,vector<string> Nms);
Names(int player,int pnb,vector<string> Nms)
{ init(player,pnb,Nms); }
// Set up names from file -- reads the first nplayers names in the file
void init(int player, int nplayers, int pnb, const string& hostsfile);
Names(int player, int nplayers, int pnb, const string& hostsfile)
{ init(player, nplayers, pnb, hostsfile); }
void set_keys( CommsecKeysPackage *keys );
Names() : nplayers(-1), portnum_base(-1), player_no(-1), server(0) { ; }
Names(const Names& other);
@@ -81,7 +98,6 @@ public:
int my_num() const { return player_no; }
};
class Player : public PlayerBase
{
protected:
@@ -161,25 +177,31 @@ class TwoPartyPlayer : public PlayerBase
{
private:
// setup sockets for comm. with only one other player
void setup_sockets(const char* hostname, ServerSocket& server, int pn, int id);
void setup_sockets(int other_player, const Names &nms, int portNum, int id);
int socket;
bool is_server;
int other_player;
bool p2pcommsec;
secret_signing_key my_secret_key;
map<int,public_signing_key> player_public_keys;
keyinfo player_send_key;
keyinfo player_recv_key;
public:
TwoPartyPlayer(const Names& Nms, int other_player, int pn_offset=0);
~TwoPartyPlayer();
void send(octetStream& o) const;
void receive(octetStream& o) const;
void send(octetStream& o);
void receive(octetStream& o);
int other_player_num() const;
/* Send and receive to/from the other player
* - o[0] contains my data, received data put in o[1]
*/
void send_receive_player(vector<octetStream>& o) const;
void send_receive_player(vector<octetStream>& o);
};
#endif

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
/*
* Receiver.cpp

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
/*
* Receiver.h

230
Networking/STS.cpp Normal file
View File

@@ -0,0 +1,230 @@
// (C) 2017 University of Bristol. See License.txt
#include "Networking/STS.h"
#include <sodium.h>
#include <string>
#include <string.h>
#include <unistd.h>
#include <stdio.h>
#include <iomanip>
#include <fcntl.h>
void STS::kdf_block(unsigned char *block)
{
crypto_hash_sha512_state state;
crypto_hash_sha512_init(&state);
unsigned char ctrbytes[sizeof kdf_counter];
kdf_counter++;
// Little endian serialization
for(size_t i=0; i<sizeof(kdf_counter); i++) {
ctrbytes[i] = (unsigned char)((kdf_counter >> i*8) & 0xFF);
}
crypto_hash_sha512_update(&state,ctrbytes,sizeof ctrbytes);
crypto_hash_sha512_update(&state,raw_secret,crypto_hash_sha512_BYTES);
crypto_hash_sha512_final(&state, block);
}
vector<unsigned char> STS::unsafe_derive_secret(size_t sz)
{
// KDF ~ H(cnt || raw_secret)
vector<unsigned char> resultSecret(sz + crypto_hash_sha512_BYTES - (sz % crypto_hash_sha512_BYTES));
size_t total=0;
while(total < sz) {
unsigned char *block = &resultSecret[total];
kdf_block(block);
total += crypto_hash_sha512_BYTES;
}
return resultSecret;
}
STS::STS()
{
phase = UNDEFINED;
}
void STS::init( const unsigned char theirPub[crypto_sign_PUBLICKEYBYTES]
, const unsigned char myPub[crypto_sign_PUBLICKEYBYTES]
, const unsigned char myPriv[crypto_sign_SECRETKEYBYTES])
{
phase = UNKNOWN;
memcpy(their_public_sign_key, theirPub, crypto_sign_PUBLICKEYBYTES);
memcpy(my_public_sign_key, myPub, crypto_sign_PUBLICKEYBYTES);
memcpy(my_private_sign_key, myPriv, crypto_sign_SECRETKEYBYTES);
memset(their_ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES);
memset(ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES);
memset(ephemeral_private_key, 0, crypto_box_SECRETKEYBYTES);
kdf_counter = 0;
}
STS::STS( const unsigned char theirPub[crypto_sign_PUBLICKEYBYTES]
, const unsigned char myPub[crypto_sign_PUBLICKEYBYTES]
, const unsigned char myPriv[crypto_sign_SECRETKEYBYTES])
{
phase = UNKNOWN;
memcpy(their_public_sign_key, theirPub, crypto_sign_PUBLICKEYBYTES);
memcpy(my_public_sign_key, myPub, crypto_sign_PUBLICKEYBYTES);
memcpy(my_private_sign_key, myPriv, crypto_sign_SECRETKEYBYTES);
memset(their_ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES);
memset(ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES);
memset(ephemeral_private_key, 0, crypto_box_SECRETKEYBYTES);
kdf_counter = 0;
}
STS::~STS()
{
memset(their_public_sign_key, 0, crypto_sign_PUBLICKEYBYTES);
memset(my_private_sign_key, 0, crypto_sign_SECRETKEYBYTES);
memset(ephemeral_private_key, 0, crypto_box_SECRETKEYBYTES);
memset(ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES);
memset(their_ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES);
memset(raw_secret, 0, crypto_hash_sha512_BYTES);
kdf_counter = 0;
phase = UNKNOWN;
}
sts_msg1_t STS::send_msg1()
{
sts_msg1_t m;
if(UNKNOWN != phase) {
throw "STS BAD PHASE";
}
crypto_box_keypair(ephemeral_public_key, ephemeral_private_key);
memcpy(m.bytes,ephemeral_public_key,crypto_box_PUBLICKEYBYTES);
phase = SENT1;
return m;
}
// If the incoming signature is valid, compute:
// shared secret = H(DH(pubB,privA) || pubA || pubB)
// msg = Sign_{privED-A} (pubA || pubB )
//
sts_msg3_t STS::recv_msg2(sts_msg2_t msg2)
{
unsigned char *theirPublicKey = msg2.pubkey;
unsigned char *theirSig = msg2.sig;
unsigned char theirSigDec[crypto_sign_BYTES];
unsigned char scalar_result[crypto_scalarmult_SCALARBYTES];
const unsigned char zeroNonce[crypto_stream_NONCEBYTES] = {0};
int ret;
crypto_hash_sha512_state state;
sts_msg3_t msg;
if(SENT1 != phase) {
throw "STS BAD PHASE";
}
ret = crypto_scalarmult(scalar_result, ephemeral_private_key, theirPublicKey);
if(0 != ret) {
throw "crypto_scalarmult failed";
}
crypto_hash_sha512_init(&state);
crypto_hash_sha512_update(&state,scalar_result,crypto_scalarmult_SCALARBYTES);
crypto_hash_sha512_update(&state,ephemeral_public_key,crypto_box_PUBLICKEYBYTES);
crypto_hash_sha512_update(&state,theirPublicKey,crypto_box_PUBLICKEYBYTES);
crypto_hash_sha512_final(&state,raw_secret);
vector<unsigned char> keKey = unsafe_derive_secret(crypto_stream_KEYBYTES);
vector<unsigned char> expectedMessage;
expectedMessage.insert(expectedMessage.end(), theirPublicKey , theirPublicKey + crypto_box_PUBLICKEYBYTES);
expectedMessage.insert(expectedMessage.end(), ephemeral_public_key, ephemeral_public_key + crypto_box_PUBLICKEYBYTES);
crypto_stream_xor(theirSigDec, theirSig, crypto_sign_BYTES, zeroNonce, &keKey[0]);
int badSig = crypto_sign_verify_detached(theirSigDec, &expectedMessage[0], expectedMessage.size(), their_public_sign_key);
if(badSig) {
throw "Bad signature received in message 2.";
} else {
unsigned char *mySigEnc = msg.bytes;
unsigned char mySig[crypto_sign_BYTES];
vector<unsigned char> signMessage;
signMessage.insert(signMessage.end(), ephemeral_public_key, ephemeral_public_key + crypto_box_PUBLICKEYBYTES);
signMessage.insert(signMessage.end(), theirPublicKey , theirPublicKey + crypto_box_PUBLICKEYBYTES);
if(0 != crypto_sign_detached(mySig, NULL, &signMessage[0], signMessage.size(), my_private_sign_key)) {
throw "Signing failed.";
}
vector<unsigned char> keKey2 = unsafe_derive_secret(crypto_stream_KEYBYTES);
crypto_stream_xor(mySigEnc, mySig, crypto_sign_BYTES, zeroNonce, &keKey2[0]);
phase = FINISHED;
return msg;
}
}
sts_msg2_t STS::recv_msg1(sts_msg1_t msg1)
{
unsigned char *theirPublicKey = msg1.bytes;
unsigned char scalar_result[crypto_scalarmult_SCALARBYTES];
crypto_hash_sha512_state state;
sts_msg2_t m;
int ret;
if(UNKNOWN != phase) {
throw "recv_msg1 called on non-unknown phase";
}
memcpy(their_ephemeral_public_key, theirPublicKey, crypto_box_PUBLICKEYBYTES);
crypto_box_keypair(ephemeral_public_key, ephemeral_private_key);
memcpy(m.pubkey,ephemeral_public_key,crypto_box_PUBLICKEYBYTES);
ret = crypto_scalarmult(scalar_result, ephemeral_private_key, theirPublicKey);
if(0 != ret) {
throw "crypto_scalarmult failed when processing message 1";
}
crypto_hash_sha512_init(&state);
crypto_hash_sha512_update(&state,scalar_result,crypto_scalarmult_SCALARBYTES);
crypto_hash_sha512_update(&state,theirPublicKey,crypto_box_PUBLICKEYBYTES);
crypto_hash_sha512_update(&state,ephemeral_public_key,crypto_box_PUBLICKEYBYTES);
crypto_hash_sha512_final(&state,raw_secret);
vector<unsigned char> livenessProof;
livenessProof.insert(livenessProof.end(), ephemeral_public_key, ephemeral_public_key + crypto_box_PUBLICKEYBYTES);
livenessProof.insert(livenessProof.end(), theirPublicKey , theirPublicKey + crypto_box_PUBLICKEYBYTES);
unsigned char mySig[crypto_sign_BYTES];
unsigned char *mySigEnc = m.sig;
vector<unsigned char> keKey = unsafe_derive_secret(crypto_stream_KEYBYTES);
unsigned char zeroNonce[crypto_stream_NONCEBYTES] = {0};
if(0 != crypto_sign_detached(mySig, NULL, &livenessProof[0], livenessProof.size(), my_private_sign_key)) {
throw "Signing failed.";
}
crypto_stream_xor(mySigEnc, mySig, crypto_sign_BYTES, zeroNonce, &keKey[0]);
phase = SENT2;
return m;
}
void STS::recv_msg3(sts_msg3_t msg3)
{
unsigned char *theirSig=msg3.bytes;
unsigned char theirSigDec[crypto_sign_BYTES];
vector<unsigned char> expectedMessage;
if(SENT2 != phase) {
throw "recv_msg3 called out of order";
}
expectedMessage.insert(expectedMessage.end(), their_ephemeral_public_key , their_ephemeral_public_key + crypto_box_PUBLICKEYBYTES);
expectedMessage.insert(expectedMessage.end(), ephemeral_public_key, ephemeral_public_key + crypto_box_PUBLICKEYBYTES);
unsigned char zeroNonce[crypto_stream_NONCEBYTES] = {0};
vector<unsigned char> keKey2 = unsafe_derive_secret(crypto_stream_KEYBYTES);
crypto_stream_xor(theirSigDec, theirSig, crypto_sign_BYTES, zeroNonce, &keKey2[0]);
int badSig = crypto_sign_verify_detached(theirSigDec, &expectedMessage[0], expectedMessage.size(), their_public_sign_key);
if(badSig) {
throw "Bad signature received in message 3.";
} else {
phase = FINISHED;
}
}
vector<unsigned char> STS::derive_secret(size_t sz)
{
if(phase != FINISHED) {
throw "Can not derive secrets till the key exchange has completed.";
}
return unsafe_derive_secret(sz);
}

72
Networking/STS.h Normal file
View File

@@ -0,0 +1,72 @@
// (C) 2017 University of Bristol. See License.txt
#ifndef _NETWORK_STS
#define _NETWORK_STS
/* The Station to Station protocol
*/
#include <iostream>
#include <fstream>
#include <vector>
#include <sodium.h>
using namespace std;
typedef enum
{ UNKNOWN // Have not started the interaction or have cleared the memory
, SENT1 // Sent initial message
, SENT2 // Received 1, sent 2
, FINISHED // Done (received msg 2 & sent 3 or received msg 3)
, UNDEFINED // For arrays/vectors/etc of STS classes that are initialized later.
} phase_t;
struct msg1_st {
unsigned char bytes[crypto_box_PUBLICKEYBYTES];
};
typedef struct msg1_st sts_msg1_t;
struct msg2_st {
unsigned char pubkey[crypto_box_PUBLICKEYBYTES];
unsigned char sig[crypto_sign_BYTES];
};
typedef struct msg2_st sts_msg2_t;
struct msg3_st {
unsigned char bytes[crypto_sign_BYTES];
};
typedef struct msg3_st sts_msg3_t;
class STS
{
phase_t phase;
unsigned char their_public_sign_key[crypto_sign_PUBLICKEYBYTES];
unsigned char my_public_sign_key[crypto_sign_PUBLICKEYBYTES];
unsigned char my_private_sign_key[crypto_sign_SECRETKEYBYTES];
unsigned char ephemeral_private_key[crypto_box_SECRETKEYBYTES];
unsigned char ephemeral_public_key[crypto_box_PUBLICKEYBYTES];
unsigned char their_ephemeral_public_key[crypto_box_PUBLICKEYBYTES];
unsigned char raw_secret[crypto_hash_sha512_BYTES];
uint64_t kdf_counter;
public:
STS();
STS( const unsigned char theirPub[crypto_sign_PUBLICKEYBYTES]
, const unsigned char myPub[crypto_sign_PUBLICKEYBYTES]
, const unsigned char myPriv[crypto_sign_SECRETKEYBYTES]);
~STS();
void init( const unsigned char theirPub[crypto_sign_PUBLICKEYBYTES]
, const unsigned char myPub[crypto_sign_PUBLICKEYBYTES]
, const unsigned char myPriv[crypto_sign_SECRETKEYBYTES]);
sts_msg1_t send_msg1();
sts_msg3_t recv_msg2(sts_msg2_t msg2);
sts_msg2_t recv_msg1(sts_msg1_t msg1);
void recv_msg3(sts_msg3_t msg3);
vector<unsigned char> derive_secret(size_t);
private:
vector<unsigned char> unsafe_derive_secret(size_t);
void kdf_block(unsigned char *block);
};
#endif /* _NETWORK_STS */

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
/*
* Sender.cpp

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
/*
* Sender.h

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
/*
* ServerSocket.cpp
@@ -57,7 +57,7 @@ ServerSocket::ServerSocket(int Portnum) : portnum(Portnum)
sleep(1);
}
else
{ cerr << "Bound on port " << Portnum << endl; }
{ cerr << "ServerSocket is bound on port " << Portnum << endl; }
}
if (fl<0) { error("set_up_socket:bind"); }
@@ -65,6 +65,11 @@ ServerSocket::ServerSocket(int Portnum) : portnum(Portnum)
fl=listen(main_socket, 1000);
if (fl<0) { error("set_up_socket:listen"); }
// Note: must not call virtual init() method in constructor: http://www.aristeia.com/EC3E/3E_item9.pdf
}
void ServerSocket::init()
{
pthread_create(&thread, 0, accept_thread, this);
}
@@ -95,6 +100,15 @@ void ServerSocket::accept_clients()
}
}
int ServerSocket::get_connection_count()
{
data_signal.lock();
int connection_count = clients.size();
data_signal.unlock();
return connection_count;
}
int ServerSocket::get_connection_socket(int id)
{
data_signal.lock();
@@ -108,8 +122,60 @@ int ServerSocket::get_connection_socket(int id)
while (clients.find(id) == clients.end())
data_signal.wait();
int client = clients[id];
int client_socket = clients[id];
used.insert(id);
data_signal.unlock();
return client;
return client_socket;
}
void* anonymous_accept_thread(void* server_socket)
{
((AnonymousServerSocket*)server_socket)->accept_clients();
return 0;
}
int AnonymousServerSocket::global_client_socket_count = 0;
void AnonymousServerSocket::init()
{
pthread_create(&thread, 0, anonymous_accept_thread, this);
}
int AnonymousServerSocket::get_connection_count()
{
return num_accepted_clients;
}
void AnonymousServerSocket::accept_clients()
{
while (true)
{
struct sockaddr dest;
memset(&dest, 0, sizeof(dest)); /* zero the struct before filling the fields */
int socksize = sizeof(dest);
int consocket = accept(main_socket, (struct sockaddr *)&dest, (socklen_t*) &socksize);
if (consocket<0) { error("set_up_socket:accept"); }
data_signal.lock();
client_connection_queue.push(consocket);
num_accepted_clients++;
data_signal.broadcast();
data_signal.unlock();
}
}
int AnonymousServerSocket::get_connection_socket(int& client_id)
{
data_signal.lock();
//while (clients.find(next_client_id) == clients.end())
while (client_connection_queue.empty())
data_signal.wait();
client_id = global_client_socket_count;
global_client_socket_count++;
int client_socket = client_connection_queue.front();
client_connection_queue.pop();
data_signal.unlock();
return client_socket;
}

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
/*
* ServerSocket.h
@@ -10,6 +10,7 @@
#include <map>
#include <set>
#include <queue>
using namespace std;
#include <pthread.h>
@@ -19,6 +20,7 @@ using namespace std;
class ServerSocket
{
protected:
int main_socket, portnum;
map<int,int> clients;
set<int> used;
@@ -28,17 +30,51 @@ class ServerSocket
// disable copying
ServerSocket(const ServerSocket& other);
// receive id from client
int assign_client_id(int consocket);
public:
ServerSocket(int Portnum);
~ServerSocket();
virtual ~ServerSocket();
void accept_clients();
virtual void init();
virtual void accept_clients();
// This depends on clients sending their id as int.
// Has to be thread-safe.
int get_connection_socket(int number);
// How many client connections have been made.
virtual int get_connection_count();
void close_socket();
};
/*
* ServerSocket where clients do not send any identifiers upon connecting.
*/
class AnonymousServerSocket : public ServerSocket
{
private:
// Global no. of client sockets that have been returned - used to create identifiers
static int global_client_socket_count;
// No. of accepted connections in this instance
int num_accepted_clients;
queue<int> client_connection_queue;
public:
AnonymousServerSocket(int Portnum) :
ServerSocket(Portnum), num_accepted_clients(0) { };
// override so clients do not send id
void accept_clients();
void init();
virtual int get_connection_count();
// Get socket for the last client who connected
// Writes a unique client identifier (i.e. a counter) to client_id
int get_connection_socket(int& client_id);
};
#endif /* NETWORKING_SERVERSOCKET_H_ */

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#ifndef _Data
#define _Data

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#include "sockets.h"
@@ -28,8 +28,6 @@ void error(const char *str1,const char *str2)
throw bad_value();
}
void set_up_server_socket(sockaddr_in& dest,int& consocket,int& main_socket,int Portnum)
{
@@ -57,7 +55,7 @@ void set_up_server_socket(sockaddr_in& dest,int& consocket,int& main_socket,int
memset(my_name,0,512*sizeof(octet));
gethostname((char*)my_name,512);
/* bind serv information to mysocket
/* bind serv information to mysocket
* - Just assume it will eventually wake up
*/
fl=1;
@@ -82,21 +80,18 @@ void set_up_server_socket(sockaddr_in& dest,int& consocket,int& main_socket,int
}
void close_server_socket(int consocket,int main_socket)
{
if (close(consocket)) { error("close(socket)"); }
if (close(main_socket)) { error("close(main_socket"); };
}
void set_up_client_socket(int& mysocket,const char* hostname,int Portnum)
{
mysocket = socket(AF_INET, SOCK_STREAM, 0);
if (mysocket<0) { error("set_up_socket:socket"); }
/* disable Nagle's algorithm */
/* disable Nagle's algorithm */
int one=1;
int fl= setsockopt(mysocket, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(int));
if (fl<0) { error("set_up_socket:setsockopt"); }
@@ -106,17 +101,8 @@ void set_up_client_socket(int& mysocket,const char* hostname,int Portnum)
struct sockaddr_in dest;
dest.sin_family = AF_INET;
dest.sin_port = htons(Portnum); // set destination port number
dest.sin_port = htons(Portnum); // set destination port number
/*
struct hostent *server;
server=gethostbyname(hostname);
if (server== NULL)
{ error("set_up_socket:gethostbyname"); }
bcopy((char *)server->h_addr,
(char *)&dest.sin_addr.s_addr,
server->h_length); // set destination IP number
*/
struct addrinfo hints, *ai=NULL,*rp;
memset (&hints, 0, sizeof(hints));
hints.ai_family = AF_INET;
@@ -140,13 +126,13 @@ void set_up_client_socket(int& mysocket,const char* hostname,int Portnum)
}
}
if (erp!=0)
{ error("set_up_socket:getaddrinfo"); }
{ error("set_up_socket:getaddrinfo"); }
for (rp=ai; rp!=NULL; rp=rp->ai_next)
{ const struct in_addr *addr4 = &((const struct sockaddr_in*)ai->ai_addr)->sin_addr;
if (ai->ai_family == AF_INET)
{ memcpy((char *)&dest.sin_addr.s_addr,addr4,sizeof(in_addr));
{ memcpy((char *)&dest.sin_addr.s_addr,addr4,sizeof(in_addr));
continue;
}
}
@@ -162,8 +148,6 @@ void set_up_client_socket(int& mysocket,const char* hostname,int Portnum)
if (fl<0) { error("set_up_socket:connect:",hostname); }
}
void close_client_socket(int socket)
{
if (close(socket))
@@ -174,8 +158,6 @@ void close_client_socket(int socket)
}
}
unsigned long long sent_amount = 0, sent_counter = 0;
@@ -195,7 +177,7 @@ void receive(int socket,int& a)
while (i==0)
{ i=recv(socket,msg,1,0);
if (i<0) { error("Receiving error - 2"); }
}
}
a=msg[0];
}

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#ifndef _sockets
#define _sockets

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#include "OT/BaseOT.h"
#include "Tools/random.h"
@@ -34,7 +34,7 @@ OT_ROLE INV_ROLE(OT_ROLE role)
return BOTH;
}
void send_if_ot_sender(const TwoPartyPlayer* P, vector<octetStream>& os, OT_ROLE role)
void send_if_ot_sender(TwoPartyPlayer* P, vector<octetStream>& os, OT_ROLE role)
{
if (role == SENDER)
{
@@ -51,7 +51,7 @@ void send_if_ot_sender(const TwoPartyPlayer* P, vector<octetStream>& os, OT_ROLE
}
}
void send_if_ot_receiver(const TwoPartyPlayer* P, vector<octetStream>& os, OT_ROLE role)
void send_if_ot_receiver(TwoPartyPlayer* P, vector<octetStream>& os, OT_ROLE role)
{
if (role == RECEIVER)
{

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#ifndef _BASE_OT
#define _BASE_OT
@@ -26,8 +26,8 @@ enum OT_ROLE
OT_ROLE INV_ROLE(OT_ROLE role);
const char* role_to_str(OT_ROLE role);
void send_if_ot_sender(const TwoPartyPlayer* P, vector<octetStream>& os, OT_ROLE role);
void send_if_ot_receiver(const TwoPartyPlayer* P, vector<octetStream>& os, OT_ROLE role);
void send_if_ot_sender(TwoPartyPlayer* P, vector<octetStream>& os, OT_ROLE role);
void send_if_ot_receiver(TwoPartyPlayer* P, vector<octetStream>& os, OT_ROLE role);
class BaseOT
{

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
/*
* BitMatrix.cpp

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
/*
* BitMatrix.h

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#include "OT/BitVector.h"

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#ifndef _BITVECTOR
#define _BITVECTOR

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#include "NPartyTripleGenerator.h"

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#ifndef OT_NPARTYTRIPLEGENERATOR_H_
#define OT_NPARTYTRIPLEGENERATOR_H_

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#include "OTExtension.h"

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#ifndef _OTEXTENSION
#define _OTEXTENSION

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
/*
* OTExtensionWithMatrix.cpp

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
/*
* OTExtensionWithMatrix.h

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#include "Networking/Player.h"
#include "OT/OTExtension.h"

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
/*
* OTMachine.h

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
/*
* OTMultiplier.cpp

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
/*
* OTMultiplier.h

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#include "OTTripleSetup.h"

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#ifndef OT_TRIPLESETUP_H_
#define OT_TRIPLESETUP_H_

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
/*
* OText_main.cpp

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
/*
* check.h

View File

@@ -1,9 +1,9 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#include "Tools.h"
#include "Math/gf2nlong.h"
void random_seed_commit(octet* seed, const TwoPartyPlayer& player, int len)
void random_seed_commit(octet* seed, TwoPartyPlayer& player, int len)
{
PRNG G;
G.ReSeed();

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#ifndef _OTTOOLS
#define _OTTOOLS
@@ -12,7 +12,7 @@
/*
* Generate a secure, random seed between 2 parties via commitment
*/
void random_seed_commit(octet* seed, const TwoPartyPlayer& player, int len);
void random_seed_commit(octet* seed, TwoPartyPlayer& player, int len);
/*
* GF(2^128) multiplication using Intel instructions

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
/*
* TripleMachine.cpp

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
/*
* TripleMachine.h

View File

@@ -1,11 +1,14 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
#include "Processor/Machine.h"
#include "Math/Setup.h"
#include "Tools/ezOptionParser.h"
#include "Tools/Config.h"
#include <iostream>
#include <map>
#include <string>
#include <stdio.h>
using namespace std;
int main(int argc, const char** argv)
@@ -108,6 +111,15 @@ int main(int argc, const char** argv)
"-b", // Flag token.
"--max-broadcast" // Flag token.
);
opt.add(
"0", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Use communications security between SPDZ players", // Help description.
"-c", // Flag token.
"--player-to-player-commsec" // Flag token.
);
opt.parse(argc, argv);
@@ -156,6 +168,7 @@ int main(int argc, const char** argv)
string memtype, hostname;
int lg2, lgp, pnbase, opening_sum, max_broadcast;
int p2pcommsec;
opt.get("--portnumbase")->getInt(pnbase);
opt.get("--lgp")->getInt(lgp);
@@ -164,11 +177,25 @@ int main(int argc, const char** argv)
opt.get("--hostname")->getString(hostname);
opt.get("--opening-sum")->getInt(opening_sum);
opt.get("--max-broadcast")->getInt(max_broadcast);
opt.get("--player-to-player-commsec")->getInt(p2pcommsec);
int mynum;
sscanf((*allArgs[1]).c_str(), "%d", &mynum);
CommsecKeysPackage *keys = NULL;
if(p2pcommsec) {
vector<public_signing_key> pubkeys;
secret_signing_key mykey;
public_signing_key mypublickey;
string prep_data_prefix = get_prep_dir(2, lgp, lg2);
Config::read_player_config(prep_data_prefix,mynum,pubkeys,mykey,mypublickey);
keys = new CommsecKeysPackage(pubkeys,mykey,mypublickey);
}
Machine(playerno, pnbase, hostname, progname, memtype, lgp, lg2,
opt.get("--direct")->isSet, opening_sum, opt.get("--parallel")->isSet,
opt.get("--threads")->isSet, max_broadcast).run();
opt.get("--threads")->isSet, max_broadcast, keys).run();
cerr << "Command line:";
for (int i = 0; i < argc; i++)

View File

@@ -0,0 +1,72 @@
// (C) 2017 University of Bristol. See License.txt
#include "Processor/Binary_File_IO.h"
#include "Math/gfp.h"
/*
* Provides generalised file read and write methods for arrays of shares.
* Stateless and not optimised for multiple reads from file.
* Intended for application specific file IO.
*/
template<class T>
void Binary_File_IO::write_to_file(const string filename, const vector< Share<T> >& buffer)
{
ofstream outf;
outf.open(filename, ios::out | ios::binary | ios::app);
if (outf.fail()) { throw file_error(filename); }
for (unsigned int i = 0; i < buffer.size(); i++)
{
buffer[i].output(outf, false);
}
outf.close();
}
template<class T>
void Binary_File_IO::read_from_file(const string filename, vector< Share<T> >& buffer, const int start_posn, int &end_posn)
{
ifstream inf;
inf.open(filename, ios::in | ios::binary);
if (inf.fail()) { throw file_missing(filename, "Binary_File_IO.read_from_file expects this file to exist."); }
int size_in_bytes = Share<T>::size() * buffer.size();
int n_read = 0;
char * read_buffer = new char[size_in_bytes];
inf.seekg(start_posn);
do
{
inf.read(read_buffer + n_read, size_in_bytes - n_read);
n_read += inf.gcount();
if (inf.eof())
{
stringstream ss;
ss << "Got to EOF when reading from disk (expecting " << size_in_bytes << " bytes).";
throw file_error(ss.str());
}
if (inf.fail())
{
stringstream ss;
ss << "IO problem when reading from disk";
throw file_error(ss.str());
}
}
while (n_read < size_in_bytes);
end_posn = inf.tellg();
//Check if at end of file by getting 1 more char.
inf.get();
if (inf.eof())
end_posn = -1;
inf.close();
for (unsigned int i = 0; i < buffer.size(); i++)
buffer[i].assign(&read_buffer[i*Share<T>::size()]);
}
template void Binary_File_IO::write_to_file(const string filename, const vector< Share<gfp> >& buffer);
template void Binary_File_IO::read_from_file(const string filename, vector< Share<gfp> >& buffer, const int start_posn, int &end_posn);

View File

@@ -0,0 +1,43 @@
// (C) 2017 University of Bristol. See License.txt
#ifndef _FILE_IO_HEADER
#define _FILE_IO_HEADER
#include "Exceptions/Exceptions.h"
#include "Math/Share.h"
#include <string>
#include <sstream>
#include <fstream>
#include <vector>
using namespace std;
/*
* Provides generalised file read and write methods for arrays of numeric data types.
* Stateless and not optimised for multiple reads from file.
* Intended for MPC application specific file IO.
*/
class Binary_File_IO
{
public:
/*
* Append the buffer values as binary to the filename.
* Throws file_error.
*/
template <class T>
void write_to_file(const string filename, const vector< Share<T> >& buffer);
/*
* Read from posn in the filename the binary values until the buffer is full.
* Assumes file holds binary that maps into the type passed in.
* Returns the current posn in the file or -1 if at eof.
* Throws file_error.
*/
template <class T>
void read_from_file(const string filename, vector< Share<T> >& buffer, const int start_posn, int &end_posn);
};
#endif

View File

@@ -1,4 +1,4 @@
// (C) 2016 University of Bristol. See License.txt
// (C) 2017 University of Bristol. See License.txt
/*
* Buffer.cpp

Some files were not shown because too many files have changed in this diff Show More