From 987a78286f5055900002644fe8b826fe54289de2 Mon Sep 17 00:00:00 2001 From: Jonathan Evans Date: Thu, 14 Sep 2017 10:35:01 +0100 Subject: [PATCH] Release to add compiler instructions for external client I/O. --- .gitignore | 39 +- Auth/MAC_Check.cpp | 2 +- Auth/MAC_Check.h | 2 +- Auth/Subroutines.cpp | 2 +- Auth/Subroutines.h | 2 +- Auth/Summer.cpp | 2 +- Auth/Summer.h | 2 +- Auth/fake-stuff.cpp | 2 +- Auth/fake-stuff.h | 2 +- CHANGELOG.md | 56 +++ CONFIG | 6 +- Check-Offline.cpp | 68 ++-- Compiler/__init__.py | 2 +- Compiler/allocator.py | 28 +- Compiler/comparison.py | 10 +- Compiler/compilerLib.py | 2 +- Compiler/config.py | 4 +- Compiler/dijkstra.py | 2 +- Compiler/exceptions.py | 4 +- Compiler/floatingpoint.py | 6 +- Compiler/graph.py | 2 +- Compiler/gs.py | 2 +- Compiler/instructions.py | 167 ++++++-- Compiler/instructions_base.py | 50 ++- Compiler/library.py | 8 +- Compiler/oram.py | 83 ++-- Compiler/path_oram.py | 7 +- Compiler/permutation.py | 2 +- Compiler/program.py | 43 ++- Compiler/tools.py | 2 +- Compiler/types.py | 246 ++++++++++-- Compiler/util.py | 33 +- Exceptions/Exceptions.h | 41 +- ExternalIO/README.md | 105 +++++ ExternalIO/bankers-bonus-client.cpp | 198 ++++++++++ ExternalIO/bankers-bonus-commsec-client.cpp | 407 ++++++++++++++++++++ Fake-Offline.cpp | 4 +- License.txt | 2 +- Makefile | 17 +- Math/Integer.cpp | 2 +- Math/Integer.h | 30 +- Math/Setup.cpp | 8 +- Math/Setup.h | 2 +- Math/Share.cpp | 19 +- Math/Share.h | 20 +- Math/Zp_Data.cpp | 4 +- Math/Zp_Data.h | 4 +- Math/bigint.cpp | 2 +- Math/bigint.h | 2 +- Math/field_types.h | 2 +- Math/gf2n.cpp | 18 +- Math/gf2n.h | 17 +- Math/gf2nlong.cpp | 2 +- Math/gf2nlong.h | 2 +- Math/gfp.cpp | 30 +- Math/gfp.h | 6 +- Math/modp.cpp | 2 +- Math/modp.h | 2 +- Math/operators.h | 10 +- Networking/Player.cpp | 252 ++++++++---- Networking/Player.h | 40 +- Networking/Receiver.cpp | 2 +- Networking/Receiver.h | 2 +- Networking/STS.cpp | 230 +++++++++++ Networking/STS.h | 72 ++++ Networking/Sender.cpp | 2 +- Networking/Sender.h | 2 +- Networking/ServerSocket.cpp | 74 +++- Networking/ServerSocket.h | 42 +- Networking/data.h | 2 +- Networking/sockets.cpp | 36 +- Networking/sockets.h | 2 +- OT/BaseOT.cpp | 6 +- OT/BaseOT.h | 6 +- OT/BitMatrix.cpp | 2 +- OT/BitMatrix.h | 2 +- OT/BitVector.cpp | 2 +- OT/BitVector.h | 2 +- OT/NPartyTripleGenerator.cpp | 2 +- OT/NPartyTripleGenerator.h | 2 +- OT/OTExtension.cpp | 2 +- OT/OTExtension.h | 2 +- OT/OTExtensionWithMatrix.cpp | 2 +- OT/OTExtensionWithMatrix.h | 2 +- OT/OTMachine.cpp | 2 +- OT/OTMachine.h | 2 +- OT/OTMultiplier.cpp | 2 +- OT/OTMultiplier.h | 2 +- OT/OTTripleSetup.cpp | 2 +- OT/OTTripleSetup.h | 2 +- OT/OText_main.cpp | 2 +- OT/OutputCheck.h | 2 +- OT/Tools.cpp | 4 +- OT/Tools.h | 4 +- OT/TripleMachine.cpp | 2 +- OT/TripleMachine.h | 2 +- Player-Online.cpp | 31 +- Processor/Binary_File_IO.cpp | 72 ++++ Processor/Binary_File_IO.h | 43 +++ Processor/Buffer.cpp | 2 +- Processor/Buffer.h | 2 +- Processor/Data_Files.cpp | 4 +- Processor/Data_Files.h | 6 +- Processor/ExternalClients.cpp | 179 +++++++++ Processor/ExternalClients.h | 65 ++++ Processor/Input.cpp | 2 +- Processor/Input.h | 2 +- Processor/InputTuple.h | 2 +- Processor/Instruction.cpp | 237 ++++++++---- Processor/Instruction.h | 51 ++- Processor/Machine.cpp | 72 ++-- Processor/Machine.h | 19 +- Processor/Memory.cpp | 21 +- Processor/Memory.h | 5 +- Processor/Online-Thread.cpp | 18 +- Processor/Online-Thread.h | 2 +- Processor/PrivateOutput.cpp | 2 +- Processor/PrivateOutput.h | 2 +- Processor/Processor.cpp | 393 ++++++++++++++++--- Processor/Processor.h | 96 +++-- Processor/Program.cpp | 26 +- Processor/Program.h | 23 +- Programs/Source/aes.mpc | 2 +- Programs/Source/bankers_bonus.mpc | 110 ++++++ Programs/Source/bankers_bonus_commsec.mpc | 117 ++++++ Programs/Source/dijkstra_tutorial.mpc | 2 +- Programs/Source/fixed_point_tutorial.mpc | 2 +- Programs/Source/gale-shapley_tutorial.mpc | 2 +- Programs/Source/oram_tutorial.mpc | 2 +- Programs/Source/tpmpc_tutorial.mpc | 2 +- Programs/Source/tutorial.mpc | 2 +- Programs/Source/vickrey.mpc | 2 +- README.md | 6 +- Scripts/gen_input_f2n.cpp | 2 +- Scripts/gen_input_fp.cpp | 2 +- Scripts/run-common.sh | 11 +- Scripts/run-online.sh | 2 +- Scripts/setup-online.sh | 2 +- Server.cpp | 48 ++- Tools/Commit.cpp | 2 +- Tools/Commit.h | 2 +- Tools/Config.cpp | 183 +++++++++ Tools/Config.h | 20 + Tools/Lock.cpp | 2 +- Tools/Lock.h | 2 +- Tools/MMO.cpp | 2 +- Tools/MMO.h | 2 +- Tools/Signal.cpp | 2 +- Tools/Signal.h | 2 +- Tools/WaitQueue.h | 2 +- Tools/aes-ni.cpp | 2 +- Tools/aes.cpp | 2 +- Tools/aes.h | 2 +- Tools/ezOptionParser.h | 150 ++++---- Tools/int.h | 2 +- Tools/mkpath.cpp | 2 +- Tools/mkpath.h | 2 +- Tools/octetStream.cpp | 120 +++++- Tools/octetStream.h | 54 ++- Tools/parse.h | 49 +++ Tools/pprint.h | 13 + Tools/random.cpp | 7 +- Tools/random.h | 2 +- Tools/sha1.cpp | 2 +- Tools/sha1.h | 2 +- Tools/time-func.cpp | 2 +- Tools/time-func.h | 2 +- check-passive.cpp | 2 +- client-setup.cpp | 180 +++++++++ compile.py | 4 +- ot-offline.cpp | 2 +- tutorial.md | 3 +- 172 files changed, 4325 insertions(+), 859 deletions(-) create mode 100644 CHANGELOG.md create mode 100644 ExternalIO/README.md create mode 100644 ExternalIO/bankers-bonus-client.cpp create mode 100644 ExternalIO/bankers-bonus-commsec-client.cpp create mode 100644 Networking/STS.cpp create mode 100644 Networking/STS.h create mode 100644 Processor/Binary_File_IO.cpp create mode 100644 Processor/Binary_File_IO.h create mode 100644 Processor/ExternalClients.cpp create mode 100644 Processor/ExternalClients.h create mode 100644 Programs/Source/bankers_bonus.mpc create mode 100644 Programs/Source/bankers_bonus_commsec.mpc create mode 100644 Tools/Config.cpp create mode 100644 Tools/Config.h create mode 100644 Tools/parse.h create mode 100644 Tools/pprint.h create mode 100644 client-setup.cpp diff --git a/.gitignore b/.gitignore index c6c59de3..2d9ea905 100644 --- a/.gitignore +++ b/.gitignore @@ -4,10 +4,42 @@ Player-Data/* Prep-Data/* logs/* Language-Definition/main.pdf +keys/* # Personal CONFIG file # ############################## CONFIG.mine +config_mine.py + +# Temporary files # +################### +*.bak +*.orig +*.rej +*.tmp +callgrind.out.* + +# Vim +.*.swp +tags + +# Eclipse # +########### +.project +.cproject +.settings + +# VS Code IDE # +############### +.vscode/** + +# Temporary files # +################### +*.bak +*.orig +*.rej +*.tmp +callgrind.out.* # Compiled source # ################### @@ -25,6 +57,8 @@ Programs/Public-Input/* *.bc *.sch *.a +*.static +*.d # Packages # ############ @@ -59,6 +93,8 @@ Programs/Public-Input/* *.log *.sql *.sqlite +*.data +Persistence/* # OS generated files # ###################### @@ -69,4 +105,5 @@ Programs/Public-Input/* .Spotlight-V100 .Trashes ehthumbs.db -Thumbs.db \ No newline at end of file +Thumbs.db +**/*.x.dSYM/** diff --git a/Auth/MAC_Check.cpp b/Auth/MAC_Check.cpp index 7bff41a3..e529c9be 100644 --- a/Auth/MAC_Check.cpp +++ b/Auth/MAC_Check.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Auth/MAC_Check.h" diff --git a/Auth/MAC_Check.h b/Auth/MAC_Check.h index b61cc0c7..e290ee54 100644 --- a/Auth/MAC_Check.h +++ b/Auth/MAC_Check.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _MAC_Check #define _MAC_Check diff --git a/Auth/Subroutines.cpp b/Auth/Subroutines.cpp index efaff68f..2e633662 100644 --- a/Auth/Subroutines.cpp +++ b/Auth/Subroutines.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Auth/Subroutines.h" diff --git a/Auth/Subroutines.h b/Auth/Subroutines.h index 6680718f..50028945 100644 --- a/Auth/Subroutines.h +++ b/Auth/Subroutines.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _Subroutines #define _Subroutines diff --git a/Auth/Summer.cpp b/Auth/Summer.cpp index b2de05c8..691f36b6 100644 --- a/Auth/Summer.cpp +++ b/Auth/Summer.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Summer.cpp diff --git a/Auth/Summer.h b/Auth/Summer.h index 4a7c5404..c3a9df13 100644 --- a/Auth/Summer.h +++ b/Auth/Summer.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Summer.h diff --git a/Auth/fake-stuff.cpp b/Auth/fake-stuff.cpp index 40b0f426..7cef045a 100644 --- a/Auth/fake-stuff.cpp +++ b/Auth/fake-stuff.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Math/gf2n.h" diff --git a/Auth/fake-stuff.h b/Auth/fake-stuff.h index 3fc6783d..0695f760 100644 --- a/Auth/fake-stuff.h +++ b/Auth/fake-stuff.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _fake_stuff diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..f6336c16 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,56 @@ +The changelog explains changes pulled through from the private development repository. Bug fixes and small enchancements are committed between releases and not documented here. + +## 0.0.2 (Sep 13, 2017) + +### Support sockets based external client input and output to a SPDZ MPC program. + +See the [ExternalIO directory](./ExternalIO/README.md) for more details and examples. + +Note that [libsodium](https://download.libsodium.org/doc/) is now a dependency on the SPDZ build. + +Added compiler instructions: + +* LISTEN +* ACCEPTCLIENTCONNECTION +* CONNECTIPV4 +* WRITESOCKETSHARE +* WRITESOCKETINT + +Removed instructions: + +* OPENSOCKET +* CLOSESOCKET + +Modified instructions: + +* READSOCKETC +* READSOCKETS +* READSOCKETINT +* WRITESOCKETC +* WRITESOCKETS + +Support secure external client input and output with new instructions: + +* READCLIENTPUBLICKEY +* INITSECURESOCKET +* RESPSECURESOCKET + +### Read/Write secret shares to disk to support persistence in a SPDZ MPC program. + +Added compiler instructions: + +* READFILESHARE +* WRITEFILESHARE + +### Other instructions + +Added compiler instructions: + +* DIGESTC - Clear truncated hash computation +* PRINTINT - Print register value + +## 0.0.1 (Sep 2, 2016) + +### Initial Release + +* See `README.md` and `tutorial.md`. diff --git a/CONFIG b/CONFIG index 8c763a37..32f940fb 100644 --- a/CONFIG +++ b/CONFIG @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt ROOT = . @@ -28,7 +28,7 @@ endif # Default is 3, which suffices for 128-bit p # MOD = -DMAX_MOD_SZ=3 -LDLIBS = -lmpirxx -lmpir $(MY_LDLIBS) -lm -lpthread +LDLIBS = -lmpirxx -lmpir -lsodium $(MY_LDLIBS) -lm -lpthread ifeq ($(USE_NTL),1) LDLIBS := -lntl $(LDLIBS) @@ -40,7 +40,7 @@ LDLIBS += -lrt endif CXX = g++ -CFLAGS = $(MY_CFLAGS) -g -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(MEMPROTECT) $(GF2N_LONG) $(PREP_DIR) -maes -mpclmul -msse4.1 $(ARCH) +CFLAGS = $(MY_CFLAGS) -g -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(MEMPROTECT) $(GF2N_LONG) $(PREP_DIR) -maes -mpclmul -msse4.1 $(ARCH) --std=c++11 -Werror CPPFLAGS = $(CFLAGS) LD = g++ diff --git a/Check-Offline.cpp b/Check-Offline.cpp index 44f5f75d..f28e86ba 100644 --- a/Check-Offline.cpp +++ b/Check-Offline.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Check-Offline.cpp @@ -62,21 +62,27 @@ void check_bits(const T& key,int N,vector& dataF,DataFieldType fiel vector > Sa(N),Sb(N),Sc(N); int n = 0; - while (!dataF[0]->eof(DATA_BIT)) - { - for (int i = 0; i < N; i++) - dataF[i]->get_one(field_type, DATA_BIT, Sa[i]); - check_share(Sa, a, mac, N, key); - - if (!(a.is_zero() || a.is_one())) + try { + while (!dataF[0]->eof(DATA_BIT)) { - cout << n << ": " << a << " neither 0 or 1" << endl; - throw bad_value(); - } - n++; - } + for (int i = 0; i < N; i++) + dataF[i]->get_one(field_type, DATA_BIT, Sa[i]); + check_share(Sa, a, mac, N, key); - cout << n << " bits of type " << T::type_string() << endl; + if (!(a.is_zero() || a.is_one())) + { + cout << n << ": " << a << " neither 0 or 1" << endl; + throw bad_value(); + } + n++; + } + + cout << n << " bits of type " << T::type_string() << endl; + } + catch (exception& e) + { + cout << "Error with bits of type " << T::type_string() << endl; + } } template @@ -85,20 +91,26 @@ void check_inputs(const T& key,int N,vector& dataF) T a, mac, x; vector< Share > Sa(N); - for (int player = 0; player < N; player++) - { - int n = 0; - while (!dataF[0]->input_eof(player)) - { - for (int i = 0; i < N; i++) - dataF[i]->get_input(Sa[i], x, player); - check_share(Sa, a, mac, N, key); - if (!a.equal(x)) - throw bad_value(); - n++; - } - cout << n << " input masks for player " << player << " of type " << T::type_string() << endl; - } + try { + for (int player = 0; player < N; player++) + { + int n = 0; + while (!dataF[0]->input_eof(player)) + { + for (int i = 0; i < N; i++) + dataF[i]->get_input(Sa[i], x, player); + check_share(Sa, a, mac, N, key); + if (!a.equal(x)) + throw bad_value(); + n++; + } + cout << n << " input masks for player " << player << " of type " << T::type_string() << endl; + } + } + catch (exception& e) + { + cout << "Error with inputs of type " << T::type_string() << endl; + } } int main(int argc, const char** argv) diff --git a/Compiler/__init__.py b/Compiler/__init__.py index 417a7bff..2879d6cb 100644 --- a/Compiler/__init__.py +++ b/Compiler/__init__.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt import compilerLib, program, instructions, types, library, floatingpoint import inspect diff --git a/Compiler/allocator.py b/Compiler/allocator.py index 5d07e844..91095dd9 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt import itertools, time from collections import defaultdict, deque @@ -11,20 +11,20 @@ import Compiler.graph import Compiler.program import heapq, itertools import operator +import sys class StraightlineAllocator: """Allocate variables in a straightline program using n registers. It is based on the precondition that every register is only defined once.""" def __init__(self, n): - self.free = defaultdict(set) self.alloc = {} self.usage = Compiler.program.RegType.create_dict(lambda: 0) self.defined = {} self.dealloc = set() self.n = n - def alloc_reg(self, reg, persistent_allocation): + def alloc_reg(self, reg, free): base = reg.vectorbase if base in self.alloc: # already allocated @@ -32,8 +32,8 @@ class StraightlineAllocator: reg_type = reg.reg_type size = base.size - if not persistent_allocation and self.free[reg_type, size]: - res = self.free[reg_type, size].pop() + if free[reg_type, size]: + res = free[reg_type, size].pop() else: if self.usage[reg_type] < self.n: res = self.usage[reg_type] @@ -48,7 +48,7 @@ class StraightlineAllocator: else: base.i = self.alloc[base] - def dealloc_reg(self, reg, inst): + def dealloc_reg(self, reg, inst, free): self.dealloc.add(reg) base = reg.vectorbase @@ -57,14 +57,14 @@ class StraightlineAllocator: if i not in self.dealloc: # not all vector elements ready for deallocation return - self.free[reg.reg_type, base.size].add(self.alloc[base]) + free[reg.reg_type, base.size].add(self.alloc[base]) if inst.is_vec() and base.vector: for i in base.vector: self.defined[i] = inst else: self.defined[reg] = inst - def process(self, program, persistent_allocation=False): + def process(self, program, alloc_pool): for k,i in enumerate(reversed(program)): unused_regs = [] for j in i.get_def(): @@ -75,7 +75,7 @@ class StraightlineAllocator: (j,i,format_trace(i.caller))) else: # unused register - self.alloc_reg(j, persistent_allocation) + self.alloc_reg(j, alloc_pool) unused_regs.append(j) if unused_regs and len(unused_regs) == len(i.get_def()): # only report if all assigned registers are unused @@ -83,9 +83,9 @@ class StraightlineAllocator: (unused_regs,i,format_trace(i.caller)) for j in i.get_used(): - self.alloc_reg(j, persistent_allocation) + self.alloc_reg(j, alloc_pool) for j in i.get_def(): - self.dealloc_reg(j, i) + self.dealloc_reg(j, i, alloc_pool) if k % 1000000 == 0 and k > 0: print "Allocated registers for %d instructions at" % k, time.asctime() @@ -98,7 +98,7 @@ class StraightlineAllocator: return self.usage -def determine_scope(block): +def determine_scope(block, options): last_def = defaultdict(lambda: -1) used_from_scope = set() @@ -120,12 +120,16 @@ def determine_scope(block): print '\tline %d: %s' % (n, instr) print '\tinstruction trace: %s' % format_trace(instr.caller, '\t\t') print '\tregister trace: %s' % format_trace(reg.caller, '\t\t') + if options.stop: + sys.exit(1) def write(reg, n): if last_def[reg] != -1: print 'Warning: double write at register', reg print '\tline %d: %s' % (n, instr) print '\ttrace: %s' % format_trace(instr.caller, '\t\t') + if options.stop: + sys.exit(1) last_def[reg] = n for n,instr in enumerate(block.instructions): diff --git a/Compiler/comparison.py b/Compiler/comparison.py index b3df11b4..8c6be393 100644 --- a/Compiler/comparison.py +++ b/Compiler/comparison.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt """ Functions for secure comparison of GF(p) types. @@ -68,10 +68,10 @@ def divide_by_two(res, x): """ Faster clear division by two using a cached value of 2^-1 mod p """ from program import Program import types - tape = Program.prog.curr_block - if tape not in inverse_of_two: - inverse_of_two[tape] = types.cint(1) / 2 - mulc(res, x, inverse_of_two[tape]) + block = Program.prog.curr_block + if len(inverse_of_two) == 0 or block not in inverse_of_two: + inverse_of_two[block] = types.cint(1) / 2 + mulc(res, x, inverse_of_two[block]) def LTZ(s, a, k, kappa): """ diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index 3df53c67..32d7573b 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt from Compiler.program import Program from Compiler.config import * diff --git a/Compiler/config.py b/Compiler/config.py index 42248b1a..3a237cc5 100644 --- a/Compiler/config.py +++ b/Compiler/config.py @@ -1,5 +1,5 @@ -# (C) 2016 University of Bristol. See License.txt - +# (C) 2017 University of Bristol. See License.txt + from collections import defaultdict #INIT_REG_MAX = 655360 diff --git a/Compiler/dijkstra.py b/Compiler/dijkstra.py index b6ad1980..2ca13df7 100644 --- a/Compiler/dijkstra.py +++ b/Compiler/dijkstra.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt from Compiler.oram import * diff --git a/Compiler/exceptions.py b/Compiler/exceptions.py index eb1bf519..8373b1d6 100644 --- a/Compiler/exceptions.py +++ b/Compiler/exceptions.py @@ -1,5 +1,5 @@ -# (C) 2016 University of Bristol. See License.txt - +# (C) 2017 University of Bristol. See License.txt + class CompilerError(Exception): """Base class for compiler exceptions.""" pass diff --git a/Compiler/floatingpoint.py b/Compiler/floatingpoint.py index 9ee2a868..0e575925 100644 --- a/Compiler/floatingpoint.py +++ b/Compiler/floatingpoint.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt from math import log, floor, ceil from Compiler.instructions import * @@ -404,8 +404,8 @@ def TruncPr(a, k, m, kappa=None): return shift_two(a, m) if kappa is None: - kappa = 40 - + kappa = 40 + b = two_power(k-1) + a r_prime, r_dprime = types.sint(), types.sint() comparison.PRandM(r_dprime, r_prime, [types.sint() for i in range(m)], diff --git a/Compiler/graph.py b/Compiler/graph.py index 97c152af..7f8e7f20 100644 --- a/Compiler/graph.py +++ b/Compiler/graph.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt import heapq from Compiler.exceptions import * diff --git a/Compiler/gs.py b/Compiler/gs.py index 8823c23b..510f27c9 100644 --- a/Compiler/gs.py +++ b/Compiler/gs.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt import sys import math diff --git a/Compiler/instructions.py b/Compiler/instructions.py index 63ad1173..1453152e 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt """ This module is for classes of actual assembly instructions. @@ -446,6 +446,13 @@ class legendrec(base.Instruction): code = base.opcodes['LEGENDREC'] arg_format = ['cw','c'] +@base.vectorize +class digestc(base.Instruction): + r""" Clear truncated hash computation, $c_i = H(c_j)[bytes]$. """ + __slots__ = [] + code = base.opcodes['DIGESTC'] + arg_format = ['cw','c','int'] + ### ### Bitwise operations ### @@ -915,6 +922,11 @@ class print_float_plain(base.IOInstruction): code = base.opcodes['PRINTFLOATPLAIN'] arg_format = ['c', 'c', 'c', 'c'] +class print_int(base.IOInstruction): + r""" Print only the value of register \verb|ci| to stdout. """ + __slots__ = [] + code = base.opcodes['PRINTINT'] + arg_format = ['ci'] class print_char(base.IOInstruction): r""" Print a single character to stdout. """ @@ -952,43 +964,156 @@ class pubinput(base.PublicFileIOInstruction): @base.vectorize class readsocketc(base.IOInstruction): - """Read an int from socket and store in register""" + """Read a variable number of clear GF(p) values from socket for a specified client id and store in registers""" __slots__ = [] code = base.opcodes['READSOCKETC'] - arg_format = ['ciw', 'int'] + arg_format = tools.chain(['ci'], itertools.repeat('cw')) + + def has_var_args(self): + return True @base.vectorize class readsockets(base.IOInstruction): - """Read a secret share + MAC from socket and store in register""" + """Read a variable number of secret shares + MACs from socket for a client id and store in registers""" __slots__ = [] code = base.opcodes['READSOCKETS'] - arg_format = ['sw', 'int'] + arg_format = tools.chain(['ci'], itertools.repeat('sw')) + + def has_var_args(self): + return True + +@base.vectorize +class readsocketint(base.IOInstruction): + """Read variable number of 32-bit int from socket for a client id and store in registers""" + __slots__ = [] + code = base.opcodes['READSOCKETINT'] + arg_format = tools.chain(['ci'], itertools.repeat('ciw')) + + def has_var_args(self): + return True @base.vectorize class writesocketc(base.IOInstruction): - """Write int from register into socket""" + """ + Write a variable number of clear GF(p) values from registers into socket + for a specified client id, message_type + """ __slots__ = [] code = base.opcodes['WRITESOCKETC'] - arg_format = ['ci', 'int'] + arg_format = tools.chain(['ci', 'int'], itertools.repeat('c')) + + def has_var_args(self): + return True @base.vectorize class writesockets(base.IOInstruction): - """Write secret share + MAC from register into socket""" + """ + Write a variable number of secret shares + MACs from registers into a socket + for a specified client id, message_type + """ __slots__ = [] code = base.opcodes['WRITESOCKETS'] - arg_format = ['s', 'int'] + arg_format = tools.chain(['ci', 'int'], itertools.repeat('s')) -class opensocket(base.IOInstruction): - """Open a server socket connection at the given port number""" + def has_var_args(self): + return True + +@base.vectorize +class writesocketshare(base.IOInstruction): + """ + Write a variable number of secret shares (without MACs) from registers into socket + for a specified client id, message_type + """ __slots__ = [] - code = base.opcodes['OPENSOCKET'] + code = base.opcodes['WRITESOCKETSHARE'] + arg_format = tools.chain(['ci', 'int'], itertools.repeat('s')) + + def has_var_args(self): + return True + +@base.vectorize +class writesocketint(base.IOInstruction): + """ + Write a variable number of 32-bit ints from registers into socket + for a specified client id, message_type + """ + __slots__ = [] + code = base.opcodes['WRITESOCKETINT'] + arg_format = tools.chain(['ci', 'int'], itertools.repeat('ci')) + + def has_var_args(self): + return True + +class listen(base.IOInstruction): + """Open a server socket on a party specific port number and listen for client connections (non-blocking)""" + __slots__ = [] + code = base.opcodes['LISTEN'] arg_format = ['int'] -class closesocket(base.IOInstruction): - """Close a server socket connection""" +class acceptclientconnection(base.IOInstruction): + """Wait for a connection at the given port and write socket handle to register """ __slots__ = [] - code = base.opcodes['CLOSESOCKET'] - arg_format = [] + code = base.opcodes['ACCEPTCLIENTCONNECTION'] + arg_format = ['ciw', 'int'] + +class connectipv4(base.IOInstruction): + """Connect to server at IPv4 address in register \verb|cj| at given port. Write socket handle to register \verb|ci|""" + __slots__ = [] + code = base.opcodes['CONNECTIPV4'] + arg_format = ['ciw', 'ci', 'int'] + +class readclientpublickey(base.IOInstruction): + """Read a client public key as 8 32-bit ints for a specified client id""" + __slots__ = [] + code = base.opcodes['READCLIENTPUBLICKEY'] + arg_format = tools.chain(['ci'], itertools.repeat('ci')) + + def has_var_args(self): + return True + +class initsecuresocket(base.IOInstruction): + """Read a client public key as 8 32-bit ints for a specified client id, + negotiate a shared key via STS and use it for replay resistant comms""" + __slots__ = [] + code = base.opcodes['INITSECURESOCKET'] + arg_format = tools.chain(['ci'], itertools.repeat('ci')) + + def has_var_args(self): + return True + +class respsecuresocket(base.IOInstruction): + """Read a client public key as 8 32-bit ints for a specified client id, + negotiate a shared key via STS and use it for replay resistant comms""" + __slots__ = [] + code = base.opcodes['RESPSECURESOCKET'] + arg_format = tools.chain(['ci'], itertools.repeat('ci')) + + def has_var_args(self): + return True + +class writesharestofile(base.IOInstruction): + """Write shares to a file""" + __slots__ = [] + code = base.opcodes['WRITEFILESHARE'] + arg_format = itertools.repeat('s') + + def has_var_args(self): + return True + +class readsharesfromfile(base.IOInstruction): + """ + Read shares from a file. Pass in start posn, return finish posn, shares. + Finish posn will return: + -2 file not found + -1 eof reached + position in file after read finished + """ + __slots__ = [] + code = base.opcodes['READFILESHARE'] + arg_format = tools.chain(['ci', 'ciw'], itertools.repeat('sw')) + + def has_var_args(self): + return True @base.gf2n @base.vectorize @@ -1173,7 +1298,7 @@ class gconvgf2n(base.Instruction): @base.gf2n @base.vectorize -class startopen(base.Instruction): +class startopen(base.VarArgsInstruction): """ Start opening secret register $s_i$. """ __slots__ = [] code = base.opcodes['STARTOPEN'] @@ -1183,12 +1308,9 @@ class startopen(base.Instruction): for arg in self.args[::-1]: program.curr_block.open_queue.append(arg.value) - def has_var_args(self): - return True - @base.gf2n @base.vectorize -class stopopen(base.Instruction): +class stopopen(base.VarArgsInstruction): """ Store previous opened value in $c_i$. """ __slots__ = [] code = base.opcodes['STOPOPEN'] @@ -1198,9 +1320,6 @@ class stopopen(base.Instruction): for arg in self.args: arg.value = program.curr_block.open_queue.pop() - def has_var_args(self): - return True - ### ### CISC-style instructions ### diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index 88ab56b1..7fe003a2 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt import itertools from random import randint @@ -78,6 +78,7 @@ opcodes = dict( MODC = 0x36, MODCI = 0x37, LEGENDREC = 0x38, + DIGESTC = 0x39, GMULBITC = 0x136, GMULBITM = 0x137, # Open @@ -95,13 +96,18 @@ opcodes = dict( # Input INPUT = 0x60, STARTINPUT = 0x61, - STOPINPUT = 0x62, + STOPINPUT = 0x62, READSOCKETC = 0x63, READSOCKETS = 0x64, WRITESOCKETC = 0x65, WRITESOCKETS = 0x66, - OPENSOCKET = 0x67, - CLOSESOCKET = 0x68, + READSOCKETINT = 0x69, + WRITESOCKETINT = 0x6a, + WRITESOCKETSHARE = 0x6b, + LISTEN = 0x6c, + ACCEPTCLIENTCONNECTION = 0x6d, + CONNECTIPV4 = 0x6e, + READCLIENTPUBLICKEY = 0x6f, # Bitwise logic ANDC = 0x70, XORC = 0x71, @@ -131,6 +137,7 @@ opcodes = dict( SUBINT = 0x9C, MULINT = 0x9D, DIVINT = 0x9E, + PRINTINT = 0x9F, # Conversion CONVINT = 0xC0, CONVMODP = 0xC1, @@ -149,8 +156,13 @@ opcodes = dict( PRINTCHRINT = 0xBA, PRINTSTRINT = 0xBB, PRINTFLOATPLAIN = 0xBC, + WRITEFILESHARE = 0xBD, + READFILESHARE = 0xBE, GBITDEC = 0x184, GBITCOM = 0x185, + # Secure socket + INITSECURESOCKET = 0x1BA, + RESPSECURESOCKET = 0x1BB ) @@ -329,13 +341,11 @@ class RegType(object): @staticmethod def create_dict(init_value_fn): """ Create a dictionary with all the RegTypes as keys """ - return { - RegType.ClearModp : init_value_fn(), - RegType.SecretModp : init_value_fn(), - RegType.ClearGF2N : init_value_fn(), - RegType.SecretGF2N : init_value_fn(), - RegType.ClearInt : init_value_fn(), - } + res = defaultdict(init_value_fn) + # initialization for legacy + for t in RegType.Types: + res[t] + return res class ArgFormat(object): @classmethod @@ -481,7 +491,7 @@ class Instruction(object): def get_encoding(self): enc = int_to_bytes(self.get_code()) - # add the number of registers to a start/stop open instruction + # add the number of registers if instruction flagged as has var args if self.has_var_args(): enc += int_to_bytes(len(self.args)) for arg,format in zip(self.args, self.arg_format): @@ -508,6 +518,8 @@ class Instruction(object): except ArgumentError as e: raise CompilerError('Invalid argument "%s" to instruction: %s' % (e.arg, self) + '\n' + e.msg) + except KeyError as e: + raise CompilerError('Incorrect number of arguments for instruction %s' % (self)) def get_used(self): """ Return the set of registers that are read in this instruction. """ @@ -537,8 +549,15 @@ class Instruction(object): def add_usage(self, req_node): pass + # String version of instruction attempting to replicate encoded version def __str__(self): - return self.__class__.__name__ + ' ' + self.get_pre_arg() + ', '.join(str(a) for a in self.args) + + if self.has_var_args(): + varargCount = str(len(self.args)) + ', ' + else: + varargCount = '' + + return self.__class__.__name__ + ' ' + self.get_pre_arg() + varargCount + ', '.join(str(a) for a in self.args) def __repr__(self): return self.__class__.__name__ + '(' + self.get_pre_arg() + ','.join(str(a) for a in self.args) + ')' @@ -725,6 +744,11 @@ class JumpInstruction(Instruction): return self.args[self.jump_arg] +class VarArgsInstruction(Instruction): + def has_var_args(self): + return True + + class CISC(Instruction): """ Base class for a CISC instruction. diff --git a/Compiler/library.py b/Compiler/library.py index 3b99e496..7e03620a 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt from Compiler.types import cint,sint,cfix,sfix,sfloat,MPCThread,Array,MemValue,cgf2n,sgf2n,_number,_mem,_register,regint,Matrix,_types, cfloat from Compiler.instructions import * @@ -72,9 +72,7 @@ def print_str(s, *args): else: val = args[i] if isinstance(val, program.Tape.Register): - if val.reg_type == 'ci': - cint(val).print_reg_plain() - elif val.is_clear: + if val.is_clear: val.print_reg_plain() else: raise CompilerError('Cannot print secret value:', args[i]) @@ -355,7 +353,7 @@ class FunctionBlock(Function): parent_node = get_tape().req_node get_tape().open_scope(lambda x: x[0], None, 'begin-' + self.name) block = get_tape().active_basicblock - block.persistent_allocation = True + block.alloc_pool = defaultdict(set) del parent_node.children[-1] self.node = get_tape().req_node print 'Compiling function', self.name diff --git a/Compiler/oram.py b/Compiler/oram.py index 33537329..99dccc81 100644 --- a/Compiler/oram.py +++ b/Compiler/oram.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt import random import math @@ -15,8 +15,6 @@ from Compiler import floatingpoint,comparison,permutation from Compiler.util import * -sys.setrecursionlimit(1000000) - print_access = False sint_bit_length = 6 max_demux_bits = 3 @@ -40,12 +38,6 @@ def maybe_stop_timer(n): if detailed_timing: stop_timer(n) -def reveal(a): - try: - return a.reveal() - except AttributeError: - return a - class Block(object): def __init__(self, value, lengths): self.value = self.value_type.hard_conv(value) @@ -53,8 +45,7 @@ class Block(object): def get_slice(self): res = [] for length,start in zip(self.lengths, series(self.lengths)): - res.append(sum(b << i for i,b in \ - enumerate(self.bits[start:start+length]))) + res.append(util.bit_compose((self.bits[start:start+length]))) return res def __repr__(self): return '<' + str(self.value) + '>' @@ -150,11 +141,17 @@ class gf2nBlock(Block): self.value = self.lower + value * self.adjust + upper return self +block_types = { sint: intBlock, + sgf2n: gf2nBlock, +} + def get_block(x, y, *args): - if isinstance(x, sgf2n) or isinstance(y, sgf2n): - return gf2nBlock(x, y, *args) - else: - return intBlock(x, y, *args) + for t in block_types: + if isinstance(x, t): + return block_types[t](x, y, *args) + elif isinstance(y, t): + return block_types[t](x, y, *args) + raise CompilerError('appropiate block type not found') def get_bit(x, index, bit_length): if isinstance(x, sgf2n): @@ -242,14 +239,14 @@ class Value(object): return (1 - self.empty) * (other == self.value) return (1 - self.empty) * self.value.equal(other, length) def reveal(self): - return Value(self.value.reveal(), self.empty.reveal()) + return Value(reveal(self.value), reveal(self.empty)) def output(self): - @if_e(self.empty) - def f(): - print_str('<>') - @else_ - def f(): - print_str('<%s>', self.value) + # @if_e(self.empty) + # def f(): + # print_str('<>') + # @else_ + # def f(): + print_str('<%s:%s>', self.empty, self.value) def __index__(self): return int(self.value) def __repr__(self): @@ -344,12 +341,13 @@ class Entry(object): def reveal(self): return Entry(x.reveal() for x in self) def output(self): - @if_e(self.is_empty) - def f(): - print_str('{empty=%s}', self.is_empty) - @else_ - def f(): - print_str('{%s: %s}', self.v, self.x) + # @if_e(self.is_empty) + # def f(): + # print_str('{empty=%s}', self.is_empty) + # @else_ + # def f(): + # print_str('{%s: %s}', self.v, self.x)\ + print_str('{%s: %s,empty=%s}', self.v, self.x, self.is_empty) class RefRAM(object): """ RAM reference. """ @@ -362,8 +360,8 @@ class RefRAM(object): crash() self.size = oram.bucket_size self.entry_type = oram.entry_type - self.l = [Array(self.size, t, array.address + \ - index * oram.bucket_size) \ + self.l = [t.dynamic_array(self.size, t, array.address + \ + index * oram.bucket_size) \ for t,array in zip(self.entry_type,oram.ram.l)] self.index = index def init_mem(self, empty_entry): @@ -410,7 +408,7 @@ class RefRAM(object): Program.prog.curr_tape.start_new_basicblock() return res def output(self): - self.reveal().print_reg() + print_ln('%s', [x.reveal() for x in self]) def print_reg(self): print_ln('listing of RAM at index %s', self.index) Program.prog.curr_tape.start_new_basicblock() @@ -428,7 +426,7 @@ class RAM(RefRAM): #print_reg(cint(0), 'r in') self.size = size self.entry_type = entry_type - self.l = [Array(self.size, t) for t in entry_type] + self.l = [t.dynamic_array(self.size, t) for t in entry_type] self.index = index class AbstractORAM(object): @@ -902,7 +900,7 @@ class List(object): def __init__(self, size, value_type, value_length=1, init_rounds=None): self.value_type = value_type self.value_length = value_length - self.l = [Array(size, value_type) \ + self.l = [value_type.dynamic_array(size, value_type) \ for i in range(value_length)] for l in self.l: l.assign_all(0) @@ -1322,8 +1320,10 @@ def get_value_size(value_type): """ Return element size. """ if value_type == sgf2n: return Program.prog.galois_length - else: + elif value_type == sint: return 127 - Program.prog.security + else: + return value_type.max_length def get_parallel(index_size, value_type, value_length): """ Returning the number of parallel readings feasible, based on @@ -1410,7 +1410,7 @@ class PackedIndexStructure(object): else: self.l[i] = [0] * self.elements_per_block time() - print_ln('packed ORAM init %s/%s', cint(i), real_init_rounds) + print_ln('packed ORAM init %s/%s', i, real_init_rounds) print 'index initialized, size', size def translate_index(self, index): """ Bit slicing *index* according parameters. Output is tuple @@ -1425,18 +1425,17 @@ class PackedIndexStructure(object): return 0, b, c else: return (index - rem) / self.entries_per_block, b, c - elif self.value_type == sgf2n: + else: index_bits = bit_decompose(index, log2(self.size)) l1 = self.log_entries_per_element l2 = self.log_entries_per_block - c = sum(bit << i for i,bit in enumerate(index_bits[:l1])) - b = sum(bit << i for i,bit in enumerate(index_bits[l1:l2])) + c = bit_compose(index_bits[:l1]) + b = bit_compose(index_bits[l1:l2]) if self.small: return 0, b, c else: - a = sum(bit << i for i,bit in enumerate(index_bits[l2:])) + a = bit_compose(index_bits[l2:]) return a, b, c - else: raise CompilerError('Cannot process indices of type', self.value_type) class Slicer(object): def __init__(self, pack, index): @@ -1624,11 +1623,11 @@ class OptimalPackedORAMWithEmpty(PackedORAMWithEmpty): def test_oram(oram_type, N, value_type=sint, iterations=100): oram = oram_type(N, value_type=value_type, entry_size=32, init_rounds=0) print 'initialized' - print_reg(cint(0), 'init') + print_ln('initialized') stop_timer() # synchronize Program.prog.curr_tape.start_new_basicblock(name='sync') - sint(0).reveal() + value_type(0).reveal() Program.prog.curr_tape.start_new_basicblock(name='sync') start_timer() #oram[value_type(0)] = -1 diff --git a/Compiler/path_oram.py b/Compiler/path_oram.py index 3e02d14a..a6cc52b7 100644 --- a/Compiler/path_oram.py +++ b/Compiler/path_oram.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt if '_Array' not in dir(): from oram import * @@ -76,7 +76,10 @@ def XOR(a, b): elif isinstance(a, sgf2n) or isinstance(b, sgf2n): return a + b else: - return a + b - 2*a*b + try: + return a ^ b + except TypeError: + return a + b - 2*a*b def pow2_eq(a, i, bit_length=40): """ Test for equality with 2**i, when a is a power of 2 (gf2n only)""" diff --git a/Compiler/permutation.py b/Compiler/permutation.py index 7c98becc..394980f2 100644 --- a/Compiler/permutation.py +++ b/Compiler/permutation.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt from random import randint import math diff --git a/Compiler/program.py b/Compiler/program.py index 109772a3..843de14a 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt from Compiler.config import * from Compiler.exceptions import * @@ -65,6 +65,7 @@ class Program(object): self.n_threads = 1 self.free_threads = set() self.public_input_file = open(self.programs_dir + '/Public-Input/%s' % self.name, 'w') + self.types = {} Program.prog = self self.reset_values() @@ -230,7 +231,7 @@ class Program(object): # runtime doesn't support 'new-style' parallelism yet old_style = True - nonempty_tapes = [t for t in self.tapes if not t.is_empty()] + nonempty_tapes = [t for t in self.tapes] sch_filename = self.programs_dir + '/Schedules/%s.sch' % self.name sch_file = open(sch_filename, 'w') @@ -327,12 +328,15 @@ class Program(object): """ The basic block that is currently being created. """ return self.curr_tape.active_basicblock - def malloc(self, size, mem_type): + def malloc(self, size, mem_type, reg_type=None): """ Allocate memory from the top """ if size == 0: return if isinstance(mem_type, type): + self.types[mem_type.reg_type] = mem_type mem_type = mem_type.reg_type + elif reg_type is not None: + self.types[mem_type] = reg_type key = size, mem_type if self.free_mem_blocks[key]: addr = self.free_mem_blocks[key].pop() @@ -346,7 +350,8 @@ class Program(object): def free(self, addr, mem_type): """ Free memory """ - if self.curr_block.persistent_allocation: + if self.curr_block.alloc_pool \ + is not self.curr_tape.basicblocks[0].alloc_pool: raise CompilerError('Cannot free memory within function block') size = self.allocated_mem_blocks.pop((addr,mem_type)) self.free_mem_blocks[size,mem_type].add(addr) @@ -354,10 +359,15 @@ class Program(object): def finalize_memory(self): import library self.curr_tape.start_new_basicblock(None, 'memory-usage') + # reset register counter to 0 + self.curr_tape.init_registers() for mem_type,size in self.allocated_mem.items(): if size: #print "Memory of type '%s' of size %d" % (mem_type, size) - library.load_mem(size - 1, mem_type) + if mem_type in self.types: + self.types[mem_type].load_mem(size - 1, mem_type) + else: + library.load_mem(size - 1, mem_type) def public_input(self, x): self.public_input_file.write('%s\n' % str(x)) @@ -407,9 +417,9 @@ class Tape: self.children = [] if scope is not None: scope.children.append(self) - self.persistent_allocation = scope.persistent_allocation + self.alloc_pool = scope.alloc_pool else: - self.persistent_allocation = False + self.alloc_pool = defaultdict(set) def new_reg(self, reg_type, size=None): return self.parent.new_reg(reg_type, size=size) @@ -511,7 +521,7 @@ class Tape: print 'Processing tape', self.name, 'with %d blocks' % len(self.basicblocks) for block in self.basicblocks: - al.determine_scope(block) + al.determine_scope(block, options) # merge open instructions # need to do this if there are several blocks @@ -563,15 +573,15 @@ class Tape: # allocate registers reg_counts = self.count_regs() - if filter(lambda n: n > REG_MAX, reg_counts) and not options.noreallocate: - print 'Tape register usage:' + if not options.noreallocate: + print 'Tape register usage:', reg_counts print 'modp: %d clear, %d secret' % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp]) print 'GF2N: %d clear, %d secret' % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N]) print 'Re-allocating...' allocator = al.StraightlineAllocator(REG_MAX) def alloc_loop(block): for reg in block.used_from_scope: - allocator.alloc_reg(reg, block.persistent_allocation) + allocator.alloc_reg(reg, block.alloc_pool) for child in block.children: if child.instructions: alloc_loop(child) @@ -584,7 +594,7 @@ class Tape: if isinstance(jump, (int,long)) and jump < 0 and \ block.exit_block.scope is not None: alloc_loop(block.exit_block.scope) - allocator.process(block.instructions, block.persistent_allocation) + allocator.process(block.instructions, block.alloc_pool) # offline data requirements print 'Compile offline data requirements...' @@ -614,10 +624,11 @@ class Tape: if not self.is_empty(): # bit length requirement - self.basicblocks[-1].instructions.append( - Compiler.instructions.reqbl(self.req_bit_length['p'], add_to_prog=False)) - self.basicblocks[-1].instructions.append( - Compiler.instructions.greqbl(self.req_bit_length['2'], add_to_prog=False)) + for x in ('p', '2'): + if self.req_bit_length['p']: + self.basicblocks[-1].instructions.append( + Compiler.instructions.reqbl(self.req_bit_length['p'], + add_to_prog=False)) print 'Tape requires prime bit length', self.req_bit_length['p'] print 'Tape requires galois bit length', self.req_bit_length['2'] diff --git a/Compiler/tools.py b/Compiler/tools.py index d30891be..b36ede6b 100644 --- a/Compiler/tools.py +++ b/Compiler/tools.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt import itertools diff --git a/Compiler/types.py b/Compiler/types.py index 1f86e321..8ac97b53 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt from Compiler.program import Tape from Compiler.exceptions import * @@ -11,6 +11,20 @@ import util import operator +class ClientMessageType: + """ Enum to define type of message sent to external client. Each may be array of length n.""" + # No client message type to be sent, for backwards compatibility - virtual machine relies on this value + NoType = 0 + # 3 x sint x n + TripleShares = 1 + # 1 x cint x n + ClearModpInt = 2 + # 1 x regint x n + Int32 = 3 + # 1 x cint (fixed point left shifted by precision) x n + ClearModpFix = 4 + + class MPCThread(object): def __init__(self, target, name, args = [], runtime_arg = None): """ Create a thread from a callable object. """ @@ -97,6 +111,10 @@ def read_mem_value(operation): class _number(object): + @staticmethod + def bit_compose(bits): + return sum(b << i for i,b in enumerate(bits)) + def square(self): return self * self @@ -152,7 +170,6 @@ class _gf2n(object): else: return tuple(t.conv(r) for r in res) - class _register(Tape.Register, _number): MemValue = staticmethod(lambda value: MemValue(value)) @@ -340,6 +357,9 @@ class _clear(_register): __rxor__ = __xor__ __ror__ = __or__ + def reveal(self): + return self + class cint(_clear, _int): " Clear mod p integer type. """ @@ -348,7 +368,25 @@ class cint(_clear, _int): reg_type = 'c' @vectorized_classmethod - def load_mem(cls, address): + def read_from_socket(cls, client_id, n=1): + res = [cls() for i in range(n)] + readsocketc(client_id, *res) + if n == 1: + return res[0] + else: + return res + + @vectorize + def write_to_socket(self, client_id, message_type=ClientMessageType.NoType): + writesocketc(client_id, message_type, self) + + @vectorized_classmethod + def write_to_socket(self, client_id, values, message_type=ClientMessageType.NoType): + """ Send a list of modp integers to socket """ + writesocketc(client_id, message_type, *values) + + @vectorized_classmethod + def load_mem(cls, address, mem_type=None): return cls._load_mem(address, ldmc, ldmci) def store_in_mem(self, address): @@ -464,6 +502,13 @@ class cint(_clear, _int): legendrec(res, self) return res + def digest(self, num_bytes): + res = cint() + digestc(res, self, num_bytes) + return res + + + class cgf2n(_clear, _gf2n): __slots__ = [] @@ -478,7 +523,7 @@ class cgf2n(_clear, _gf2n): return res @vectorized_classmethod - def load_mem(cls, address): + def load_mem(cls, address, mem_type=None): return cls._load_mem(address, gldmc, gldmci) def store_in_mem(self, address): @@ -560,7 +605,7 @@ class regint(_register, _int): protectmemint(regint(start), regint(end)) @vectorized_classmethod - def load_mem(cls, address): + def load_mem(cls, address, mem_type=None): return cls._load_mem(address, ldmint, ldminti) def store_in_mem(self, address): @@ -581,14 +626,40 @@ class regint(_register, _int): return res @vectorized_classmethod - def read_from_socket(cls): - res = cls() - readsocketc(res,0) + def read_from_socket(cls, client_id, n=1): + """ Receive n register values from socket """ + res = [cls() for i in range(n)] + readsocketint(client_id, *res) + if n == 1: + return res[0] + else: + return res + + @vectorized_classmethod + def read_client_public_key(cls, client_id): + """ Receive 8 register values from socket containing client public key.""" + res = [cls() for i in range(8)] + readclientpublickey(client_id, *res) return res + @vectorized_classmethod + def init_secure_socket(cls, client_id, w1, w2, w3, w4, w5, w6, w7, w8): + """ Use 8 register values containing client public key.""" + initsecuresocket(client_id, w1, w2, w3, w4, w5, w6, w7, w8) + + @vectorized_classmethod + def resp_secure_socket(cls, client_id, w1, w2, w3, w4, w5, w6, w7, w8): + """ Receive 8 register values from socket containing client public key.""" + respsecuresocket(client_id, w1, w2, w3, w4, w5, w6, w7, w8) + @vectorize - def write_to_socket(self): - writesocketc(self,0) + def write_to_socket(self, client_id, message_type=ClientMessageType.NoType): + writesocketint(client_id, message_type, self) + + @vectorized_classmethod + def write_to_socket(self, client_id, values, message_type=ClientMessageType.NoType): + """ Send a list of integers to socket """ + writesocketint(client_id, message_type, *values) @vectorize_init def __init__(self, val=None, size=None): @@ -614,7 +685,11 @@ class regint(_register, _int): elif isinstance(val, regint): addint(self, val, regint(0)) else: - raise CompilerError("Cannot convert '%s' to integer" % type(val)) + try: + val.to_regint(self) + except AttributeError: + raise CompilerError("Cannot convert '%s' to integer" % \ + type(val)) @vectorize @read_mem_value @@ -652,10 +727,10 @@ class regint(_register, _int): return self.int_op(other, divint, True) def __mod__(self, other): - return cint(self) % other + return self - (self / other) * other def __rmod__(self, other): - return other % cint(self) + return regint(other) % self def __rpow__(self, other): return other**cint(self) @@ -679,10 +754,16 @@ class regint(_register, _int): return 1 - (self < other) def __lshift__(self, other): - return regint(cint(self) << other) + if isinstance(other, (int, long)): + return self * 2**other + else: + return regint(cint(self) << other) def __rshift__(self, other): - return regint(cint(self) >> other) + if isinstance(other, (int, long)): + return self / 2**other + else: + return regint(cint(self) >> other) def __rlshift__(self, other): return regint(other << cint(self)) @@ -706,6 +787,31 @@ class regint(_register, _int): def mod2m(self, *args, **kwargs): return cint(self).mod2m(*args, **kwargs) + def bit_decompose(self, bit_length=None): + res = [] + x = self + two = regint(2) + for i in range(bit_length or program.bit_length): + y = x / two + res.append(x - two * y) + x = y + return res + + @staticmethod + def bit_compose(bits): + two = regint(2) + res = 0 + for bit in reversed(bits): + res *= two + res += bit + return res + + def reveal(self): + return self + + def print_reg_plain(self): + print_int(self) + class _secret(_register): __slots__ = [] @@ -875,18 +981,54 @@ class sint(_secret, _int): stopinput(player, res) return res + @classmethod + def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType): + """ Securely obtain shares of n values input by a client """ + # send shares of a triple to client + triples = list(itertools.chain(*(sint.get_random_triple() for i in range(n)))) + sint.write_shares_to_socket(client_id, triples, message_type) + + received = cint.read_from_socket(client_id, n) + y = [0] * n + for i in range(n): + y[i] = received[i] - triples[i * 3] + return y + @vectorized_classmethod - def read_from_socket(cls): - res = cls() - readsockets(res,0) - return res + def read_from_socket(cls, client_id, n=1): + """ Receive n shares and MAC shares from socket """ + res = [cls() for i in range(n)] + readsockets(client_id, *res) + if n == 1: + return res[0] + else: + return res @vectorize - def write_to_socket(self): - writesockets(self,0) + def write_to_socket(self, client_id, message_type=ClientMessageType.NoType): + """ Send share and MAC share to socket """ + writesockets(client_id, message_type, self) @vectorized_classmethod - def load_mem(cls, address): + def write_to_socket(self, client_id, values, message_type=ClientMessageType.NoType): + """ Send a list of shares and MAC shares to socket """ + writesockets(client_id, message_type, *values) + + @vectorize + def write_share_to_socket(self, client_id, message_type=ClientMessageType.NoType): + """ Send only share to socket """ + writesocketshare(client_id, message_type, self) + + @vectorized_classmethod + def write_shares_to_socket(cls, client_id, values, message_type=ClientMessageType.NoType, include_macs=False): + """ Send shares of a list of values to a specified client socket """ + if include_macs: + writesockets(client_id, message_type, *values) + else: + writesocketshare(client_id, message_type, *values) + + @vectorized_classmethod + def load_mem(cls, address, mem_type=None): return cls._load_mem(address, ldms, ldmsi) def store_in_mem(self, address): @@ -1035,7 +1177,7 @@ class sgf2n(_secret, _gf2n): return super(sgf2n, self).mul(other) @vectorized_classmethod - def load_mem(cls, address): + def load_mem(cls, address, mem_type=None): return cls._load_mem(address, gldms, gldmsi) def store_in_mem(self, address): @@ -1100,9 +1242,10 @@ class sgf2n(_secret, _gf2n): bit_length = bit_length or program.galois_length random_bits = [self.get_random_bit() \ for i in range(0, bit_length, step)] + one = cgf2n(1) masked = sum([b * (one << (i * step)) for i,b in enumerate(random_bits)], self).reveal() - masked_bits = masked.bit_decompose(bit_length) + masked_bits = masked.bit_decompose(bit_length,step=step) return [m + r for m,r in zip(masked_bits, random_bits)] @vectorize @@ -1456,6 +1599,29 @@ class cfix(_number): res.append(cint.load_mem(address)) return cfix(*res) + @vectorized_classmethod + def read_from_socket(cls, client_id, n=1): + """ Read one or more cfix values from a socket. + Sender will have already bit shifted and sent as cints.""" + cint_input = cint.read_from_socket(client_id, n) + if n == 1: + return cfix(cint_inputs) + else: + return map(cfix, cint_inputs) + + @vectorize + def write_to_socket(self, client_id, message_type=ClientMessageType.NoType): + """ Send cfix to socket. Value is sent as bit shifted cint. """ + writesocketc(client_id, message_type, cint(self.v)) + + @vectorized_classmethod + def write_to_socket(self, client_id, values, message_type=ClientMessageType.NoType): + """ Send a list of cfix values to socket. Values are sent as bit shifted cints. """ + def cfix_to_cint(fix_val): + return cint(fix_val.v) + cint_values = map(cfix_to_cint, values) + writesocketc(client_id, message_type, *cint_values) + @vectorize_init def __init__(self, v=None, size=None): f = self.f @@ -1613,6 +1779,13 @@ class sfix(_number): else: cls.k = k + @classmethod + def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType): + """ Securely obtain shares of n values input by a client. + Assumes client has already run bit shift to convert fixed point to integer.""" + sint_inputs = sint.receive_from_client(n, client_id, ClientMessageType.TripleShares) + return map(sfix, sint_inputs) + @vectorized_classmethod def load_mem(cls, address, mem_type=None): res = [] @@ -1787,7 +1960,7 @@ class sfloat(_number): error = 0 @vectorized_classmethod - def load_mem(cls, address): + def load_mem(cls, address, mem_type=None): res = [] for i in range(4): res.append(sint.load_mem(address + i * get_global_vector_size())) @@ -2075,10 +2248,13 @@ class Array(object): if value_type in _types: value_type = _types[value_type] self.address = address - if address is None: - self.address = program.malloc(length, value_type.reg_type) self.length = length self.value_type = value_type + if address is None: + self.address = self._malloc() + + def _malloc(self): + return program.malloc(self.length, self.value_type) def delete(self): if program: @@ -2106,7 +2282,7 @@ class Array(object): def f(i): res[i] = self[start+i*step] return res - return self.value_type.load_mem(self.get_address(index)) + return self._load(self.get_address(index)) def __setitem__(self, index, value): if isinstance(index, slice): @@ -2117,7 +2293,13 @@ class Array(object): self[i] = value[source_index] source_index.iadd(1) return - self.value_type.conv(value).store_in_mem(self.get_address(index)) + self._store(self.value_type.conv(value), self.get_address(index)) + + def _load(self, address): + return self.value_type.load_mem(address) + + def _store(self, value, address): + value.store_in_mem(address) def __len__(self): return self.length @@ -2149,6 +2331,8 @@ class Array(object): self[i] = mem_value return self +sint.dynamic_array = Array +sgf2n.dynamic_array = Array class Matrix(object): def __init__(self, rows, columns, value_type, address=None): @@ -2309,7 +2493,7 @@ class MemValue(_mem): else: self.value_type = type(value) self.reg_type = self.value_type.reg_type - self.address = program.malloc(1, self.reg_type) + self.address = program.malloc(1, self.value_type) self.deleted = False self.write(value) @@ -2339,7 +2523,7 @@ class MemValue(_mem): if not isinstance(self.register, self.value_type): raise CompilerError('Mismatch in register type, cannot write \ %s to %s' % (type(self.register), self.value_type)) - library.store_in_mem(self.register, self.address) + self.register.store_in_mem(self.address) self.last_write_block = program.curr_block return self diff --git a/Compiler/util.py b/Compiler/util.py index a4f9e3fd..f44c0d71 100644 --- a/Compiler/util.py +++ b/Compiler/util.py @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt import math import operator @@ -54,7 +54,14 @@ def bit_decompose(a, bits): return a.bit_decompose(bits) def bit_compose(bits): - return sum(b << i for i,b in enumerate(bits)) + bits = list(bits) + try: + if bits: + return bits[0].bit_compose(bits) + else: + return 0 + except AttributeError: + return sum(b << i for i,b in enumerate(bits)) def series(a): sum = 0 @@ -103,3 +110,25 @@ OR = or_op def pow2(bits): powers = [b.if_else(2**2**i, 1) for i,b in enumerate(bits)] return tree_reduce(operator.mul, powers) + +def irepeat(l, n): + return reduce(operator.add, ([i] * n for i in l)) + +def int_len(x): + return len(bin(x)) - 2 + +def reveal(x): + if isinstance(x, str): + return x + try: + return x.reveal() + except AttributeError: + pass + try: + return [reveal(y) for y in x] + except TypeError: + pass + return x + +def is_constant(x): + return isinstance(x, (int, long, bool)) diff --git a/Exceptions/Exceptions.h b/Exceptions/Exceptions.h index 3c20a4ba..c96e16d2 100644 --- a/Exceptions/Exceptions.h +++ b/Exceptions/Exceptions.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _Exceptions #define _Exceptions @@ -121,9 +121,39 @@ class file_error: public exception } }; class end_of_file: public exception - { virtual const char* what() const throw() - { return "End of file reached"; } + { string filename, context, ans; + public: + end_of_file(string pfilename="no filename", string pcontext="") : + filename(pfilename), context(pcontext) + { + ans="End of file when reading "; + ans+=filename; + ans+=" "; + ans+=context; + } + ~end_of_file()throw() { } + virtual const char* what() const throw() + { + return ans.c_str(); + } }; +class file_missing: public exception + { string filename, context, ans; + public: + file_missing(string pfilename="no filename", string pcontext="") : + filename(pfilename), context(pcontext) + { + ans="File missing : "; + ans+=filename; + ans+=" "; + ans+=context; + } + ~file_missing()throw() { } + virtual const char* what() const throw() + { + return ans.c_str(); + } + }; class Processor_Error: public exception { string msg; public: @@ -137,6 +167,11 @@ class Processor_Error: public exception return msg.c_str(); } }; +class Invalid_Instruction : public Processor_Error + { + public: + Invalid_Instruction(string m) : Processor_Error(m) {} + }; class max_mod_sz_too_small : public exception { int len; public: diff --git a/ExternalIO/README.md b/ExternalIO/README.md new file mode 100644 index 00000000..07b5b775 --- /dev/null +++ b/ExternalIO/README.md @@ -0,0 +1,105 @@ +(C) 2017 University of Bristol. See License.txt. + +The ExternalIO directory contains examples of managing I/O between external client processes and SPDZ parties running SPDZ engines. These instructions assume that SPDZ has been built as per the [project readme](../README.md). + +## I/O MPC Instructions + +### Connection Setup + +**listen**(*int port_num*) + +Setup a socket server to listen for client connections. Runs in own thread so once created clients will be able to connect in the background. + +*port_num* - the port number to listen on. + +**acceptclientconnection**(*regint client_socket_id*, *int port_num*) + +Picks the first available client socket connection. Blocks if none available. + +*client_socket_id* - an identifier used to refer to the client socket. + +*port_num* - the port number identifies the socket server to accept connections on. + +### Data Exchange + +Only the sint methods are documented here, equivalent methods are available for the other data types **cfix**, **cint** and **regint**. See implementation details in [types.py](../Compiler/types.py). + +*[sint inputs]* **sint.read_from_socket**(*regint client_socket_id*, *int number_of_inputs*) + +Read a share of an input from a client, blocking on the client send. + +*client_socket_id* - an identifier used to refer to the client socket. + +*number_of_inputs* - the number of inputs expected + +*[inputs]* - returned list of shares of private input. + +**sint.write_to_socket**(*regint client_socket_id*, *[sint values]*, *int message_type*) + +Write shares of values including macs to an external client. + +*client_socket_id* - an identifier used to refer to the client socket. + +*[values]* - list of shares of values to send to client. + +*message_type* - optional integer which will be sent in first 4 bytes of message, to indicate message type to client. + +See also sint.write_shares_to_socket where macs can be explicitly included or excluded from the message. + +*[sint inputs]* **sint.receive_from_client**(*int number_of_inputs*, *regint client_socket_id*, *int message_type*) + +Receive shares of private inputs from a client, blocking on client send. This is an abstraction which first sends shares of random values to the client and then receives masked input from the client, using the input protocol introduced in [Confidential Benchmarking based on Multiparty Computation. Damgard et al.](http://eprint.iacr.org/2015/1006.pdf) + +*number_of_inputs* - the number of inputs expected + +*client_socket_id* - an identifier used to refer to the client socket. + +*message_type* - optional integer which will be sent in first 4 bytes of message, to indicate message type to client. + +*[inputs]* - returned list of shares of private input. + + +## Securing communications + +Two cryptographic protocols have been implemented for use in particular applications and are included here for completeness: + +1. Communication security using a Station to Station key agreement and libsodium Secret Box using a nonce counter for message ordering. +2. Authenticated Diffie-Hellman without message ordering. + + Please note these are **NOT** required to allow external client I/O. Your mileage may vary, for example in a web setting TLS may be sufficient to secure communications between processes. + +[client-setup.cpp](../client-setup.cpp) is a utility which is run to generate the key material for both the external clients and SPDZ parties for both protocols. + +#### MPC instructions + +**regint.init_secure_socket**(*regint client_socket_id*, *[regint] public_signing_key*) + +STS protocol initiator. Read a client public key for a specified client connection and negotiate a shared key via STS. All subsequent write_socket / read_socket instructions are encrypted / decrypted for replay resistant commsec. + +*client_socket_id* - an identifier used to refer to the client socket. + +*public_signing_key* - client public key supplied as list of 8 32-bit ints. + +**regint.resp_secure_socket**(*regint client_socket_id*, *[regint] public_signing_key*) + +STS protocol responder. Read a client public key for a specified client connection and negotiate a shared key via STS. All subsequent write_socket / read_socket instructions are encrypted / decrypted for replay resistant commsec. + +*client_socket_id* - an identifier used to refer to the client socket. + +*public_signing_key* - client public key supplied as list of 8 32-bit ints. + +*[regint public_key]* **regint.read_client_public_key**(*regint client_socket_id*) + +Instruction to read the client public key and run setup for the authenticated Diffie-Hellman encryption. All subsequent write_socket instructions are encrypted. Only the sint.read_from_socket instruction is encrypted. + +*client_socket_id* - an identifier used to refer to the client socket. + +*public_key* - client public key made available to mpc programs as list of 8 32-bit ints. + +## Working Examples + +See [bankers-bonus-client.cpp](./bankers-bonus-client.cpp) which acts as a client to [bankers_bonus.mpc](../Programs/Source/bankers_bonus.mpc) and demonstrates sending input and receiving output with no communications security. + +See [bankers-bonus-commsec-client.cpp](./bankers-bonus-commsec-client.cpp) which acts as a client to [bankers_bonus_commsec.mpc](../Programs/Source/bankers_bonus_commsec.mpc) which runs the same algorithm but includes both the available crypto protocols. + +More instructions on how to run these are provided in the *-client files. diff --git a/ExternalIO/bankers-bonus-client.cpp b/ExternalIO/bankers-bonus-client.cpp new file mode 100644 index 00000000..661d24f8 --- /dev/null +++ b/ExternalIO/bankers-bonus-client.cpp @@ -0,0 +1,198 @@ +/* + * (C) 2017 University of Bristol. See License.txt + * + * Demonstrate external client inputing and receiving outputs from a SPDZ process, + * following the protocol described in https://eprint.iacr.org/2015/1006.pdf. + * + * Provides a client to bankers_bonus.mpc program to calculate which banker pays for lunch based on + * the private value annual bonus. Up to 8 clients can connect to the SPDZ engines running + * the bankers_bonus.mpc program. + * + * Each connecting client: + * - sends a unique id to identify the client + * - sends an integer input (bonus value to compare) + * - sends an integer (0 meaining more players will join this round or 1 meaning stop the round and calc the result). + * + * The result is returned authenticated with a share of a random value: + * - share of winning unique id [y] + * - share of random value [r] + * - share of winning unique id * random value [w] + * winning unique id is valid if ∑ [y] * ∑ [r] = ∑ [w] + * + * No communications security is used. + * + * To run with 2 parties / SPDZ engines: + * ./Scripts/setup-online.sh to create triple shares for each party (spdz engine). + * ./compile.py bankers_bonus + * ./Scripts/run-online bankers_bonus to run the engines. + * + * ./bankers-bonus-client.x 123 2 100 0 + * ./bankers-bonus-client.x 456 2 200 0 + * ./bankers-bonus-client.x 789 2 50 1 + * + * Expect winner to be second client with id 456. + */ + +#include "Math/gfp.h" +#include "Math/gf2n.h" +#include "Networking/sockets.h" +#include "Tools/int.h" +#include "Math/Setup.h" +#include "Auth/fake-stuff.h" + +#include +#include +#include +#include + +// Send the private inputs masked with a random value. +// Receive shares of a preprocessed triple from each SPDZ engine, combine and check the triples are valid. +// Add the private input value to triple[0] and send to each spdz engine. +void send_private_inputs(vector& values, vector& sockets, int nparties) +{ + int num_inputs = values.size(); + octetStream os; + vector< vector > triples(num_inputs, vector(3)); + vector triple_shares(3); + + // Receive num_inputs triples from SPDZ + for (int j = 0; j < nparties; j++) + { + os.reset_write_head(); + os.Receive(sockets[j]); + + for (int j = 0; j < num_inputs; j++) + { + for (int k = 0; k < 3; k++) + { + triple_shares[k].unpack(os); + triples[j][k] += triple_shares[k]; + } + } + } + + // Check triple relations (is a party cheating?) + for (int i = 0; i < num_inputs; i++) + { + if (triples[i][0] * triples[i][1] != triples[i][2]) + { + cerr << "Incorrect triple at " << i << ", aborting\n"; + exit(1); + } + } + // Send inputs + triple[0], so SPDZ can compute shares of each value + os.reset_write_head(); + for (int i = 0; i < num_inputs; i++) + { + gfp y = values[i] + triples[i][0]; + y.pack(os); + } + for (int j = 0; j < nparties; j++) + os.Send(sockets[j]); +} + +// Assumes that Scripts/setup-online.sh has been run to compute prime +void initialise_fields(const string& dir_prefix) +{ + int lg2; + bigint p; + + string filename = dir_prefix + "Params-Data"; + cout << "loading params from: " << filename << endl; + + ifstream inpf(filename.c_str()); + if (inpf.fail()) { throw file_error(filename.c_str()); } + inpf >> p; + inpf >> lg2; + + inpf.close(); + + gfp::init_field(p); + gf2n::init_field(lg2); +} + + +// Receive shares of the result and sum together. +// Also receive authenticating values. +gfp receive_result(vector& sockets, int nparties) +{ + vector output_values(3); + octetStream os; + for (int i = 0; i < nparties; i++) + { + os.reset_write_head(); + os.Receive(sockets[i]); + for (unsigned int j = 0; j < 3; j++) + { + gfp value; + value.unpack(os); + output_values[j] += value; + } + } + + if (output_values[0] * output_values[1] != output_values[2]) + { + cerr << "Unable to authenticate output value as correct, aborting." << endl; + exit(1); + } + return output_values[0]; +} + + +int main(int argc, char** argv) +{ + int my_client_id; + int nparties; + int salary_value; + int finish; + int port_base = 14000; + string host_name = "localhost"; + + if (argc < 5) { + cout << "Usage is bankers-bonus-client " + << " " + << "" << endl; + exit(0); + } + + my_client_id = atoi(argv[1]); + nparties = atoi(argv[2]); + salary_value = atoi(argv[3]); + finish = atoi(argv[4]); + if (argc > 5) + host_name = argv[5]; + if (argc > 6) + port_base = atoi(argv[6]); + + // init static gfp + string prep_data_prefix = get_prep_dir(nparties, 128, 40); + initialise_fields(prep_data_prefix); + + // Setup connections from this client to each party socket + vector sockets(nparties); + for (int i = 0; i < nparties; i++) + { + set_up_client_socket(sockets[i], host_name.c_str(), port_base + i); + } + cout << "Finish setup socket connections to SPDZ engines." << endl; + + // Map inputs into gfp + vector input_values_gfp(3); + input_values_gfp[0].assign(my_client_id); + input_values_gfp[1].assign(salary_value); + input_values_gfp[2].assign(finish); + + // Run the commputation + send_private_inputs(input_values_gfp, sockets, nparties); + cout << "Sent private inputs to each SPDZ engine, waiting for result..." << endl; + + // Get the result back (client_id of winning client) + gfp result = receive_result(sockets, nparties); + + cout << "Winning client id is : " << result << endl; + + for (unsigned int i = 0; i < sockets.size(); i++) + close_client_socket(sockets[i]); + + return 0; +} diff --git a/ExternalIO/bankers-bonus-commsec-client.cpp b/ExternalIO/bankers-bonus-commsec-client.cpp new file mode 100644 index 00000000..9b99ad1e --- /dev/null +++ b/ExternalIO/bankers-bonus-commsec-client.cpp @@ -0,0 +1,407 @@ +/* + * (C) 2017 University of Bristol. See License.txt + * + * Demonstrate external client inputing and receiving outputs from a SPDZ process, + * following the protocol described in https://eprint.iacr.org/2015/1006.pdf. + * Uses SPDZ implemented encryption for external client communication, see bankers-bonus-client.cpp + * for a simpler client with no crypto. + * + * Provides a client to bankers_bonus_commsec.mpc program to calculate which banker pays for lunch based on + * the private value annual bonus. Up to 8 clients can connect to the SPDZ engines running + * the bankers_bonus.mpc program. + * + * Each connecting client: + * - runs crypto setup to demonstrate both DH Auth Encryption and STS protocol for comms security. + * - sends a unique id to identify the client + * - sends an integer input (bonus value to compare) + * - sends an integer (0 meaining more players will join this round or 1 meaning stop the round and calc the result). + * + * The result is returned authenticated with a share of a random value: + * - share of winning unique id [y] + * - share of random value [r] + * - share of winning unique id * random value [w] + * winning unique id is valid if ∑ [y] * ∑ [r] = ∑ [w] + * + * To run with 2 parties (SPDZ engines) and 3 external clients: + * ./Scripts/setup-online.sh to create triple shares for each party (spdz engine). + * ./client-setup.x 2 -nc 3 to create the crypto key material for both parties and clients. + * ./compile.py bankers_bonus_commsec + * ./Scripts/run-online bankers_bonus_commsec to run the engines. + * + * ./bankers-bonus-commsec-client.x 0 2 100 0 + * ./bankers-bonus-commsec-client.x 1 2 200 0 + * ./bankers-bonus-commsec-client.x 2 2 50 1 + * + * Expect winner to be second client with id 1. + * Note here client id must match id used in generating client key material, Client-Keys-C. + */ + +#include "Math/gfp.h" +#include "Math/gf2n.h" +#include "Networking/sockets.h" +#include "Networking/STS.h" +#include "Tools/int.h" +#include "Math/Setup.h" +#include "Auth/fake-stuff.h" + +#include +#include +#include +#include +#include + +typedef pair< vector, vector > keypair_t; // A pair of send/recv keys for talking to SPDZ +typedef vector< keypair_t > commsec_t; // A database of send/recv keys indexed by server +typedef struct { + unsigned char client_secretkey[crypto_sign_SECRETKEYBYTES]; + unsigned char client_publickey[crypto_sign_PUBLICKEYBYTES]; + vector client_publickey_ints; + vector< vector >server_publickey; +} sign_key_container_t; + +keypair_t sts_response_role_exceptions(sign_key_container_t keys, vector& sockets, int server_id) +{ + STS ke(&keys.server_publickey[server_id][0], keys.client_publickey, keys.client_secretkey); + sts_msg1_t m1; + sts_msg2_t m2; + sts_msg3_t m3; + octetStream os; + + os.Receive(sockets[server_id]); + os.consume(m1.bytes, sizeof m1.bytes); + m2 = ke.recv_msg1(m1); + os.reset_write_head(); + os.append(m2.pubkey, sizeof m2.pubkey); + os.append(m2.sig, sizeof m2.sig); + os.Send(sockets[server_id]); + os.Receive(sockets[server_id]); + os.consume(m3.bytes, sizeof m3.bytes); + ke.recv_msg3(m3); + vector recvKey = ke.derive_secret(crypto_secretbox_KEYBYTES); + vector sendKey = ke.derive_secret(crypto_secretbox_KEYBYTES); + return make_pair(sendKey,recvKey); +} + +keypair_t sts_initiator_role_exceptions(sign_key_container_t keys, vector& sockets, int server_id) +{ + STS ke(&keys.server_publickey[server_id][0], keys.client_publickey, keys.client_secretkey); + sts_msg1_t m1; + sts_msg2_t m2; + sts_msg3_t m3; + octetStream os; + + m1 = ke.send_msg1(); + cout << "m1: "; + for (unsigned int j = 0; j < 32; j++) + cout << setfill('0') << setw(2) << hex << (int) m1.bytes[j]; + cout << dec << endl; + os.reset_write_head(); + os.append(m1.bytes, sizeof m1.bytes); + os.Send(sockets[server_id]); + + os.reset_write_head(); + os.Receive(sockets[server_id]); + os.consume(m2.pubkey, sizeof m2.pubkey); + os.consume(m2.sig, sizeof m2.sig); + m3 = ke.recv_msg2(m2); + + os.reset_write_head(); + os.append(m3.bytes, sizeof m3.bytes); + os.Send(sockets[server_id]); + + vector sendKey = ke.derive_secret(crypto_secretbox_KEYBYTES); + vector recvKey = ke.derive_secret(crypto_secretbox_KEYBYTES); + return make_pair(sendKey,recvKey); +} + +pair< vector, vector > sts_response_role(sign_key_container_t keys, vector& sockets, int server_id) +{ + pair< vector, vector > res; + try { + res = sts_response_role_exceptions(keys, sockets, server_id); + } catch(char const *e) { + cerr << "Error in STS: " << e << endl; + exit(1); + } + return res; +} + +pair< vector, vector > sts_initiator_role(sign_key_container_t keys, vector& sockets, int server_id) +{ + pair< vector, vector > res; + try { + res = sts_initiator_role_exceptions(keys, sockets, server_id); + } catch(char const *e) { + cerr << "Error in STS: " << e << endl; + exit(1); + } + return res; +} + +// Send the private inputs masked with a random value. +// Receive shares of a preprocessed triple from each SPDZ engine, combine and check the triples are valid. +// Add the private input value to triple[0] and send to each spdz engine. +void send_private_inputs(vector& values, vector& sockets, int nparties, + commsec_t commsec, vector& keys) +{ + int num_inputs = values.size(); + octetStream os; + vector< vector > triples(num_inputs, vector(3)); + vector triple_shares(3); + + // Receive num_inputs triples from SPDZ + for (int j = 0; j < nparties; j++) + { + os.reset_write_head(); + os.Receive(sockets[j]); + os.decrypt_sequence(&commsec[j].second[0],0); + os.decrypt(keys[j]); + + for (int j = 0; j < num_inputs; j++) + { + for (int k = 0; k < 3; k++) + { + triple_shares[k].unpack(os); + triples[j][k] += triple_shares[k]; + } + } + } + // Check triple relations + for (int i = 0; i < num_inputs; i++) + { + if (triples[i][0] * triples[i][1] != triples[i][2]) + { + cerr << "Incorrect triple at " << i << ", aborting\n"; + exit(1); + } + } + // Send inputs + triple[0], so SPDZ can compute shares of each value + os.reset_write_head(); + for (int i = 0; i < num_inputs; i++) + { + gfp y = values[i] + triples[i][0]; + y.pack(os); + } + for (int j = 0; j < nparties; j++) { + os.encrypt_sequence(&commsec[j].first[0],0); + os.Send(sockets[j]); + } +} + +// Send public key in clear to each SPDZ engine. +void send_public_key(vector& pubkey, int socket) +{ + octetStream os; + os.reset_write_head(); + + for (unsigned int i = 0; i < pubkey.size(); i++) + { + os.store(pubkey[i]); + } + + os.Send(socket); +} + +// Assumes that Scripts/setup-online.sh has been run to compute prime +void initialise_fields(const string& dir_prefix) +{ + int lg2; + bigint p; + + string filename = dir_prefix + "Params-Data"; + cout << "loading params from: " << filename << endl; + + ifstream inpf(filename.c_str()); + if (inpf.fail()) { throw file_error(filename.c_str()); } + inpf >> p; + inpf >> lg2; + + inpf.close(); + + gfp::init_field(p); + gf2n::init_field(lg2); +} + +// Assumes that client-setup has been run to create key pairs for clients and parties +void generate_symmetric_keys(vector& keys, vector& client_public_key_ints, + sign_key_container_t *sts_key, const string& dir_prefix, int client_no) +{ + unsigned char client_publickey[crypto_box_PUBLICKEYBYTES]; + unsigned char client_secretkey[crypto_box_SECRETKEYBYTES]; + unsigned char server_publickey[crypto_box_PUBLICKEYBYTES]; + unsigned char scalarmult_q[crypto_scalarmult_BYTES]; + crypto_generichash_state h; + + // read client public/secret keys + SPDZ server public keys + ifstream keyfile; + stringstream client_filename; + client_filename << dir_prefix << "Client-Keys-C" << client_no; + keyfile.open(client_filename.str().c_str()); + if (keyfile.fail()) + throw file_error(client_filename.str()); + keyfile.read((char*)client_publickey, sizeof client_publickey); + if (keyfile.eof()) + throw end_of_file(client_filename.str(), "client public key" ); + + // Convert client public key unsigned char to int, reverse endianness + for(unsigned int j = 0; j < client_public_key_ints.size(); j++) { + int keybyte = 0; + for(unsigned int k = 0; k < 4; k++) { + keybyte = keybyte + (((int)client_publickey[j*4+k]) << ((3-k) * 8)); + } + client_public_key_ints[j] = keybyte; + } + + keyfile.read((char*)client_secretkey, sizeof client_secretkey); + if (keyfile.eof()) { + throw end_of_file(client_filename.str(), "client private key" ); + } + + keyfile.read((char*)sts_key->client_publickey, crypto_sign_PUBLICKEYBYTES); + keyfile.read((char*)sts_key->client_secretkey, crypto_sign_SECRETKEYBYTES); + // Convert client public key unsigned char to int, reverse endianness + sts_key->client_publickey_ints.resize(8); + for(unsigned int j = 0; j < sts_key->client_publickey_ints.size(); j++) { + int keybyte = 0; + for(unsigned int k = 0; k < 4; k++) { + keybyte = keybyte + (((int)sts_key->client_publickey[j*4+k]) << ((3-k) * 8)); + } + sts_key->client_publickey_ints[j] = keybyte; + } + + for (unsigned int i = 0; i < keys.size(); i++) + { + keys[i] = new octet[crypto_generichash_BYTES]; + keyfile.read((char*)server_publickey, crypto_box_PUBLICKEYBYTES); + if (keyfile.eof()) + throw end_of_file(client_filename.str(), "server public key for party " + i); + keyfile.read((char*)(&sts_key->server_publickey[i][0]), crypto_sign_PUBLICKEYBYTES); + if (keyfile.eof()) + throw end_of_file(client_filename.str(), "server public signing key for party " + i); + + // Derive a shared key from this server's secret key and the client's public key + // shared key = h(q || client_secretkey || server_publickey) + if (crypto_scalarmult(scalarmult_q, client_secretkey, server_publickey) != 0) { + cerr << "Scalar mult failed\n"; + exit(1); + } + crypto_generichash_init(&h, NULL, 0U, crypto_generichash_BYTES); + crypto_generichash_update(&h, scalarmult_q, sizeof scalarmult_q); + crypto_generichash_update(&h, client_publickey, sizeof client_publickey); + crypto_generichash_update(&h, server_publickey, sizeof server_publickey); + crypto_generichash_final(&h, keys[i], crypto_generichash_BYTES); + } + keyfile.close(); + + cout << "My public key is: "; + for (unsigned int j = 0; j < 32; j++) + cout << setfill('0') << setw(2) << hex << (int) client_publickey[j]; + cout << dec << endl; +} + + +// Receive shares of the result and sum together. +// Also receive authenticating values. +gfp receive_result(vector& sockets, int nparties, commsec_t commsec, vector& keys) +{ + vector output_values(3); + octetStream os; + for (int i = 0; i < nparties; i++) + { + os.reset_write_head(); + os.Receive(sockets[i]); + + os.decrypt_sequence(&commsec[i].second[0],1); + os.decrypt(keys[i]); + + for (unsigned int j = 0; j < 3; j++) + { + gfp value; + value.unpack(os); + output_values[j] += value; + } + } + + if (output_values[0] * output_values[1] != output_values[2]) + { + cerr << "Unable to authenticate output value as correct, aborting." << endl; + exit(1); + } + return output_values[0]; +} + + +int main(int argc, char** argv) +{ + int my_client_id; + int nparties; + int salary_value; + int finish; + int port_base = 14000; + sign_key_container_t sts_key; + string host_name = "localhost"; + + if (argc < 5) { + cout << "Usage is external-client " + << " " + << "" << endl; + exit(0); + } + + my_client_id = atoi(argv[1]); + nparties = atoi(argv[2]); + salary_value = atoi(argv[3]); + finish = atoi(argv[4]); + if (argc > 5) + host_name = argv[5]; + if (argc > 6) + port_base = atoi(argv[6]); + + sts_key.server_publickey.resize(nparties); + for(int i = 0 ; i < nparties; i++) { + sts_key.server_publickey[i].resize(crypto_sign_PUBLICKEYBYTES); + } + + // init static gfp + string prep_data_prefix = get_prep_dir(nparties, 128, 40); + initialise_fields(prep_data_prefix); + + // Generate session keys to decrypt data sent from each spdz engine (party) + vector session_keys(nparties); + vector client_public_key_ints(8); + + generate_symmetric_keys(session_keys, client_public_key_ints, &sts_key, prep_data_prefix, my_client_id); + + // Setup connections from this client to each party socket and send the client public keys + vector sockets(nparties); + // vector< pair , vector > > commseckey(nparties); + commsec_t commseckey(nparties); + for (int i = 0; i < nparties; i++) + { + set_up_client_socket(sockets[i], host_name.c_str(), port_base + i); + send_public_key(sts_key.client_publickey_ints, sockets[i]); + send_public_key(client_public_key_ints, sockets[i]); + commseckey[i] = sts_initiator_role(sts_key, sockets, i); + } + cout << "Finish setup socket connections to SPDZ engines." << endl; + + // Map inputs into gfp + vector input_values_gfp(3); + input_values_gfp[0].assign(my_client_id); + input_values_gfp[1].assign(salary_value); + input_values_gfp[2].assign(finish); + + // Send the inputs to the SPDZ Engines + send_private_inputs(input_values_gfp, sockets, nparties, commseckey, session_keys); + cout << "Sent private inputs to each SPDZ engine, waiting for result..." << endl; + + // Get the result back + gfp result = receive_result(sockets, nparties, commseckey, session_keys); + + cout << "Winning client id is : " << result << endl; + + for (unsigned int i = 0; i < sockets.size(); i++) + close_client_socket(sockets[i]); + + return 0; +} diff --git a/Fake-Offline.cpp b/Fake-Offline.cpp index 750598d1..cfac46ec 100644 --- a/Fake-Offline.cpp +++ b/Fake-Offline.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Math/gf2n.h" @@ -490,8 +490,6 @@ int main(int argc, const char** argv) bigint p; generate_online_setup(outf, prep_data_prefix, p, lgp, lg2); - generate_keys(prep_data_prefix, nplayers); - /* Find number players and MAC keys etc*/ gfp keyp,pp; keyp.assign_zero(); gf2n key2,p2; key2.assign_zero(); diff --git a/License.txt b/License.txt index e30817d3..ed66199f 100644 --- a/License.txt +++ b/License.txt @@ -1,6 +1,6 @@ University of Bristol : Open Access Software Licence -Copyright (c) 2016, The University of Bristol, a chartered corporation having Royal Charter number RC000648 and a charity (number X1121) and its place of administration being at Senate House, Tyndall Avenue, Bristol, BS8 1TH, United Kingdom. +Copyright (c) 2017, The University of Bristol, a chartered corporation having Royal Charter number RC000648 and a charity (number X1121) and its place of administration being at Senate House, Tyndall Avenue, Bristol, BS8 1TH, United Kingdom. All rights reserved diff --git a/Makefile b/Makefile index eff45dac..2b89126d 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt include CONFIG @@ -26,7 +26,7 @@ LIB = libSPDZ.a LIBSIMPLEOT = SimpleOT/libsimpleot.a -all: gen_input online offline +all: gen_input online offline externalIO online: Fake-Offline.x Server.x Player-Online.x Check-Offline.x @@ -34,6 +34,8 @@ offline: $(OT_EXE) Check-Offline.x gen_input: gen_input_f2n.x gen_input_fp.x +externalIO: client-setup.x bankers-bonus-client.x bankers-bonus-commsec-client.x + Fake-Offline.x: Fake-Offline.cpp $(COMMON) $(PROCESSOR) $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) @@ -69,7 +71,14 @@ gen_input_f2n.x: Scripts/gen_input_f2n.cpp $(COMMON) gen_input_fp.x: Scripts/gen_input_fp.cpp $(COMMON) $(CXX) $(CFLAGS) Scripts/gen_input_fp.cpp -o gen_input_fp.x $(COMMON) $(LDLIBS) +client-setup.x: client-setup.cpp $(COMMON) $(PROCESSOR) + $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) + +bankers-bonus-client.x: ExternalIO/bankers-bonus-client.cpp $(COMMON) $(PROCESSOR) + $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) + +bankers-bonus-commsec-client.x: ExternalIO/bankers-bonus-commsec-client.cpp $(COMMON) $(PROCESSOR) + $(CXX) $(CFLAGS) -o $@ $^ $(LDLIBS) clean: - -rm */*.o *.o *.x core.* *.a gmon.out - + -rm */*.o *.o */*.d *.d *.x core.* *.a gmon.out diff --git a/Math/Integer.cpp b/Math/Integer.cpp index b6dc06e7..184952d9 100644 --- a/Math/Integer.cpp +++ b/Math/Integer.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Integer.cpp diff --git a/Math/Integer.h b/Math/Integer.h index b1730b34..89e00571 100644 --- a/Math/Integer.h +++ b/Math/Integer.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Integer.h @@ -15,10 +15,13 @@ using namespace std; class Integer { +protected: long a; public: + static string type_string() { return "integer"; } + Integer() { a = 0; } Integer(long a) : a(a) {} @@ -26,6 +29,31 @@ class Integer void assign_zero() { a = 0; } + long operator+(const Integer& other) const { return a + other.a; } + long operator-(const Integer& other) const { return a - other.a; } + long operator*(const Integer& other) const { return a * other.a; } + long operator/(const Integer& other) const { return a / other.a; } + + long operator>>(const Integer& other) const { return a >> other.a; } + long operator<<(const Integer& other) const { return a << other.a; } + + long operator^(const Integer& other) const { return a ^ other.a; } + long operator&(const Integer& other) const { return a ^ other.a; } + long operator|(const Integer& other) const { return a ^ other.a; } + + bool operator==(const Integer& other) const { return a == other.a; } + bool operator!=(const Integer& other) const { return a != other.a; } + bool operator<(const Integer& other) const { return a < other.a; } + bool operator<=(const Integer& other) const { return a <= other.a; } + bool operator>(const Integer& other) const { return a > other.a; } + bool operator>=(const Integer& other) const { return a >= other.a; } + + long operator^=(const Integer& other) { return a ^= other.a; } + + friend unsigned int& operator+=(unsigned int& x, const Integer& other) { return x += other.a; } + + friend ostream& operator<<(ostream& s, const Integer& x) { x.output(s, true); return s; } + void output(ostream& s,bool human) const; void input(istream& s,bool human); diff --git a/Math/Setup.cpp b/Math/Setup.cpp index 29c8002d..c13704b1 100644 --- a/Math/Setup.cpp +++ b/Math/Setup.cpp @@ -1,5 +1,5 @@ -// (C) 2016 University of Bristol. See License.txt - +// (C) 2017 University of Bristol. See License.txt + #include "Math/Setup.h" #include "Math/gfp.h" @@ -111,8 +111,8 @@ void generate_online_setup(ofstream& outf, string dirname, bigint& p, int lgp, i } string get_prep_dir(int nparties, int lg2p, int gf2ndegree) -{ - if (gf2ndegree == 0) +{ + if (gf2ndegree == 0) gf2ndegree = gf2n::default_length(); stringstream ss; ss << PREP_DIR << nparties << "-" << lg2p << "-" << gf2ndegree << "/"; diff --git a/Math/Setup.h b/Math/Setup.h index a3813f1d..db82a511 100644 --- a/Math/Setup.h +++ b/Math/Setup.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Setup.h diff --git a/Math/Share.cpp b/Math/Share.cpp index 065c69a8..57b138a7 100644 --- a/Math/Share.cpp +++ b/Math/Share.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Share.h" @@ -99,6 +99,23 @@ T combine(const vector< Share >& S) } + + +template +inline void Share::pack(octetStream& os) const +{ + a.pack(os); + mac.pack(os); +} + +template +inline void Share::unpack(octetStream& os) +{ + a.unpack(os); + mac.unpack(os); +} + + template bool check_macs(const vector< Share >& S,const T& key) { diff --git a/Math/Share.h b/Math/Share.h index 95382c11..0320029c 100644 --- a/Math/Share.h +++ b/Math/Share.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _Share @@ -69,6 +69,19 @@ class Share void sub(const Share& S1,const Share& S2); void add(const Share& S1) { add(*this,S1); } + Share operator+(const Share& x) const + { Share res; res.add(*this, x); return res; } + template + Share operator*(const U& x) const + { Share res; res.mul(*this, x); return res; } + + Share& operator+=(const Share& x) { add(x); return *this; } + template + Share& operator*=(const U& x) { mul(*this, x); return *this; } + + Share operator<<(int i) { return this->operator*(T(1) << i); } + Share& operator<<=(int i) { return *this = *this << i; } + // Input and output from a stream // - Can do in human or machine only format (later should be faster) void output(ostream& s,bool human) const @@ -80,6 +93,11 @@ class Share mac.input(s,human); } + friend ostream& operator<<(ostream& s, const Share& x) { x.output(s, true); return s; } + + void pack(octetStream& os) const; + void unpack(octetStream& os); + /* Takes a vector of shares, one from each player and * determines the shared value * - i.e. Partially open the shares diff --git a/Math/Zp_Data.cpp b/Math/Zp_Data.cpp index d5ef56cc..7b43591e 100644 --- a/Math/Zp_Data.cpp +++ b/Math/Zp_Data.cpp @@ -1,5 +1,5 @@ -// (C) 2016 University of Bristol. See License.txt - +// (C) 2017 University of Bristol. See License.txt + #include "Zp_Data.h" diff --git a/Math/Zp_Data.h b/Math/Zp_Data.h index 0ddb14e9..954d0d4f 100644 --- a/Math/Zp_Data.h +++ b/Math/Zp_Data.h @@ -1,5 +1,5 @@ -// (C) 2016 University of Bristol. See License.txt - +// (C) 2017 University of Bristol. See License.txt + #ifndef _Zp_Data #define _Zp_Data diff --git a/Math/bigint.cpp b/Math/bigint.cpp index 238616da..a6b38a96 100644 --- a/Math/bigint.cpp +++ b/Math/bigint.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "bigint.h" diff --git a/Math/bigint.h b/Math/bigint.h index 99c6f407..364f8b8c 100644 --- a/Math/bigint.h +++ b/Math/bigint.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _bigint #define _bigint diff --git a/Math/field_types.h b/Math/field_types.h index f1f9fb6d..bfeb519c 100644 --- a/Math/field_types.h +++ b/Math/field_types.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * types.h diff --git a/Math/gf2n.cpp b/Math/gf2n.cpp index 967d0629..1e8a7698 100644 --- a/Math/gf2n.cpp +++ b/Math/gf2n.cpp @@ -1,5 +1,5 @@ -// (C) 2016 University of Bristol. See License.txt - +// (C) 2017 University of Bristol. See License.txt + #include "Math/gf2n.h" @@ -58,12 +58,12 @@ void gf2n_short::init_tables() void gf2n_short::init_field(int nn) { - if (nn == 0) - { - nn = default_length(); - cerr << "Using GF(2^" << nn << ")" << endl; - } - + if (nn == 0) + { + nn = default_length(); + cerr << "Using GF(2^" << nn << ")" << endl; + } + gf2n_short::init_tables(); int i,j=-1; for (i=0; i(aa)&mask; } void assign(const char* buffer) { a = *(word*)buffer; } @@ -93,8 +94,10 @@ class gf2n_short } gf2n_short() { a=0; } - gf2n_short(const gf2n_short& g) { assign(g); } - gf2n_short(int g) { assign(g); } + gf2n_short(word a) { assign(a); } + gf2n_short(long a) { assign(a); } + gf2n_short(int a) { assign(a); } + gf2n_short(const char* a) { assign(a); } ~gf2n_short() { ; } gf2n_short& operator=(const gf2n_short& g) @@ -167,7 +170,7 @@ class gf2n_short void input(istream& s,bool human); friend ostream& operator<<(ostream& s,const gf2n_short& x) - { s << hex << "0x" << x.a << dec; + { s << hex << showbase << x.a << dec; return s; } friend istream& operator>>(istream& s,gf2n_short& x) diff --git a/Math/gf2nlong.cpp b/Math/gf2nlong.cpp index 84972667..f465911c 100644 --- a/Math/gf2nlong.cpp +++ b/Math/gf2nlong.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * gf2n_longlong.cpp diff --git a/Math/gf2nlong.h b/Math/gf2nlong.h index 4211d8d7..cd935064 100644 --- a/Math/gf2nlong.h +++ b/Math/gf2nlong.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * gf2nlong.h diff --git a/Math/gfp.cpp b/Math/gfp.cpp index da20448d..f0c8baa8 100644 --- a/Math/gfp.cpp +++ b/Math/gfp.cpp @@ -1,5 +1,5 @@ -// (C) 2016 University of Bristol. See License.txt - +// (C) 2017 University of Bristol. See License.txt + #include "Math/gfp.h" @@ -71,10 +71,15 @@ void gfp::SHL(const gfp& x,int n) { if (!x.is_zero()) { - bigint bi; - to_bigint(bi,x,false); - mpn_lshift(bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_size,n); - to_gfp(*this, bi); + if (n != 0) + { + bigint bi; + to_bigint(bi,x,false); + mpn_lshift(bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_size,n); + to_gfp(*this, bi); + } + else + assign(x); } else { @@ -87,10 +92,15 @@ void gfp::SHR(const gfp& x,int n) { if (!x.is_zero()) { - bigint bi; - to_bigint(bi,x); - mpn_rshift(bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_size,n); - to_gfp(*this, bi); + if (n != 0) + { + bigint bi; + to_bigint(bi,x); + mpn_rshift(bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_d, bi.get_mpz_t()->_mp_size,n); + to_gfp(*this, bi); + } + else + assign(x); } else { diff --git a/Math/gfp.h b/Math/gfp.h index 5ca3afb5..c5e8d945 100644 --- a/Math/gfp.h +++ b/Math/gfp.h @@ -1,5 +1,5 @@ -// (C) 2016 University of Bristol. See License.txt - +// (C) 2017 University of Bristol. See License.txt + #ifndef _gfp #define _gfp @@ -93,7 +93,7 @@ class gfp bool is_zero() const { return isZero(a,ZpD); } - bool is_one() const { return isOne(a,ZpD); } + bool is_one() const { return isOne(a,ZpD); } bool is_bit() const { return is_zero() or is_one(); } bool equal(const gfp& y) const { return areEqual(a,y.a,ZpD); } bool operator==(const gfp& y) const { return equal(y); } diff --git a/Math/modp.cpp b/Math/modp.cpp index 033edda3..604ce363 100644 --- a/Math/modp.cpp +++ b/Math/modp.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Zp_Data.h" #include "modp.h" diff --git a/Math/modp.h b/Math/modp.h index b384c3a1..bbab5991 100644 --- a/Math/modp.h +++ b/Math/modp.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _Modp #define _Modp diff --git a/Math/operators.h b/Math/operators.h index d4d34b4a..b2910cb3 100644 --- a/Math/operators.h +++ b/Math/operators.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * operations.h @@ -17,15 +17,15 @@ T& operator*=(const T& y, const bool& x) { y = x ? y : T(); return y; } template T operator+(const T& x, const U& y) { T res; res.add(x, y); return res; } -template -T operator*(const T& x, const U& y) { T res; res.mul(x, y); return res; } +template +T operator*(const T& x, const T& y) { T res; res.mul(x, y); return res; } template T operator-(const T& x, const U& y) { T res; res.sub(x, y); return res; } template T& operator+=(T& x, const U& y) { x.add(y); return x; } -template -T& operator*=(T& x, const U& y) { x.mul(y); return x; } +template +T& operator*=(T& x, const T& y) { x.mul(y); return x; } template T& operator-=(T& x, const U& y) { x.sub(y); return x; } diff --git a/Networking/Player.cpp b/Networking/Player.cpp index 2325cec5..fd081742 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -1,21 +1,37 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Player.h" #include "Exceptions/Exceptions.h" +#include "Networking/STS.h" #include +#include -// Use printf rather than cout so valgrind can detect thread issues +using namespace std; + +CommsecKeysPackage::CommsecKeysPackage(vector playerpubs, + secret_signing_key mypriv, + public_signing_key mypub) +{ + player_public_keys = playerpubs; + my_secret_key = mypriv; + my_public_key = mypub; +} void Names::init(int player,int pnb,const char* servername) -{ +{ player_no=player; portnum_base=pnb; setup_names(servername); + keys = NULL; setup_server(); } +void Names::init(int player,int pnb,vector Nms) +{ + init(player, pnb, Nms); +} void Names::init(int player,int pnb,vector Nms) { @@ -23,18 +39,10 @@ void Names::init(int player,int pnb,vector Nms) portnum_base=pnb; nplayers=Nms.size(); names.resize(nplayers); - for (int i=0; i Nms) -{ - player_no=player; - portnum_base=pnb; - nplayers=Nms.size(); - names=Nms; + for (int i=0; ikeys = keys; +} + void Names::setup_names(const char *servername) { int socket_num; @@ -79,11 +93,16 @@ void Names::setup_names(const char *servername) // Send my name octet my_name[512]; memset(my_name,0,512*sizeof(octet)); - gethostname((char*)my_name,512); + sockaddr_in address; + socklen_t size = sizeof address; + getsockname(socket_num, (sockaddr*)&address, &size); + char* name = inet_ntoa(address.sin_addr); + // max length of IP address with ending 0 + strncpy((char*)my_name, name, 16); fprintf(stderr, "My Name = %s\n",my_name); send(socket_num,my_name,512); cerr << "My number = " << player_no << endl; - + // Now get the set of names int i; receive(socket_num,nplayers); @@ -102,6 +121,7 @@ void Names::setup_names(const char *servername) void Names::setup_server() { server = new ServerSocket(portnum_base + player_no); + server->init(); } @@ -113,6 +133,7 @@ Names::Names(const Names& other) nplayers = other.nplayers; portnum_base = other.portnum_base; names = other.names; + keys = NULL; server = 0; } @@ -135,7 +156,7 @@ Player::Player(const Names& Nms, int id) : PlayerBase(Nms), send_to_self_socket( Player::~Player() -{ +{ /* Close down the sockets */ for (int i=0; i& names,int portnum_base,int id_base,ServerSocket& server) { - sockets.resize(nplayers); - // Set up the client side - for (int i=player_no; i& o,bool donthash) const { for (int i=0; iplayer_no) { o[player_no].Send(sockets[i]); } - else if (iplayer_no) + else if (i>player_no) { o[i].reset_write_head(); - o[i].Receive(sockets[i]); + o[i].Receive(sockets[i]); } } if (!donthash) @@ -240,7 +269,7 @@ void Player::Check_Broadcast() const Broadcast_Receive(h,true); for (int i=0; i other_player; - setup_sockets(Nms.names[other_player].c_str(), *Nms.server, Nms.portnum_base + other_player, id); + setup_sockets(other_player, Nms, Nms.portnum_base + other_player, id); } TwoPartyPlayer::~TwoPartyPlayer() -{ +{ + for(size_t i=0; i < my_secret_key.size(); i++) { + my_secret_key[i] = 0; + } close_client_socket(socket); } -void TwoPartyPlayer::setup_sockets(const char* hostname, ServerSocket& server, int pn, int id) +static pair sts_initiator(int socket, CommsecKeysPackage *keys, int other_player) { - if (is_server) - { - fprintf(stderr, "Setting up server with id %d\n",id); - socket = server.get_connection_socket(id); + sts_msg1_t m1; + sts_msg2_t m2; + sts_msg3_t m3; + octetStream socket_stream; + + // Start Station to Station Protocol + STS ke(&keys->player_public_keys[other_player][0], &keys->my_public_key[0], &keys->my_secret_key[0]); + m1 = ke.send_msg1(); + socket_stream.reset_write_head(); + socket_stream.append(m1.bytes, sizeof m1.bytes); + socket_stream.Send(socket); + socket_stream.Receive(socket); + socket_stream.consume(m2.pubkey, sizeof m2.pubkey); + socket_stream.consume(m2.sig, sizeof m2.sig); + m3 = ke.recv_msg2(m2); + socket_stream.reset_write_head(); + socket_stream.append(m3.bytes, sizeof m3.bytes); + socket_stream.Send(socket); + + // Use results of STS to generate send and receive keys. + vector sendKey = ke.derive_secret(crypto_secretbox_KEYBYTES); + vector recvKey = ke.derive_secret(crypto_secretbox_KEYBYTES); + keyinfo sendkeyinfo = make_pair(sendKey,0); + keyinfo recvkeyinfo = make_pair(recvKey,0); + return make_pair(sendkeyinfo,recvkeyinfo); +} + +static pair sts_responder(int socket, CommsecKeysPackage *keys, int other_player) + // secret_signing_key mykey, public_signing_key mypubkey, public_signing_key theirkey) +{ + sts_msg1_t m1; + sts_msg2_t m2; + sts_msg3_t m3; + octetStream socket_stream; + + // Start Station to Station Protocol for the responder + STS ke(&keys->player_public_keys[other_player][0], &keys->my_public_key[0], &keys->my_secret_key[0]); + socket_stream.Receive(socket); + socket_stream.consume(m1.bytes, sizeof m1.bytes); + m2 = ke.recv_msg1(m1); + socket_stream.reset_write_head(); + socket_stream.append(m2.pubkey, sizeof m2.pubkey); + socket_stream.append(m2.sig, sizeof m2.sig); + socket_stream.Send(socket); + socket_stream.Receive(socket); + socket_stream.consume(m3.bytes, sizeof m3.bytes); + ke.recv_msg3(m3); + + // Use results of STS to generate send and receive keys. + vector recvKey = ke.derive_secret(crypto_secretbox_KEYBYTES); + vector sendKey = ke.derive_secret(crypto_secretbox_KEYBYTES); + keyinfo sendkeyinfo = make_pair(sendKey,0); + keyinfo recvkeyinfo = make_pair(recvKey,0); + return make_pair(sendkeyinfo,recvkeyinfo); +} + +void TwoPartyPlayer::setup_sockets(int other_player, const Names &nms, int portNum, int id) +{ + const char *hostname = nms.names[other_player].c_str(); + ServerSocket *server = nms.server; + if (is_server) { + fprintf(stderr, "Setting up server with id %d\n",id); + socket = server->get_connection_socket(id); + if(NULL != nms.keys) { + pair send_recv_pair = sts_responder(socket, nms.keys, other_player); + player_send_key = send_recv_pair.first; + player_recv_key = send_recv_pair.second; + } } - else - { - fprintf(stderr, "Setting up client to %s:%d with id %d\n", hostname, pn, id); - set_up_client_socket(socket, hostname, pn); - ::send(socket, (unsigned char*)&id, sizeof(id)); + else { + fprintf(stderr, "Setting up client to %s:%d with id %d\n", hostname, portNum, id); + set_up_client_socket(socket, hostname, portNum); + ::send(socket, (unsigned char*)&id, sizeof(id)); + if(NULL != nms.keys) { + pair send_recv_pair = sts_initiator(socket, nms.keys, other_player); + player_send_key = send_recv_pair.first; + player_recv_key = send_recv_pair.second; + } } } @@ -381,31 +481,37 @@ int TwoPartyPlayer::other_player_num() const return other_player; } -void TwoPartyPlayer::send(octetStream& o) const +void TwoPartyPlayer::send(octetStream& o) { + if(p2pcommsec) { + o.encrypt_sequence(&player_send_key.first[0], player_send_key.second); + player_send_key.second++; + } o.Send(socket); } -void TwoPartyPlayer::receive(octetStream& o) const +void TwoPartyPlayer::receive(octetStream& o) { o.reset_write_head(); o.Receive(socket); + if(p2pcommsec) { + o.decrypt_sequence(&player_recv_key.first[0], player_recv_key.second); + player_recv_key.second++; + } } -void TwoPartyPlayer::send_receive_player(vector& o) const +void TwoPartyPlayer::send_receive_player(vector& o) { { if (is_server) { - o[0].Send(socket); - o[1].reset_write_head(); - o[1].Receive(socket); + send(o[0]); + receive(o[1]); } else { - o[1].reset_write_head(); - o[1].Receive(socket); - o[0].Send(socket); + receive(o[1]); + send(o[0]); } } } diff --git a/Networking/Player.h b/Networking/Player.h index 50006119..7784ce8b 100644 --- a/Networking/Player.h +++ b/Networking/Player.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _Player #define _Player @@ -23,6 +23,23 @@ using namespace std; #include "Networking/Receiver.h" #include "Networking/Sender.h" +typedef vector public_signing_key; +typedef vector secret_signing_key; +typedef vector chachakey; +typedef pair< chachakey, uint64_t > keyinfo; + +class CommsecKeysPackage { +public: + vector player_public_keys; + secret_signing_key my_secret_key; + public_signing_key my_public_key; + + CommsecKeysPackage(vector playerpubs, + secret_signing_key mypriv, + public_signing_key mypub); + ~CommsecKeysPackage(); +}; + /* Class to get the names off the server */ class Names { @@ -31,6 +48,8 @@ class Names int portnum_base; int player_no; + CommsecKeysPackage *keys; + void setup_names(const char *servername); void setup_server(); @@ -39,7 +58,6 @@ class Names mutable ServerSocket* server; - // Usual setup names void init(int player,int pnb,const char* servername); Names(int player,int pnb,const char* servername) { init(player,pnb,servername); } @@ -50,11 +68,10 @@ class Names void init(int player,int pnb,vector Nms); Names(int player,int pnb,vector Nms) { init(player,pnb,Nms); } - // Set up names from file -- reads the first nplayers names in the file void init(int player, int nplayers, int pnb, const string& hostsfile); Names(int player, int nplayers, int pnb, const string& hostsfile) { init(player, nplayers, pnb, hostsfile); } - + void set_keys( CommsecKeysPackage *keys ); Names() : nplayers(-1), portnum_base(-1), player_no(-1), server(0) { ; } Names(const Names& other); @@ -81,7 +98,6 @@ public: int my_num() const { return player_no; } }; - class Player : public PlayerBase { protected: @@ -161,25 +177,31 @@ class TwoPartyPlayer : public PlayerBase { private: // setup sockets for comm. with only one other player - void setup_sockets(const char* hostname, ServerSocket& server, int pn, int id); + void setup_sockets(int other_player, const Names &nms, int portNum, int id); int socket; bool is_server; int other_player; + bool p2pcommsec; + + secret_signing_key my_secret_key; + map player_public_keys; + keyinfo player_send_key; + keyinfo player_recv_key; public: TwoPartyPlayer(const Names& Nms, int other_player, int pn_offset=0); ~TwoPartyPlayer(); - void send(octetStream& o) const; - void receive(octetStream& o) const; + void send(octetStream& o); + void receive(octetStream& o); int other_player_num() const; /* Send and receive to/from the other player * - o[0] contains my data, received data put in o[1] */ - void send_receive_player(vector& o) const; + void send_receive_player(vector& o); }; #endif diff --git a/Networking/Receiver.cpp b/Networking/Receiver.cpp index c7527f3c..b4353af2 100644 --- a/Networking/Receiver.cpp +++ b/Networking/Receiver.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Receiver.cpp diff --git a/Networking/Receiver.h b/Networking/Receiver.h index f7d62d07..f81912ba 100644 --- a/Networking/Receiver.h +++ b/Networking/Receiver.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Receiver.h diff --git a/Networking/STS.cpp b/Networking/STS.cpp new file mode 100644 index 00000000..c6310a3b --- /dev/null +++ b/Networking/STS.cpp @@ -0,0 +1,230 @@ +// (C) 2017 University of Bristol. See License.txt + +#include "Networking/STS.h" +#include +#include +#include +#include +#include +#include +#include + +void STS::kdf_block(unsigned char *block) +{ + crypto_hash_sha512_state state; + crypto_hash_sha512_init(&state); + unsigned char ctrbytes[sizeof kdf_counter]; + kdf_counter++; + + // Little endian serialization + for(size_t i=0; i> i*8) & 0xFF); + } + crypto_hash_sha512_update(&state,ctrbytes,sizeof ctrbytes); + crypto_hash_sha512_update(&state,raw_secret,crypto_hash_sha512_BYTES); + crypto_hash_sha512_final(&state, block); +} + +vector STS::unsafe_derive_secret(size_t sz) +{ + // KDF ~ H(cnt || raw_secret) + vector resultSecret(sz + crypto_hash_sha512_BYTES - (sz % crypto_hash_sha512_BYTES)); + size_t total=0; + while(total < sz) { + unsigned char *block = &resultSecret[total]; + kdf_block(block); + total += crypto_hash_sha512_BYTES; + } + return resultSecret; +} + +STS::STS() +{ + phase = UNDEFINED; +} + +void STS::init( const unsigned char theirPub[crypto_sign_PUBLICKEYBYTES] + , const unsigned char myPub[crypto_sign_PUBLICKEYBYTES] + , const unsigned char myPriv[crypto_sign_SECRETKEYBYTES]) +{ + phase = UNKNOWN; + memcpy(their_public_sign_key, theirPub, crypto_sign_PUBLICKEYBYTES); + memcpy(my_public_sign_key, myPub, crypto_sign_PUBLICKEYBYTES); + memcpy(my_private_sign_key, myPriv, crypto_sign_SECRETKEYBYTES); + memset(their_ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES); + memset(ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES); + memset(ephemeral_private_key, 0, crypto_box_SECRETKEYBYTES); + kdf_counter = 0; +} + +STS::STS( const unsigned char theirPub[crypto_sign_PUBLICKEYBYTES] + , const unsigned char myPub[crypto_sign_PUBLICKEYBYTES] + , const unsigned char myPriv[crypto_sign_SECRETKEYBYTES]) +{ + phase = UNKNOWN; + memcpy(their_public_sign_key, theirPub, crypto_sign_PUBLICKEYBYTES); + memcpy(my_public_sign_key, myPub, crypto_sign_PUBLICKEYBYTES); + memcpy(my_private_sign_key, myPriv, crypto_sign_SECRETKEYBYTES); + memset(their_ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES); + memset(ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES); + memset(ephemeral_private_key, 0, crypto_box_SECRETKEYBYTES); + kdf_counter = 0; +} + +STS::~STS() +{ + memset(their_public_sign_key, 0, crypto_sign_PUBLICKEYBYTES); + memset(my_private_sign_key, 0, crypto_sign_SECRETKEYBYTES); + memset(ephemeral_private_key, 0, crypto_box_SECRETKEYBYTES); + memset(ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES); + memset(their_ephemeral_public_key, 0, crypto_box_PUBLICKEYBYTES); + memset(raw_secret, 0, crypto_hash_sha512_BYTES); + kdf_counter = 0; + phase = UNKNOWN; +} + +sts_msg1_t STS::send_msg1() +{ + sts_msg1_t m; + if(UNKNOWN != phase) { + throw "STS BAD PHASE"; + } + + crypto_box_keypair(ephemeral_public_key, ephemeral_private_key); + memcpy(m.bytes,ephemeral_public_key,crypto_box_PUBLICKEYBYTES); + phase = SENT1; + return m; +} + +// If the incoming signature is valid, compute: +// shared secret = H(DH(pubB,privA) || pubA || pubB) +// msg = Sign_{privED-A} (pubA || pubB ) +// +sts_msg3_t STS::recv_msg2(sts_msg2_t msg2) +{ + unsigned char *theirPublicKey = msg2.pubkey; + unsigned char *theirSig = msg2.sig; + unsigned char theirSigDec[crypto_sign_BYTES]; + unsigned char scalar_result[crypto_scalarmult_SCALARBYTES]; + const unsigned char zeroNonce[crypto_stream_NONCEBYTES] = {0}; + int ret; + crypto_hash_sha512_state state; + sts_msg3_t msg; + + if(SENT1 != phase) { + throw "STS BAD PHASE"; + } + ret = crypto_scalarmult(scalar_result, ephemeral_private_key, theirPublicKey); + if(0 != ret) { + throw "crypto_scalarmult failed"; + } + + crypto_hash_sha512_init(&state); + crypto_hash_sha512_update(&state,scalar_result,crypto_scalarmult_SCALARBYTES); + crypto_hash_sha512_update(&state,ephemeral_public_key,crypto_box_PUBLICKEYBYTES); + crypto_hash_sha512_update(&state,theirPublicKey,crypto_box_PUBLICKEYBYTES); + crypto_hash_sha512_final(&state,raw_secret); + + vector keKey = unsafe_derive_secret(crypto_stream_KEYBYTES); + vector expectedMessage; + expectedMessage.insert(expectedMessage.end(), theirPublicKey , theirPublicKey + crypto_box_PUBLICKEYBYTES); + expectedMessage.insert(expectedMessage.end(), ephemeral_public_key, ephemeral_public_key + crypto_box_PUBLICKEYBYTES); + + crypto_stream_xor(theirSigDec, theirSig, crypto_sign_BYTES, zeroNonce, &keKey[0]); + + int badSig = crypto_sign_verify_detached(theirSigDec, &expectedMessage[0], expectedMessage.size(), their_public_sign_key); + + if(badSig) { + throw "Bad signature received in message 2."; + } else { + unsigned char *mySigEnc = msg.bytes; + unsigned char mySig[crypto_sign_BYTES]; + vector signMessage; + signMessage.insert(signMessage.end(), ephemeral_public_key, ephemeral_public_key + crypto_box_PUBLICKEYBYTES); + signMessage.insert(signMessage.end(), theirPublicKey , theirPublicKey + crypto_box_PUBLICKEYBYTES); + if(0 != crypto_sign_detached(mySig, NULL, &signMessage[0], signMessage.size(), my_private_sign_key)) { + throw "Signing failed."; + } + vector keKey2 = unsafe_derive_secret(crypto_stream_KEYBYTES); + crypto_stream_xor(mySigEnc, mySig, crypto_sign_BYTES, zeroNonce, &keKey2[0]); + + phase = FINISHED; + return msg; + } +} + +sts_msg2_t STS::recv_msg1(sts_msg1_t msg1) +{ + unsigned char *theirPublicKey = msg1.bytes; + unsigned char scalar_result[crypto_scalarmult_SCALARBYTES]; + crypto_hash_sha512_state state; + sts_msg2_t m; + int ret; + + if(UNKNOWN != phase) { + throw "recv_msg1 called on non-unknown phase"; + } + + memcpy(their_ephemeral_public_key, theirPublicKey, crypto_box_PUBLICKEYBYTES); + + crypto_box_keypair(ephemeral_public_key, ephemeral_private_key); + memcpy(m.pubkey,ephemeral_public_key,crypto_box_PUBLICKEYBYTES); + ret = crypto_scalarmult(scalar_result, ephemeral_private_key, theirPublicKey); + if(0 != ret) { + throw "crypto_scalarmult failed when processing message 1"; + } + + crypto_hash_sha512_init(&state); + crypto_hash_sha512_update(&state,scalar_result,crypto_scalarmult_SCALARBYTES); + crypto_hash_sha512_update(&state,theirPublicKey,crypto_box_PUBLICKEYBYTES); + crypto_hash_sha512_update(&state,ephemeral_public_key,crypto_box_PUBLICKEYBYTES); + crypto_hash_sha512_final(&state,raw_secret); + + vector livenessProof; + livenessProof.insert(livenessProof.end(), ephemeral_public_key, ephemeral_public_key + crypto_box_PUBLICKEYBYTES); + livenessProof.insert(livenessProof.end(), theirPublicKey , theirPublicKey + crypto_box_PUBLICKEYBYTES); + unsigned char mySig[crypto_sign_BYTES]; + unsigned char *mySigEnc = m.sig; + vector keKey = unsafe_derive_secret(crypto_stream_KEYBYTES); + + unsigned char zeroNonce[crypto_stream_NONCEBYTES] = {0}; + if(0 != crypto_sign_detached(mySig, NULL, &livenessProof[0], livenessProof.size(), my_private_sign_key)) { + throw "Signing failed."; + } + crypto_stream_xor(mySigEnc, mySig, crypto_sign_BYTES, zeroNonce, &keKey[0]); + + phase = SENT2; + return m; +} + +void STS::recv_msg3(sts_msg3_t msg3) +{ + unsigned char *theirSig=msg3.bytes; + unsigned char theirSigDec[crypto_sign_BYTES]; + vector expectedMessage; + if(SENT2 != phase) { + throw "recv_msg3 called out of order"; + } + + expectedMessage.insert(expectedMessage.end(), their_ephemeral_public_key , their_ephemeral_public_key + crypto_box_PUBLICKEYBYTES); + expectedMessage.insert(expectedMessage.end(), ephemeral_public_key, ephemeral_public_key + crypto_box_PUBLICKEYBYTES); + unsigned char zeroNonce[crypto_stream_NONCEBYTES] = {0}; + vector keKey2 = unsafe_derive_secret(crypto_stream_KEYBYTES); + + crypto_stream_xor(theirSigDec, theirSig, crypto_sign_BYTES, zeroNonce, &keKey2[0]); + int badSig = crypto_sign_verify_detached(theirSigDec, &expectedMessage[0], expectedMessage.size(), their_public_sign_key); + + if(badSig) { + throw "Bad signature received in message 3."; + } else { + phase = FINISHED; + } +} + +vector STS::derive_secret(size_t sz) +{ + if(phase != FINISHED) { + throw "Can not derive secrets till the key exchange has completed."; + } + return unsafe_derive_secret(sz); +} diff --git a/Networking/STS.h b/Networking/STS.h new file mode 100644 index 00000000..d588d6c9 --- /dev/null +++ b/Networking/STS.h @@ -0,0 +1,72 @@ +// (C) 2017 University of Bristol. See License.txt + +#ifndef _NETWORK_STS +#define _NETWORK_STS + +/* The Station to Station protocol + */ + +#include +#include +#include +#include + +using namespace std; + +typedef enum + { UNKNOWN // Have not started the interaction or have cleared the memory + , SENT1 // Sent initial message + , SENT2 // Received 1, sent 2 + , FINISHED // Done (received msg 2 & sent 3 or received msg 3) + , UNDEFINED // For arrays/vectors/etc of STS classes that are initialized later. +} phase_t; + +struct msg1_st { + unsigned char bytes[crypto_box_PUBLICKEYBYTES]; +}; +typedef struct msg1_st sts_msg1_t; +struct msg2_st { + unsigned char pubkey[crypto_box_PUBLICKEYBYTES]; + unsigned char sig[crypto_sign_BYTES]; +}; +typedef struct msg2_st sts_msg2_t; +struct msg3_st { + unsigned char bytes[crypto_sign_BYTES]; +}; +typedef struct msg3_st sts_msg3_t; + +class STS +{ + phase_t phase; + unsigned char their_public_sign_key[crypto_sign_PUBLICKEYBYTES]; + unsigned char my_public_sign_key[crypto_sign_PUBLICKEYBYTES]; + unsigned char my_private_sign_key[crypto_sign_SECRETKEYBYTES]; + unsigned char ephemeral_private_key[crypto_box_SECRETKEYBYTES]; + unsigned char ephemeral_public_key[crypto_box_PUBLICKEYBYTES]; + unsigned char their_ephemeral_public_key[crypto_box_PUBLICKEYBYTES]; + unsigned char raw_secret[crypto_hash_sha512_BYTES]; + uint64_t kdf_counter; + public: + STS(); + STS( const unsigned char theirPub[crypto_sign_PUBLICKEYBYTES] + , const unsigned char myPub[crypto_sign_PUBLICKEYBYTES] + , const unsigned char myPriv[crypto_sign_SECRETKEYBYTES]); + ~STS(); + + void init( const unsigned char theirPub[crypto_sign_PUBLICKEYBYTES] + , const unsigned char myPub[crypto_sign_PUBLICKEYBYTES] + , const unsigned char myPriv[crypto_sign_SECRETKEYBYTES]); + + sts_msg1_t send_msg1(); + sts_msg3_t recv_msg2(sts_msg2_t msg2); + + sts_msg2_t recv_msg1(sts_msg1_t msg1); + void recv_msg3(sts_msg3_t msg3); + + vector derive_secret(size_t); + private: + vector unsafe_derive_secret(size_t); + void kdf_block(unsigned char *block); +}; + +#endif /* _NETWORK_STS */ diff --git a/Networking/Sender.cpp b/Networking/Sender.cpp index 89dc0eed..afeb2c49 100644 --- a/Networking/Sender.cpp +++ b/Networking/Sender.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Sender.cpp diff --git a/Networking/Sender.h b/Networking/Sender.h index ff95c8bb..f07fcec8 100644 --- a/Networking/Sender.h +++ b/Networking/Sender.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Sender.h diff --git a/Networking/ServerSocket.cpp b/Networking/ServerSocket.cpp index 653068a1..d890ae8f 100644 --- a/Networking/ServerSocket.cpp +++ b/Networking/ServerSocket.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * ServerSocket.cpp @@ -57,7 +57,7 @@ ServerSocket::ServerSocket(int Portnum) : portnum(Portnum) sleep(1); } else - { cerr << "Bound on port " << Portnum << endl; } + { cerr << "ServerSocket is bound on port " << Portnum << endl; } } if (fl<0) { error("set_up_socket:bind"); } @@ -65,6 +65,11 @@ ServerSocket::ServerSocket(int Portnum) : portnum(Portnum) fl=listen(main_socket, 1000); if (fl<0) { error("set_up_socket:listen"); } + // Note: must not call virtual init() method in constructor: http://www.aristeia.com/EC3E/3E_item9.pdf +} + +void ServerSocket::init() +{ pthread_create(&thread, 0, accept_thread, this); } @@ -95,6 +100,15 @@ void ServerSocket::accept_clients() } } +int ServerSocket::get_connection_count() +{ + data_signal.lock(); + int connection_count = clients.size(); + data_signal.unlock(); + return connection_count; +} + + int ServerSocket::get_connection_socket(int id) { data_signal.lock(); @@ -108,8 +122,60 @@ int ServerSocket::get_connection_socket(int id) while (clients.find(id) == clients.end()) data_signal.wait(); - int client = clients[id]; + int client_socket = clients[id]; used.insert(id); data_signal.unlock(); - return client; + return client_socket; +} + +void* anonymous_accept_thread(void* server_socket) +{ + ((AnonymousServerSocket*)server_socket)->accept_clients(); + return 0; +} + +int AnonymousServerSocket::global_client_socket_count = 0; + +void AnonymousServerSocket::init() +{ + pthread_create(&thread, 0, anonymous_accept_thread, this); +} + +int AnonymousServerSocket::get_connection_count() +{ + return num_accepted_clients; +} + +void AnonymousServerSocket::accept_clients() +{ + while (true) + { + struct sockaddr dest; + memset(&dest, 0, sizeof(dest)); /* zero the struct before filling the fields */ + int socksize = sizeof(dest); + int consocket = accept(main_socket, (struct sockaddr *)&dest, (socklen_t*) &socksize); + if (consocket<0) { error("set_up_socket:accept"); } + + data_signal.lock(); + client_connection_queue.push(consocket); + num_accepted_clients++; + data_signal.broadcast(); + data_signal.unlock(); + } +} + +int AnonymousServerSocket::get_connection_socket(int& client_id) +{ + data_signal.lock(); + + //while (clients.find(next_client_id) == clients.end()) + while (client_connection_queue.empty()) + data_signal.wait(); + + client_id = global_client_socket_count; + global_client_socket_count++; + int client_socket = client_connection_queue.front(); + client_connection_queue.pop(); + data_signal.unlock(); + return client_socket; } diff --git a/Networking/ServerSocket.h b/Networking/ServerSocket.h index 08a9cbc4..27c388ee 100644 --- a/Networking/ServerSocket.h +++ b/Networking/ServerSocket.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * ServerSocket.h @@ -10,6 +10,7 @@ #include #include + #include using namespace std; #include @@ -19,6 +20,7 @@ using namespace std; class ServerSocket { +protected: int main_socket, portnum; map clients; set used; @@ -28,17 +30,51 @@ class ServerSocket // disable copying ServerSocket(const ServerSocket& other); + // receive id from client + int assign_client_id(int consocket); + public: ServerSocket(int Portnum); - ~ServerSocket(); + virtual ~ServerSocket(); - void accept_clients(); + virtual void init(); + + virtual void accept_clients(); // This depends on clients sending their id as int. // Has to be thread-safe. int get_connection_socket(int number); + // How many client connections have been made. + virtual int get_connection_count(); + void close_socket(); }; +/* + * ServerSocket where clients do not send any identifiers upon connecting. + */ +class AnonymousServerSocket : public ServerSocket +{ +private: + // Global no. of client sockets that have been returned - used to create identifiers + static int global_client_socket_count; + // No. of accepted connections in this instance + int num_accepted_clients; + queue client_connection_queue; + +public: + AnonymousServerSocket(int Portnum) : + ServerSocket(Portnum), num_accepted_clients(0) { }; + // override so clients do not send id + void accept_clients(); + void init(); + + virtual int get_connection_count(); + + // Get socket for the last client who connected + // Writes a unique client identifier (i.e. a counter) to client_id + int get_connection_socket(int& client_id); +}; + #endif /* NETWORKING_SERVERSOCKET_H_ */ diff --git a/Networking/data.h b/Networking/data.h index d131a6a9..564cd61b 100644 --- a/Networking/data.h +++ b/Networking/data.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _Data #define _Data diff --git a/Networking/sockets.cpp b/Networking/sockets.cpp index f225a13d..2553fd58 100644 --- a/Networking/sockets.cpp +++ b/Networking/sockets.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "sockets.h" @@ -28,8 +28,6 @@ void error(const char *str1,const char *str2) throw bad_value(); } - - void set_up_server_socket(sockaddr_in& dest,int& consocket,int& main_socket,int Portnum) { @@ -57,7 +55,7 @@ void set_up_server_socket(sockaddr_in& dest,int& consocket,int& main_socket,int memset(my_name,0,512*sizeof(octet)); gethostname((char*)my_name,512); - /* bind serv information to mysocket + /* bind serv information to mysocket * - Just assume it will eventually wake up */ fl=1; @@ -82,21 +80,18 @@ void set_up_server_socket(sockaddr_in& dest,int& consocket,int& main_socket,int } - void close_server_socket(int consocket,int main_socket) { if (close(consocket)) { error("close(socket)"); } if (close(main_socket)) { error("close(main_socket"); }; } - - void set_up_client_socket(int& mysocket,const char* hostname,int Portnum) { mysocket = socket(AF_INET, SOCK_STREAM, 0); if (mysocket<0) { error("set_up_socket:socket"); } - - /* disable Nagle's algorithm */ + + /* disable Nagle's algorithm */ int one=1; int fl= setsockopt(mysocket, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(int)); if (fl<0) { error("set_up_socket:setsockopt"); } @@ -106,17 +101,8 @@ void set_up_client_socket(int& mysocket,const char* hostname,int Portnum) struct sockaddr_in dest; dest.sin_family = AF_INET; - dest.sin_port = htons(Portnum); // set destination port number + dest.sin_port = htons(Portnum); // set destination port number - /* - struct hostent *server; - server=gethostbyname(hostname); - if (server== NULL) - { error("set_up_socket:gethostbyname"); } - bcopy((char *)server->h_addr, - (char *)&dest.sin_addr.s_addr, - server->h_length); // set destination IP number - */ struct addrinfo hints, *ai=NULL,*rp; memset (&hints, 0, sizeof(hints)); hints.ai_family = AF_INET; @@ -140,13 +126,13 @@ void set_up_client_socket(int& mysocket,const char* hostname,int Portnum) } } if (erp!=0) - { error("set_up_socket:getaddrinfo"); } + { error("set_up_socket:getaddrinfo"); } for (rp=ai; rp!=NULL; rp=rp->ai_next) { const struct in_addr *addr4 = &((const struct sockaddr_in*)ai->ai_addr)->sin_addr; - + if (ai->ai_family == AF_INET) - { memcpy((char *)&dest.sin_addr.s_addr,addr4,sizeof(in_addr)); + { memcpy((char *)&dest.sin_addr.s_addr,addr4,sizeof(in_addr)); continue; } } @@ -162,8 +148,6 @@ void set_up_client_socket(int& mysocket,const char* hostname,int Portnum) if (fl<0) { error("set_up_socket:connect:",hostname); } } - - void close_client_socket(int socket) { if (close(socket)) @@ -174,8 +158,6 @@ void close_client_socket(int socket) } } - - unsigned long long sent_amount = 0, sent_counter = 0; @@ -195,7 +177,7 @@ void receive(int socket,int& a) while (i==0) { i=recv(socket,msg,1,0); if (i<0) { error("Receiving error - 2"); } - } + } a=msg[0]; } diff --git a/Networking/sockets.h b/Networking/sockets.h index a7c38fb0..a0f1d945 100644 --- a/Networking/sockets.h +++ b/Networking/sockets.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _sockets #define _sockets diff --git a/OT/BaseOT.cpp b/OT/BaseOT.cpp index 2788ed72..54428693 100644 --- a/OT/BaseOT.cpp +++ b/OT/BaseOT.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "OT/BaseOT.h" #include "Tools/random.h" @@ -34,7 +34,7 @@ OT_ROLE INV_ROLE(OT_ROLE role) return BOTH; } -void send_if_ot_sender(const TwoPartyPlayer* P, vector& os, OT_ROLE role) +void send_if_ot_sender(TwoPartyPlayer* P, vector& os, OT_ROLE role) { if (role == SENDER) { @@ -51,7 +51,7 @@ void send_if_ot_sender(const TwoPartyPlayer* P, vector& os, OT_ROLE } } -void send_if_ot_receiver(const TwoPartyPlayer* P, vector& os, OT_ROLE role) +void send_if_ot_receiver(TwoPartyPlayer* P, vector& os, OT_ROLE role) { if (role == RECEIVER) { diff --git a/OT/BaseOT.h b/OT/BaseOT.h index 188801c3..e7c2d13e 100644 --- a/OT/BaseOT.h +++ b/OT/BaseOT.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _BASE_OT #define _BASE_OT @@ -26,8 +26,8 @@ enum OT_ROLE OT_ROLE INV_ROLE(OT_ROLE role); const char* role_to_str(OT_ROLE role); -void send_if_ot_sender(const TwoPartyPlayer* P, vector& os, OT_ROLE role); -void send_if_ot_receiver(const TwoPartyPlayer* P, vector& os, OT_ROLE role); +void send_if_ot_sender(TwoPartyPlayer* P, vector& os, OT_ROLE role); +void send_if_ot_receiver(TwoPartyPlayer* P, vector& os, OT_ROLE role); class BaseOT { diff --git a/OT/BitMatrix.cpp b/OT/BitMatrix.cpp index 012bcafc..797f9c31 100644 --- a/OT/BitMatrix.cpp +++ b/OT/BitMatrix.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * BitMatrix.cpp diff --git a/OT/BitMatrix.h b/OT/BitMatrix.h index 3dae6070..409f1959 100644 --- a/OT/BitMatrix.h +++ b/OT/BitMatrix.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * BitMatrix.h diff --git a/OT/BitVector.cpp b/OT/BitVector.cpp index 285dfd6c..64fe3eee 100644 --- a/OT/BitVector.cpp +++ b/OT/BitVector.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "OT/BitVector.h" diff --git a/OT/BitVector.h b/OT/BitVector.h index 54eac5a4..671512b1 100644 --- a/OT/BitVector.h +++ b/OT/BitVector.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _BITVECTOR #define _BITVECTOR diff --git a/OT/NPartyTripleGenerator.cpp b/OT/NPartyTripleGenerator.cpp index 14a14f7b..23c795ed 100644 --- a/OT/NPartyTripleGenerator.cpp +++ b/OT/NPartyTripleGenerator.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "NPartyTripleGenerator.h" diff --git a/OT/NPartyTripleGenerator.h b/OT/NPartyTripleGenerator.h index 91a90572..ae3881a0 100644 --- a/OT/NPartyTripleGenerator.h +++ b/OT/NPartyTripleGenerator.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef OT_NPARTYTRIPLEGENERATOR_H_ #define OT_NPARTYTRIPLEGENERATOR_H_ diff --git a/OT/OTExtension.cpp b/OT/OTExtension.cpp index 8efd07af..d264ed2c 100644 --- a/OT/OTExtension.cpp +++ b/OT/OTExtension.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "OTExtension.h" diff --git a/OT/OTExtension.h b/OT/OTExtension.h index 91724a6a..e07367c1 100644 --- a/OT/OTExtension.h +++ b/OT/OTExtension.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _OTEXTENSION #define _OTEXTENSION diff --git a/OT/OTExtensionWithMatrix.cpp b/OT/OTExtensionWithMatrix.cpp index 7894ae94..d96d6c24 100644 --- a/OT/OTExtensionWithMatrix.cpp +++ b/OT/OTExtensionWithMatrix.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * OTExtensionWithMatrix.cpp diff --git a/OT/OTExtensionWithMatrix.h b/OT/OTExtensionWithMatrix.h index d9cf59bd..6bb52b57 100644 --- a/OT/OTExtensionWithMatrix.h +++ b/OT/OTExtensionWithMatrix.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * OTExtensionWithMatrix.h diff --git a/OT/OTMachine.cpp b/OT/OTMachine.cpp index db5e2b6b..6d11af88 100644 --- a/OT/OTMachine.cpp +++ b/OT/OTMachine.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Networking/Player.h" #include "OT/OTExtension.h" diff --git a/OT/OTMachine.h b/OT/OTMachine.h index 9706e68f..9f80f3db 100644 --- a/OT/OTMachine.h +++ b/OT/OTMachine.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * OTMachine.h diff --git a/OT/OTMultiplier.cpp b/OT/OTMultiplier.cpp index 658ae650..ead5d3e0 100644 --- a/OT/OTMultiplier.cpp +++ b/OT/OTMultiplier.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * OTMultiplier.cpp diff --git a/OT/OTMultiplier.h b/OT/OTMultiplier.h index 15d9530e..5149a24b 100644 --- a/OT/OTMultiplier.h +++ b/OT/OTMultiplier.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * OTMultiplier.h diff --git a/OT/OTTripleSetup.cpp b/OT/OTTripleSetup.cpp index 049eda82..4f10f08d 100644 --- a/OT/OTTripleSetup.cpp +++ b/OT/OTTripleSetup.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "OTTripleSetup.h" diff --git a/OT/OTTripleSetup.h b/OT/OTTripleSetup.h index 354cd21e..31809e58 100644 --- a/OT/OTTripleSetup.h +++ b/OT/OTTripleSetup.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef OT_TRIPLESETUP_H_ #define OT_TRIPLESETUP_H_ diff --git a/OT/OText_main.cpp b/OT/OText_main.cpp index fc2edaaf..39867547 100644 --- a/OT/OText_main.cpp +++ b/OT/OText_main.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * OText_main.cpp diff --git a/OT/OutputCheck.h b/OT/OutputCheck.h index a598b893..01681e79 100644 --- a/OT/OutputCheck.h +++ b/OT/OutputCheck.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * check.h diff --git a/OT/Tools.cpp b/OT/Tools.cpp index 32e9208f..a2cc03d6 100644 --- a/OT/Tools.cpp +++ b/OT/Tools.cpp @@ -1,9 +1,9 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Tools.h" #include "Math/gf2nlong.h" -void random_seed_commit(octet* seed, const TwoPartyPlayer& player, int len) +void random_seed_commit(octet* seed, TwoPartyPlayer& player, int len) { PRNG G; G.ReSeed(); diff --git a/OT/Tools.h b/OT/Tools.h index 1038d643..53ec588b 100644 --- a/OT/Tools.h +++ b/OT/Tools.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _OTTOOLS #define _OTTOOLS @@ -12,7 +12,7 @@ /* * Generate a secure, random seed between 2 parties via commitment */ -void random_seed_commit(octet* seed, const TwoPartyPlayer& player, int len); +void random_seed_commit(octet* seed, TwoPartyPlayer& player, int len); /* * GF(2^128) multiplication using Intel instructions diff --git a/OT/TripleMachine.cpp b/OT/TripleMachine.cpp index b1f72c9d..5c1e6dbd 100644 --- a/OT/TripleMachine.cpp +++ b/OT/TripleMachine.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * TripleMachine.cpp diff --git a/OT/TripleMachine.h b/OT/TripleMachine.h index d377bca2..d318b814 100644 --- a/OT/TripleMachine.h +++ b/OT/TripleMachine.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * TripleMachine.h diff --git a/Player-Online.cpp b/Player-Online.cpp index 246e845c..b0397d48 100644 --- a/Player-Online.cpp +++ b/Player-Online.cpp @@ -1,11 +1,14 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Processor/Machine.h" +#include "Math/Setup.h" #include "Tools/ezOptionParser.h" +#include "Tools/Config.h" #include #include #include +#include using namespace std; int main(int argc, const char** argv) @@ -108,6 +111,15 @@ int main(int argc, const char** argv) "-b", // Flag token. "--max-broadcast" // Flag token. ); + opt.add( + "0", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Use communications security between SPDZ players", // Help description. + "-c", // Flag token. + "--player-to-player-commsec" // Flag token. + ); opt.parse(argc, argv); @@ -156,6 +168,7 @@ int main(int argc, const char** argv) string memtype, hostname; int lg2, lgp, pnbase, opening_sum, max_broadcast; + int p2pcommsec; opt.get("--portnumbase")->getInt(pnbase); opt.get("--lgp")->getInt(lgp); @@ -164,11 +177,25 @@ int main(int argc, const char** argv) opt.get("--hostname")->getString(hostname); opt.get("--opening-sum")->getInt(opening_sum); opt.get("--max-broadcast")->getInt(max_broadcast); + opt.get("--player-to-player-commsec")->getInt(p2pcommsec); + int mynum; + sscanf((*allArgs[1]).c_str(), "%d", &mynum); + + CommsecKeysPackage *keys = NULL; + if(p2pcommsec) { + vector pubkeys; + secret_signing_key mykey; + public_signing_key mypublickey; + string prep_data_prefix = get_prep_dir(2, lgp, lg2); + Config::read_player_config(prep_data_prefix,mynum,pubkeys,mykey,mypublickey); + keys = new CommsecKeysPackage(pubkeys,mykey,mypublickey); + } + Machine(playerno, pnbase, hostname, progname, memtype, lgp, lg2, opt.get("--direct")->isSet, opening_sum, opt.get("--parallel")->isSet, - opt.get("--threads")->isSet, max_broadcast).run(); + opt.get("--threads")->isSet, max_broadcast, keys).run(); cerr << "Command line:"; for (int i = 0; i < argc; i++) diff --git a/Processor/Binary_File_IO.cpp b/Processor/Binary_File_IO.cpp new file mode 100644 index 00000000..5e4e4b8b --- /dev/null +++ b/Processor/Binary_File_IO.cpp @@ -0,0 +1,72 @@ +// (C) 2017 University of Bristol. See License.txt + +#include "Processor/Binary_File_IO.h" +#include "Math/gfp.h" + +/* + * Provides generalised file read and write methods for arrays of shares. + * Stateless and not optimised for multiple reads from file. + * Intended for application specific file IO. + */ + +template +void Binary_File_IO::write_to_file(const string filename, const vector< Share >& buffer) +{ + ofstream outf; + + outf.open(filename, ios::out | ios::binary | ios::app); + if (outf.fail()) { throw file_error(filename); } + + for (unsigned int i = 0; i < buffer.size(); i++) + { + buffer[i].output(outf, false); + } + + outf.close(); +} + +template +void Binary_File_IO::read_from_file(const string filename, vector< Share >& buffer, const int start_posn, int &end_posn) +{ + ifstream inf; + inf.open(filename, ios::in | ios::binary); + if (inf.fail()) { throw file_missing(filename, "Binary_File_IO.read_from_file expects this file to exist."); } + + int size_in_bytes = Share::size() * buffer.size(); + int n_read = 0; + char * read_buffer = new char[size_in_bytes]; + inf.seekg(start_posn); + do + { + inf.read(read_buffer + n_read, size_in_bytes - n_read); + n_read += inf.gcount(); + + if (inf.eof()) + { + stringstream ss; + ss << "Got to EOF when reading from disk (expecting " << size_in_bytes << " bytes)."; + throw file_error(ss.str()); + } + if (inf.fail()) + { + stringstream ss; + ss << "IO problem when reading from disk"; + throw file_error(ss.str()); + } + } + while (n_read < size_in_bytes); + + end_posn = inf.tellg(); + + //Check if at end of file by getting 1 more char. + inf.get(); + if (inf.eof()) + end_posn = -1; + inf.close(); + + for (unsigned int i = 0; i < buffer.size(); i++) + buffer[i].assign(&read_buffer[i*Share::size()]); +} + +template void Binary_File_IO::write_to_file(const string filename, const vector< Share >& buffer); +template void Binary_File_IO::read_from_file(const string filename, vector< Share >& buffer, const int start_posn, int &end_posn); diff --git a/Processor/Binary_File_IO.h b/Processor/Binary_File_IO.h new file mode 100644 index 00000000..b0c6ce0b --- /dev/null +++ b/Processor/Binary_File_IO.h @@ -0,0 +1,43 @@ +// (C) 2017 University of Bristol. See License.txt + +#ifndef _FILE_IO_HEADER +#define _FILE_IO_HEADER + +#include "Exceptions/Exceptions.h" +#include "Math/Share.h" + +#include +#include +#include +#include + +using namespace std; + +/* + * Provides generalised file read and write methods for arrays of numeric data types. + * Stateless and not optimised for multiple reads from file. + * Intended for MPC application specific file IO. + */ + +class Binary_File_IO +{ + public: + + /* + * Append the buffer values as binary to the filename. + * Throws file_error. + */ + template + void write_to_file(const string filename, const vector< Share >& buffer); + + /* + * Read from posn in the filename the binary values until the buffer is full. + * Assumes file holds binary that maps into the type passed in. + * Returns the current posn in the file or -1 if at eof. + * Throws file_error. + */ + template + void read_from_file(const string filename, vector< Share >& buffer, const int start_posn, int &end_posn); +}; + +#endif diff --git a/Processor/Buffer.cpp b/Processor/Buffer.cpp index a5ca4452..78878205 100644 --- a/Processor/Buffer.cpp +++ b/Processor/Buffer.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Buffer.cpp diff --git a/Processor/Buffer.h b/Processor/Buffer.h index a9936d1e..f8aae575 100644 --- a/Processor/Buffer.h +++ b/Processor/Buffer.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Buffer.h diff --git a/Processor/Data_Files.cpp b/Processor/Data_Files.cpp index 96d64d0a..eb0eed77 100644 --- a/Processor/Data_Files.cpp +++ b/Processor/Data_Files.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Processor/Data_Files.h" @@ -56,7 +56,7 @@ void DataPositions::print_cost() const file >> cost_per_item; if (cost_per_item < 0) break; - int items_used = files[i][j]; + long long items_used = files[i][j]; double cost = items_used * cost_per_item; total_cost += cost; cerr.fill(' '); diff --git a/Processor/Data_Files.h b/Processor/Data_Files.h index 2b558112..c9167df1 100644 --- a/Processor/Data_Files.h +++ b/Processor/Data_Files.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _Data_Files #define _Data_Files @@ -73,10 +73,10 @@ class Data_Files DataPositions usage; - const string prep_data_dir; - public: + const string prep_data_dir; + static const char* dtype_names[N_DTYPE]; static const char* field_names[N_DATA_FIELD_TYPE]; static const char* long_field_names[N_DATA_FIELD_TYPE]; diff --git a/Processor/ExternalClients.cpp b/Processor/ExternalClients.cpp new file mode 100644 index 00000000..6a5df049 --- /dev/null +++ b/Processor/ExternalClients.cpp @@ -0,0 +1,179 @@ +// (C) 2017 University of Bristol. See License.txt + +#include "Processor/ExternalClients.h" +#include +#include +#include + +ExternalClients::ExternalClients(int party_num, const string& prep_data_dir): + party_num(party_num), prep_data_dir(prep_data_dir), server_connection_count(-1) +{ +} + +ExternalClients::~ExternalClients() +{ + // close client sockets + for (map::iterator it = external_client_sockets.begin(); + it != external_client_sockets.end(); it++) + { + if (close(it->second)) + { + error("failed to close external client connection socket)"); + } + } + for (map::iterator it = client_connection_servers.begin(); + it != client_connection_servers.end(); it++) + { + delete it->second; + } + for (map::iterator it = symmetric_client_keys.begin(); + it != symmetric_client_keys.end(); it++) + { + delete[] it->second; + } + for (map,uint64_t> >::iterator it_cs = symmetric_client_commsec_send_keys.begin(); + it_cs != symmetric_client_commsec_send_keys.end(); it_cs++) + { + memset(&(it_cs->second.first[0]), 0, it_cs->second.first.size()); + } + for (map,uint64_t> >::iterator it_cs = symmetric_client_commsec_recv_keys.begin(); + it_cs != symmetric_client_commsec_recv_keys.end(); it_cs++) + { + memset(&(it_cs->second.first[0]), 0, it_cs->second.first.size()); + } +} + +void ExternalClients::start_listening(int portnum_base) +{ + client_connection_servers[portnum_base] = new AnonymousServerSocket(portnum_base + get_party_num()); + client_connection_servers[portnum_base]->init(); + cerr << "Start listening on thread " << this_thread::get_id() << endl; + cerr << "Party " << get_party_num() << " is listening on port " << (portnum_base + get_party_num()) + << " for external client connections." << endl; +} + +int ExternalClients::get_client_connection(int portnum_base) +{ + map::iterator it = client_connection_servers.find(portnum_base); + if (it == client_connection_servers.end()) + { + cerr << "Thread " << this_thread::get_id() << " didn't find server." << endl; + return -1; + } + cerr << "Thread " << this_thread::get_id() << " found server." << endl; + int client_id, socket; + socket = client_connection_servers[portnum_base]->get_connection_socket(client_id); + external_client_sockets[client_id] = socket; + cerr << "Party " << get_party_num() << " received external client connection from client id: " << dec << client_id << endl; + return client_id; +} + +int ExternalClients::connect_to_server(int portnum_base, int ipv4_address) +{ + struct in_addr addr = { (unsigned int)ipv4_address }; + int csocket; + const char* address_str = inet_ntoa(addr); + cerr << "Party " << get_party_num() << " connecting to server at " << address_str << " on port " << portnum_base + get_party_num() << endl; + set_up_client_socket(csocket, address_str, portnum_base + get_party_num()); + cerr << "Party " << get_party_num() << " connected to server at " << address_str << " on port " << portnum_base + get_party_num() << endl; + int server_id = server_connection_count; + // server identifiers are -1, -2, ... to avoid conflict with client identifiers + server_connection_count--; + external_client_sockets[server_id] = csocket; + return server_id; +} + +void ExternalClients::curve25519_ints_to_bytes(unsigned char *bytes, const vector& key_ints) +{ + for(unsigned int j = 0; j < key_ints.size(); j++) { + for(unsigned int k = 0; k < 4; k++) { + bytes[j*sizeof(int) + k] = (key_ints[j] >> ((3-k)*8)) & 0xFF; + } + } +} + +// Generate sesssion key for a newly connected client, store in symmetric_client_keys +// public_key is expected to be size 8 and contain integer values of public key bytes. +// Assumes load_server_keys has been run. +void ExternalClients::generate_session_key_for_client(int client_id, const vector& public_key) +{ + assert(public_key.size() * sizeof(int) == crypto_box_PUBLICKEYBYTES); + + load_server_keys_once(); + + unsigned char client_publickey[crypto_box_PUBLICKEYBYTES]; + + curve25519_ints_to_bytes(client_publickey, public_key); + + cerr << "Recevied client public key for client " << dec << client_id << " :"; + for (unsigned int j = 0; j < crypto_box_PUBLICKEYBYTES; j++) + cerr << hex << (int) client_publickey[j] << " "; + cerr << dec << endl; + + unsigned char scalarmult_q_by_server[crypto_scalarmult_BYTES]; + crypto_generichash_state h; + + symmetric_client_keys[client_id] = new octet[crypto_generichash_BYTES]; + + // Derive a shared key from this server's secret key and the client's public key + // shared key = h(q || server_secretkey || client_publickey) + if (crypto_scalarmult(scalarmult_q_by_server, server_secretkey, client_publickey) != 0) { + cerr << "Scalar mult failed\n"; + exit(1); + } + crypto_generichash_init(&h, NULL, 0U, crypto_generichash_BYTES); + crypto_generichash_update(&h, scalarmult_q_by_server, sizeof scalarmult_q_by_server); + crypto_generichash_update(&h, client_publickey, sizeof client_publickey); + crypto_generichash_update(&h, server_publickey, sizeof server_publickey); + crypto_generichash_final(&h, symmetric_client_keys[client_id], crypto_generichash_BYTES); +} + +// Read pre-computed server keys from client-setup for this SPDZ engine. +// Only needs to be done once per run, but is only necessary if an external connection +// is being requested. +void ExternalClients::load_server_keys_once() +{ + if (server_keys_loaded) { + return; + } + + ifstream keyfile; + stringstream filename; + filename << prep_data_dir << "Player-SPDZ-Keys-P" << get_party_num(); + keyfile.open(filename.str().c_str()); + if (keyfile.fail()) + throw file_error(filename.str().c_str()); + + keyfile.read((char*)server_publickey, sizeof server_publickey); + if (keyfile.eof()) + throw end_of_file(filename.str(), "server public key" ); + keyfile.read((char*)server_secretkey, sizeof server_secretkey); + if (keyfile.eof()) + throw end_of_file(filename.str(), "server private key" ); + + bool loaded_ed25519 = true; + + keyfile.read((char*)server_publickey_ed25519, sizeof server_publickey_ed25519); + if (keyfile.eof() || keyfile.bad()) + loaded_ed25519 = false; + keyfile.read((char*)server_secretkey_ed25519, sizeof server_secretkey_ed25519); + if (keyfile.eof() || keyfile.bad()) + loaded_ed25519 = false; + + keyfile.close(); + + ed25519_keys_loaded = loaded_ed25519; + server_keys_loaded = true; +} + +void ExternalClients::require_ed25519_keys() +{ + if (!ed25519_keys_loaded) + throw "Ed25519 keys required but not found in player key files"; +} + +int ExternalClients::get_party_num() +{ + return party_num; +} + diff --git a/Processor/ExternalClients.h b/Processor/ExternalClients.h new file mode 100644 index 00000000..12810e99 --- /dev/null +++ b/Processor/ExternalClients.h @@ -0,0 +1,65 @@ +// (C) 2017 University of Bristol. See License.txt + +#ifndef _ExternalClients +#define _ExternalClients + +#include "Networking/ServerSocket.h" +#include "Networking/sockets.h" +#include "Exceptions/Exceptions.h" +#include +#include +#include +#include +#include +#include + +/* + * Manage the reading and writing of data from/to external clients via Sockets. + * Generate the session keys for encryption/decryption of secret communication with external clients. + */ + +class ExternalClients +{ + map client_connection_servers; + + int party_num; + const string prep_data_dir; + int server_connection_count; + unsigned char server_publickey[crypto_box_PUBLICKEYBYTES]; + unsigned char server_secretkey[crypto_box_SECRETKEYBYTES]; + bool server_keys_loaded = false; + bool ed25519_keys_loaded = false; + + public: + + unsigned char server_publickey_ed25519[crypto_sign_ed25519_PUBLICKEYBYTES]; + unsigned char server_secretkey_ed25519[crypto_sign_ed25519_SECRETKEYBYTES]; + + // Maps holding per client values (indexed by unique 32-bit id) + std::map external_client_sockets; + std::map symmetric_client_keys; + std::map,uint64_t>> symmetric_client_commsec_send_keys; + std::map,uint64_t>> symmetric_client_commsec_recv_keys; + + ExternalClients(int party_num, const string& prep_data_dir); + ~ExternalClients(); + + void start_listening(int portnum_base); + + int get_client_connection(int portnum_base); + + int connect_to_server(int portnum_base, int ipv4_address); + + // return the socket for a given client or server identifier + int get_socket(int socket_id); + + void curve25519_ints_to_bytes(unsigned char bytes[crypto_box_PUBLICKEYBYTES], const vector& key_ints); + void generate_session_key_for_client(int client_id, const vector& public_key); + + void load_server_keys_once(); + + int get_party_num(); + void require_ed25519_keys(); +}; + +#endif diff --git a/Processor/Input.cpp b/Processor/Input.cpp index cc42ff40..5628ad25 100644 --- a/Processor/Input.cpp +++ b/Processor/Input.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Input.cpp diff --git a/Processor/Input.h b/Processor/Input.h index 69a87072..d324f898 100644 --- a/Processor/Input.h +++ b/Processor/Input.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Input.h diff --git a/Processor/InputTuple.h b/Processor/InputTuple.h index 67c38a12..740f3211 100644 --- a/Processor/InputTuple.h +++ b/Processor/InputTuple.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * InputTuple.h diff --git a/Processor/Instruction.cpp b/Processor/Instruction.cpp index 8f104fa6..bcf039ad 100644 --- a/Processor/Instruction.cpp +++ b/Processor/Instruction.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Processor/Instruction.h" @@ -6,6 +6,7 @@ #include "Processor/Processor.h" #include "Exceptions/Exceptions.h" #include "Tools/time-func.h" +#include "Tools/parse.h" #include #include @@ -13,28 +14,6 @@ #include -// Read a byte -int get_val(istream& s) -{ - char cc; - s.get(cc); - int a=cc; - if (a<0) { a+=256; } - return a; -} - -// Read a 4-byte integer -int get_int(istream& s) -{ - int n = 0; - for (int i=0; i<4; i++) - { n<<=8; - int t=get_val(s); - n+=t; - } - return n; -} - // Convert modp to signed bigint of a given bit length void to_signed_bigint(bigint& bi, const gfp& x, int len) { @@ -57,18 +36,10 @@ void to_signed_bigint(bigint& bi, const gfp& x, int len) } -void get_vector(int m, vector& start, istream& s) -{ - start.resize(m); - for (int i = 0; i < m; i++) - start[i] = get_int(s); -} - - void Instruction::parse(istream& s) { n=0; start.resize(0); - r[0]=0; r[1]=0; r[2]=0; + r[0]=0; r[1]=0; r[2]=0; r[3]=0; int pos=s.tellg(); opcode=get_int(s); @@ -78,6 +49,13 @@ void Instruction::parse(istream& s) if (size==0) size=1; + parse_operands(s, pos); +} + + +void BaseInstruction::parse_operands(istream& s, int pos) +{ + int num_var_args = 0; switch (opcode) { // instructions with 3 register operands @@ -182,6 +160,7 @@ void Instruction::parse(istream& s) case GRAWOUTPUT: case PRINTCHRINT: case PRINTSTRINT: + case PRINTINT: r[0]=get_int(s); break; // instructions with 3 registers + 1 integer operand @@ -227,6 +206,8 @@ void Instruction::parse(istream& s) case RUN_TAPE: case STARTPRIVATEOUTPUT: case GSTARTPRIVATEOUTPUT: + case DIGESTC: + case CONNECTIPV4: // write socket handle, read IPv4 address, portnum r[0]=get_int(s); r[1]=get_int(s); n = get_int(s); @@ -259,10 +240,7 @@ void Instruction::parse(istream& s) case GSTOPPRIVATEOUTPUT: case INPUTMASK: case GINPUTMASK: - case READSOCKETC: - case READSOCKETS: - case WRITESOCKETC: - case WRITESOCKETS: + case ACCEPTCLIENTCONNECTION: r[0]=get_int(s); n = get_int(s); break; @@ -272,49 +250,84 @@ void Instruction::parse(istream& s) case JMP: case START: case STOP: - case OPENSOCKET: + case LISTEN: n = get_int(s); break; // instructions with no operand case TIME: case CRASH: - case CLOSESOCKET: - break; + break; // instructions with 4 register operands case PRINTFLOATPLAIN: get_vector(4, start, s); break; - // open instructions + // open instructions + read/write instructions with variable length args case STARTOPEN: case STOPOPEN: case GSTARTOPEN: case GSTOPOPEN: - int m; - m = get_int(s); - get_vector(m, start, s); + case WRITEFILESHARE: + num_var_args = get_int(s); + get_vector(num_var_args, start, s); + break; + + // read from file, input is opcode num_args, + // start_file_posn (read), end_file_posn(write) var1, var2, ... + case READFILESHARE: + num_var_args = get_int(s) - 2; + r[0] = get_int(s); + r[1] = get_int(s); + get_vector(num_var_args, start, s); + break; + + // read from external client, input is : opcode num_args, client_id, var1, var2 ... + case READSOCKETC: + case READSOCKETS: + case READSOCKETINT: + case READCLIENTPUBLICKEY: + num_var_args = get_int(s) - 1; + r[0] = get_int(s); + get_vector(num_var_args, start, s); + break; + + // write to external client, input is : opcode num_args, client_id, message_type, var1, var2 ... + case WRITESOCKETC: + case WRITESOCKETS: + case WRITESOCKETSHARE: + case WRITESOCKETINT: + num_var_args = get_int(s) - 2; + r[0] = get_int(s); + r[1] = get_int(s); + get_vector(num_var_args, start, s); + break; + case INITSECURESOCKET: + case RESPSECURESOCKET: + num_var_args = get_int(s) - 1; + r[0] = get_int(s); + get_vector(num_var_args, start, s); break; // raw input case STOPINPUT: case GSTOPINPUT: // subtract player number argument - m = get_int(s) - 1; + num_var_args = get_int(s) - 1; n = get_int(s); - get_vector(m, start, s); + get_vector(num_var_args, start, s); break; case GBITDEC: case GBITCOM: - m = get_int(s) - 2; + num_var_args = get_int(s) - 2; r[0] = get_int(s); n = get_int(s); - get_vector(m, start, s); + get_vector(num_var_args, start, s); break; case PREP: case GPREP: // subtract extra argument - m = get_int(s) - 1; + num_var_args = get_int(s) - 1; s.read((char*)r, sizeof(r)); - start.resize(m); - for (int i = 0; i < m; i++) + start.resize(num_var_args); + for (int i = 0; i < num_var_args; i++) { start[i] = get_int(s); } break; case USE_PREP: @@ -341,8 +354,8 @@ void Instruction::parse(istream& s) break; default: ostringstream os; - os << "Invalid instruction " << hex << showbase << opcode << " at " << pos; - throw Processor_Error(os.str()); + os << "Invalid instruction " << hex << showbase << opcode << " at " << dec << pos; + throw Invalid_Instruction(os.str()); } } @@ -376,7 +389,7 @@ bool Instruction::get_offline_data_usage(DataPositions& usage) } } -RegType Instruction::get_reg_type() const +int BaseInstruction::get_reg_type() const { switch (opcode) { case LDMINT: @@ -386,6 +399,16 @@ RegType Instruction::get_reg_type() const case PUSHINT: case POPINT: case MOVINT: + case READSOCKETINT: + case WRITESOCKETINT: + case READCLIENTPUBLICKEY: + case INITSECURESOCKET: + case RESPSECURESOCKET: + case LDARG: + case LDINT: + case CONVMODP: + case GCONVGF2N: + case RAND: return INT; case PREP: case USE_PREP: @@ -402,7 +425,7 @@ RegType Instruction::get_reg_type() const } } -int Instruction::get_max_reg(RegType reg_type) const +int BaseInstruction::get_max_reg(int reg_type) const { if (get_reg_type() != reg_type) { return 0; } @@ -420,7 +443,7 @@ int Instruction::get_mem(RegType reg_type, SecrecyType sec_type) const return 0; } -bool Instruction::is_direct_memory_access(SecrecyType sec_type) const +bool BaseInstruction::is_direct_memory_access(SecrecyType sec_type) const { if (sec_type == SECRET) { @@ -825,10 +848,21 @@ void Instruction::execute(Processor& Proc) const case LEGENDREC: to_bigint(Proc.temp.aa, Proc.read_Cp(r[1])); Proc.temp.aa = mpz_legendre(Proc.temp.aa.get_mpz_t(), gfp::pr().get_mpz_t()); - //Proc.temp.aa = legendre; to_gfp(Proc.temp.ansp, Proc.temp.aa); Proc.write_Cp(r[0], Proc.temp.ansp); break; + case DIGESTC: + { + octetStream o; + to_bigint(Proc.temp.aa, Proc.read_Cp(r[1])); + + to_gfp(Proc.temp.ansp, Proc.temp.aa); + Proc.temp.ansp.pack(o); + // keep first n bytes + to_gfp(Proc.temp.ansp, o.check_sum(n)); + Proc.write_Cp(r[0], Proc.temp.ansp); + } + break; case DIVCI: if (n == 0) throw Processor_Error("Division by immediate zero"); @@ -1455,6 +1489,12 @@ void Instruction::execute(Processor& Proc) const cout << res << flush; } break; + case PRINTINT: + if (Proc.P.my_num() == 0) + { + cout << Proc.read_Ci(r[0]) << flush; + } + break; case PRINTSTR: if (Proc.P.my_num() == 0) { @@ -1490,16 +1530,13 @@ void Instruction::execute(Processor& Proc) const case GUSE_PREP: break; case TIME: - cout << "Elapsed time: " << Proc.machine.timer[0].elapsed() << endl; + Proc.machine.time(); break; case START: - cout << "Starting timer " << n << " at " << Proc.machine.timer[n].elapsed() - << " after " << Proc.machine.timer[n].idle() << endl; - Proc.machine.timer[n].start(); + Proc.machine.start(n); break; case STOP: - Proc.machine.timer[n].stop(); - cout << "Stopped timer " << n << " at " << Proc.machine.timer[n].elapsed() << endl; + Proc.machine.stop(n); break; case RUN_TAPE: Proc.DataF.skip(Proc.machine.run_tape(r[0], n, r[1], -1)); @@ -1513,39 +1550,81 @@ void Instruction::execute(Processor& Proc) const // *** // TODO: read/write shared GF(2^n) data instructions // *** - case OPENSOCKET: - Proc.open_socket(n); + case LISTEN: + // listen for connections at port number n + Proc.external_clients.start_listening(n); break; - case CLOSESOCKET: - Proc.close_socket(); + case ACCEPTCLIENTCONNECTION: + { + // get client connection at port number n + my_num()) + int client_handle = Proc.external_clients.get_client_connection(n); + if (client_handle == -1) + { + stringstream ss; + ss << "No connection on port " << r[0] << endl; + throw Processor_Error(ss.str()); + } + Proc.write_Ci(r[0], client_handle); break; - case READSOCKETC: // n is *unused atm*, r[0] is register to write to - int dest; - Proc.read_socket(dest); - Proc.write_Ci(r[0], (long)dest); + } + case CONNECTIPV4: + { + // connect to server at port n + my_num() + int ipv4 = Proc.read_Ci(r[1]); + int server_handle = Proc.external_clients.connect_to_server(n, ipv4); + Proc.write_Ci(r[0], server_handle); + break; + } + case READCLIENTPUBLICKEY: + Proc.read_client_public_key(Proc.read_Ci(r[0]), start); + break; + case INITSECURESOCKET: + Proc.init_secure_socket(Proc.read_Ci(r[i]), start); + break; + case RESPSECURESOCKET: + Proc.resp_secure_socket(Proc.read_Ci(r[i]), start); + break; + case READSOCKETINT: + Proc.read_socket_ints(Proc.read_Ci(r[0]), start); + break; + case READSOCKETC: + Proc.read_socket_vector(Proc.read_Ci(r[0]), start); break; case READSOCKETS: - // read share then MAC share - Proc.read_socket(Proc.temp.ansp); - Proc.get_Sp_ref(r[0]).set_share(Proc.temp.ansp); - Proc.read_socket(Proc.temp.ansp); - Proc.get_Sp_ref(r[0]).set_mac(Proc.temp.ansp); + // read shares and MAC shares + Proc.read_socket_private(Proc.read_Ci(r[0]), start, true); break; case GREADSOCKETS: //Proc.get_S2_ref(r[0]).get_share().pack(socket_octetstream); //Proc.get_S2_ref(r[0]).get_mac().pack(socket_octetstream); break; - case WRITESOCKETC: // n is *unused atm*, r[0] is register to write to; - Proc.write_socket((int&)Proc.get_Ci_ref(r[0])); + case WRITESOCKETINT: + Proc.write_socket(INT, CLEAR, false, Proc.read_Ci(r[0]), r[1], start); + break; + case WRITESOCKETC: + Proc.write_socket(MODP, CLEAR, false, Proc.read_Ci(r[0]), r[1], start); break; case WRITESOCKETS: - Proc.write_socket(Proc.get_Sp_ref(r[0]).get_share()); - Proc.write_socket(Proc.get_Sp_ref(r[0]).get_mac()); + // Send shares + MACs + Proc.write_socket(MODP, SECRET, true, Proc.read_Ci(r[0]), r[1], start); + break; + case WRITESOCKETSHARE: + // Send only shares, no MACs + // N.B. doesn't make sense to have a corresponding read instruction for this + Proc.write_socket(MODP, SECRET, false, Proc.read_Ci(r[0]), r[1], start); break; /*case GWRITESOCKETS: Proc.get_S2_ref(r[0]).get_share().pack(socket_octetstream); Proc.get_S2_ref(r[0]).get_mac().pack(socket_octetstream); break;*/ + case WRITEFILESHARE: + // Write shares to file system + Proc.write_shares_to_file(start); + break; + case READFILESHARE: + // Read shares from file system + Proc.read_shares_from_file(Proc.read_Ci(r[0]), r[1], start); + break; case PUBINPUT: Proc.public_input >> Proc.get_Ci_ref(r[0]); break; diff --git a/Processor/Instruction.h b/Processor/Instruction.h index 9975021f..7ab07734 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _Instruction #define _Instruction @@ -11,7 +11,6 @@ #include using namespace std; -#include "Processor/Memory.h" #include "Processor/Data_Files.h" #include "Networking/Player.h" #include "Math/Integer.h" @@ -23,7 +22,7 @@ class Processor; /* * Opcode constants * - * Whenever these are changed the corresponding dict in Compiler/instructions.py + * Whenever these are changed the corresponding dict in Compiler/instructions_base.py * MUST also be changed. (+ the documentation) */ enum @@ -89,6 +88,7 @@ enum MODC = 0x36, MODCI = 0x37, LEGENDREC = 0x38, + DIGESTC = 0x39, // Open STARTOPEN = 0xA0, STOPOPEN = 0xA1, @@ -107,8 +107,13 @@ enum READSOCKETS = 0x64, WRITESOCKETC = 0x65, WRITESOCKETS = 0x66, - OPENSOCKET = 0x67, - CLOSESOCKET = 0x68, + READSOCKETINT = 0x69, + WRITESOCKETINT = 0x6a, + WRITESOCKETSHARE = 0x6b, + LISTEN = 0x6c, + ACCEPTCLIENTCONNECTION = 0x6d, + CONNECTIPV4 = 0x6e, + READCLIENTPUBLICKEY = 0x6f, // Bitwise logic ANDC = 0x70, XORC = 0x71, @@ -138,6 +143,7 @@ enum SUBINT = 0x9C, MULINT = 0x9D, DIVINT = 0x9E, + PRINTINT = 0x9F, // Conversion CONVINT = 0xC0, CONVMODP = 0xC1, @@ -156,6 +162,8 @@ enum PRINTCHRINT = 0xBA, PRINTSTRINT = 0xBB, PRINTFLOATPLAIN = 0xBC, + WRITEFILESHARE = 0xBD, + READFILESHARE = 0xBE, // GF(2^n) versions @@ -241,6 +249,9 @@ enum GRAWOUTPUT = 0x1B7, GSTARTPRIVATEOUTPUT = 0x1B8, GSTOPPRIVATEOUTPUT = 0x1B9, + // Commsec ops + INITSECURESOCKET = 0x1BA, + RESPSECURESOCKET = 0x1BB }; @@ -259,7 +270,6 @@ enum SecrecyType { MAX_SECRECY_TYPE }; - struct TempVars { gf2n ans2; Share Sans2; gfp ansp; Share Sansp; @@ -273,29 +283,38 @@ struct TempVars { }; -class Instruction +class BaseInstruction { +protected: int opcode; // The code int size; // Vector size - int r[3]; // Three possible registers + int r[4]; // Fixed parameter registers unsigned int n; // Possible immediate value vector start; // Values for a start/stop open - public: +public: + virtual ~BaseInstruction() {}; - // Reads a single instruction from the istream - void parse(istream& s); - - // Return whether usage is known - bool get_offline_data_usage(DataPositions& usage); + void parse_operands(istream& s, int pos); bool is_gf2n_instruction() const { return ((opcode&0x100)!=0); } - RegType get_reg_type() const; + virtual int get_reg_type() const; bool is_direct_memory_access(SecrecyType sec_type) const; // Returns the maximal register used - int get_max_reg(RegType reg_type) const; + int get_max_reg(int reg_type) const; +}; + + +class Instruction : public BaseInstruction +{ +public: + // Reads a single instruction from the istream + void parse(istream& s); + + // Return whether usage is known + bool get_offline_data_usage(DataPositions& usage); // Returns the memory size used if applicable and known int get_mem(RegType reg_type, SecrecyType sec_type) const; diff --git a/Processor/Machine.cpp b/Processor/Machine.cpp index 061a4f41..549e06ee 100644 --- a/Processor/Machine.cpp +++ b/Processor/Machine.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Machine.h" @@ -17,12 +17,14 @@ using namespace std; Machine::Machine(int my_number, int PortnumBase, string hostname, string progname_str, string memtype, int lgp, int lg2, bool direct, - int opening_sum, bool parallel, bool receive_threads, int max_broadcast) + int opening_sum, bool parallel, bool receive_threads, int max_broadcast, + CommsecKeysPackage *commsec_keys) : my_number(my_number), nthreads(0), tn(0), numt(0), usage_unknown(false), progname(progname_str), direct(direct), opening_sum(opening_sum), parallel(parallel), receive_threads(receive_threads), max_broadcast(max_broadcast) { N.init(my_number,PortnumBase,hostname.c_str()); + N.set_keys(commsec_keys); if (opening_sum < 2) this->opening_sum = N.num_players(); @@ -106,36 +108,9 @@ Machine::Machine(int my_number, int PortnumBase, string hostname, if (pinp.fail()) { throw file_error(filename); } progs[i].parse(pinp); pinp.close(); - if (progs[i].direct_mem2_s() > M2.size_s()) - { - cerr << threadname << " needs more secret mod2 memory, resizing to " - << progs[i].direct_mem2_s() << endl; - M2.resize_s(progs[i].direct_mem2_s()); - } - if (progs[i].direct_memp_s() > Mp.size_s()) - { - cerr << threadname << " needs more secret modp memory, resizing to " - << progs[i].direct_memp_s() << endl; - Mp.resize_s(progs[i].direct_memp_s()); - } - if (progs[i].direct_mem2_c() > M2.size_c()) - { - cerr << threadname << " needs more clear mod2 memory, resizing to " - << progs[i].direct_mem2_c() << endl; - M2.resize_c(progs[i].direct_mem2_c()); - } - if (progs[i].direct_memp_c() > Mp.size_c()) - { - cerr << threadname << " needs more clear modp memory, resizing to " - << progs[i].direct_memp_c() << endl; - Mp.resize_c(progs[i].direct_memp_c()); - } - if (progs[i].direct_memi_c() > Mi.size_c()) - { - cerr << threadname << " needs more clear integer memory, resizing to " - << progs[i].direct_memi_c() << endl; - Mi.resize_c(progs[i].direct_memi_c()); - } + M2.minimum_size(GF2N, progs[i], threadname); + Mp.minimum_size(MODP, progs[i], threadname); + Mi.minimum_size(INT, progs[i], threadname); } progs[0].print_offline_cost(); @@ -179,6 +154,10 @@ Machine::Machine(int my_number, int PortnumBase, string hostname, DataPositions Machine::run_tape(int thread_number, int tape_number, int arg, int line_number) { + if (thread_number >= (int)tinfo.size()) + throw Processor_Error("invalid thread number: " + to_string(thread_number) + "/" + to_string(tinfo.size())); + if (tape_number >= (int)progs.size()) + throw Processor_Error("invalid tape number: " + to_string(tape_number) + "/" + to_string(progs.size())); pthread_mutex_lock(&t_mutex[thread_number]); tinfo[thread_number].prognum=tape_number; tinfo[thread_number].arg=arg; @@ -303,10 +282,7 @@ void Machine::run() cerr << "Join timer: " << i << " " << join_timer[i].elapsed() << endl; cerr << "Finish timer: " << finish_timer.elapsed() << endl; cerr << "Process timer: " << proc_timer.elapsed() << endl; - cerr << "Time = " << timer[0].elapsed() << " seconds " << endl; - timer.erase(0); - for (map::iterator it = timer.begin(); it != timer.end(); it++) - cerr << "Time" << it->first << " = " << it->second.elapsed() << " seconds " << endl; + print_timers(); if (opening_sum < N.num_players() && !direct) cerr << "Summed at most " << opening_sum << " shares at once with indirect communication" << endl; @@ -359,4 +335,28 @@ void Machine::run() cerr << "End of prog" << endl; } +void BaseMachine::time() +{ + cout << "Elapsed time: " << timer[0].elapsed() << endl; +} +void BaseMachine::start(int n) +{ + cout << "Starting timer " << n << " at " << timer[n].elapsed() + << " after " << timer[n].idle() << endl; + timer[n].start(); +} + +void BaseMachine::stop(int n) +{ + timer[n].stop(); + cout << "Stopped timer " << n << " at " << timer[n].elapsed() << endl; +} + +void BaseMachine::print_timers() +{ + cerr << "Time = " << timer[0].elapsed() << " seconds " << endl; + timer.erase(0); + for (map::iterator it = timer.begin(); it != timer.end(); it++) + cerr << "Time" << it->first << " = " << it->second.elapsed() << " seconds " << endl; +} diff --git a/Processor/Machine.h b/Processor/Machine.h index 703d000b..2fabbe47 100644 --- a/Processor/Machine.h +++ b/Processor/Machine.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Machine.h @@ -21,7 +21,19 @@ #include using namespace std; -class Machine +class BaseMachine +{ +protected: + std::map timer; + void print_timers(); + +public: + void time(); + void start(int n); + void stop(int n); +}; + +class Machine : public BaseMachine { /* The mutex's lock the C-threads and then only release * then we an MPC thread is ready to run on the C-thread. @@ -58,7 +70,6 @@ class Machine Memory Mp; Memory Mi; - std::map timer; vector join_timer; Timer finish_timer; @@ -73,7 +84,7 @@ class Machine Machine(int my_number, int PortnumBase, string hostname, string progname, string memtype, int lgp, int lg2, bool direct, int opening_sum, bool parallel, - bool receive_threads, int max_broadcast); + bool receive_threads, int max_broadcast, CommsecKeysPackage *keys); DataPositions run_tape(int thread_number, int tape_number, int arg, int line_number); void join_tape(int thread_number); diff --git a/Processor/Memory.cpp b/Processor/Memory.cpp index 5ce3c564..a1d1acb8 100644 --- a/Processor/Memory.cpp +++ b/Processor/Memory.cpp @@ -1,12 +1,31 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Processor/Memory.h" +#include "Processor/Instruction.h" #include "Math/gf2n.h" #include "Math/gfp.h" #include "Math/Integer.h" #include +template +void Memory::minimum_size(RegType reg_type, const Program& program, string threadname) +{ + const int* sizes = program.direct_mem(reg_type); + if (sizes[SECRET] > size_s()) + { + cerr << threadname << " needs more secret " << T::type_string() << " memory, resizing to " + << sizes[SECRET] << endl; + resize_s(sizes[SECRET]); + } + if (sizes[CLEAR] > size_c()) + { + cerr << threadname << " needs more clear " << T::type_string() << " memory, resizing to " + << sizes[CLEAR] << endl; + resize_c(sizes[CLEAR]); + } +} + #ifdef MEMPROTECT template void Memory::protect_s(unsigned int start, unsigned int end) diff --git a/Processor/Memory.h b/Processor/Memory.h index 21c1b1f8..3ad57345 100644 --- a/Processor/Memory.h +++ b/Processor/Memory.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _Memory #define _Memory @@ -14,6 +14,7 @@ template class Memory; template ostream& operator<<(ostream& s,const Memory& M); template istream& operator>>(istream& s,Memory& M); +#include "Processor/Program.h" #include "Math/Share.h" template class Memory @@ -72,6 +73,8 @@ class Memory { (void)start, (void)end; cerr << "Memory protection not activated" << endl; } #endif + void minimum_size(RegType reg_type, const Program& program, string threadname); + friend ostream& operator<< <>(ostream& s,const Memory& M); friend istream& operator>> <>(istream& s,Memory& M); diff --git a/Processor/Online-Thread.cpp b/Processor/Online-Thread.cpp index f3cfa51c..e969c88a 100644 --- a/Processor/Online-Thread.cpp +++ b/Processor/Online-Thread.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Processor/Program.h" @@ -48,14 +48,14 @@ void* Main_Func(void* ptr) if (machine.direct) { cerr << "Using direct communication. If computation stalls, use -m when compiling." << endl; - MC2 = new Direct_MAC_Check(*(tinfo->alpha2i), *(tinfo->Nms), num); - MCp = new Direct_MAC_Check(*(tinfo->alphapi), *(tinfo->Nms), num); + MC2 = new Direct_MAC_Check(*(tinfo->alpha2i),*(tinfo->Nms), num); + MCp = new Direct_MAC_Check(*(tinfo->alphapi),*(tinfo->Nms), num); } else if (machine.parallel) { cerr << "Using indirect communication with background threads." << endl; - MC2 = new Parallel_MAC_Check(*(tinfo->alpha2i), *(tinfo->Nms), num, machine.opening_sum); - MCp = new Parallel_MAC_Check(*(tinfo->alphapi), *(tinfo->Nms), num, machine.opening_sum); + MC2 = new Parallel_MAC_Check(*(tinfo->alpha2i),*(tinfo->Nms), num, machine.opening_sum); + MCp = new Parallel_MAC_Check(*(tinfo->alphapi),*(tinfo->Nms), num, machine.opening_sum); } else { @@ -64,16 +64,14 @@ void* Main_Func(void* ptr) MCp = new MAC_Check(*(tinfo->alphapi), machine.opening_sum); } - Processor Proc(tinfo->thread_num,DataF,P,*MC2,*MCp,machine); + // Allocate memory for first program before starting the clock + Processor Proc(tinfo->thread_num,DataF,P,*MC2,*MCp,machine,progs[0]); Share a,b,c; bool flag=true; int program=-3; // int exec=0; - // Allocate memory for first program before starting the clock - Proc.reset(progs[0].num_regs2(),progs[0].num_regsp(),progs[0].num_regi(),tinfo->arg); - // synchronize cerr << "Locking for sync of thread " << num << endl; pthread_mutex_lock(&t_mutex[num]); @@ -103,7 +101,7 @@ void* Main_Func(void* ptr) else { // RUN PROGRAM //printf("\tClient %d about to run %d in execution %d\n",num,program,exec); - Proc.reset(progs[program].num_regs2(),progs[program].num_regsp(),progs[program].num_regi(),tinfo->arg); + Proc.reset(progs[program],tinfo->arg); // Bits, Triples, Squares, and Inverses skipping DataF.seekg(tinfo->pos); diff --git a/Processor/Online-Thread.h b/Processor/Online-Thread.h index 485a073e..7714258c 100644 --- a/Processor/Online-Thread.h +++ b/Processor/Online-Thread.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _Online_Thread #define _Online_Thread diff --git a/Processor/PrivateOutput.cpp b/Processor/PrivateOutput.cpp index 909871d4..ec9c9c25 100644 --- a/Processor/PrivateOutput.cpp +++ b/Processor/PrivateOutput.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * PrivateOutput.cpp diff --git a/Processor/PrivateOutput.h b/Processor/PrivateOutput.h index 52c7522f..2957727f 100644 --- a/Processor/PrivateOutput.h +++ b/Processor/PrivateOutput.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * PrivateOutput.h diff --git a/Processor/Processor.cpp b/Processor/Processor.cpp index a62c78e2..4245cb58 100644 --- a/Processor/Processor.cpp +++ b/Processor/Processor.cpp @@ -1,19 +1,23 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Processor/Processor.h" +#include "Networking/STS.h" #include "Auth/MAC_Check.h" #include "Auth/fake-stuff.h" +#include +#include Processor::Processor(int thread_num,Data_Files& DataF,Player& P, MAC_Check& MC2,MAC_Check& MCp,Machine& machine, - int num_regs2,int num_regsp,int num_regi) -: thread_num(thread_num),socket_is_open(false),DataF(DataF),P(P),MC2(MC2),MCp(MCp),machine(machine), - input2(*this,MC2),inputp(*this,MCp),privateOutput2(*this),privateOutputp(*this),sent(0),rounds(0) + const Program& program) +: thread_num(thread_num),DataF(DataF),P(P),MC2(MC2),MCp(MCp),machine(machine), + input2(*this,MC2),inputp(*this,MCp),privateOutput2(*this),privateOutputp(*this),sent(0),rounds(0), + external_clients(ExternalClients(P.my_num(), DataF.prep_data_dir)),binary_file_io(Binary_File_IO()) { - reset(num_regs2,num_regsp,num_regi,0); + reset(program,0); public_input.open(get_filename("Programs/Public-Input/",false).c_str()); private_input.open(get_filename("Player-Data/Private-Input-",true).c_str()); @@ -27,7 +31,6 @@ Processor::~Processor() cerr << "Sent " << sent << " elements in " << rounds << " rounds" << endl; } - string Processor::get_filename(const char* prefix, bool use_number) { stringstream filename; @@ -43,16 +46,15 @@ string Processor::get_filename(const char* prefix, bool use_number) } -void Processor::reset(int num_regs2,int num_regsp,int num_regi,int arg) +void Processor::reset(const Program& program,int arg) { - reg_max2 = num_regs2; - reg_maxp = num_regsp; - reg_maxi = num_regi; + reg_max2 = program.num_reg(GF2N); + reg_maxp = program.num_reg(MODP); + reg_maxi = program.num_reg(INT); C2.resize(reg_max2); Cp.resize(reg_maxp); S2.resize(reg_max2); Sp.resize(reg_maxp); Ci.resize(reg_maxi); this->arg = arg; - close_socket(); #ifdef DEBUG rw2.resize(2*reg_max2); @@ -65,59 +67,323 @@ void Processor::reset(int num_regs2,int num_regsp,int num_regi,int arg) } #include "Networking/sockets.h" +#include "Math/Setup.h" -// Set up a server socket for some client -void Processor::open_socket(int portnum_base) +// Write socket (typically SPDZ engine -> external client), for different register types. +// RegType and SecrecyType determines how registers are read and the socket stream is packed. +// If message_type is > 0, send message_type in bytes 0 - 3, to allow an external client to +// determine the data structure being sent in a message. +// Encryption is enabled if key material (for DH Auth Encryption and/or STS protocol) has been already setup. +void Processor::write_socket(const RegType reg_type, const SecrecyType secrecy_type, const bool send_macs, + int socket_id, int message_type, const vector& registers) { - if (!socket_is_open) + if (socket_id >= (int)external_clients.external_client_sockets.size()) { - socket_is_open = true; - sockaddr_in dest; - set_up_server_socket(dest, final_socket_fd, socket_fd, portnum_base + P.my_num()); + cerr << "No socket connection exists for client id " << socket_id << endl; + return; + } + int m = registers.size(); + socket_stream.reset_write_head(); + + //First 4 bytes is message_type (unless indicate not needed) + if (message_type != 0) { + socket_stream.store(message_type); + } + + for (int i = 0; i < m; i++) + { + if (reg_type == MODP && secrecy_type == SECRET) { + // Send vector of secret shares and optionally macs + get_S_ref(registers[i]).get_share().pack(socket_stream); + if (send_macs) + get_S_ref(registers[i]).get_mac().pack(socket_stream); + } + else if (reg_type == MODP && secrecy_type == CLEAR) { + // Send vector of clear public field elements + get_C_ref(registers[i]).pack(socket_stream); + } + else if (reg_type == INT && secrecy_type == CLEAR) { + // Send vector of 32-bit clear ints + socket_stream.store((int&)get_Ci_ref(registers[i])); + } + else { + stringstream ss; + ss << "Write socket instruction with unknown reg type " << reg_type << + " and secrecy type " << secrecy_type << "." << endl; + throw Processor_Error(ss.str()); + } + } + + // Apply DH Auth encryption if session keys have been created. + map::iterator it = external_clients.symmetric_client_keys.find(socket_id); + if (it != external_clients.symmetric_client_keys.end()) { + socket_stream.encrypt(it->second); + } + + // Apply STS commsec encryption if session keys have been created. + try { + maybe_encrypt_sequence(socket_id); + socket_stream.Send(external_clients.external_client_sockets[socket_id]); + } + catch (bad_value& e) { + cerr << "Send error thrown when writing " << m << " values of type " << reg_type << " to socket id " + << socket_id << "." << endl; } } -void Processor::close_socket() + +// Receive vector of 32-bit clear ints +void Processor::read_socket_ints(int client_id, const vector& registers) { - if (socket_is_open) + if (client_id >= (int)external_clients.external_client_sockets.size()) { - socket_is_open = false; - close_server_socket(final_socket_fd, socket_fd); + cerr << "No socket connection exists for client id " << client_id << endl; + return; + } + + int m = registers.size(); + socket_stream.reset_write_head(); + socket_stream.Receive(external_clients.external_client_sockets[client_id]); + maybe_decrypt_sequence(client_id); + for (int i = 0; i < m; i++) + { + int val; + socket_stream.get(val); + write_Ci(registers[i], (long)val); } } -// Receive 32-bit int -void Processor::read_socket(int& x) -{ - octet bytes[4]; - receive(final_socket_fd, bytes, 4); - x = BYTES_TO_INT(bytes); -} - -// Send 32-bit int -void Processor::write_socket(int x) -{ - octet bytes[4]; - INT_TO_BYTES(bytes, x); - send(final_socket_fd, bytes, 4); -} - -// Receive field element +// Receive vector of public field elements template -void Processor::read_socket(T& x) +void Processor::read_socket_vector(int client_id, const vector& registers) { + if (client_id >= (int)external_clients.external_client_sockets.size()) + { + cerr << "No socket connection exists for client id " << client_id << endl; + return; + } + + int m = registers.size(); socket_stream.reset_write_head(); - socket_stream.Receive(final_socket_fd); - x.unpack(socket_stream); + socket_stream.Receive(external_clients.external_client_sockets[client_id]); + maybe_decrypt_sequence(client_id); + for (int i = 0; i < m; i++) + { + get_C_ref(registers[i]).unpack(socket_stream); + } } -// Send field element +// Receive vector of field element shares over private channel template -void Processor::write_socket(const T& x) +void Processor::read_socket_private(int client_id, const vector& registers, bool read_macs) { + if (client_id >= (int)external_clients.external_client_sockets.size()) + { + cerr << "No socket connection exists for client id " << client_id << endl; + return; + } + int m = registers.size(); socket_stream.reset_write_head(); - x.pack(socket_stream); - socket_stream.Send(final_socket_fd); + socket_stream.Receive(external_clients.external_client_sockets[client_id]); + maybe_decrypt_sequence(client_id); + + map::iterator it = external_clients.symmetric_client_keys.find(client_id); + if (it != external_clients.symmetric_client_keys.end()) + { + socket_stream.decrypt(it->second); + } + for (int i = 0; i < m; i++) + { + temp.ansp.unpack(socket_stream); + get_Sp_ref(registers[i]).set_share(temp.ansp); + if (read_macs) + { + temp.ansp.unpack(socket_stream); + get_Sp_ref(registers[i]).set_mac(temp.ansp); + } + } +} + +// Read socket for client public key as 8 ints, calculate session key for client. +void Processor::read_client_public_key(int client_id, const vector& registers) { + + read_socket_ints(client_id, registers); + + // After read into registers, need to extract values + vector client_public_key (registers.size(), 0); + for(unsigned int i = 0; i < registers.size(); i++) { + client_public_key[i] = (int&)get_Ci_ref(registers[i]); + } + + external_clients.generate_session_key_for_client(client_id, client_public_key); +} + +void Processor::init_secure_socket_internal(int client_id, const vector& registers) { + external_clients.symmetric_client_commsec_send_keys.erase(client_id); + external_clients.symmetric_client_commsec_recv_keys.erase(client_id); + unsigned char client_public_bytes[crypto_sign_PUBLICKEYBYTES]; + sts_msg1_t m1; + sts_msg2_t m2; + sts_msg3_t m3; + + external_clients.load_server_keys_once(); + external_clients.require_ed25519_keys(); + + // Validate inputs and state + if(registers.size() != 8) { + throw "Invalid call to init_secure_socket."; + } + if (client_id >= (int)external_clients.external_client_sockets.size()) + { + cerr << "No socket connection exists for client id " << client_id << endl; + throw "No socket connection exists for client"; + } + + // Extract client long term public key into bytes + vector client_public_key (registers.size(), 0); + for(unsigned int i = 0; i < registers.size(); i++) { + client_public_key[i] = (int&)get_Ci_ref(registers[i]); + } + external_clients.curve25519_ints_to_bytes(client_public_bytes, client_public_key); + + // Start Station to Station Protocol + STS ke(client_public_bytes, external_clients.server_publickey_ed25519, external_clients.server_secretkey_ed25519); + m1 = ke.send_msg1(); + socket_stream.reset_write_head(); + socket_stream.append(m1.bytes, sizeof m1.bytes); + socket_stream.Send(external_clients.external_client_sockets[client_id]); + socket_stream.ReceiveExpected(external_clients.external_client_sockets[client_id], + 96); + socket_stream.consume(m2.pubkey, sizeof m2.pubkey); + socket_stream.consume(m2.sig, sizeof m2.sig); + m3 = ke.recv_msg2(m2); + socket_stream.reset_write_head(); + socket_stream.append(m3.bytes, sizeof m3.bytes); + socket_stream.Send(external_clients.external_client_sockets[client_id]); + + // Use results of STS to generate send and receive keys. + vector sendKey = ke.derive_secret(crypto_secretbox_KEYBYTES); + vector recvKey = ke.derive_secret(crypto_secretbox_KEYBYTES); + external_clients.symmetric_client_commsec_send_keys[client_id] = make_pair(sendKey,0); + external_clients.symmetric_client_commsec_recv_keys[client_id] = make_pair(recvKey,0); +} + +void Processor::init_secure_socket(int client_id, const vector& registers) { + + try { + init_secure_socket_internal(client_id, registers); + } catch (char const *e) { + cerr << "STS initiator role failed with: " << e << endl; + throw Processor_Error("STS initiator failed"); + } +} + +void Processor::resp_secure_socket(int client_id, const vector& registers) { + try { + resp_secure_socket_internal(client_id, registers); + } catch (char const *e) { + cerr << "STS responder role failed with: " << e << endl; + throw Processor_Error("STS responder failed"); + } +} + +void Processor::resp_secure_socket_internal(int client_id, const vector& registers) { + external_clients.symmetric_client_commsec_send_keys.erase(client_id); + external_clients.symmetric_client_commsec_recv_keys.erase(client_id); + unsigned char client_public_bytes[crypto_sign_PUBLICKEYBYTES]; + sts_msg1_t m1; + sts_msg2_t m2; + sts_msg3_t m3; + + external_clients.load_server_keys_once(); + external_clients.require_ed25519_keys(); + + // Validate inputs and state + if(registers.size() != 8) { + throw "Invalid call to init_secure_socket."; + } + if (client_id >= (int)external_clients.external_client_sockets.size()) + { + cerr << "No socket connection exists for client id " << client_id << endl; + throw "No socket connection exists for client"; + } + vector client_public_key (registers.size(), 0); + for(unsigned int i = 0; i < registers.size(); i++) { + client_public_key[i] = (int&)get_Ci_ref(registers[i]); + } + external_clients.curve25519_ints_to_bytes(client_public_bytes, client_public_key); + + // Start Station to Station Protocol for the responder + STS ke(client_public_bytes, external_clients.server_publickey_ed25519, external_clients.server_secretkey_ed25519); + socket_stream.reset_read_head(); + socket_stream.ReceiveExpected(external_clients.external_client_sockets[client_id], + 32); + socket_stream.consume(m1.bytes, sizeof m1.bytes); + m2 = ke.recv_msg1(m1); + socket_stream.reset_write_head(); + socket_stream.append(m2.pubkey, sizeof m2.pubkey); + socket_stream.append(m2.sig, sizeof m2.sig); + socket_stream.Send(external_clients.external_client_sockets[client_id]); + + socket_stream.ReceiveExpected(external_clients.external_client_sockets[client_id], + 64); + socket_stream.consume(m3.bytes, sizeof m3.bytes); + ke.recv_msg3(m3); + + // Use results of STS to generate send and receive keys. + vector recvKey = ke.derive_secret(crypto_secretbox_KEYBYTES); + vector sendKey = ke.derive_secret(crypto_secretbox_KEYBYTES); + external_clients.symmetric_client_commsec_recv_keys[client_id] = make_pair(recvKey,0); + external_clients.symmetric_client_commsec_send_keys[client_id] = make_pair(sendKey,0); +} + +// Read share data from a file starting at file_pos until registers filled. +// file_pos_register is written with new file position (-1 is eof). +// Tolerent to no file if no shares yet persisted. +template +void Processor::read_shares_from_file(int start_file_posn, int end_file_pos_register, const vector& data_registers) { + string filename; + filename = "Persistence/Transactions-P" + to_string(P.my_num()) + ".data"; + + unsigned int size = data_registers.size(); + + vector< Share > outbuf(size); + + int end_file_posn = start_file_posn; + + try { + binary_file_io.read_from_file(filename, outbuf, start_file_posn, end_file_posn); + + for (unsigned int i = 0; i < size; i++) + { + get_Sp_ref(data_registers[i]).set_share(outbuf[i].get_share()); + get_Sp_ref(data_registers[i]).set_mac(outbuf[i].get_mac()); + } + + write_Ci(end_file_pos_register, (long)end_file_posn); + } + catch (file_missing& e) { + cerr << "Got file missing error, will return -2. " << e.what() << endl; + write_Ci(end_file_pos_register, (long)-2); + } +} + +// Append share data in data_registers to end of file. Expects Persistence directory to exist. +template +void Processor::write_shares_to_file(const vector& data_registers) { + string filename; + filename = "Persistence/Transactions-P" + to_string(P.my_num()) + ".data"; + + unsigned int size = data_registers.size(); + + vector< Share > inpbuf (size); + + for (unsigned int i = 0; i < size; i++) + { + inpbuf[i] = get_S_ref(data_registers[i]); + } + + binary_file_io.write_to_file(filename, inpbuf); } template @@ -180,12 +446,6 @@ void Processor::POpen_Stop(const vector& reg,const Player& P,MAC_Check& rounds++; } - - - - - - ostream& operator<<(ostream& s,const Processor& P) { s << "Processor State" << endl; @@ -196,7 +456,7 @@ ostream& operator<<(ostream& s,const Processor& P) P.read_C2(i).output(s,true); s << "\t"; P.read_S2(i).output(s,true); - s << endl; + s << endl; } s << "Char p Registers" << endl; s << "Val\tClearReg\tSharedReg" << endl; @@ -205,18 +465,37 @@ ostream& operator<<(ostream& s,const Processor& P) P.read_Cp(i).output(s,true); s << "\t"; P.read_Sp(i).output(s,true); - s << endl; + s << endl; } return s; } +void Processor::maybe_decrypt_sequence(int client_id) +{ + map,uint64_t> >::iterator it_cs = external_clients.symmetric_client_commsec_recv_keys.find(client_id); + if (it_cs != external_clients.symmetric_client_commsec_recv_keys.end()) + { + socket_stream.decrypt_sequence(&it_cs->second.first[0], it_cs->second.second); + it_cs->second.second++; + } +} + +void Processor::maybe_encrypt_sequence(int client_id) +{ + map,uint64_t> >::iterator it_cs = external_clients.symmetric_client_commsec_send_keys.find(client_id); + if (it_cs != external_clients.symmetric_client_commsec_send_keys.end()) + { + socket_stream.encrypt_sequence(&it_cs->second.first[0], it_cs->second.second); + it_cs->second.second++; + } +} template void Processor::POpen_Start(const vector& reg,const Player& P,MAC_Check& MC,int size); template void Processor::POpen_Start(const vector& reg,const Player& P,MAC_Check& MC,int size); template void Processor::POpen_Stop(const vector& reg,const Player& P,MAC_Check& MC,int size); template void Processor::POpen_Stop(const vector& reg,const Player& P,MAC_Check& MC,int size); -template void Processor::read_socket(gfp& x); -template void Processor::read_socket(gf2n& x); -template void Processor::write_socket(const gfp& x); -template void Processor::write_socket(const gf2n& x); +template void Processor::read_socket_private(int client_id, const vector& registers, bool send_macs); +template void Processor::read_socket_vector(int client_id, const vector& registers); +template void Processor::read_shares_from_file(int start_file_pos, int end_file_pos_register, const vector& data_registers); +template void Processor::write_shares_to_file(const vector& data_registers); diff --git a/Processor/Processor.h b/Processor/Processor.h index a4a0d37e..b1482a37 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _Processor @@ -19,19 +19,43 @@ #include "Input.h" #include "PrivateOutput.h" #include "Machine.h" +#include "ExternalClients.h" +#include "Binary_File_IO.h" +#include "Instruction.h" #include -class Processor +class ProcessorBase +{ + // Stack + stack stacki; + +protected: + // Optional argument to tape + int arg; + +public: + void pushi(long x) { stacki.push(x); } + void popi(long& x) { x = stacki.top(); stacki.pop(); } + + int get_arg() const + { + return arg; + } + + void set_arg(int new_arg) + { + arg=new_arg; + } +}; + +class Processor : public ProcessorBase { vector C2; vector Cp; vector > S2; vector > Sp; vector Ci; - - // Stack - stack stacki; // This is the vector of partially opened values and shares we need to store // as the Open commands are split in two @@ -43,13 +67,8 @@ class Processor int reg_max2,reg_maxp,reg_maxi; int thread_num; - // Optional argument to tape - int arg; - - // For reading/reading data from a socket (i.e. external party to SPDZ) + // Data structure used for reading/writing data to/from a socket (i.e. an external party to SPDZ) octetStream socket_stream; - int socket_fd, final_socket_fd; - bool socket_is_open; #ifdef DEBUG vector rw2; @@ -91,14 +110,17 @@ class Processor int sent, rounds; + ExternalClients external_clients; + Binary_File_IO binary_file_io; + static const int reg_bytes = 4; - void reset(int num_regs2,int num_regsp,int num_regi,int arg); // Reset the state of the processor + void reset(const Program& program,int arg); // Reset the state of the processor string get_filename(const char* basename, bool use_number); Processor(int thread_num,Data_Files& DataF,Player& P, MAC_Check& MC2,MAC_Check& MCp,Machine& machine, - int num_regs2 = 256,int num_regsp = 256,int num_regi = 256); + const Program& program); ~Processor(); int get_thread_num() @@ -106,19 +128,6 @@ class Processor return thread_num; } - int get_arg() const - { - return arg; - } - - void set_arg(int new_arg) - { - arg=new_arg; - } - - void pushi(long x) { stacki.push(x); } - void popi(long& x) { x = stacki.top(); stacki.pop(); } - #ifdef DEBUG const gf2n& read_C2(int i) const { if (rw2[i]==0) @@ -226,16 +235,29 @@ class Processor template Share& get_S_ref(int i); template T& get_C_ref(int i); - // Access to sockets for reading clear/shared data - void open_socket(int portnum_base); - void close_socket(); - void read_socket(int& x); - void write_socket(int x); - template - void read_socket(T& x); - template - void write_socket(const T& x); + // Access to external client sockets for reading clear/shared data + void read_socket_ints(int client_id, const vector& registers); + // Setup client public key + void read_client_public_key(int client_id, const vector& registers); + void init_secure_socket(int client_id, const vector& registers); + void init_secure_socket_internal(int client_id, const vector& registers); + void resp_secure_socket(int client_id, const vector& registers); + void resp_secure_socket_internal(int client_id, const vector& registers); + + void write_socket(const RegType reg_type, const SecrecyType secrecy_type, const bool send_macs, + int socket_id, int message_type, const vector& registers); + template + void read_socket_vector(int client_id, const vector& registers); + template + void read_socket_private(int client_id, const vector& registers, bool send_macs); + + // Read and write secret numeric data to file (name hardcoded at present) + template + void read_shares_from_file(int start_file_pos, int end_file_pos_register, const vector& data_registers); + template + void write_shares_to_file(const vector& data_registers); + // Access to PO (via calls to POpen start/stop) template void POpen_Start(const vector& reg,const Player& P,MAC_Check& MC,int size); @@ -245,6 +267,10 @@ class Processor // Print the processor state friend ostream& operator<<(ostream& s,const Processor& P); + + private: + void maybe_decrypt_sequence(int client_id); + void maybe_encrypt_sequence(int client_id); }; template<> inline Share& Processor::get_S_ref(int i) { return get_S2_ref(i); } diff --git a/Processor/Program.cpp b/Processor/Program.cpp index 9b132305..1d3a47f9 100644 --- a/Processor/Program.cpp +++ b/Processor/Program.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Processor/Program.h" @@ -7,23 +7,24 @@ void Program::compute_constants() { - max_reg2 = 0; - max_regp = 0; - max_regi = 0; for (int reg_type = 0; reg_type < MAX_REG_TYPE; reg_type++) - for (int sec_type = 0; sec_type < MAX_SECRECY_TYPE; sec_type++) - max_mem[reg_type][sec_type] = 0; + { + max_reg[reg_type] = 0; + for (int sec_type = 0; sec_type < MAX_SECRECY_TYPE; sec_type++) + max_mem[reg_type][sec_type] = 0; + } for (unsigned int i=0; i= MAX_NUM_CLIENTS) + finish.reveal() == 0 + + winning_client_id = determine_winner(number_clients, client_values, client_ids) + + # print_ln('Found winner, index: %s.', winning_client_id.reveal()) + + write_winner_to_clients(client_sockets, number_clients, winning_client_id) + + return True + +main() diff --git a/Programs/Source/bankers_bonus_commsec.mpc b/Programs/Source/bankers_bonus_commsec.mpc new file mode 100644 index 00000000..a60a2e61 --- /dev/null +++ b/Programs/Source/bankers_bonus_commsec.mpc @@ -0,0 +1,117 @@ +# (C) 2017 University of Bristol. See License.txt +# coding: latin-1 +""" + Solve Bankers bonus, aka Millionaires problem. + to deduce the maximum value from a range of integer input. + + Demonstrate clients external to computing parties supplying input and receiving + an authenticated result. See bankers-bonus-commsec-client.cpp for client (and setup instructions). + + For an implementation without communications security see bankers_bonus.mpc. + + Wait for MAX_NUM_CLIENTS to join the game or client finish flag to be sent + before calculating the maximum. + + Note each client connects in a single thread and so is potentially blocked. + + Each round / game will reset and so this runs indefinitiely. +""" + +from Compiler.types import sint, regint, Array, Matrix, MemValue +from Compiler.instructions import listen, acceptclientconnection +from Compiler.library import print_ln, do_while, if_e, else_, for_range +from Compiler.util import if_else + +PORTNUM = 14000 +MAX_NUM_CLIENTS = 8 + +def accept_client_input(): + """ + Wait for socket connection and read for client public key. + send share of random value, receive input and deduce share. + Expect 3 inputs: unique id, bonus value and flag to indicate end of this round. + """ + client_socket_id = regint() + acceptclientconnection(client_socket_id, PORTNUM) + + # Crypto setup + public_signing_key = regint.read_from_socket(client_socket_id, 8) + public_key = regint.read_client_public_key(client_socket_id) + regint.resp_secure_socket(client_socket_id,*public_signing_key) + + client_inputs = sint.receive_from_client(3, client_socket_id) + + return client_socket_id, client_inputs[0], client_inputs[1], client_inputs[2] + + +def determine_winner(number_clients, client_values, client_ids): + """Work out and return client_id which corresponds to max client_value""" + max_value = Array(1, sint) + max_value[0] = client_values[0] + win_client_id = Array(1, sint) + win_client_id[0] = client_ids[0] + + @for_range(number_clients-1) + def loop_body(i): + # Is this client input a new maximum, will be sint(1) if true, else sint(0) + is_new_max = max_value[0] < client_values[i+1] + # Keep latest max_value + max_value[0] = if_else(is_new_max, client_values[i+1], max_value[0]) + # Keep current winning client id + win_client_id[0] = if_else(is_new_max, client_ids[i+1], win_client_id[0]) + + return win_client_id[0] + + +def write_winner_to_clients(sockets, number_clients, winning_client_id): + """Send share of winning client id to all clients who joined game.""" + + # Setup authenticate result using share of random. + # client can validate ∑ winning_client_id * ∑ rnd_from_triple = ∑ auth_result + rnd_from_triple = sint.get_random_triple()[0] + auth_result = winning_client_id * rnd_from_triple + + @for_range(number_clients) + def loop_body(i): + sint.write_shares_to_socket(sockets[i], [winning_client_id, rnd_from_triple, auth_result]) + + +def main(): + """Listen in while loop for players to join a game. + Once maxiumum reached or have notified that round finished, run comparison and return result.""" + # Start listening for client socket connections + listen(PORTNUM) + print_ln('Listening for client connections on base port %s', PORTNUM) + + @do_while + def game_loop(): + print_ln('Starting a new round of the game.') + + # Clients socket id (integer). + client_sockets = Array(MAX_NUM_CLIENTS, regint) + # Number of clients + number_clients = MemValue(regint(0)) + # Clients secret input. + client_values = Array(MAX_NUM_CLIENTS, sint) + # Client ids to identity client + client_ids = Array(MAX_NUM_CLIENTS, sint) + + # Loop round waiting for each client to connect + @do_while + def client_connections(): + + client_sockets[number_clients], client_ids[number_clients], client_values[number_clients], finish = accept_client_input() + number_clients.write(number_clients+1) + + # continue while both expressions are false + return (number_clients >= MAX_NUM_CLIENTS) + finish.reveal() == 0 + + winning_client_id = determine_winner(number_clients, client_values, client_ids) + + print_ln('Found winner, index: %s.', winning_client_id.reveal()) + + write_winner_to_clients(client_sockets, number_clients, winning_client_id) + + return True + +main() diff --git a/Programs/Source/dijkstra_tutorial.mpc b/Programs/Source/dijkstra_tutorial.mpc index c595087b..dbbef223 100644 --- a/Programs/Source/dijkstra_tutorial.mpc +++ b/Programs/Source/dijkstra_tutorial.mpc @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt import dijkstra from path_oram import OptimalORAM diff --git a/Programs/Source/fixed_point_tutorial.mpc b/Programs/Source/fixed_point_tutorial.mpc index 5994b626..640dbc91 100644 --- a/Programs/Source/fixed_point_tutorial.mpc +++ b/Programs/Source/fixed_point_tutorial.mpc @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt program.bit_length = 80 print "program.bit_length: ", program.bit_length diff --git a/Programs/Source/gale-shapley_tutorial.mpc b/Programs/Source/gale-shapley_tutorial.mpc index 2b296d41..52c345ec 100644 --- a/Programs/Source/gale-shapley_tutorial.mpc +++ b/Programs/Source/gale-shapley_tutorial.mpc @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt from Compiler import gs from Compiler.path_oram import OptimalORAM diff --git a/Programs/Source/oram_tutorial.mpc b/Programs/Source/oram_tutorial.mpc index a3b9a62a..c974578e 100644 --- a/Programs/Source/oram_tutorial.mpc +++ b/Programs/Source/oram_tutorial.mpc @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt from path_oram import OptimalORAM diff --git a/Programs/Source/tpmpc_tutorial.mpc b/Programs/Source/tpmpc_tutorial.mpc index befedf44..b75afb0c 100644 --- a/Programs/Source/tpmpc_tutorial.mpc +++ b/Programs/Source/tpmpc_tutorial.mpc @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt """ Example programs used in the SPDZ tutorial at the TPMPC 2017 workshop in Bristol. diff --git a/Programs/Source/tutorial.mpc b/Programs/Source/tutorial.mpc index 05f394d9..f4b45558 100644 --- a/Programs/Source/tutorial.mpc +++ b/Programs/Source/tutorial.mpc @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt def test(actual, expected): if isinstance(actual, (sint, sgf2n)): diff --git a/Programs/Source/vickrey.mpc b/Programs/Source/vickrey.mpc index 6494ee4c..586c27a6 100644 --- a/Programs/Source/vickrey.mpc +++ b/Programs/Source/vickrey.mpc @@ -1,4 +1,4 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt import util from Compiler import types diff --git a/README.md b/README.md index f593785a..e15f636b 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,15 @@ -(C) 2016 University of Bristol. See License.txt +(C) 2017 University of Bristol. See License.txt Software for the SPDZ and MASCOT secure multi-party computation protocols. See `Programs/Source/` for some example MPC programs, and `tutorial.md` for -a basic tutorial. More examples and documentation will be available in the -coming weeks. +a basic tutorial. See also https://www.cs.bris.ac.uk/Research/CryptographySecurity/SPDZ #### Requirements: - GCC - MPIR library, compiled with C++ support (use flag --enable-cxx when running configure) + - libsodium library, tested against 1.0.11 - CPU supporting AES-NI and PCLMUL - Python 2.x, ideally with `gmpy` package (for testing) diff --git a/Scripts/gen_input_f2n.cpp b/Scripts/gen_input_f2n.cpp index 4544d514..78e8f1e3 100644 --- a/Scripts/gen_input_f2n.cpp +++ b/Scripts/gen_input_f2n.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include #include diff --git a/Scripts/gen_input_fp.cpp b/Scripts/gen_input_fp.cpp index cc47f8a4..ab11fbed 100644 --- a/Scripts/gen_input_fp.cpp +++ b/Scripts/gen_input_fp.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include #include diff --git a/Scripts/run-common.sh b/Scripts/run-common.sh index f946f353..453bffcd 100644 --- a/Scripts/run-common.sh +++ b/Scripts/run-common.sh @@ -1,11 +1,13 @@ -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt run_player() { port=$((RANDOM%10000+10000)) - >&2 echo Port $port bin=$1 shift + if ! test -e $SPDZROOT/logs; then + mkdir $SPDZROOT/logs + fi if test $bin = Player-Online.x; then params="$* -pn $port -h localhost" else @@ -14,13 +16,16 @@ run_player() { if test $bin = Player-KeyGen.x -a ! -e Player-Data/Params-Data; then ./Setup.x $players $size 40 fi - >&2 echo Parameters $params + >&2 echo Running $SPDZROOT/Server.x $players $port $SPDZROOT/Server.x $players $port & rem=$(($players - 2)) for i in $(seq 0 $rem); do + echo "trying with player $i" + >&2 echo Running $prefix $SPDZROOT/$bin $i $params $prefix $SPDZROOT/$bin $i $params 2>&1 | tee $SPDZROOT/logs/$i & done last_player=$(($players - 1)) + >&2 echo Running $prefix $SPDZROOT/$bin $last_player $params $prefix $SPDZROOT/$bin $last_player $params > $SPDZROOT/logs/$last_player 2>&1 || return 1 } diff --git a/Scripts/run-online.sh b/Scripts/run-online.sh index eee9f6e5..7865713d 100755 --- a/Scripts/run-online.sh +++ b/Scripts/run-online.sh @@ -1,6 +1,6 @@ #!/bin/bash -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt HERE=$(cd `dirname $0`; pwd) SPDZROOT=$HERE/.. diff --git a/Scripts/setup-online.sh b/Scripts/setup-online.sh index 9f5328df..85082e24 100755 --- a/Scripts/setup-online.sh +++ b/Scripts/setup-online.sh @@ -1,6 +1,6 @@ #!/bin/bash -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt HERE=$(cd `dirname $0`; pwd) SPDZROOT=$HERE/.. diff --git a/Server.cpp b/Server.cpp index 103887cf..6189b926 100644 --- a/Server.cpp +++ b/Server.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Networking/sockets.h" @@ -18,16 +18,46 @@ int nmachines; +/* + * Get the client ip number on the socket connection for client i. + */ +void get_ip(int num) +{ + struct sockaddr_storage addr; + socklen_t len = sizeof addr; + + getpeername(socket_num[num], (struct sockaddr*)&addr, &len); + + // supports both IPv4 and IPv6: + char ipstr[INET6_ADDRSTRLEN]; + if (addr.ss_family == AF_INET) { + struct sockaddr_in *s = (struct sockaddr_in *)&addr; + inet_ntop(AF_INET, &s->sin_addr, ipstr, sizeof ipstr); + } else { // AF_INET6 + struct sockaddr_in6 *s = (struct sockaddr_in6 *)&addr; + inet_ntop(AF_INET6, &s->sin6_addr, ipstr, sizeof ipstr); + } + + names[num]=new octet[512]; + strncpy((char*)names[num], ipstr, INET6_ADDRSTRLEN); + + cerr << "Client IP address: " << names[num] << endl; +} + + void get_name(int num) { // Now all machines are set up, send GO to start them. send(socket_num[num], GO); cerr << "Player " << num << " started." << endl; - // Receive Name - names[num]=new octet[512]; - receive(socket_num[num],names[num],512); - cerr << "Player " << num << " is on machine " << names[num] << endl; + // Receive name sent by client (legacy) - not used here + octet my_name[512]; + receive(socket_num[num],my_name,512); + cerr << "Player " << num << " sent name (info only) " << my_name << endl; + + // Get client IP + get_ip(num); } @@ -42,9 +72,6 @@ void send_names(int num) } - - - /* Takes command line arguments of - Number of machines connecting - Base PORTNUM address @@ -71,6 +98,7 @@ int main(int argc,char **argv) // port number one lower to avoid conflict with players ServerSocket server(PortnumBase - 1); + server.init(); // set up connections for (i=0; i +#include +#include + +namespace Config { + class ConfigError : public std::exception + { + std::string s; + + public: + ConfigError(std::string ss) : s(ss) {} + ~ConfigError() throw () {} + const char* what() const throw() { return s.c_str(); } + }; + + static void output(const vector &vec, ofstream &of) + { + copy(vec.begin(), vec.end(), ostreambuf_iterator(of)); + } + void print_vector(const vector &vec) + { + cerr << hex; + for(size_t i = 0; i < vec.size(); i ++ ) { + cerr << setfill('0') << setw(2) << (int)vec[i]; + } + cerr << dec << endl; + } + + uint64_t getW64le(ifstream &infile) + { + uint8_t buf[8]; + uint64_t res=0; + infile.read((char*)buf,sizeof buf); + + if (!infile.good()) + throw ConfigError("getW64le: could not read from config file"); + + for(size_t i = 0; i < sizeof buf ; i ++ ) { + res |= ((uint64_t)buf[i]) << i*8; + } + + return res; + } + + void putW64le(ofstream &outf, uint64_t nr) + { + char buf[8]; + for(int i=0;i<8;i++) { + char byte = (uint8_t)(nr >> (i*8)); + buf[i] = (char)byte; + } + outf.write(buf,sizeof buf); + } + + const string default_player_config_file_prefix = "Player-SPDZ-Keys-P"; + string player_config_file(int player_number) + { + stringstream filename; + filename << default_player_config_file_prefix << player_number; + return filename.str(); + } + + void read_player_config(string cfgdir,int my_number,vector pubkeys,secret_signing_key mykey, public_signing_key mypubkey) + { + string filename; + filename = cfgdir + player_config_file(my_number); + ifstream infile(filename.c_str(), ios::in | ios::binary); + + infile.seekg(crypto_box_PUBLICKEYBYTES + crypto_box_SECRETKEYBYTES); + mypubkey.resize(crypto_sign_PUBLICKEYBYTES); + infile.read((char*)&mypubkey[0], crypto_sign_PUBLICKEYBYTES); + mykey.resize(crypto_sign_SECRETKEYBYTES); + infile.read((char*)&mykey[0], crypto_sign_SECRETKEYBYTES); + + // If we've failed by this point, abort. After this point we'll + // just try to read optional content. + if (!infile.good()) { + throw ConfigError("Could not parse player config file."); + } + + // Deal gracefully with absence of additional key material + try { + uint64_t nrClients = getW64le(infile); + infile.ignore(nrClients * (crypto_sign_PUBLICKEYBYTES + crypto_box_PUBLICKEYBYTES)); + uint64_t nrPlayers = getW64le(infile); + pubkeys.resize(nrPlayers); + for(size_t i=0; i client_pubs, vector client_signing_pubs + , vector player_pubs, vector player_signing_pubs) + { + stringstream filename; + filename << config_dir << "Player-SPDZ-Keys-P" << player_number; + ofstream outf(filename.str().c_str(), ios::out | ios::binary); + if (outf.fail()) + throw file_error(filename.str().c_str()); + if(crypto_box_PUBLICKEYBYTES != my_pub.size() || + crypto_box_SECRETKEYBYTES != my_priv.size() || + crypto_sign_PUBLICKEYBYTES != my_signing_pub.size() || + crypto_sign_SECRETKEYBYTES != my_signing_priv.size()) { + throw "Invalid key sizes"; + } else if(client_pubs.size() != client_signing_pubs.size()) { + throw "Incorrect number of client keys"; + } else if(player_pubs.size() != player_signing_pubs.size()) { + throw "Incorrect number of player keys"; + } else { + for(size_t i = 0; i < client_pubs.size(); i++) { + if(crypto_box_PUBLICKEYBYTES != client_pubs[i].size() || + crypto_sign_PUBLICKEYBYTES != client_signing_pubs[i].size()) { + throw "Incorrect size of client key."; + } + } + for(size_t i = 0; i < player_pubs.size(); i++) { + if(crypto_box_PUBLICKEYBYTES != player_pubs[i].size() || + crypto_sign_PUBLICKEYBYTES != player_signing_pubs[i].size()) { + throw "Incorrect size of player key."; + } + } + } + // Write public and secret X25519 keys + output(my_pub, outf); + output(my_priv, outf); + output(my_signing_pub, outf); + output(my_signing_priv, outf); + + putW64le(outf, (uint64_t)client_pubs.size()); + // Write all client public keys + for (size_t j = 0; j < client_pubs.size(); j++) { + output(client_pubs[j], outf); + output(client_signing_pubs[j], outf); + } + putW64le(outf, (uint64_t)player_pubs.size()); + for (size_t j = 0; j < player_pubs.size(); j++) { + output(player_pubs[j], outf); + output(player_signing_pubs[j], outf); + } + outf.flush(); + outf.close(); + } +} diff --git a/Tools/Config.h b/Tools/Config.h new file mode 100644 index 00000000..1ad725b6 --- /dev/null +++ b/Tools/Config.h @@ -0,0 +1,20 @@ +#include "Tools/octetStream.h" +#include "Networking/Player.h" +#include +namespace Config { + typedef vector public_key; + typedef vector public_signing_key; + typedef vector secret_key; + typedef vector secret_signing_key; + void read_player_config(string cfgdir,int my_number,vector pubkeys,secret_signing_key mysecretkey, public_signing_key mypubkey); + void write_player_config_file(string config_dir + ,int player_number, public_key my_pub, secret_key my_priv + , public_signing_key my_signing_pub, secret_signing_key my_signing_priv + , vector client_pubs, vector client_signing_pubs + , vector player_pubs, vector player_signing_pubs); + uint64_t getW64le(ifstream &infile); + void putW64le(ofstream &outf, uint64_t nr); + extern const string default_player_config_file_prefix; + string player_config_file(int player_number); + void print_vector(const vector &vec); +} diff --git a/Tools/Lock.cpp b/Tools/Lock.cpp index 73b9152c..79221017 100644 --- a/Tools/Lock.cpp +++ b/Tools/Lock.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Lock.cpp diff --git a/Tools/Lock.h b/Tools/Lock.h index 06c3b4b8..59d533f4 100644 --- a/Tools/Lock.h +++ b/Tools/Lock.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Lock.h diff --git a/Tools/MMO.cpp b/Tools/MMO.cpp index a940305a..040febaa 100644 --- a/Tools/MMO.cpp +++ b/Tools/MMO.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * MMO.cpp diff --git a/Tools/MMO.h b/Tools/MMO.h index 3f6fe3e8..99938336 100644 --- a/Tools/MMO.h +++ b/Tools/MMO.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * MMO.h diff --git a/Tools/Signal.cpp b/Tools/Signal.cpp index 420fdd15..6515190f 100644 --- a/Tools/Signal.cpp +++ b/Tools/Signal.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Signal.cpp diff --git a/Tools/Signal.h b/Tools/Signal.h index 27fbacc5..1507d5fa 100644 --- a/Tools/Signal.h +++ b/Tools/Signal.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * Signal.h diff --git a/Tools/WaitQueue.h b/Tools/WaitQueue.h index 20722fd2..e07b9ef3 100644 --- a/Tools/WaitQueue.h +++ b/Tools/WaitQueue.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * WaitQueue.h diff --git a/Tools/aes-ni.cpp b/Tools/aes-ni.cpp index e1125efd..6af4e596 100644 --- a/Tools/aes-ni.cpp +++ b/Tools/aes-ni.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "aes.h" diff --git a/Tools/aes.cpp b/Tools/aes.cpp index 23a98334..99b99e2e 100644 --- a/Tools/aes.cpp +++ b/Tools/aes.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "aes.h" diff --git a/Tools/aes.h b/Tools/aes.h index 2924abb3..5d25101e 100644 --- a/Tools/aes.h +++ b/Tools/aes.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef __AES_H #define __AES_H diff --git a/Tools/ezOptionParser.h b/Tools/ezOptionParser.h index d2b09c05..7012a36b 100644 --- a/Tools/ezOptionParser.h +++ b/Tools/ezOptionParser.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* This file is part of ezOptionParser. See MIT-LICENSE. @@ -52,7 +52,7 @@ static T fromString(const char* s) { return t; }; /* ################################################################### */ -static inline bool isdigit(const std::string & s, int i=0) { +static inline bool isdigit(const std::string & s, int i=0) { int n = s.length(); for(; i < n; ++i) switch(s[i]) { @@ -85,12 +85,12 @@ For example, -d < --dimension < --dmn, and also lower come before upper. The def static bool CmpOptStringPtr(std::string * s1, std::string * s2) { int c1,c2; const char *s=s1->c_str(); - for(c1=0; c1 < (long int)s1->size(); ++c1) + for(c1=0; c1 < (long int)s1->size(); ++c1) if (isalnum(s[c1])) // locale sensitive. break; s=s2->c_str(); - for(c2=0; c2 < (long int)s2->size(); ++c2) + for(c2=0; c2 < (long int)s2->size(); ++c2) if (isalnum(s[c2])) break; @@ -232,67 +232,67 @@ static void ToD(std::string ** strings, double * out, int n) { }; /* ################################################################### */ static void StringsToInts(std::vector & strings, std::vector & out) { - for(int i=0; i < (long int)strings.size(); ++i) { + for(int i=0; i < (long int)strings.size(); ++i) { out.push_back(atoi(strings[i].c_str())); } }; /* ################################################################### */ static void StringsToInts(std::vector * strings, std::vector * out) { - for(int i=0; i < (long int)strings->size(); ++i) { + for(int i=0; i < (long int)strings->size(); ++i) { out->push_back(atoi(strings->at(i)->c_str())); } }; /* ################################################################### */ static void StringsToLongs(std::vector & strings, std::vector & out) { - for(int i=0; i < (long int)strings.size(); ++i) { + for(int i=0; i < (long int)strings.size(); ++i) { out.push_back(atol(strings[i].c_str())); } }; /* ################################################################### */ static void StringsToLongs(std::vector * strings, std::vector * out) { - for(int i=0; i < (long int)strings->size(); ++i) { + for(int i=0; i < (long int)strings->size(); ++i) { out->push_back(atol(strings->at(i)->c_str())); } }; /* ################################################################### */ static void StringsToULongs(std::vector & strings, std::vector & out) { - for(int i=0; i < (long int)strings.size(); ++i) { + for(int i=0; i < (long int)strings.size(); ++i) { out.push_back(strtoul(strings[i].c_str(),0,0)); } }; /* ################################################################### */ static void StringsToULongs(std::vector * strings, std::vector * out) { - for(int i=0; i < (long int)strings->size(); ++i) { + for(int i=0; i < (long int)strings->size(); ++i) { out->push_back(strtoul(strings->at(i)->c_str(),0,0)); } }; /* ################################################################### */ static void StringsToFloats(std::vector & strings, std::vector & out) { - for(int i=0; i < (long int)strings.size(); ++i) { + for(int i=0; i < (long int)strings.size(); ++i) { out.push_back(atof(strings[i].c_str())); } }; /* ################################################################### */ static void StringsToFloats(std::vector * strings, std::vector * out) { - for(int i=0; i < (long int)strings->size(); ++i) { + for(int i=0; i < (long int)strings->size(); ++i) { out->push_back(atof(strings->at(i)->c_str())); } }; /* ################################################################### */ static void StringsToDoubles(std::vector & strings, std::vector & out) { - for(int i=0; i < (long int)strings.size(); ++i) { + for(int i=0; i < (long int)strings.size(); ++i) { out.push_back(atof(strings[i].c_str())); } }; /* ################################################################### */ static void StringsToDoubles(std::vector * strings, std::vector * out) { - for(int i=0; i < (long int)strings->size(); ++i) { + for(int i=0; i < (long int)strings->size(); ++i) { out->push_back(atof(strings->at(i)->c_str())); } }; /* ################################################################### */ static void StringsToStrings(std::vector * strings, std::vector * out) { - for(int i=0; i < (long int)strings->size(); ++i) { + for(int i=0; i < (long int)strings->size(); ++i) { out->push_back( *strings->at(i) ); } }; @@ -335,7 +335,7 @@ static char** CommandLineToArgvA(char* CmdLine, int* _argc) { i = 0; j = 0; - while( (a = CmdLine[i]) ) { + while( (a = CmdLine[i]) ) { if(in_QM) { if( (a == '\"') || (a == '\'')) // rsz. Added single quote. @@ -498,71 +498,71 @@ void ezOptionValidator::reset() { type = NOTYPE; }; /* ------------------------------------------------------------------- */ -ezOptionValidator::ezOptionValidator(char _type) : s1(0), op(0), quiet(0), type(_type), size(0), insensitive(0) { +ezOptionValidator::ezOptionValidator(char _type) : s1(0), op(0), quiet(0), type(_type), size(0), insensitive(0) { id = ezOptionParserIDGenerator::instance().next(); }; /* ------------------------------------------------------------------- */ -ezOptionValidator::ezOptionValidator(char _type, char _op, const char* list, int _size) : s1(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { +ezOptionValidator::ezOptionValidator(char _type, char _op, const char* list, int _size) : s1(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { id = ezOptionParserIDGenerator::instance().next(); s1 = new char[size]; memcpy(s1, list, size); }; /* ------------------------------------------------------------------- */ -ezOptionValidator::ezOptionValidator(char _type, char _op, const unsigned char* list, int _size) : u1(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { +ezOptionValidator::ezOptionValidator(char _type, char _op, const unsigned char* list, int _size) : u1(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { id = ezOptionParserIDGenerator::instance().next(); u1 = new unsigned char[size]; memcpy(u1, list, size); }; /* ------------------------------------------------------------------- */ -ezOptionValidator::ezOptionValidator(char _type, char _op, const short* list, int _size) : s2(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { +ezOptionValidator::ezOptionValidator(char _type, char _op, const short* list, int _size) : s2(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { id = ezOptionParserIDGenerator::instance().next(); s2 = new short[size]; memcpy(s2, list, size*sizeof(short)); }; /* ------------------------------------------------------------------- */ -ezOptionValidator::ezOptionValidator(char _type, char _op, const unsigned short* list, int _size) : u2(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { +ezOptionValidator::ezOptionValidator(char _type, char _op, const unsigned short* list, int _size) : u2(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { id = ezOptionParserIDGenerator::instance().next(); u2 = new unsigned short[size]; memcpy(u2, list, size*sizeof(unsigned short)); }; /* ------------------------------------------------------------------- */ -ezOptionValidator::ezOptionValidator(char _type, char _op, const int* list, int _size) : s4(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { +ezOptionValidator::ezOptionValidator(char _type, char _op, const int* list, int _size) : s4(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { id = ezOptionParserIDGenerator::instance().next(); s4 = new int[size]; memcpy(s4, list, size*sizeof(int)); }; /* ------------------------------------------------------------------- */ -ezOptionValidator::ezOptionValidator(char _type, char _op, const unsigned int* list, int _size) : u4(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { +ezOptionValidator::ezOptionValidator(char _type, char _op, const unsigned int* list, int _size) : u4(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { id = ezOptionParserIDGenerator::instance().next(); u4 = new unsigned int[size]; memcpy(u4, list, size*sizeof(unsigned int)); }; /* ------------------------------------------------------------------- */ -ezOptionValidator::ezOptionValidator(char _type, char _op, const long long* list, int _size) : s8(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { +ezOptionValidator::ezOptionValidator(char _type, char _op, const long long* list, int _size) : s8(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { id = ezOptionParserIDGenerator::instance().next(); s8 = new long long[size]; memcpy(s8, list, size*sizeof(long long)); }; /* ------------------------------------------------------------------- */ -ezOptionValidator::ezOptionValidator(char _type, char _op, const unsigned long long* list, int _size) : u8(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { +ezOptionValidator::ezOptionValidator(char _type, char _op, const unsigned long long* list, int _size) : u8(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { id = ezOptionParserIDGenerator::instance().next(); u8 = new unsigned long long[size]; memcpy(u8, list, size*sizeof(unsigned long long)); }; /* ------------------------------------------------------------------- */ -ezOptionValidator::ezOptionValidator(char _type, char _op, const float* list, int _size) : f(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { +ezOptionValidator::ezOptionValidator(char _type, char _op, const float* list, int _size) : f(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { id = ezOptionParserIDGenerator::instance().next(); f = new float[size]; memcpy(f, list, size*sizeof(float)); }; /* ------------------------------------------------------------------- */ -ezOptionValidator::ezOptionValidator(char _type, char _op, const double* list, int _size) : d(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { +ezOptionValidator::ezOptionValidator(char _type, char _op, const double* list, int _size) : d(0), op(_op), quiet(0), type(_type), size(_size), insensitive(0) { id = ezOptionParserIDGenerator::instance().next(); d = new double[size]; memcpy(d, list, size*sizeof(double)); }; /* ------------------------------------------------------------------- */ -ezOptionValidator::ezOptionValidator(char _type, char _op, const char** list, int _size, bool _insensitive) : t(0), op(_op), quiet(0), type(_type), size(_size), insensitive(_insensitive) { +ezOptionValidator::ezOptionValidator(char _type, char _op, const char** list, int _size, bool _insensitive) : t(0), op(_op), quiet(0), type(_type), size(_size), insensitive(_insensitive) { id = ezOptionParserIDGenerator::instance().next(); t = new std::string*[size]; int i=0; @@ -577,7 +577,7 @@ _type: s1, u1, s2, u2, ..., f, d, t _op: lt, gt, ..., in _list: comma-delimited string */ -ezOptionValidator::ezOptionValidator(const char* _type, const char* _op, const char* _list, bool _insensitive) : t(0), quiet(0), type(0), size(0), insensitive(_insensitive) { +ezOptionValidator::ezOptionValidator(const char* _type, const char* _op, const char* _list, bool _insensitive) : t(0), quiet(0), type(0), size(0), insensitive(_insensitive) { id = ezOptionParserIDGenerator::instance().next(); switch(_type[0]) { @@ -932,11 +932,11 @@ bool ezOptionValidator::isValid(const std::string * valueAsString) { /* ################################################################### */ class OptionGroup { public: - OptionGroup() : delim(0), expectArgs(0), isRequired(false), isSet(false) { } + OptionGroup() : delim(0), expectArgs(0), isRequired(false), isSet(false) { } ~OptionGroup() { - int i; - for(i=0; i < (long int)flags.size(); ++i) + int i; + for(i=0; i < (long int)flags.size(); ++i) delete flags[i]; flags.clear(); @@ -988,8 +988,8 @@ public: /* ################################################################### */ void OptionGroup::clearArgs() { int i,j; - for(i=0; i < (long int)args.size(); ++i) { - for(j=0; j < (long int)args[i]->size(); ++j) + for(i=0; i < (long int)args.size(); ++i) { + for(j=0; j < (long int)args[i]->size(); ++j) delete args[i]->at(j); delete args[i]; @@ -1209,7 +1209,7 @@ void OptionGroup::getMultiInts(std::vector< std::vector >& out) { } else { if (!args.empty()) { int n = args.size(); - if ((long int)out.size() < n) out.resize(n); + if ((long int)out.size() < n) out.resize(n); for(int i=0; i < n; ++i) { StringsToInts(args[i], &out[i]); } @@ -1228,7 +1228,7 @@ void OptionGroup::getMultiLongs(std::vector< std::vector >& out) { } else { if (!args.empty()) { int n = args.size(); - if ((long int)out.size() < n) out.resize(n); + if ((long int)out.size() < n) out.resize(n); for(int i=0; i < n; ++i) { StringsToLongs(args[i], &out[i]); } @@ -1247,7 +1247,7 @@ void OptionGroup::getMultiULongs(std::vector< std::vector >& out) } else { if (!args.empty()) { int n = args.size(); - if ((long int)out.size() < n) out.resize(n); + if ((long int)out.size() < n) out.resize(n); for(int i=0; i < n; ++i) { StringsToULongs(args[i], &out[i]); } @@ -1266,7 +1266,7 @@ void OptionGroup::getMultiFloats(std::vector< std::vector >& out) { } else { if (!args.empty()) { int n = args.size(); - if ((long int)out.size() < n) out.resize(n); + if ((long int)out.size() < n) out.resize(n); for(int i=0; i < n; ++i) { StringsToFloats(args[i], &out[i]); } @@ -1285,7 +1285,7 @@ void OptionGroup::getMultiDoubles(std::vector< std::vector >& out) { } else { if (!args.empty()) { int n = args.size(); - if ((long int)out.size() < n) out.resize(n); + if ((long int)out.size() < n) out.resize(n); for(int i=0; i < n; ++i) { StringsToDoubles(args[i], &out[i]); } @@ -1304,10 +1304,10 @@ void OptionGroup::getMultiStrings(std::vector< std::vector >& out) } else { if (!args.empty()) { int n = args.size(); - if ((long int)out.size() < n) out.resize(n); + if ((long int)out.size() < n) out.resize(n); for(int i=0; i < n; ++i) { - for(int j=0; j < (long int)args[i]->size(); ++j) + for(int j=0; j < (long int)args[i]->size(); ++j) out[i].push_back( *args[i]->at(j) ); } } @@ -1378,19 +1378,19 @@ void ezOptionParser::reset() { this->doublespace = 1; int i; - for(i=0; i < (long int)groups.size(); ++i) + for(i=0; i < (long int)groups.size(); ++i) delete groups[i]; groups.clear(); - for(i=0; i < (long int)unknownArgs.size(); ++i) + for(i=0; i < (long int)unknownArgs.size(); ++i) delete unknownArgs[i]; unknownArgs.clear(); - for(i=0; i < (long int)firstArgs.size(); ++i) + for(i=0; i < (long int)firstArgs.size(); ++i) delete firstArgs[i]; firstArgs.clear(); - for(i=0; i < (long int)lastArgs.size(); ++i) + for(i=0; i < (long int)lastArgs.size(); ++i) delete lastArgs[i]; lastArgs.clear(); @@ -1405,18 +1405,18 @@ void ezOptionParser::reset() { /* ################################################################### */ void ezOptionParser::resetArgs() { int i; - for(i=0; i < (long int)groups.size(); ++i) + for(i=0; i < (long int)groups.size(); ++i) groups[i]->clearArgs(); - for(i=0; i < (long int)unknownArgs.size(); ++i) + for(i=0; i < (long int)unknownArgs.size(); ++i) delete unknownArgs[i]; unknownArgs.clear(); - for(i=0; i < (long int)firstArgs.size(); ++i) + for(i=0; i < (long int)firstArgs.size(); ++i) delete firstArgs[i]; firstArgs.clear(); - for(i=0; i < (long int)lastArgs.size(); ++i) + for(i=0; i < (long int)lastArgs.size(); ++i) delete lastArgs[i]; lastArgs.clear(); }; @@ -1540,7 +1540,7 @@ bool ezOptionParser::exportFile(const char * filename, bool all) { bool quote; // Export the first args, except the program name, so start from 1. - for(i=1; i < (long int)firstArgs.size(); ++i) { + for(i=1; i < (long int)firstArgs.size(); ++i) { quote = ((firstArgs[i]->find_first_of(" \t") != std::string::npos) && (firstArgs[i]->find_first_of("\'\"") == std::string::npos)); if (quote) @@ -1557,7 +1557,7 @@ bool ezOptionParser::exportFile(const char * filename, bool all) { out.append("\n"); std::vector stringPtrs(groups.size()); - int m; + int m; int n = groups.size(); for(i=0; i < n; ++i) { stringPtrs[i] = groups[i]->flags[0]; @@ -1609,7 +1609,7 @@ bool ezOptionParser::exportFile(const char * filename, bool all) { } // Export the last args. - for(i=0; i < (long int)lastArgs.size(); ++i) { + for(i=0; i < (long int)lastArgs.size(); ++i) { quote = ( lastArgs[i]->find_first_of(" \t") != std::string::npos ); if (quote) out.append("\""); @@ -1804,18 +1804,18 @@ void ezOptionParser::getUsageDescriptions(std::string & usage, int width, Layout std::map stringPtrToIndexMap; std::vector stringPtrs(groups.size()); - for(i=0; i < (long int)groups.size(); ++i) { + for(i=0; i < (long int)groups.size(); ++i) { std::sort(groups[i]->flags.begin(), groups[i]->flags.end(), CmpOptStringPtr); stringPtrToIndexMap[groups[i]->flags[0]] = i; stringPtrs[i] = groups[i]->flags[0]; } - size_t j, k; + size_t j, k; std::string opts; std::vector sortedOpts; // Sort first flag of each group with other groups. std::sort(stringPtrs.begin(), stringPtrs.end(), CmpOptStringPtr); - for(i=0; i < (long int)groups.size(); ++i) { + for(i=0; i < (long int)groups.size(); ++i) { //printf("DEBUG:%d: %d %d %s\n", __LINE__, i, stringPtrToIndexMap[stringPtrs[i]], stringPtrs[i]->c_str()); k = stringPtrToIndexMap[stringPtrs[i]]; opts.clear(); @@ -1823,7 +1823,7 @@ void ezOptionParser::getUsageDescriptions(std::string & usage, int width, Layout opts.append(*groups[k]->flags[j]); opts.append(", "); - if ((long int)opts.size() > width) + if ((long int)opts.size() > width) opts.append("\n"); } // The last flag. No need to append comma anymore. @@ -1851,8 +1851,8 @@ void ezOptionParser::getUsageDescriptions(std::string & usage, int width, Layout // Find longest opt flag string to set column start for help usage descriptions. int maxlen=0; if (layout == ALIGN) { - for(i=0; i < (long int)groups.size(); ++i) { - if (maxlen < (long int)sortedOpts[i].size()) + for(i=0; i < (long int)groups.size(); ++i) { + if (maxlen < (long int)sortedOpts[i].size()) maxlen = sortedOpts[i].size(); } } @@ -1861,7 +1861,7 @@ void ezOptionParser::getUsageDescriptions(std::string & usage, int width, Layout int helpwidth; std::list::iterator cIter, insertionIter; size_t pos; - for(i=0; i < (long int)groups.size(); ++i) { + for(i=0; i < (long int)groups.size(); ++i) { k = stringPtrToIndexMap[stringPtrs[i]]; if (layout == STAGGER) @@ -1876,13 +1876,13 @@ void ezOptionParser::getUsageDescriptions(std::string & usage, int width, Layout for(insertionIter=desc.begin(), cIter=insertionIter++; cIter != desc.end(); cIter=insertionIter++) { - if ((long int)((*cIter)->size()) > helpwidth) { + if ((long int)((*cIter)->size()) > helpwidth) { // Get pointer to next string to insert new strings before it. std::string *rem = *cIter; // Remove this line and add back in pieces. desc.erase(cIter); // Loop until remaining string is short enough. - while ((long int)rem->size() > helpwidth) { + while ((long int)rem->size() > helpwidth) { // Find whitespace to split before helpwidth. if (rem->at(helpwidth) == ' ') { // If word ends exactly at helpwidth, then split after it. @@ -1940,7 +1940,7 @@ void ezOptionParser::getUsageDescriptions(std::string & usage, int width, Layout bool ezOptionParser::gotExpected(std::vector & badOptions) { int i,j; - for(i=0; i < (long int)groups.size(); ++i) { + for(i=0; i < (long int)groups.size(); ++i) { OptionGroup *g = groups[i]; // If was set, ensure number of args is correct. if (g->isSet) { @@ -1949,8 +1949,8 @@ bool ezOptionParser::gotExpected(std::vector & badOptions) { continue; } - for(j=0; j < (long int)g->args.size(); ++j) { - if ((g->expectArgs != -1) && (g->expectArgs != (long int)g->args[j]->size())) + for(j=0; j < (long int)g->args.size(); ++j) { + if ((g->expectArgs != -1) && (g->expectArgs != (long int)g->args[j]->size())) badOptions.push_back(*g->flags[0]); } } @@ -1962,7 +1962,7 @@ bool ezOptionParser::gotExpected(std::vector & badOptions) { bool ezOptionParser::gotRequired(std::vector & badOptions) { int i; - for(i=0; i < (long int)groups.size(); ++i) { + for(i=0; i < (long int)groups.size(); ++i) { OptionGroup *g = groups[i]; // Simple case when required but user never set it. if (g->isRequired && (!g->isSet)) { @@ -1987,10 +1987,10 @@ bool ezOptionParser::gotValid(std::vector & badOptions, std::vector ezOptionValidator *v = validators[validatorid]; bool nextgroup = false; - for (int i = 0; i < (long int)g->args.size(); ++i) { + for (int i = 0; i < (long int)g->args.size(); ++i) { if (nextgroup) break; std::vector< std::string* > * args = g->args[i]; - for (int j = 0; j < (long int)args->size(); ++j) { + for (int j = 0; j < (long int)args->size(); ++j) { if (!v->isValid(args->at(j))) { badOptions.push_back(*g->flags[0]); badArgs.push_back(*args->at(j)); @@ -2013,7 +2013,7 @@ void ezOptionParser::parse(int argc, const char * argv[]) { std::cout << (*it).first << " => " << (*it).second << std::endl; */ - int i, k, firstOptIndex=0, lastOptIndex=0; + int i, k, firstOptIndex=0, lastOptIndex=0; std::string s; OptionGroup *g; @@ -2090,7 +2090,7 @@ void ezOptionParser::prettyPrint(std::string & out) { int i,j,k; out += "First Args:\n"; - for(i=0; i < (long int)firstArgs.size(); ++i) { + for(i=0; i < (long int)firstArgs.size(); ++i) { sprintf(tmp, "%d: %s\n", i+1, firstArgs[i]->c_str()); out += tmp; } @@ -2111,7 +2111,7 @@ void ezOptionParser::prettyPrint(std::string & out) { g = get(stringPtrs[i]->c_str()); out += "\n"; // The flag names: - for(j=0; j < (long int)g->flags.size()-1; ++j) { + for(j=0; j < (long int)g->flags.size()-1; ++j) { sprintf(tmp, "%s, ", g->flags[j]->c_str()); out += tmp; } @@ -2124,12 +2124,12 @@ void ezOptionParser::prettyPrint(std::string & out) { sprintf(tmp, "%s (default)\n", g->defaults.c_str()); out += tmp; } else { - for(k=0; k < (long int)g->args.size(); ++k) { - for(j=0; j < (long int)g->args[k]->size()-1; ++j) { + for(k=0; k < (long int)g->args.size(); ++k) { + for(j=0; j < (long int)g->args[k]->size()-1; ++j) { sprintf(tmp, "%s%c", g->args[k]->at(j)->c_str(), g->delim); out += tmp; } - sprintf(tmp, "%s\n", g->args[k]->back()->c_str()); + sprintf(tmp, "%s\n", g->args[k]->back()->c_str()); out += tmp; } } @@ -2144,13 +2144,13 @@ void ezOptionParser::prettyPrint(std::string & out) { } out += "\nLast Args:\n"; - for(i=0; i < (long int)lastArgs.size(); ++i) { + for(i=0; i < (long int)lastArgs.size(); ++i) { sprintf(tmp, "%d: %s\n", i+1, lastArgs[i]->c_str()); out += tmp; } out += "\nUnknown Args:\n"; - for(i=0; i < (long int)unknownArgs.size(); ++i) { + for(i=0; i < (long int)unknownArgs.size(); ++i) { sprintf(tmp, "%d: %s\n", i+1, unknownArgs[i]->c_str()); out += tmp; } diff --git a/Tools/int.h b/Tools/int.h index 78e253bc..69e7a494 100644 --- a/Tools/int.h +++ b/Tools/int.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * int.h diff --git a/Tools/mkpath.cpp b/Tools/mkpath.cpp index a09525c8..5dc13ec0 100644 --- a/Tools/mkpath.cpp +++ b/Tools/mkpath.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Tools/mkpath.h" #include diff --git a/Tools/mkpath.h b/Tools/mkpath.h index 4eca401e..d10be4ae 100644 --- a/Tools/mkpath.h +++ b/Tools/mkpath.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef TOOLS_MKPATH_H_ #define TOOLS_MKPATH_H_ diff --git a/Tools/octetStream.cpp b/Tools/octetStream.cpp index d3d83dbb..55337410 100644 --- a/Tools/octetStream.cpp +++ b/Tools/octetStream.cpp @@ -1,7 +1,8 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include +#include #include "octetStream.h" #include @@ -10,7 +11,6 @@ #include "Exceptions/Exceptions.h" #include "Networking/data.h" - void octetStream::assign(const octetStream& os) { if (os.len>=mxlen) @@ -45,38 +45,31 @@ octetStream::octetStream(const octetStream& os) void octetStream::hash(octetStream& output) const { - blk_SHA_CTX ctx; - blk_SHA1_Init(&ctx); - blk_SHA1_Update(&ctx,data,len); - blk_SHA1_Final(output.data,&ctx); - output.len=HASH_SIZE; + crypto_generichash(output.data, crypto_generichash_BYTES_MIN, data, len, NULL, 0); + output.len=crypto_generichash_BYTES_MIN; } octetStream octetStream::hash() const { - octetStream h(HASH_SIZE); + octetStream h(crypto_generichash_BYTES_MIN); hash(h); return h; } -bigint octetStream::check_sum() const +bigint octetStream::check_sum(int req_bytes) const { - unsigned char hash[HASH_SIZE]; - - blk_SHA_CTX ctx; - blk_SHA1_Init(&ctx); - blk_SHA1_Update(&ctx,data,len); - blk_SHA1_Final(hash,&ctx); + unsigned char hash[req_bytes]; + crypto_generichash(hash, req_bytes, data, len, NULL, 0); bigint ans; - bigintFromBytes(ans,hash,HASH_SIZE); + bigintFromBytes(ans,hash,req_bytes); + // cout << ans << "\n"; return ans; } - bool octetStream::equals(const octetStream& a) const { if (len!=a.len) { return false; } @@ -89,9 +82,7 @@ bool octetStream::equals(const octetStream& a) const void octetStream::append_random(int num) { resize(len+num); - int randomData = open("/dev/urandom", O_RDONLY); - read(randomData, data+len, num*sizeof(unsigned char)); - close(randomData); + randombytes_buf(data+len, num); len+=num; } @@ -126,6 +117,13 @@ void octetStream::store(unsigned int l) len+=4; } +void octetStream::store(int l) +{ + resize(len+4); + INT_TO_BYTES(data+len,l); + len+=4; +} + void octetStream::get(unsigned int& l) { @@ -133,6 +131,12 @@ void octetStream::get(unsigned int& l) ptr+=4; } +void octetStream::get(int& l) +{ + l=BYTES_TO_INT(data+ptr); + ptr+=4; +} + void octetStream::store(const bigint& x) { @@ -165,8 +169,84 @@ void octetStream::get(bigint& ans) } } +// Construct the ciphertext as `crypto_secretbox(pt, counter||random)` +void octetStream::encrypt_sequence(const octet* key, uint64_t counter) +{ + octet nonce[crypto_secretbox_NONCEBYTES]; + int i; + int message_len_bytes = len; + randombytes_buf(nonce, sizeof nonce); + if(counter == UINT64_MAX) { + throw Processor_Error("Encryption would overflow counter. Too many messages."); + } else { + counter++; + } + for(i=0; i<8; i++) { + nonce[i] = uint8_t ((counter >> (8*i)) & 0xFF); + } + resize(len + crypto_secretbox_MACBYTES + crypto_secretbox_NONCEBYTES); + // Encrypt data in-place + crypto_secretbox_easy(data, data, message_len_bytes, nonce, key); + // Adjust length to account for MAC, then append nonce + len += crypto_secretbox_MACBYTES; + append(nonce, sizeof nonce); +} + +void octetStream::decrypt_sequence(const octet* key, uint64_t counter) +{ + int ciphertext_len = len - crypto_box_NONCEBYTES; + const octet *nonce = data + ciphertext_len; + int i; + uint64_t recvCounter=0; + // Numbers are typically 24U + 16U so cast to int is safe. + if (len < (int)(crypto_box_NONCEBYTES + crypto_secretbox_MACBYTES)) + { + throw Processor_Error("Cannot decrypt octetStream: ciphertext too short"); + } + for(i=7; i>=0; i--) { + recvCounter |= (uint64_t) *(nonce + i); + recvCounter = recvCounter << (i*8); + } + if(recvCounter != counter + 1) { + throw Processor_Error("Incorrect counter on stream. Possible MITM."); + } + if (crypto_secretbox_open_easy(data, data, ciphertext_len, nonce, key) != 0) + { + throw Processor_Error("octetStream decryption failed!"); + } + rewind_write_head(crypto_box_NONCEBYTES + crypto_secretbox_MACBYTES); +} + +void octetStream::encrypt(const octet* key) +{ + octet nonce[crypto_secretbox_NONCEBYTES]; + randombytes_buf(nonce, sizeof nonce); + int message_len_bytes = len; + resize(len + crypto_secretbox_MACBYTES + crypto_secretbox_NONCEBYTES); + + // Encrypt data in-place + crypto_secretbox_easy(data, data, message_len_bytes, nonce, key); + // Adjust length to account for MAC, then append nonce + len += crypto_secretbox_MACBYTES; + append(nonce, sizeof nonce); +} + +void octetStream::decrypt(const octet* key) +{ + int ciphertext_len = len - crypto_box_NONCEBYTES; + // Numbers are typically 24U + 16U so cast to int is safe. + if (len < (int)(crypto_box_NONCEBYTES + crypto_secretbox_MACBYTES)) + { + throw Processor_Error("Cannot decrypt octetStream: ciphertext too short"); + } + if (crypto_secretbox_open_easy(data, data, ciphertext_len, data + ciphertext_len, key) != 0) + { + throw Processor_Error("octetStream decryption failed!"); + } + rewind_write_head(crypto_box_NONCEBYTES + crypto_secretbox_MACBYTES); +} ostream& operator<<(ostream& s,const octetStream& o) { diff --git a/Tools/octetStream.h b/Tools/octetStream.h index 13487dec..18ff3934 100644 --- a/Tools/octetStream.h +++ b/Tools/octetStream.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _octetStream #define _octetStream @@ -23,6 +23,9 @@ #include #include #include + +#include + using namespace std; @@ -50,11 +53,15 @@ class octetStream int get_length() const { return len; } octet* get_data() const { return data; } + bool done() const { return ptr == len; } + bool empty() const { return len == 0; } + int left() const { return len - ptr; } + octetStream hash() const; // output must have length at least HASH_SIZE void hash(octetStream& output) const; // The following produces a check sum for debugging purposes - bigint check_sum() const; + bigint check_sum(int req_bytes=crypto_hash_BYTES) const; void concat(const octetStream& os); @@ -83,13 +90,19 @@ class octetStream void get_bytes(octet* ans, int& l); //Assumes enough space in ans void store(unsigned int a); - void store(int a) { store((unsigned int) a); } + void store(int a); void get(unsigned int& a); - void get(int& a) { get((unsigned int&) a); } + void get(int& a); void store(const bigint& x); void get(bigint& ans); + // works for all statically allocated types + template + void serialize(const T& x) { append((octet*)&x, sizeof(x)); } + template + void unserialize(T& x) { consume((octet*)&x, sizeof(x)); } + void consume(octetStream& s,int l) { s.resize(l); consume(s.data,l); @@ -98,6 +111,20 @@ class octetStream void Send(int socket_num) const; void Receive(int socket_num); + void ReceiveExpected(int socket_num, int expected); + + // In-place authenticated encryption using sodium; key of length crypto_generichash_BYTES + // ciphertext = Enc(message) | MAC | counter + // + // This is much like 'encrypt' but uses a deterministic counter for the nonce, + // allowing enforcement of message order. + void encrypt_sequence(const octet* key, uint64_t counter); + void decrypt_sequence(const octet* key, uint64_t counter); + + // In-place authenticated encryption using sodium; key of length crypto_secretbox_KEYBYTES + // ciphertext = Enc(message) | MAC | nonce + void encrypt(const octet* key); + void decrypt(const octet* key); friend ostream& operator<<(ostream& s,const octetStream& o); friend class PRNG; @@ -157,5 +184,24 @@ inline void octetStream::Receive(int socket_num) receive(socket_num,data,len); } +inline void octetStream::ReceiveExpected(int socket_num, int expected) +{ + octet blen[4]; + receive(socket_num,blen,4); + + int nlen=decode_length(blen); + if (nlen != expected) { + cerr << "octetStream::ReceiveExpected: got " << nlen << + " length, expected " << expected << endl; + throw bad_value(); + } + + len=0; + resize(nlen); + len=nlen; + + receive(socket_num,data,len); +} + #endif diff --git a/Tools/parse.h b/Tools/parse.h new file mode 100644 index 00000000..0ced2e3f --- /dev/null +++ b/Tools/parse.h @@ -0,0 +1,49 @@ +/* + * parse.h + * + */ + +#ifndef TOOLS_PARSE_H_ +#define TOOLS_PARSE_H_ + +#include +#include +using namespace std; + +// Read a byte +inline int get_val(istream& s) +{ + char cc; + s.get(cc); + int a=cc; + if (a<0) { a+=256; } + return a; +} + +// Read a 4-byte integer +inline int get_int(istream& s) +{ + int n = 0; + for (int i=0; i<4; i++) + { n<<=8; + int t=get_val(s); + n+=t; + } + return n; +} + +// Read several integers +inline void get_ints(int* res, istream& s, int count) +{ + for (int i = 0; i < count; i++) + res[i] = get_int(s); +} + +inline void get_vector(int m, vector& start, istream& s) +{ + start.resize(m); + for (int i = 0; i < m; i++) + start[i] = get_int(s); +} + +#endif /* TOOLS_PARSE_H_ */ diff --git a/Tools/pprint.h b/Tools/pprint.h new file mode 100644 index 00000000..3df479f1 --- /dev/null +++ b/Tools/pprint.h @@ -0,0 +1,13 @@ + +#include +#include + +using namespace std; + +inline void pprint_bytes(const char *label, unsigned char *bytes, int len) +{ + cout << label << ": "; + for (int j = 0; j < len; j++) + cout << setfill('0') << setw(2) << hex << (int) bytes[j]; + cout << dec << endl; +} diff --git a/Tools/random.cpp b/Tools/random.cpp index 6d1f236a..c8953220 100644 --- a/Tools/random.cpp +++ b/Tools/random.cpp @@ -1,8 +1,9 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Tools/random.h" #include +#include #include using namespace std; @@ -18,9 +19,7 @@ PRNG::PRNG() : cnt(0) void PRNG::ReSeed() { - FILE* rD=fopen("/dev/urandom", "r"); - fread(seed,sizeof(octet),SEED_SIZE,rD); - fclose(rD); + randombytes_buf(seed, SEED_SIZE); InitSeed(); } diff --git a/Tools/random.h b/Tools/random.h index 0c868904..43773478 100644 --- a/Tools/random.h +++ b/Tools/random.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _random #define _random diff --git a/Tools/sha1.cpp b/Tools/sha1.cpp index 2c555c33..fe272332 100644 --- a/Tools/sha1.cpp +++ b/Tools/sha1.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * SHA1 routine optimized to do word accesses rather than byte accesses, diff --git a/Tools/sha1.h b/Tools/sha1.h index 6d2cd1bd..fad2af86 100644 --- a/Tools/sha1.h +++ b/Tools/sha1.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _SHA1 #define _SHA1 diff --git a/Tools/time-func.cpp b/Tools/time-func.cpp index 6ab28446..f8dda0fe 100644 --- a/Tools/time-func.cpp +++ b/Tools/time-func.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Tools/time-func.h" diff --git a/Tools/time-func.h b/Tools/time-func.h index f3357c23..4a2667ef 100644 --- a/Tools/time-func.h +++ b/Tools/time-func.h @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #ifndef _timer #define _timer diff --git a/check-passive.cpp b/check-passive.cpp index 53b9c4e8..7f120953 100644 --- a/check-passive.cpp +++ b/check-passive.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt #include "Math/gf2n.h" #include "Math/gfp.h" diff --git a/client-setup.cpp b/client-setup.cpp new file mode 100644 index 00000000..2c36658d --- /dev/null +++ b/client-setup.cpp @@ -0,0 +1,180 @@ +// (C) 2017 University of Bristol. See License.txt + +// Preprocessing stage to: +// Create the public/private key pairs for each client +// Create the public/private key pairs for each spdz engine +// For each client store the client keys + all spdz engine public keys +// in a file named Client-Keys-C +// For each spdz engine store the spdz engine keys + all client public keys +// in a file named Player-SPDZ-Keys-P +// + +#include + +#include "Math/gf2n.h" +#include "Math/gfp.h" +#include "Math/Share.h" +#include "Math/Setup.h" +#include "Auth/fake-stuff.h" +#include "Exceptions/Exceptions.h" + +#include "Math/Setup.h" +#include "Processor/Data_Files.h" +#include "Tools/mkpath.h" +#include "Tools/ezOptionParser.h" +#include "Tools/Config.h" + +#include +#include +using namespace std; + +static void output(const vector &vec, ofstream &of) +{ + copy(vec.begin(), vec.end(), ostreambuf_iterator(of)); +} + +int main(int argc, const char** argv) +{ + ez::ezOptionParser opt; + + opt.syntax = "./client-setup.x [OPTIONS]\n"; + + opt.add( + "0", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Number of external clients (default: nplayers)", // Help description. + "-nc", // Flag token. + "--numclients" // Flag token. + ); + opt.add( + "128", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Bit length of GF(p) field (default: 128)", // Help description. + "-lgp", // Flag token. + "--lgp" // Flag token. + ); + opt.add( + "40", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Bit length of GF(2^n) field (default: 40)", // Help description. + "-lg2", // Flag token. + "--lg2" // Flag token. + ); + opt.parse(argc, argv); + + string prep_data_prefix; + + string usage; + + int nplayers; + if (opt.firstArgs.size() == 2) + { + nplayers = atoi(opt.firstArgs[1]->c_str()); + } + else if (opt.lastArgs.size() == 1) + { + nplayers = atoi(opt.lastArgs[0]->c_str()); + } + else + { + cerr << "ERROR: invalid number of arguments\n"; + opt.getUsage(usage); + cout << usage; + return 1; + } + + int lg2, lgp, nclients; + opt.get("--numclients")->getInt(nclients); + if (nclients <= 0) + nclients = nplayers; + opt.get("--lgp")->getInt(lgp); + opt.get("--lg2")->getInt(lg2); + + cout << "nplayers = " << nplayers << endl; + cout << "nclients = " << nclients << endl; + cout << "lgp = " << lgp << endl; + cout << "lgp2 = " << lg2 << endl; + + prep_data_prefix = get_prep_dir(nplayers, lgp, lg2); + cout << "prep dir = " << prep_data_prefix << endl; + + vector client_publickeys; + vector client_secretkeys; + client_publickeys.resize(nclients); + client_secretkeys.resize(nclients); + for (int i = 0; i < nclients; i++) { + client_secretkeys[i].resize(crypto_box_SECRETKEYBYTES); + client_publickeys[i].resize(crypto_box_PUBLICKEYBYTES); + randombytes_buf(&client_secretkeys[i][0], client_secretkeys[i].size()); + crypto_scalarmult_base(&client_publickeys[i][0], &client_secretkeys[i][0]); + } + + vector client_signing_publickeys; + vector client_signing_secretkeys; + client_signing_publickeys.resize(nclients); + client_signing_secretkeys.resize(nclients); + for (int i = 0; i < nclients; i++) { + client_signing_publickeys[i].resize(crypto_sign_PUBLICKEYBYTES); + client_signing_secretkeys[i].resize(crypto_sign_SECRETKEYBYTES); + crypto_sign_keypair(&client_signing_publickeys[i][0], &client_signing_secretkeys[i][0]); + } + + vector server_publickeys; + vector server_secretkeys; + server_publickeys.resize(nplayers); + server_secretkeys.resize(nplayers); + for (int i = 0; i < nplayers; i++) { + server_publickeys[i].resize(crypto_box_PUBLICKEYBYTES); + server_secretkeys[i].resize(crypto_box_SECRETKEYBYTES); + randombytes_buf(&server_secretkeys[i][0], server_secretkeys[i].size()); + crypto_scalarmult_base(&server_publickeys[i][0], &server_secretkeys[i][0]); + } + vector server_signing_publickeys; + vector server_signing_secretkeys; + server_signing_publickeys.resize(nplayers); + server_signing_secretkeys.resize(nplayers); + for (int i = 0; i < nplayers; i++) { + server_signing_publickeys[i].resize(crypto_sign_PUBLICKEYBYTES); + server_signing_secretkeys[i].resize(crypto_sign_SECRETKEYBYTES); + crypto_sign_keypair(&server_signing_publickeys[i][0], &server_signing_secretkeys[i][0]); + } + + /* Write client files */ + for (int i = 0; i < nclients; i++) { + stringstream filename; + filename << prep_data_prefix << "Client-Keys-C" << i; + ofstream outf(filename.str().c_str()); + if (outf.fail()) + throw file_error(filename.str().c_str()); + // Write public key and secret key + output(client_publickeys[i],outf); + output(client_secretkeys[i],outf); + output(client_signing_publickeys[i],outf); + output(client_signing_secretkeys[i],outf); + int keycount = 2; + + // Write all spdz engine public keys + for (int j = 0; j < nplayers; j++) { + output(server_publickeys[j], outf); + output(server_signing_publickeys[j], outf); + keycount++; + } + outf.close(); + cout << "Wrote " << keycount << " keys to " << filename.str() << endl; + } + + /* Write spdz engine files */ + for (int i = 0; i < nplayers; i++) { + Config::write_player_config_file( prep_data_prefix, i + , server_publickeys[i], server_secretkeys[i] + , server_signing_publickeys[i], server_signing_secretkeys[i] + , client_publickeys, client_signing_publickeys + , server_publickeys, server_signing_publickeys); + } +} diff --git a/compile.py b/compile.py index 2b125872..711417af 100755 --- a/compile.py +++ b/compile.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -# (C) 2016 University of Bristol. See License.txt +# (C) 2017 University of Bristol. See License.txt # ===== Compiler usage instructions ===== @@ -60,6 +60,8 @@ def main(): help="profile compilation") parser.add_option("-C", "--continous", action="store_true", dest="continuous", help="continuous computation") + parser.add_option("-s", "--stop", action="store_true", dest="stop", + help="stop on register errors") options,args = parser.parse_args() if len(args) < 1: parser.print_help() diff --git a/ot-offline.cpp b/ot-offline.cpp index 13c35633..23c4afae 100644 --- a/ot-offline.cpp +++ b/ot-offline.cpp @@ -1,4 +1,4 @@ -// (C) 2016 University of Bristol. See License.txt +// (C) 2017 University of Bristol. See License.txt /* * OT-Offline.cpp diff --git a/tutorial.md b/tutorial.md index ef87f090..5b8a1285 100644 --- a/tutorial.md +++ b/tutorial.md @@ -1,4 +1,4 @@ -(C) 2016 University of Bristol. See License.txt +(C) 2017 University of Bristol. See License.txt Suppose we want to add 2 integers mod p in clear, where p has 128 bits and compute over 2 parties inputs: P0, P1. @@ -130,6 +130,7 @@ inputs. The executables can be found after compiling SPDZ. Customizing those should be straightforward. Make sure you copy the output files to Player-Data /Private-Input-{i} files. +There is a sockets interface to provide input and output from external client processes. See the [ExternalIO directory](./ExternalIO/README.md). Other examples ==============