From cd25c2e9f192a14be2430f48e8d4e3855cb68dc0 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Wed, 9 Nov 2022 11:21:34 +1100 Subject: [PATCH] Decision tree training. --- BMR/Register.h | 3 + CHANGELOG.md | 11 + CONFIG | 4 +- Compiler/GC/instructions.py | 49 ++ Compiler/GC/types.py | 59 +- Compiler/allocator.py | 6 +- Compiler/circuit.py | 2 - Compiler/circuit_oram.py | 3 +- Compiler/compilerLib.py | 7 + Compiler/decision_tree.py | 504 ++++++++++++++++++ Compiler/floatingpoint.py | 1 + Compiler/instructions.py | 13 + Compiler/instructions_base.py | 15 +- Compiler/library.py | 6 +- Compiler/ml.py | 17 +- Compiler/mpc_math.py | 32 ++ Compiler/non_linear.py | 2 + Compiler/oram.py | 11 +- Compiler/program.py | 27 +- Compiler/sorting.py | 17 +- Compiler/sqrt_oram.py | 21 +- Compiler/types.py | 81 ++- ECDSA/Fake-ECDSA.cpp | 1 + ECDSA/P256Element.cpp | 19 +- ECDSA/P256Element.h | 6 +- ECDSA/fake-spdz-ecdsa-party.cpp | 1 + ECDSA/hm-ecdsa-party.hpp | 3 + ECDSA/ot-ecdsa-party.hpp | 1 + ExternalIO/README.md | 2 +- FHEOffline/PairwiseSetup.cpp | 8 + FHEOffline/Prover.cpp | 1 + GC/BitAdder.hpp | 3 +- GC/BitPrepFiles.h | 6 +- GC/FakeSecret.h | 7 + GC/Instruction.h | 1 + GC/Machine.h | 2 +- GC/Machine.hpp | 4 +- GC/Memory.h | 2 + GC/NoShare.h | 1 + GC/PersonalPrep.hpp | 5 +- GC/PostSacriBin.cpp | 1 + GC/Processor.h | 1 + GC/Processor.hpp | 41 +- GC/Program.h | 2 + GC/Secret.h | 3 + GC/Semi.cpp | 36 ++ GC/Semi.h | 31 ++ GC/SemiPrep.cpp | 27 +- GC/SemiPrep.h | 6 +- GC/SemiSecret.h | 5 + GC/SemiSecret.hpp | 55 ++ GC/ShareParty.h | 2 - GC/ShareParty.hpp | 2 - GC/ShareSecret.h | 1 + GC/ShareSecret.hpp | 10 +- GC/ShareThread.h | 5 +- GC/ShareThread.hpp | 49 +- GC/ThreadMaster.hpp | 1 + GC/TinierSharePrep.hpp | 2 +- GC/TinyMC.h | 2 +- GC/TinyPrep.hpp | 2 + GC/instructions.h | 1 + License.txt | 30 +- Machines/MalRep.hpp | 2 + Machines/Rep.hpp | 5 +- Machines/dealer-ring-party.cpp | 8 +- Machines/emulate.cpp | 1 + Machines/malicious-rep-bin-party.cpp | 2 + Machines/mascot-offline.cpp | 1 + Machines/no-party.cpp | 2 + Machines/ps-rep-bin-party.cpp | 3 + Machines/real-bmr-party.cpp | 1 + Machines/replicated-bin-party.cpp | 2 + Machines/replicated-ring-party.cpp | 1 - Machines/sy-rep-field-party.cpp | 3 +- Machines/sy-rep-ring-party.cpp | 3 +- Machines/sy-shamir-party.cpp | 1 + Machines/tinier-party.cpp | 1 + Makefile | 69 ++- Math/BitVec.h | 4 + Math/Square.hpp | 11 + Math/Zp_Data.h | 4 +- Math/field_types.h | 3 +- Math/mpn_fixed.h | 14 - Networking/data.h | 2 +- OT/BaseOT.h | 2 +- OT/BitMatrix.h | 3 + OT/BitMatrix.hpp | 4 +- OT/MamaRectangle.h | 2 + OT/NPartyTripleGenerator.h | 5 +- OT/NPartyTripleGenerator.hpp | 43 +- OT/OTCorrelator.hpp | 2 +- OT/OTExtensionWithMatrix.cpp | 7 +- OT/OTExtensionWithMatrix.h | 6 +- OT/OTMultiplier.h | 2 + OT/OTMultiplier.hpp | 58 ++ Processor/BaseMachine.cpp | 16 +- Processor/BaseMachine.h | 3 +- Processor/Data_Files.hpp | 17 +- Processor/Instruction.h | 1 + Processor/Instruction.hpp | 46 +- Processor/Machine.h | 2 +- Processor/Machine.hpp | 3 +- Processor/Online-Thread.hpp | 10 +- Processor/Processor.h | 2 +- Processor/Processor.hpp | 5 +- Processor/Program.h | 2 + Processor/ThreadQueues.cpp | 26 +- Processor/instructions.h | 7 + Programs/Source/adult.mpc | 54 ++ Programs/Source/bench-dt.mpc | 32 ++ Programs/Source/benchmark_secureNN.mpc | 7 +- Programs/Source/gc_oram.mpc | 3 - Programs/Source/mnist_full_A.mpc | 1 + Programs/Source/spect.mpc | 49 ++ Programs/Source/test_gc.mpc | 2 +- Protocols/Beaver.h | 1 + Protocols/Beaver.hpp | 10 +- Protocols/DabitSacrifice.hpp | 3 +- Protocols/DealerMC.h | 2 +- Protocols/DealerMC.hpp | 4 +- Protocols/DealerPrep.hpp | 1 + Protocols/FakeProtocol.h | 47 ++ Protocols/HemiMatrixPrep.hpp | 7 +- Protocols/HemiPrep.h | 9 +- Protocols/HemiPrep.hpp | 54 ++ Protocols/HighGearKeyGen.hpp | 2 +- Protocols/LowGearKeyGen.hpp | 1 + Protocols/MAC_Check.h | 6 +- Protocols/MAC_Check.hpp | 6 +- Protocols/MAC_Check_Base.h | 2 +- Protocols/MAC_Check_Base.hpp | 2 +- Protocols/MalRepRingPrep.hpp | 3 +- Protocols/MaliciousRepPrep.hpp | 5 + Protocols/MascotPrep.hpp | 14 - Protocols/PostSacriRepRingShare.h | 1 + Protocols/ProtocolSetup.h | 8 + Protocols/Rep3Share.h | 5 +- Protocols/Rep3Share2k.h | 4 +- Protocols/Rep3Shuffler.h | 33 ++ Protocols/Rep3Shuffler.hpp | 131 +++++ Protocols/Replicated.h | 7 + Protocols/Replicated.hpp | 9 +- Protocols/ReplicatedInput.h | 6 +- Protocols/ReplicatedInput.hpp | 3 +- Protocols/ReplicatedPrep.h | 2 + Protocols/ReplicatedPrep.hpp | 96 +++- Protocols/SecureShuffle.hpp | 2 +- Protocols/Semi.h | 15 +- Protocols/SemiInput.h | 18 +- Protocols/SemiInput.hpp | 22 +- Protocols/SemiMC.h | 6 +- Protocols/SemiMC.hpp | 29 +- Protocols/SemiPrep.h | 10 +- Protocols/SemiPrep.hpp | 31 +- Protocols/SemiPrep2k.h | 6 + .../{ReplicatedPrep2k.h => SemiRep3Prep.h} | 21 +- Protocols/Shamir.h | 3 +- Protocols/Shamir.hpp | 22 +- Protocols/ShamirInput.h | 14 +- Protocols/ShamirInput.hpp | 46 +- Protocols/ShamirMC.h | 2 +- Protocols/ShamirMC.hpp | 2 +- Protocols/ShamirShare.h | 1 + Protocols/ShuffleSacrifice.hpp | 6 +- Protocols/SpdzWiseMC.h | 2 +- README.md | 12 +- Scripts/build.sh | 3 +- Scripts/compile-for-emulation.sh | 3 + Scripts/emulate-append.sh | 7 + Scripts/run-common.sh | 11 +- Scripts/test_tutorial.sh | 1 + Scripts/tldr.sh | 5 + Tools/ExecutionStats.cpp | 5 +- Tools/names.cpp | 2 +- Utils/Check-Offline.cpp | 1 + Utils/binary-example.cpp | 1 + Utils/l2h-example.cpp | 1 + azure-pipelines.yml | 2 +- doc/Compiler.rst | 17 + doc/Doxyfile | 2 +- doc/compilation.rst | 5 + doc/index.rst | 1 + doc/io.rst | 2 + doc/machine-learning.rst | 3 + doc/non-linear.rst | 6 +- doc/troubleshooting.rst | 10 + 187 files changed, 2357 insertions(+), 329 deletions(-) create mode 100644 Compiler/decision_tree.py create mode 100644 GC/Semi.cpp create mode 100644 GC/Semi.h create mode 100644 Programs/Source/adult.mpc create mode 100644 Programs/Source/bench-dt.mpc create mode 100644 Programs/Source/spect.mpc create mode 100644 Protocols/Rep3Shuffler.h create mode 100644 Protocols/Rep3Shuffler.hpp rename Protocols/{ReplicatedPrep2k.h => SemiRep3Prep.h} (51%) create mode 100755 Scripts/compile-for-emulation.sh create mode 100755 Scripts/emulate-append.sh diff --git a/BMR/Register.h b/BMR/Register.h index 6a15a720..4def6590 100644 --- a/BMR/Register.h +++ b/BMR/Register.h @@ -235,6 +235,9 @@ public: template static void ands(T& processor, const vector& args) { processor.ands(args); } template + static void andrsvec(T& processor, const vector& args) + { processor.andrsvec(args); } + template static void xors(T& processor, const vector& args) { processor.xors(args); } template static void inputb(T& processor, const vector& args) { processor.input(args); } diff --git a/CHANGELOG.md b/CHANGELOG.md index e8e01534..f201d464 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,16 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. +## 0.3.4 (Nov 9, 2022) + +- Decision tree learning +- Optimized oblivious shuffle in Rep3 +- Optimized daBit generation in Rep3 and semi-honest HE-based 2PC +- Optimized element-vector AND in SemiBin +- Optimized input protocol in Shamir-based protocols +- Square-root ORAM (@Quitlox) +- Improved ORAM in binary circuits +- UTF-8 outputs + ## 0.3.3 (Aug 25, 2022) - Use SoftSpokenOT to avoid unclear security of KOS OT extension candidate diff --git a/CONFIG b/CONFIG index fb9db200..0d41c9ef 100644 --- a/CONFIG +++ b/CONFIG @@ -67,8 +67,11 @@ endif # MOD = -DMAX_MOD_SZ=10 -DGFP_MOD_SZ=5 LDLIBS = -lmpirxx -lmpir -lsodium $(MY_LDLIBS) +LDLIBS += -Wl,-rpath -Wl,$(CURDIR)/local/lib -L$(CURDIR)/local/lib LDLIBS += -lboost_system -lssl -lcrypto +CFLAGS += -I./local/include + ifeq ($(USE_NTL),1) CFLAGS += -DUSE_NTL LDLIBS := -lntl $(LDLIBS) @@ -100,5 +103,4 @@ ifeq ($(USE_KOS),1) CFLAGS += -DUSE_KOS else CFLAGS += -std=c++17 -LDLIBS += -llibOTe -lcryptoTools endif diff --git a/Compiler/GC/instructions.py b/Compiler/GC/instructions.py index 2b5ec46a..73a8af21 100644 --- a/Compiler/GC/instructions.py +++ b/Compiler/GC/instructions.py @@ -13,6 +13,7 @@ import Compiler.instructions as spdz import Compiler.tools as tools import collections import itertools +import math class SecretBitsAF(base.RegisterArgFormat): reg_type = 'sb' @@ -50,6 +51,7 @@ opcodes = dict( INPUTBVEC = 0x247, SPLIT = 0x248, CONVCBIT2S = 0x249, + ANDRSVEC = 0x24a, XORCBI = 0x210, BITDECC = 0x211, NOTCB = 0x212, @@ -155,6 +157,52 @@ class andrs(BinaryVectorInstruction): def add_usage(self, req_node): req_node.increment(('bit', 'triple'), sum(self.args[::4])) + req_node.increment(('bit', 'mixed'), + sum(int(math.ceil(x / 64)) for x in self.args[::4])) + +class andrsvec(base.VarArgsInstruction, base.Mergeable, + base.DynFormatInstruction): + """ Constant-vector AND of secret bit registers (vectorized version). + + :param: total number of arguments to follow (int) + :param: number of arguments to follow for one operation / + operation vector size plus three (int) + :param: vector size (int) + :param: result vector (sbit) + :param: (repeat)... + :param: constant operand (sbits) + :param: vector operand + :param: (repeat)... + :param: (repeat from number of arguments to follow for one operation)... + + """ + code = opcodes['ANDRSVEC'] + + def __init__(self, *args, **kwargs): + super(andrsvec, self).__init__(*args, **kwargs) + for i, n in self.bases(iter(self.args)): + size = self.args[i + 1] + for x in self.args[i + 2:i + n]: + assert x.n == size + + @classmethod + def dynamic_arg_format(cls, args): + yield 'int' + for i, n in cls.bases(args): + yield 'int' + n_args = (n - 3) // 2 + assert n_args > 0 + for j in range(n_args): + yield 'sbw' + for j in range(n_args + 1): + yield 'sb' + yield 'int' + + def add_usage(self, req_node): + for i, n in self.bases(iter(self.args)): + size = self.args[i + 1] + req_node.increment(('bit', 'triple'), size * (n - 3) // 2) + req_node.increment(('bit', 'mixed'), size) class ands(BinaryVectorInstruction): """ Bitwise AND of secret bit register vector. @@ -605,6 +653,7 @@ class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction, for i, n in cls.bases(args): yield 'int' yield 'p' + assert n > 3 for j in range(n - 3): yield 'sbw' yield 'int' diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index d092a474..e895061a 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -652,7 +652,7 @@ class sbitvec(_vec, _bit): You can access the rows by member :py:obj:`v` and the columns by calling :py:obj:`elements`. - There are three ways to create an instance: + There are four ways to create an instance: 1. By transposition:: @@ -685,6 +685,11 @@ class sbitvec(_vec, _bit): This should output:: [1, 0, 1] + + 4. Private input:: + + x = sbitvec.get_type(32).get_input_from(player) + """ bit_extend = staticmethod(lambda v, n: v[:n] + [0] * (n - len(v))) is_clear = False @@ -904,6 +909,34 @@ class sbitvec(_vec, _bit): def __mul__(self, other): if isinstance(other, int): return self.from_vec(x * other for x in self.v) + if isinstance(other, sbitvec): + if len(other.v) == 1: + other = other.v[0] + elif len(self.v) == 1: + self, other = other, self.v[0] + else: + raise CompilerError('no operand of lenght 1: %d/%d', + (len(self.v), len(other.v))) + if not isinstance(other, sbits): + return NotImplemented + ops = [] + for x in self.v: + if not util.is_zero(x): + assert x.n == other.n + ops.append(x) + if ops: + prods = [sbits.get_type(other.n)() for i in ops] + inst.andrsvec(3 + 2 * len(ops), other.n, *prods, other, *ops) + res = [] + i = 0 + for x in self.v: + if util.is_zero(x): + res.append(0) + else: + res.append(prods[i]) + i += 1 + return sbitvec.from_vec(res) + __rmul__ = __mul__ def __add__(self, other): return self.from_vec(x + y for x, y in zip(self.v, other)) def bit_and(self, other): @@ -945,6 +978,13 @@ class sbitvec(_vec, _bit): else: res.append([x.expand(m) if (expand and isinstance(x, bits)) else x for x in y.v]) return res + def demux(self): + if len(self) == 1: + return sbitvec.from_vec([self.v[0].bit_not(), self.v[0]]) + a = sbitvec.from_vec(self.v[:len(self) // 2]).demux() + b = sbitvec.from_vec(self.v[len(self) // 2:]).demux() + prod = [a * bb for bb in b.v] + return sbitvec.from_vec(reduce(operator.add, (x.v for x in prod))) class bit(object): n = 1 @@ -1243,20 +1283,19 @@ class sbitintvec(sbitvec, _bitint, _number, _sbitintbase): return other * self.v[0] elif isinstance(other, sbitfixvec): return NotImplemented - _, other_bits = self.expand(other, False) + my_bits, other_bits = self.expand(other, False) + matrix = [] m = float('inf') - for x in itertools.chain(self.v, other_bits): + for x in itertools.chain(my_bits, other_bits): try: m = min(m, x.n) except: pass - if m == 1: - op = operator.mul - else: - op = operator.and_ - matrix = [] for i, b in enumerate(other_bits): - matrix.append([op(x, b) for x in self.v[:len(self.v)-i]]) + if m == 1: + matrix.append([x * b for x in my_bits[:len(self.v)-i]]) + else: + matrix.append((sbitvec.from_vec(my_bits[:len(self.v)-i]) * b).v) v = sbitint.wallace_tree_from_matrix(matrix) return self.from_vec(v[:len(self.v)]) __rmul__ = __mul__ @@ -1366,7 +1405,7 @@ class sbitfix(_fix): cls.set_precision(f, k) return cls._new(cls.int_type(other), k, f) -class sbitfixvec(_fix): +class sbitfixvec(_fix, _vec): """ Vector of fixed-point numbers for parallel binary computation. Use :py:obj:`set_precision()` to change the precision. diff --git a/Compiler/allocator.py b/Compiler/allocator.py index bf431ca3..e5c99a7b 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -261,6 +261,7 @@ class Merger: instructions = self.instructions merge_nodes = self.open_nodes depths = self.depths + self.req_num = defaultdict(lambda: 0) if not merge_nodes: return 0 @@ -281,6 +282,7 @@ class Merger: print('Merging %d %s in round %d/%d' % \ (len(merge), t.__name__, i, len(merges))) self.do_merge(merge) + self.req_num[t.__name__, 'round'] += 1 preorder = None @@ -530,7 +532,9 @@ class Merger: can_eliminate_defs = True for reg in inst.get_def(): for dup in reg.duplicates: - if not dup.can_eliminate: + if not (dup.can_eliminate and reduce( + operator.and_, + (x.can_eliminate for x in dup.vector), True)): can_eliminate_defs = False break # remove if instruction has result that isn't used diff --git a/Compiler/circuit.py b/Compiler/circuit.py index 9c4187f7..41e4df9e 100644 --- a/Compiler/circuit.py +++ b/Compiler/circuit.py @@ -137,8 +137,6 @@ def sha3_256(x): 0x4a43f8804b0ad882fa493be44dff80f562d661a05647c15166d71ebff8c6ffa7 0xf0d7aa0ab2d92d580bb080e17cbb52627932ba37f085d3931270d31c39357067 - Note that :py:obj:`sint` to :py:obj:`sbitvec` conversion is only - implemented for computation modulo a power of two. """ global Keccak_f diff --git a/Compiler/circuit_oram.py b/Compiler/circuit_oram.py index f5ddebfd..a2cada54 100644 --- a/Compiler/circuit_oram.py +++ b/Compiler/circuit_oram.py @@ -1,5 +1,6 @@ -from Compiler.path_oram import * +from Compiler.oram import * +from Compiler.path_oram import PathORAM, XOR from Compiler.util import bit_compose def first_diff(a_bits, b_bits): diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index 4a4706ff..c304ebc4 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -125,6 +125,13 @@ class Compiler: default=defaults.binary, help="bit length of sint in binary circuit (default: 0 for arithmetic)", ) + parser.add_option( + "-G", + "--garbled-circuit", + dest="garbled", + action="store_true", + help="compile for binary circuits only (default: false)", + ) parser.add_option( "-F", "--field", diff --git a/Compiler/decision_tree.py b/Compiler/decision_tree.py new file mode 100644 index 00000000..89e3fe5c --- /dev/null +++ b/Compiler/decision_tree.py @@ -0,0 +1,504 @@ +from Compiler.types import * +from Compiler.sorting import * +from Compiler.library import * +from Compiler import util, oram + +from itertools import accumulate +import math + +debug = False +debug_split = False +debug_layers = False +max_leaves = None + +def get_type(x): + if isinstance(x, (Array, SubMultiArray)): + return x.value_type + elif isinstance(x, (tuple, list)): + x = x[0] + x[-1] + if util.is_constant(x): + return cint + else: + return type(x) + else: + return type(x) + +def PrefixSum(x): + return x.get_vector().prefix_sum() + +def PrefixSumR(x): + tmp = get_type(x).Array(len(x)) + tmp.assign_vector(x) + break_point() + tmp[:] = tmp.get_reverse_vector().prefix_sum() + break_point() + return tmp.get_reverse_vector() + +def PrefixSum_inv(x): + tmp = get_type(x).Array(len(x) + 1) + tmp.assign_vector(x, base=1) + tmp[0] = 0 + return tmp.get_vector(size=len(x), base=1) - tmp.get_vector(size=len(x)) + +def PrefixSumR_inv(x): + tmp = get_type(x).Array(len(x) + 1) + tmp.assign_vector(x) + tmp[-1] = 0 + return tmp.get_vector(size=len(x)) - tmp.get_vector(base=1, size=len(x)) + +class SortPerm: + def __init__(self, x): + B = sint.Matrix(len(x), 2) + B.set_column(0, 1 - x.get_vector()) + B.set_column(1, x.get_vector()) + self.perm = Array.create_from(dest_comp(B)) + def apply(self, x): + res = Array.create_from(x) + reveal_sort(self.perm, res, False) + return res + def unapply(self, x): + res = Array.create_from(x) + reveal_sort(self.perm, res, True) + return res + +def Sort(keys, *to_sort, n_bits=None, time=False): + if time: + start_timer(1) + for k in keys: + assert len(k) == len(keys[0]) + n_bits = n_bits or [None] * len(keys) + bs = Matrix.create_from( + sum([k.get_vector().bit_decompose(nb) + for k, nb in reversed(list(zip(keys, n_bits)))], [])) + res = Matrix.create_from(to_sort) + res = res.transpose() + if time: + start_timer(11) + print_ln('sort') + radix_sort_from_matrix(bs, res) + if time: + stop_timer(11) + stop_timer(1) + return res.transpose() + +def VectMax(key, *data): + def reducer(x, y): + b = x[0] > y[0] + return [b.if_else(xx, yy) for xx, yy in zip(x, y)] + if debug: + key = list(key) + data = [list(x) for x in data] + print_ln('vect max key=%s data=%s', util.reveal(key), util.reveal(data)) + return util.tree_reduce(reducer, zip(key, *data))[1:] + +def GroupSum(g, x): + assert len(g) == len(x) + p = PrefixSumR(x) * g + pi = SortPerm(g.get_vector().bit_not()) + p1 = pi.apply(p) + s1 = PrefixSumR_inv(p1) + d1 = PrefixSum_inv(s1) + d = pi.unapply(d1) * g + return PrefixSum(d) + +def GroupPrefixSum(g, x): + assert len(g) == len(x) + s = get_type(x).Array(len(x) + 1) + s[0] = 0 + s.assign_vector(PrefixSum(x), base=1) + q = get_type(s).Array(len(x)) + q.assign_vector(s.get_vector(size=len(x)) * g) + return s.get_vector(size=len(x), base=1) - GroupSum(g, q) + +def GroupMax(g, keys, *x): + if debug: + print_ln('group max input g=%s keys=%s x=%s', util.reveal(g), + util.reveal(keys), util.reveal(x)) + assert len(keys) == len(g) + for xx in x: + assert len(xx) == len(g) + n = len(g) + m = int(math.ceil(math.log(n, 2))) + keys = Array.create_from(keys) + x = [Array.create_from(xx) for xx in x] + g_new = Array.create_from(g) + g_old = g_new.same_shape() + for d in range(m): + w = 2 ** d + g_old[:] = g_new[:] + break_point() + vsize = n - w + g_new.assign_vector(g_old.get_vector(size=vsize).bit_or( + g_old.get_vector(size=vsize, base=w)), base=w) + b = keys.get_vector(size=vsize) > keys.get_vector(size=vsize, base=w) + for xx in [keys] + x: + a = b.if_else(xx.get_vector(size=vsize), + xx.get_vector(size=vsize, base=w)) + xx.assign_vector(g_old.get_vector(size=vsize, base=w).if_else( + xx.get_vector(size=vsize, base=w), a), base=w) + break_point() + if debug: + print_ln('group max w=%s b=%s a=%s keys=%s x=%s g=%s', w, b.reveal(), + util.reveal(a), util.reveal(keys), + util.reveal(x), g_new.reveal()) + t = sint.Array(len(g)) + t[-1] = 1 + t.assign_vector(g.get_vector(size=n - 1, base=1)) + if debug: + print_ln('group max end g=%s t=%s keys=%s x=%s', util.reveal(g), + util.reveal(t), util.reveal(keys), util.reveal(x)) + return [GroupSum(g, t[:] * xx) for xx in [keys] + x] + +def ModifiedGini(g, y, debug=False): + assert len(g) == len(y) + y = [y.get_vector().bit_not(), y] + u = [GroupPrefixSum(g, yy) for yy in y] + s = [GroupSum(g, yy) for yy in y] + w = [ss - uu for ss, uu in zip(s, u)] + us = sum(u) + ws = sum(w) + uqs = u[0] ** 2 + u[1] ** 2 + wqs = w[0] ** 2 + w[1] ** 2 + res = sfix(uqs) / us + sfix(wqs) / ws + if debug: + print_ln('u0=%s', util.reveal(u[0])) + print_ln('u0=%s', util.reveal(u[1])) + print_ln('us=%s', util.reveal(us)) + print_ln('w0=%s', util.reveal(w[0])) + print_ln('w1=%s', util.reveal(w[1])) + print_ln('ws=%s', util.reveal(ws)) + print_ln('p=%s', util.reveal(p)) + print_ln('q=%s', util.reveal(q)) + print_ln('g=%s y=%s s=%s', + util.reveal(g), util.reveal(y), + util.reveal(s)) + if debug: + print_ln('gini %s %s', str(res), util.reveal(res)) + return res + +MIN_VALUE = -10000 + +def FormatLayer(h, g, *a): + return CropLayer(h, *FormatLayer_without_crop(g, *a)) + +def FormatLayer_without_crop(g, *a): + for x in a: + assert len(x) == len(g) + v = [g.if_else(aa, 0) for aa in a] + v = Sort([g.bit_not()], *v, n_bits=[1]) + return v + +def CropLayer(k, *v): + if max_leaves: + n = min(2 ** k, max_leaves) + else: + n = 2 ** k + return [vv[:min(n, len(vv))] for vv in v] + +def TrainLeafNodes(h, g, y, NID): + assert len(g) == len(y) + assert len(g) == len(NID) + Label = GroupSum(g, y.bit_not()) < GroupSum(g, y) + return FormatLayer(h, g, NID, Label) + +def GroupSame(g, y): + assert len(g) == len(y) + s = GroupSum(g, [sint(1)] * len(g)) + s0 = GroupSum(g, y.bit_not()) + s1 = GroupSum(g, y) + if debug_split: + print_ln('group same g=%s', util.reveal(g)) + print_ln('group same y=%s', util.reveal(y)) + return (s == s0).bit_or(s == s1) + +def GroupFirstOne(g, b): + assert len(g) == len(b) + s = GroupPrefixSum(g, b) + return s * b == 1 + +class TreeTrainer: + """ Decision tree training by `Hamada et al.`_ + + :param x: sample data (by attribute, list or + :py:obj:`~Compiler.types.Matrix`) + :param y: binary labels (list or sint vector) + :param h: height (int) + :param binary: binary attributes instead of continuous + :param attr_lengths: attribute description for mixed data + (list of 0/1 for continuous/binary) + :param n_threads: number of threads (default: single thread) + + .. _`Hamada et al.`: https://arxiv.org/abs/2112.12906 + + """ + def ApplyTests(self, x, AID, Threshold): + m = len(x) + n = len(AID) + assert len(AID) == len(Threshold) + for xx in x: + assert len(xx) == len(AID) + e = sint.Matrix(m, n) + AID = Array.create_from(AID) + @for_range_multithread(self.n_threads, 1, m) + def _(j): + e[j][:] = AID[:] == j + xx = sum(x[j] * e[j] for j in range(m)) + if debug: + print_ln('apply e=%s xx=%s', util.reveal(e), util.reveal(xx)) + return 2 * xx < Threshold + + def AttributeWiseTestSelection(self, g, x, y, time=False, debug=False): + assert len(g) == len(x) + assert len(g) == len(y) + if time: + start_timer(2) + s = ModifiedGini(g, y, debug=debug) + if time: + stop_timer(2) + if debug: + print_ln('gini %s', s.reveal()) + xx = x + t = get_type(x).Array(len(x)) + t[-1] = MIN_VALUE + t.assign_vector(xx.get_vector(size=len(x) - 1) + \ + xx.get_vector(size=len(x) - 1, base=1)) + gg = g + p = sint.Array(len(x)) + p[-1] = 1 + p.assign_vector(gg.get_vector(base=1, size=len(x) - 1).bit_or( + xx.get_vector(size=len(x) - 1) == \ + xx.get_vector(size=len(x) - 1, base=1))) + break_point() + if debug: + print_ln('attribute t=%s p=%s', util.reveal(t), util.reveal(p)) + s = p[:].if_else(MIN_VALUE, s) + t = p[:].if_else(MIN_VALUE, t[:]) + if debug: + print_ln('attribute s=%s t=%s', util.reveal(s), util.reveal(t)) + if time: + start_timer(3) + s, t = GroupMax(gg, s, t) + if time: + stop_timer(3) + if debug: + print_ln('attribute s=%s t=%s', util.reveal(s), util.reveal(t)) + return t, s + + def GlobalTestSelection(self, x, y, g): + assert len(y) == len(g) + for xx in x: + assert(len(xx) == len(g)) + m = len(x) + n = len(y) + u, t = [get_type(x).Matrix(m, n) for i in range(2)] + v = get_type(y).Matrix(m, n) + s = sfix.Matrix(m, n) + @for_range_multithread(self.n_threads, 1, m) + def _(j): + single = not self.n_threads or self.n_threads == 1 + print_ln('run %s', j) + @if_e(self.attr_lengths[j]) + def _(): + u[j][:], v[j][:] = Sort((PrefixSum(g), x[j]), x[j], y, + n_bits=[util.log2(n), 1], time=single) + @else_ + def _(): + u[j][:], v[j][:] = Sort((PrefixSum(g), x[j]), x[j], y, + n_bits=[util.log2(n), None], + time=single) + if self.debug_threading: + print_ln('global sort %s %s %s', j, util.reveal(u[j]), + util.reveal(v[j])) + t[j][:], s[j][:] = self.AttributeWiseTestSelection( + g, u[j], v[j], time=single, debug=self.debug_selection) + if self.debug_threading: + print_ln('global attribute %s %s %s', j, util.reveal(t[j]), + util.reveal(s[j])) + n = len(g) + a, tt = [sint.Array(n) for i in range(2)] + if self.debug_threading: + print_ln('global s=%s', util.reveal(s)) + if self.debug_gini: + print_ln('Gini indices ' + ' '.join(str(i) + ':%s' for i in range(m)), + *(ss[0].reveal() for ss in s)) + start_timer(4) + a[:], tt[:] = VectMax((s[j][:] for j in range(m)), range(m), + (t[j][:] for j in range(m))) + stop_timer(4) + return a[:], tt[:] + + def TrainInternalNodes(self, k, x, y, g, NID): + assert len(g) == len(y) + for xx in x: + assert len(xx) == len(g) + AID, Threshold = self.GlobalTestSelection(x, y, g) + s = GroupSame(g[:], y[:]) + if debug or debug_split: + print_ln('AID=%s', util.reveal(AID)) + print_ln('Threshold=%s', util.reveal(Threshold)) + print_ln('GroupSame=%s', util.reveal(s)) + AID, Threshold = s.if_else(0, AID), s.if_else(MIN_VALUE, Threshold) + b = self.ApplyTests(x, AID, Threshold) + return FormatLayer_without_crop(g[:], NID, AID, Threshold), b + + @method_block + def train_layer(self, k): + x = self.x + y = self.y + g = self.g + NID = self.NID + layer_matrix = self.layer_matrix + self.layer_matrix[k], b = \ + self.TrainInternalNodes(k, x, y, g, NID) + if debug: + print_ln('internal %s %s', + util.reveal(layer_matrix[k]), util.reveal(b)) + if debug_layers: + print_ln('layer %s:', k) + for name, data in zip(('NID', 'AID', 'Thr'), layer_matrix[k]): + print_ln(' %s: %s', name, data.reveal()) + NID[:] = 2 ** k * b + NID + b_not = b.bit_not() + if debug: + print_ln('b_not=%s', b_not.reveal()) + g[:] = GroupFirstOne(g, b_not) + GroupFirstOne(g, b) + y[:], g[:], NID[:], *xx = Sort([b], y, g, NID, *x, n_bits=[1]) + for i, xxx in enumerate(xx): + x[i] = xxx + + def __init__(self, x, y, h, binary=False, attr_lengths=None, + n_threads=None): + assert not (binary and attr_lengths) + if binary: + attr_lengths = [1] * len(x) + else: + attr_lengths = attr_lengths or ([0] * len(x)) + for l in attr_lengths: + assert l in (0, 1) + self.attr_lengths = Array.create_from(regint(attr_lengths)) + Array.check_indices = False + Matrix.disable_index_checks() + for xx in x: + assert len(xx) == len(y) + n = len(y) + self.g = sint.Array(n) + self.g.assign_all(0) + self.g[0] = 1 + self.NID = sint.Array(n) + self.NID.assign_all(1) + self.y = Array.create_from(y) + self.x = Matrix.create_from(x) + self.layer_matrix = sint.Tensor([h, 3, n]) + self.n_threads = n_threads + self.debug_selection = False + self.debug_threading = False + self.debug_gini = True + + def train(self): + """ Train and return decision tree. """ + h = len(self.layer_matrix) + @for_range(h) + def _(k): + self.train_layer(k) + return self.get_tree(h) + + def train_with_testing(self, *test_set): + """ Train decision tree and test against test data. + + :param y: binary labels (list or sint vector) + :param x: sample data (by attribute, list or + :py:obj:`~Compiler.types.Matrix`) + :returns: tree + + """ + for k in range(len(self.layer_matrix)): + self.train_layer(k) + tree = self.get_tree(k + 1) + output_decision_tree(tree) + test_decision_tree('train', tree, self.y, self.x, + n_threads=self.n_threads) + if test_set: + test_decision_tree('test', tree, *test_set, + n_threads=self.n_threads) + return tree + + def get_tree(self, h): + Layer = [None] * (h + 1) + for k in range(h): + Layer[k] = CropLayer(k, *self.layer_matrix[k]) + Layer[h] = TrainLeafNodes(h, self.g[:], self.y[:], self.NID) + return Layer + +def DecisionTreeTraining(x, y, h, binary=False): + return TreeTrainer(x, y, h, binary=binary).train() + +def output_decision_tree(layers): + """ Print decision tree output by :py:class:`TreeTrainer`. """ + print_ln('full model %s', util.reveal(layers)) + for i, layer in enumerate(layers[:-1]): + print_ln('level %s:', i) + for j, x in enumerate(('NID', 'AID', 'Thr')): + print_ln(' %s: %s', x, util.reveal(layer[j])) + print_ln('leaves:') + for j, x in enumerate(('NID', 'result')): + print_ln(' %s: %s', x, util.reveal(layers[-1][j])) + +def pick(bits, x): + if len(bits) == 1: + return bits[0] * x[0] + else: + try: + return x[0].dot_product(bits, x) + except: + return sum(aa * bb for aa, bb in zip(bits, x)) + +def run_decision_tree(layers, data): + """ Run decision tree against sample data. + + :param layers: tree output by :py:class:`TreeTrainer` + :param data: sample data (:py:class:`~Compiler.types.Array`) + :returns: binary label + + """ + h = len(layers) - 1 + index = 1 + for k, layer in enumerate(layers[:-1]): + assert len(layer) == 3 + for x in layer: + assert len(x) <= 2 ** k + bits = layer[0].equal(index, k) + threshold = pick(bits, layer[2]) + key_index = pick(bits, layer[1]) + if key_index.is_clear: + key = data[key_index] + else: + key = pick( + oram.demux(key_index.bit_decompose(util.log2(len(data)))), data) + child = 2 * key < threshold + index += child * 2 ** k + bits = layers[h][0].equal(index, h) + return pick(bits, layers[h][1]) + +def test_decision_tree(name, layers, y, x, n_threads=None): + start_timer(100) + n = len(y) + x = x.transpose().reveal() + y = y.reveal() + guess = regint.Array(n) + truth = regint.Array(n) + correct = regint.Array(2) + parts = regint.Array(2) + layers = [Matrix.create_from(util.reveal(layer)) for layer in layers] + @for_range_multithread(n_threads, 1, n) + def _(i): + guess[i] = run_decision_tree([[part[:] for part in layer] + for layer in layers], x[i]).reveal() + truth[i] = y[i].reveal() + @for_range(n) + def _(i): + parts[truth[i]] += 1 + c = (guess[i].bit_xor(truth[i]).bit_not()) + correct[truth[i]] += c + print_ln('%s for height %s: %s/%s (%s/%s, %s/%s)', name, len(layers) - 1, + sum(correct), n, correct[0], parts[0], correct[1], parts[1]) + stop_timer(100) diff --git a/Compiler/floatingpoint.py b/Compiler/floatingpoint.py index 94a47f1b..7786f73c 100644 --- a/Compiler/floatingpoint.py +++ b/Compiler/floatingpoint.py @@ -311,6 +311,7 @@ def BitDecField(a, k, m, kappa, bits_to_compute=None): @instructions_base.ret_cisc def Pow2(a, l, kappa): + comparison.program.curr_tape.require_bit_length(l - 1) m = int(ceil(log(l, 2))) t = BitDec(a, m, m, kappa) return Pow2_from_bits(t) diff --git a/Compiler/instructions.py b/Compiler/instructions.py index a8894a0d..c5131832 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -614,6 +614,18 @@ class submr(base.SubBase): code = base.opcodes['SUBMR'] arg_format = ['sw','c','s'] +@base.vectorize +class prefixsums(base.Instruction): + """ Prefix sum. + + :param: result (sint) + :param: input (sint) + + """ + __slots__ = [] + code = base.opcodes['PREFIXSUMS'] + arg_format = ['sw','s'] + @base.gf2n @base.vectorize class mulc(base.MulBase): @@ -2301,6 +2313,7 @@ class dotprods(base.VarArgsInstruction, base.DataInstruction, yield 'int' for i, n in self.bases(args): yield 's' + field + 'w' + assert n > 2 for j in range(n - 2): yield 's' + field yield 'int' diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index 7a47c46c..f811e47c 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -80,6 +80,7 @@ opcodes = dict( SUBSI = 0x2A, SUBCFI = 0x2B, SUBSFI = 0x2C, + PREFIXSUMS = 0x2D, # Multiplication/division MULC = 0x30, MULM = 0x31, @@ -702,10 +703,16 @@ class ClearIntAF(RegisterArgFormat): reg_type = RegType.ClearInt class IntArgFormat(ArgFormat): + n_bits = 32 + @classmethod def check(cls, arg): - if not isinstance(arg, int) and not arg is None: - raise ArgumentError(arg, 'Expected an integer-valued argument') + if not arg is None: + if not isinstance(arg, int): + raise ArgumentError(arg, 'Expected an integer-valued argument') + if arg >= 2 ** cls.n_bits or arg < -2 ** cls.n_bits: + raise ArgumentError( + arg, 'Immediate value outside of %d-bit range' % cls.n_bits) @classmethod def encode(cls, arg): @@ -718,6 +725,8 @@ class IntArgFormat(ArgFormat): return str(self.i) class LongArgFormat(IntArgFormat): + n_bits = 64 + @classmethod def encode(cls, arg): return list(struct.pack('>Q', arg)) @@ -729,8 +738,6 @@ class ImmediateModpAF(IntArgFormat): @classmethod def check(cls, arg): super(ImmediateModpAF, cls).check(arg) - if arg >= 2**32 or arg < -2**32: - raise ArgumentError(arg, 'Immediate value outside of 32-bit range') class ImmediateGF2NAF(IntArgFormat): @classmethod diff --git a/Compiler/library.py b/Compiler/library.py index 42a5826d..1f1fd88c 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -139,7 +139,7 @@ def print_str_if(cond, ss, *args): """ Print string conditionally. See :py:func:`print_ln_if` for details. """ if util.is_constant(cond): if cond: - print_ln(ss, *args) + print_str(ss, *args) else: subs = ss.split('%s') assert len(subs) == len(args) + 1 @@ -1021,9 +1021,11 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [], def f(i): state = tuplify(initializer()) start_block = get_block() + j = i * n_parallel + one = regint(1) for k in range(n_parallel): - j = i * n_parallel + k state = reducer(tuplify(loop_body(j)), state) + j += one if n_parallel > 1 and start_block != get_block(): print('WARNING: parallelization broken ' 'by control flow instruction') diff --git a/Compiler/ml.py b/Compiler/ml.py index 173c2eac..bc93933d 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -73,8 +73,13 @@ from functools import reduce def log_e(x): return mpc_math.log_fx(x, math.e) +use_mux = False + def exp(x): - return mpc_math.pow_fx(math.e, x) + if use_mux: + return mpc_math.mux_exp(math.e, x) + else: + return mpc_math.pow_fx(math.e, x) def get_limit(x): exp_limit = 2 ** (x.k - x.f - 1) @@ -164,13 +169,16 @@ def softmax(x): return softmax_from_exp(exp_for_softmax(x)[0]) def exp_for_softmax(x): - m = util.max(x) - get_limit(x[0]) + 1 + math.log(len(x), 2) + m = util.max(x) - get_limit(x[0]) + math.log(len(x)) mv = m.expand_to_vector(len(x)) try: x = x.get_vector() except AttributeError: x = sfix(x) - return (x - mv > -get_limit(x)).if_else(exp(x - mv), 0), m + if use_mux: + return exp(x - mv), m + else: + return (x - mv > -get_limit(x)).if_else(exp(x - mv), 0), m def softmax_from_exp(x): return x / sum(x) @@ -2002,6 +2010,9 @@ class Optimizer: return res def __init__(self, report_loss=None): + if get_program().options.binary: + raise CompilerError( + 'machine learning code not compatible with binary circuits') self.tol = 0.000 self.report_loss = report_loss self.X_by_label = None diff --git a/Compiler/mpc_math.py b/Compiler/mpc_math.py index a16214a8..8b5836bc 100644 --- a/Compiler/mpc_math.py +++ b/Compiler/mpc_math.py @@ -8,6 +8,8 @@ This has to imported explicitly. import math +import operator +from functools import reduce from Compiler import floatingpoint from Compiler import types from Compiler import comparison @@ -398,6 +400,36 @@ def exp2_fx(a, zero_output=False, as19=False): return s.if_else(1 / g, g) +def mux_exp(x, y, block_size=8): + assert util.is_constant_float(x) + from Compiler.GC.types import sbitvec, sbits + bits = sbitvec.from_vec(y.v.bit_decompose(y.k, maybe_mixed=True)).v + sign = bits[-1] + m = math.log(2 ** (y.k - y.f - 1), x) + del bits[int(math.ceil(math.log(m, 2))) + y.f:] + parts = [] + for i in range(0, len(bits), block_size): + one_hot = sbitvec.from_vec(bits[i:i + block_size]).demux().v + exp = [] + try: + for j in range(len(one_hot)): + exp.append(types.cfix.int_rep(x ** (j * 2 ** (i - y.f)), y.f)) + except OverflowError: + pass + exp = list(filter(lambda x: x < 2 ** (y.k - 1), exp)) + bin_part = [0] * max(x.bit_length() for x in exp) + for j in range(len(bin_part)): + for k, (a, b) in enumerate(zip(one_hot, exp)): + bin_part[j] ^= a if util.bit_decompose(b, len(bin_part))[j] \ + else 0 + if util.is_zero(bin_part[j]): + bin_part[j] = sbits.get_type(y.size)(0) + if i == 0: + bin_part[j] = sign.if_else(0, bin_part[j]) + parts.append(y._new(y.int_type(sbitvec.from_vec(bin_part)))) + return util.tree_reduce(operator.mul, parts) + + @types.vectorize @instructions_base.sfix_cisc def log2_fx(x, use_division=True): diff --git a/Compiler/non_linear.py b/Compiler/non_linear.py index 01cb4db5..66e82908 100644 --- a/Compiler/non_linear.py +++ b/Compiler/non_linear.py @@ -32,6 +32,8 @@ class NonLinear: return shift_two(a, m) prog = program.Program.prog if prog.use_trunc_pr: + if not prog.options.ring: + prog.curr_tape.require_bit_length(k + prog.security) if signed and prog.use_trunc_pr != -1: a += (1 << (k - 1)) res = sint() diff --git a/Compiler/oram.py b/Compiler/oram.py index ebc5b8a0..bbaa3938 100644 --- a/Compiler/oram.py +++ b/Compiler/oram.py @@ -1034,8 +1034,9 @@ def get_n_threads_for_tree(size): class TreeORAM(AbstractORAM): """ Tree ORAM. """ - def __init__(self, size, value_type=sint, value_length=1, entry_size=None, \ + def __init__(self, size, value_type=None, value_length=1, entry_size=None, \ bucket_oram=TrivialORAM, init_rounds=-1): + value_type = value_type or sint print('create oram of size', size) self.bucket_oram = bucket_oram # heuristic bucket size @@ -1722,6 +1723,8 @@ def OptimalORAM(size,*args,**kwargs): :param value_type: :py:class:`sint` (default) / :py:class:`sg2fn` / :py:class:`sfix` """ + if not util.is_constant(size): + raise CompilerError('ORAM size has be a compile-time constant') if get_program().options.binary: return BinaryORAM(size, *args, **kwargs) if optimal_threshold is None: @@ -1772,6 +1775,12 @@ class OptimalPackedORAMWithEmpty(PackedORAMWithEmpty): def test_oram(oram_type, N, value_type=sint, iterations=100): stop_grind() oram = oram_type(N, value_type=value_type, entry_size=32, init_rounds=0) + test_oram_initialized(oram, iterations) + return oram + +def test_oram_initialized(oram, iterations=100): + N = oram.size + value_type = oram.value_type value_type = value_type.get_type(32) index_type = value_type.get_type(log2(N)) start_grind() diff --git a/Compiler/program.py b/Compiler/program.py index f92ab497..7431d600 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -29,6 +29,7 @@ data_types = dict( bit=2, inverse=3, dabit=4, + mixed=5, ) field_types = dict( @@ -45,6 +46,7 @@ class defaults: ring = 0 field = 0 binary = 0 + garbled = False prime = None galois = 40 budget = 100000 @@ -150,10 +152,11 @@ class Program(object): gc.ldmsd, gc.stmsd, gc.stmsdci, - gc.xors, gc.andrs, gc.ands, gc.inputb, + gc.inputbvec, + gc.reveal, ] self.use_trunc_pr = False """ Setting whether to use special probabilistic truncation. """ @@ -350,7 +353,8 @@ class Program(object): print("Writing to", sch_filename) sch_file.write(str(self.max_par_tapes()) + "\n") sch_file.write(str(len(nonempty_tapes)) + "\n") - sch_file.write(" ".join(tape.name for tape in nonempty_tapes) + "\n") + sch_file.write(" ".join("%s:%d" % (tape.name, len(tape)) + for tape in nonempty_tapes) + "\n") sch_file.write("1 0\n") sch_file.write("0\n") sch_file.write(" ".join(sys.argv) + "\n") @@ -506,7 +510,8 @@ class Program(object): self.set_security(security) def optimize_for_gc(self): - pass + import Compiler.GC.instructions as gc + self.to_merge += [gc.xors] def get_tape_counter(self): res = self.tape_counter @@ -686,6 +691,7 @@ class Tape: self.purged = False self.n_rounds = 0 self.n_to_merge = 0 + self.rounds = Tape.ReqNum() self.warn_about_mem = parent.program.warn_about_mem[-1] def __len__(self): @@ -750,6 +756,7 @@ class Tape: inst.add_usage(req_node) req_node.num["all", "round"] += self.n_rounds req_node.num["all", "inv"] += self.n_to_merge + req_node.num += self.rounds def expand_cisc(self): new_instructions = [] @@ -796,7 +803,14 @@ class Tape: self.name = name self.outfile = self.program.programs_dir + "/Bytecode/" + self.name + ".bc" + def __len__(self): + if self.purged: + return self.size + else: + return sum(len(block) for block in self.basicblocks) + def purge(self): + self.size = len(self) for block in self.basicblocks: block.purge() self._is_empty = len(self.basicblocks) == 0 @@ -865,6 +879,8 @@ class Tape: numrounds = merger.longest_paths_merge() block.n_rounds = numrounds block.n_to_merge = len(merger.open_nodes) + if options.verbose: + block.rounds = merger.req_num if merger.counter and self.program.verbose: print( "Block requires", @@ -1113,7 +1129,8 @@ class Tape: __rmul__ = __mul__ def set_all(self, value): - if value == float("inf") and self["all", "inv"] > 0: + if Program.prog.options.verbose and \ + value == float("inf") and self["all", "inv"] > 0: print("Going to unknown from %s" % self) res = Tape.ReqNum() for i in self: @@ -1142,6 +1159,8 @@ class Tape: res = [] for req, num in self.items(): domain = t(req[0]) + if num < 0: + num = float('inf') n = "%12.0f" % num if req[1] == "input": res += ["%s %s inputs from player %d" % (n, domain, req[2])] diff --git a/Compiler/sorting.py b/Compiler/sorting.py index 248b3ea0..fc619b73 100644 --- a/Compiler/sorting.py +++ b/Compiler/sorting.py @@ -3,12 +3,7 @@ from Compiler import types, library, instructions def dest_comp(B): Bt = B.transpose() - Bt_flat = Bt.get_vector() - St_flat = Bt.value_type.Array(len(Bt_flat)) - St_flat.assign(Bt_flat) - @library.for_range(len(St_flat) - 1) - def _(i): - St_flat[i + 1] = St_flat[i + 1] + St_flat[i] + St_flat = Bt.get_vector().prefix_sum() Tt_flat = Bt.get_vector() * St_flat.get_vector() Tt = types.Matrix(*Bt.sizes, B.value_type) Tt.assign_vector(Tt_flat) @@ -37,8 +32,14 @@ def radix_sort(k, D, n_bits=None, signed=True): bs = types.Matrix.create_from(k.get_vector().bit_decompose(n_bits)) if signed and len(bs) > 1: bs[-1][:] = bs[-1][:].bit_not() - B = types.sint.Matrix(len(k), 2) - h = types.Array.create_from(types.sint(types.regint.inc(len(k)))) + radix_sort_from_matrix(bs, D) + +def radix_sort_from_matrix(bs, D): + n = len(D) + for b in bs: + assert(len(b) == n) + B = types.sint.Matrix(n, 2) + h = types.Array.create_from(types.sint(types.regint.inc(n))) @library.for_range(len(bs)) def _(i): b = bs[i] diff --git a/Compiler/sqrt_oram.py b/Compiler/sqrt_oram.py index 1024ab88..741baaf7 100644 --- a/Compiler/sqrt_oram.py +++ b/Compiler/sqrt_oram.py @@ -10,9 +10,7 @@ from Compiler.GC.types import cbit, sbit, sbitint, sbits from Compiler.program import Program from Compiler.types import (Array, MemValue, MultiArray, _clear, _secret, cint, regint, sint, sintbit) -from oram import demux_array, get_n_threads - -program = Program.prog +from Compiler.oram import demux_array, get_n_threads # Adds messages on completion of heavy computation steps debug = False @@ -44,6 +42,13 @@ B = TypeVar("B", sintbit, sbit) class SqrtOram(Generic[T, B]): + """Oblivious RAM using the "Square-Root" algorithm. + + :param MultiArray data: The data with which to initialize the ORAM. One may provide a MultiArray such that every "block" can hold multiple elements (an Array). + :param sint value_type: The secret type to use, defaults to sint. + :param int k: Leave at 0, this parameter is used to recursively pass down the depth of this ORAM. + :param int period: Leave at None, this parameter is used to recursively pass down the top-level period. + """ # TODO: Preferably this is an Array of vectors, but this is currently not supported # One should regard these structures as Arrays where an entry may hold more # than one value (which is a nice property to have when using the ORAM in @@ -69,14 +74,6 @@ class SqrtOram(Generic[T, B]): t: cint def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type[T] = sint, k: int = 0, period: int | None = None, initialize: bool = True, empty_data=False) -> None: - """Initialize a new Oblivious RAM using the "Square-Root" algorithm. - - Args: - data (MultiArray): The data with which to initialize the ORAM. One may provide a MultiArray such that every "block" can hold multiple elements (an Array). - value_type (sint): The secret type to use, defaults to sint. - k (int): Leave at 0, this parameter is used to recursively pass down the depth of this ORAM. - period (int): Leave at None, this parameter is used to recursively pass down the top-level period. - """ global debug, allow_memory_allocation # Correctly initialize the shuffle (memory) depending on the type of data @@ -103,6 +100,7 @@ class SqrtOram(Generic[T, B]): self.index_size = util.log2(self.n) + 1 # +1 because signed self.index_type = value_type.get_type(self.index_size) self.entry_length = entry_length + self.size = self.n if debug: lib.print_ln( @@ -632,6 +630,7 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]): # The item at logical_address # will be in block with index h (block.) # at position l in block.data (block.data) + program = Program.prog h = MemValue(self.value_type.bit_compose(sbits.get_type(program.bit_length)( logical_address).right_shift(pack_log, program.bit_length))) l = self.value_type.bit_compose(sbits(logical_address) & (pack - 1)) diff --git a/Compiler/types.py b/Compiler/types.py index 5e4893e3..3366e2f4 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -749,7 +749,14 @@ class _register(Tape.Register, _number, _structure): self.mov(res[i], self) return res -class _clear(_register): +class _arithmetic_register(_register): + """ Arithmetic circuit type. """ + def __init__(self, *args, **kwargs): + if program.options.garbled: + raise CompilerError('functionality only available in arithmetic circuits') + super(_arithmetic_register, self).__init__(*args, **kwargs) + +class _clear(_arithmetic_register): """ Clear domain-dependent type. """ __slots__ = [] mov = staticmethod(movc) @@ -1085,6 +1092,8 @@ class cint(_clear, _int): def __ne__(self, other): return 1 - (self == other) + equal = lambda self, other, *args, **kwargs: self.__eq__(other) + def __lshift__(self, other): """ Clear left shift. @@ -1836,7 +1845,7 @@ class longint: res += x.bit_decompose(64) return res[:bit_length] -class _secret(_register, _secret_structure): +class _secret(_arithmetic_register, _secret_structure): __slots__ = [] mov = staticmethod(set_instruction_type(movs)) @@ -2682,6 +2691,15 @@ class sint(_secret, _int): comparison.Trunc(res, tmp, 2 * k, k, kappa, True) return res + @vectorize + def int_mod(self, other, bit_length=None): + """ Secret integer modulo. + + :param other: sint + :param bit_length: bit length of input (default: global bit length) + """ + return self - other * self.int_div(other, bit_length=bit_length) + def trunc_zeros(self, n_zeros, bit_length=None, signed=True): bit_length = bit_length or program.bit_length return comparison.TruncZeros(self, bit_length, n_zeros, signed) @@ -2808,6 +2826,13 @@ class sint(_secret, _int): res = res.get_vector() return res + @vectorize + def prefix_sum(self): + """ Prefix sum. """ + res = sint() + prefixsums(res, self) + return res + class sintbit(sint): """ :py:class:`sint` holding a bit, supporting binary operations (``&, |, ^``). """ @@ -3940,6 +3965,8 @@ class _single(_number, _secret_structure): :param n: number of inputs (int) :param client_id: regint :param size: vector size (default 1) + :returns: list of length ``n`` + """ sint_inputs = cls.int_type.receive_from_client(n, client_id, message_type) @@ -3977,6 +4004,8 @@ class _single(_number, _secret_structure): def conv(cls, other): if isinstance(other, cls): return other + elif isinstance(other, (list, tuple)): + return type(other)(cls.conv(x) for x in other) else: try: return cls.from_sint(other) @@ -4216,7 +4245,7 @@ class _fix(_single): if isinstance(other, _fix) and (cls.k, cls.f) == (other.k, other.f): return other else: - return cls(other) + return super(_fix, cls).conv(other) @classmethod def _new(cls, other, k=None, f=None): @@ -4524,6 +4553,9 @@ class sfix(_fix): return self._new(self.v.secure_permute(*args, **kwargs), k=self.k, f=self.f) + def prefix_sum(self): + return self._new(self.v.prefix_sum(), k=self.k, f=self.f) + class unreduced_sfix(_single): int_type = sint @@ -5271,6 +5303,8 @@ class Array(_vectorizable): a[:] += b[:] """ + check_indices = True + @classmethod def create_from(cls, l): """ Convert Python iterator or vector to array. Basic type will be taken @@ -5283,7 +5317,9 @@ class Array(_vectorizable): """ if isinstance(l, cls): - return l + res = l.same_shape() + res[:] = l[:] + return res if isinstance(l, _number): tmp = l t = type(l) @@ -5304,7 +5340,6 @@ class Array(_vectorizable): self.debug = debug self.creator_tape = program.curr_tape self.sink = None - self.check_indices = True if alloc: self.alloc() @@ -5435,7 +5470,10 @@ class Array(_vectorizable): return self.value_type.load_mem(address) def _store(self, value, address): - self.value_type.conv(value).store_in_mem(address) + tmp = self.value_type.conv(value) + if not isinstance(tmp, _vec) and tmp.size != self.value_type.mem_size(): + raise CompilerError('size mismatch in array assignment') + tmp.store_in_mem(address) def __len__(self): return self.length @@ -5506,6 +5544,12 @@ class Array(_vectorizable): get_part_vector = get_vector + def get_reverse_vector(self): + """ Return vector with content in reverse order. """ + size = self.length + address = regint.inc(size, size - 1, -1) + return self.value_type.load_mem(self.address + address, size=size) + def get_part(self, base, size): """ Part array. @@ -5605,7 +5649,6 @@ class Array(_vectorizable): """ Vector subtraction. :param other: vector or container of same length and type that supports operations with type of this array """ - assert len(self) == len(other) return self.get_vector() - other def __mul__(self, value): @@ -5668,7 +5711,7 @@ class Array(_vectorizable): """ Reveal the whole array. :returns: Array of relevant clear type. """ - return Array.create_from(x.reveal() for x in self) + return Array.create_from(self.get_vector().reveal()) def reveal_list(self): """ Reveal as list. """ @@ -6367,13 +6410,15 @@ class SubMultiArray(_vectorizable): res = Matrix(self.sizes[1], self.sizes[0], self.value_type) library.break_point() if self.value_type.n_elements() == 1: - @library.for_range_opt(self.sizes[0]) - def _(j): - res.set_column(j, self[j][:]) + nr = self.sizes[1] + nc = self.sizes[0] + a = regint.inc(nr * nc, 0, nr, 1, nc) + b = regint.inc(nr * nc, 0, 1, nc) + res[:] = self.value_type.load_mem(self.address + a + b) else: - @library.for_range_opt(self.sizes[1]) + @library.for_range_opt(self.sizes[1], budget=100) def _(i): - @library.for_range_opt(self.sizes[0]) + @library.for_range_opt(self.sizes[0], budget=100) def _(j): res[i][j] = self[j][i] library.break_point() @@ -6424,7 +6469,7 @@ class SubMultiArray(_vectorizable): def randomize(self, *args): """ Randomize according to data type. """ - if self.total_size() < program.options.budget: + if self.total_size() < program.budget: self.assign_vector( self.value_type.get_random(*args, size=self.total_size())) else: @@ -6432,6 +6477,12 @@ class SubMultiArray(_vectorizable): def _(i): self[i].randomize(*args) + def reveal(self): + """ Reveal to :py:obj:`MultiArray` of same shape. """ + res = MultiArray(self.sizes, self.value_type.clear_type) + res[:] = self.get_vector().reveal() + return res + def reveal_list(self): """ Reveal as list. """ return list(self.get_vector().reveal()) @@ -6542,7 +6593,7 @@ class Matrix(MultiArray): @staticmethod def create_from(rows): rows = list(rows) - if isinstance(rows[0], (list, tuple)): + if isinstance(rows[0], (list, tuple, Array)): t = type(rows[0][0]) else: t = type(rows[0]) diff --git a/ECDSA/Fake-ECDSA.cpp b/ECDSA/Fake-ECDSA.cpp index 23f81b9e..ecf7011b 100644 --- a/ECDSA/Fake-ECDSA.cpp +++ b/ECDSA/Fake-ECDSA.cpp @@ -22,4 +22,5 @@ int main() generate_mac_keys>(key, 2, prefix); make_mult_triples>(key, 2, 1000, false, prefix); make_inverse>(key, 2, 1000, false, prefix); + P256Element::finish(); } diff --git a/ECDSA/P256Element.cpp b/ECDSA/P256Element.cpp index 2c8c776d..1ff3273f 100644 --- a/ECDSA/P256Element.cpp +++ b/ECDSA/P256Element.cpp @@ -14,7 +14,14 @@ void P256Element::init() curve = EC_GROUP_new_by_curve_name(NID_secp256k1); assert(curve != 0); auto modulus = EC_GROUP_get0_order(curve); - Scalar::init_field(BN_bn2dec(modulus), false); + auto mod = BN_bn2dec(modulus); + Scalar::init_field(mod, false); + free(mod); +} + +void P256Element::finish() +{ + EC_GROUP_free(curve); } P256Element::P256Element() @@ -42,6 +49,11 @@ P256Element::P256Element(word other) : BN_free(exp); } +P256Element::~P256Element() +{ + EC_POINT_free(point); +} + P256Element& P256Element::operator =(const P256Element& other) { assert(EC_POINT_copy(point, other.point) != 0); @@ -99,7 +111,7 @@ bool P256Element::operator ==(const P256Element& other) const return not cmp; } -void P256Element::pack(octetStream& os) const +void P256Element::pack(octetStream& os, int) const { octet* buffer; size_t length = EC_POINT_point2buf(curve, point, @@ -107,9 +119,10 @@ void P256Element::pack(octetStream& os) const assert(length != 0); os.store_int(length, 8); os.append(buffer, length); + free(buffer); } -void P256Element::unpack(octetStream& os) +void P256Element::unpack(octetStream& os, int) { size_t length = os.get_int(8); assert( diff --git a/ECDSA/P256Element.h b/ECDSA/P256Element.h index 27ea7f75..bd005c84 100644 --- a/ECDSA/P256Element.h +++ b/ECDSA/P256Element.h @@ -32,11 +32,13 @@ public: static string type_string() { return "P256"; } static void init(); + static void finish(); P256Element(); P256Element(const P256Element& other); P256Element(const Scalar& other); P256Element(word other); + ~P256Element(); P256Element& operator=(const P256Element& other); @@ -58,8 +60,8 @@ public: bool is_zero() { return *this == P256Element(); } void add(octetStream& os) { *this += os.get(); } - void pack(octetStream& os) const; - void unpack(octetStream& os); + void pack(octetStream& os, int = -1) const; + void unpack(octetStream& os, int = -1); octetStream hash(size_t n_bytes) const; diff --git a/ECDSA/fake-spdz-ecdsa-party.cpp b/ECDSA/fake-spdz-ecdsa-party.cpp index 5bef730d..ea19c8ee 100644 --- a/ECDSA/fake-spdz-ecdsa-party.cpp +++ b/ECDSA/fake-spdz-ecdsa-party.cpp @@ -64,4 +64,5 @@ int main(int argc, const char** argv) pShare::MAC_Check::teardown(); Share::MAC_Check::teardown(); + P256Element::finish(); } diff --git a/ECDSA/hm-ecdsa-party.hpp b/ECDSA/hm-ecdsa-party.hpp index fc19e989..07520f33 100644 --- a/ECDSA/hm-ecdsa-party.hpp +++ b/ECDSA/hm-ecdsa-party.hpp @@ -30,6 +30,8 @@ #include "GC/ThreadMaster.hpp" #include "GC/Secret.hpp" #include "Machines/ShamirMachine.hpp" +#include "Machines/MalRep.hpp" +#include "Machines/Rep.hpp" #include @@ -69,4 +71,5 @@ void run(int argc, const char** argv) preprocessing(tuples, n_tuples, sk, proc, opts); // check(tuples, sk, {}, P); sign_benchmark(tuples, sk, MCp, P, opts, prep_mul ? 0 : &proc); + P256Element::finish(); } diff --git a/ECDSA/ot-ecdsa-party.hpp b/ECDSA/ot-ecdsa-party.hpp index ebf0aea9..550c0ac8 100644 --- a/ECDSA/ot-ecdsa-party.hpp +++ b/ECDSA/ot-ecdsa-party.hpp @@ -140,4 +140,5 @@ void run(int argc, const char** argv) pShare::MAC_Check::teardown(); T::MAC_Check::teardown(); + P256Element::finish(); } diff --git a/ExternalIO/README.md b/ExternalIO/README.md index f5f418ed..89328440 100644 --- a/ExternalIO/README.md +++ b/ExternalIO/README.md @@ -15,7 +15,7 @@ make bankers-bonus-client.x ./compile.py bankers_bonus 1 Scripts/setup-ssl.sh Scripts/setup-clients.sh 3 -Scripts/.sh bankers_bonus-1 & +PLAYERS= Scripts/.sh bankers_bonus-1 & ./bankers-bonus-client.x 0 100 0 & ./bankers-bonus-client.x 1 200 0 & ./bankers-bonus-client.x 2 50 1 diff --git a/FHEOffline/PairwiseSetup.cpp b/FHEOffline/PairwiseSetup.cpp index 01971182..bc890ed2 100644 --- a/FHEOffline/PairwiseSetup.cpp +++ b/FHEOffline/PairwiseSetup.cpp @@ -116,6 +116,14 @@ void secure_init(T& setup, Player& P, U& machine, ofstream file(filename); os.output(file); } + + if (OnlineOptions::singleton.verbose) + { + cerr << "Ciphertext length: " << params.p0().numBits(); + for (size_t i = 1; i < params.FFTD().size(); i++) + cerr << "+" << params.FFTD()[i].get_prime().numBits(); + cerr << endl; + } } template diff --git a/FHEOffline/Prover.cpp b/FHEOffline/Prover.cpp index d92f3080..7127b8c7 100644 --- a/FHEOffline/Prover.cpp +++ b/FHEOffline/Prover.cpp @@ -128,6 +128,7 @@ size_t Prover::NIZKPoK(Proof& P, octetStream& ciphertexts, octetStream& cl bool ok=false; int cnt=0; + (void) cnt; while (!ok) { cnt++; Stage_1(P,ciphertexts,c,pk); diff --git a/GC/BitAdder.hpp b/GC/BitAdder.hpp index 437af179..a3f821a0 100644 --- a/GC/BitAdder.hpp +++ b/GC/BitAdder.hpp @@ -44,7 +44,8 @@ void BitAdder::add(vector>& res, const vector>>& summ &supplies); BitAdder().add(res, summands, start, summands[0][0].size(), proc, T::default_length); - queues->wrap_up(job); + if (start) + queues->wrap_up(job); } else add(res, summands, 0, res.size(), proc, length); diff --git a/GC/BitPrepFiles.h b/GC/BitPrepFiles.h index 0a406a46..e8c4d0cf 100644 --- a/GC/BitPrepFiles.h +++ b/GC/BitPrepFiles.h @@ -6,12 +6,12 @@ #ifndef GC_BITPREPFILES_H_ #define GC_BITPREPFILES_H_ -namespace GC -{ - #include "ShiftableTripleBuffer.h" #include "Processor/Data_Files.h" +namespace GC +{ + template class BitPrepFiles : public ShiftableTripleBuffer, public Sub_Data_Files { diff --git a/GC/FakeSecret.h b/GC/FakeSecret.h index ee7a8446..cd43ae1d 100644 --- a/GC/FakeSecret.h +++ b/GC/FakeSecret.h @@ -11,11 +11,13 @@ #include "GC/Access.h" #include "GC/ArgTuples.h" #include "GC/NoShare.h" +#include "GC/Processor.h" #include "Math/gf2nlong.h" #include "Tools/SwitchableOutput.h" #include "Processor/DummyProtocol.h" +#include "Processor/Instruction.h" #include "Protocols/FakePrep.h" #include "Protocols/FakeMC.h" #include "Protocols/FakeProtocol.h" @@ -85,6 +87,11 @@ public: { processor.andrs(args); } static void ands(GC::Processor& processor, const vector& regs); template + static void andrsvec(T&, const vector&) + { throw runtime_error("andrsvec not implemented"); } + static void andm(GC::Processor& processor, const ::Instruction& instruction) + { processor.andm(instruction); } + template static void xors(GC::Processor& processor, const vector& regs) { processor.xors(regs); } template diff --git a/GC/Instruction.h b/GC/Instruction.h index e990f954..ab6f3f47 100644 --- a/GC/Instruction.h +++ b/GC/Instruction.h @@ -64,6 +64,7 @@ enum INPUTBVEC = 0x247, SPLIT = 0x248, CONVCBIT2S = 0x249, + ANDRSVEC = 0x24a, // write to clear CLEAR_WRITE = 0x210, XORCBI = 0x210, diff --git a/GC/Machine.h b/GC/Machine.h index ecf352cc..991f0014 100644 --- a/GC/Machine.h +++ b/GC/Machine.h @@ -47,7 +47,7 @@ public: ~Machine(); void load_schedule(const string& progname); - void load_program(const string& threadname, const string& filename); + size_t load_program(const string& threadname, const string& filename); template void reset(const U& program); diff --git a/GC/Machine.hpp b/GC/Machine.hpp index 8cfe08f2..8b555f6c 100644 --- a/GC/Machine.hpp +++ b/GC/Machine.hpp @@ -35,12 +35,14 @@ Machine::~Machine() } template -void Machine::load_program(const string& threadname, const string& filename) +size_t Machine::load_program(const string& threadname, + const string& filename) { (void)threadname; progs.push_back({}); progs.back().parse_file(filename); reset(progs.back()); + return progs.back().size(); } template diff --git a/GC/Memory.h b/GC/Memory.h index 006a91d9..aa02d563 100644 --- a/GC/Memory.h +++ b/GC/Memory.h @@ -18,6 +18,8 @@ using namespace std; class NoMemory { +public: + void resize_min(size_t, const char*) {} }; namespace GC diff --git a/GC/NoShare.h b/GC/NoShare.h index 917e71c5..ec2c85ac 100644 --- a/GC/NoShare.h +++ b/GC/NoShare.h @@ -154,6 +154,7 @@ public: static void xors(Processor&, const vector&) { fail(); } static void ands(Processor&, const vector&) { fail(); } static void andrs(Processor&, const vector&) { fail(); } + static void andrsvec(Processor&, const vector&) { fail(); } static void trans(Processor&, Integer, const vector&) { fail(); } diff --git a/GC/PersonalPrep.hpp b/GC/PersonalPrep.hpp index df172585..44c4080e 100644 --- a/GC/PersonalPrep.hpp +++ b/GC/PersonalPrep.hpp @@ -8,6 +8,8 @@ #include "PersonalPrep.h" +#include "Protocols/ShuffleSacrifice.hpp" + namespace GC { @@ -36,7 +38,8 @@ void PersonalPrep::buffer_personal_triples(size_t batch_size, ThreadQueues* q PersonalTripleJob job(&triples, input_player); int start = queues->distribute(job, batch_size); buffer_personal_triples(triples, start, batch_size); - queues->wrap_up(job); + if (start) + queues->wrap_up(job); } else buffer_personal_triples(triples, 0, batch_size); diff --git a/GC/PostSacriBin.cpp b/GC/PostSacriBin.cpp index 74248060..aff82818 100644 --- a/GC/PostSacriBin.cpp +++ b/GC/PostSacriBin.cpp @@ -10,6 +10,7 @@ #include "Protocols/Replicated.hpp" #include "Protocols/MaliciousRepMC.hpp" #include "Protocols/MalRepRingPrep.hpp" +#include "Protocols/ReplicatedPrep.hpp" #include "ShareSecret.hpp" namespace GC diff --git a/GC/Processor.h b/GC/Processor.h index a5acb950..e21cf600 100644 --- a/GC/Processor.h +++ b/GC/Processor.h @@ -91,6 +91,7 @@ public: void and_(const vector& args, bool repeat); void andrs(const vector& args) { and_(args, true); } void ands(const vector& args) { and_(args, false); } + void andrsvec(const vector& args); void input(const vector& args); void inputb(typename T::Input& input, ProcessorBase& input_processor, diff --git a/GC/Processor.hpp b/GC/Processor.hpp index 96b2d62d..87296edf 100644 --- a/GC/Processor.hpp +++ b/GC/Processor.hpp @@ -15,6 +15,7 @@ using namespace std; #include "GC/Program.h" #include "Access.h" #include "Processor/FixInput.h" +#include "Math/BitVec.h" #include "GC/Machine.hpp" #include "Processor/ProcessorBase.hpp" @@ -205,9 +206,13 @@ template void Processor::mem_op(int n, Memory& dest, const Memory& source, Integer dest_address, Integer source_address) { + dest.check_index(dest_address + n - 1); + source.check_index(source_address + n - 1); + auto d = &dest[dest_address]; + auto s = &source[source_address]; for (int i = 0; i < n; i++) { - dest[dest_address + i] = source[source_address + i]; + *d++ = *s++; } } @@ -302,6 +307,40 @@ void Processor::and_(const vector& args, bool repeat) } } +template +void Processor::andrsvec(const vector& args) +{ + int N_BITS = T::default_length; + auto it = args.begin(); + while (it < args.end()) + { + int n_args = (*it++ - 3) / 2; + int size = *it++; + int base = *(it + n_args); + assert(n_args <= N_BITS); + for (int i = 0; i < size; i += 1) + { + if (i % N_BITS == 0) + for (int j = 0; j < n_args; j++) + S.at(*(it + j) + i / N_BITS).resize_regs( + min(N_BITS, size - i)); + + T y; + y.get_regs().push_back(S.at(base + i / N_BITS).get_reg(i % N_BITS)); + for (int j = 0; j < n_args; j++) + { + T x, tmp; + x.get_regs().push_back( + S.at(*(it + n_args + 1 + j) + i / N_BITS).get_reg( + i % N_BITS)); + tmp.and_(1, x, y, false); + S.at(*(it + j) + i / N_BITS).get_reg(i % N_BITS) = tmp.get_reg(0); + } + } + it += 2 * n_args + 1; + } +} + template void Processor::input(const vector& args) { diff --git a/GC/Program.h b/GC/Program.h index 8280c3f7..5d4b1643 100644 --- a/GC/Program.h +++ b/GC/Program.h @@ -40,6 +40,8 @@ class Program Program(); + size_t size() const { return p.size(); } + // Read in a program void parse_file(const string& filename); void parse(const string& programe); diff --git a/GC/Secret.h b/GC/Secret.h index c4b6e8eb..9fee3f2f 100644 --- a/GC/Secret.h +++ b/GC/Secret.h @@ -98,6 +98,9 @@ public: static void ands(Processor& processor, const vector& args) { T::ands(processor, args); } template + static void andrsvec(Processor& processor, const vector& args) + { T::andrsvec(processor, args); } + template static void xors(Processor& processor, const vector& args) { T::xors(processor, args); } template diff --git a/GC/Semi.cpp b/GC/Semi.cpp new file mode 100644 index 00000000..e00fed69 --- /dev/null +++ b/GC/Semi.cpp @@ -0,0 +1,36 @@ +/* + * Semi.cpp + * + */ + +#include "Semi.h" +#include "SemiPrep.h" + +#include "Protocols/MAC_Check_Base.hpp" +#include "Protocols/Replicated.hpp" +#include "Protocols/SemiInput.hpp" +#include "Protocols/Beaver.hpp" + +namespace GC +{ + +void Semi::prepare_mult(const SemiSecret& x, const SemiSecret& y, int n, + bool repeat) +{ + if (repeat and OnlineOptions::singleton.live_prep) + { + this->triples.push_back({{}}); + auto& triple = this->triples.back(); + triple = dynamic_cast(prep)->get_mixed_triple(n); + for (int i = 0; i < 2; i++) + triple[1 + i] = triple[1 + i].mask(n); + triple[0] = triple[0].extend_bit().mask(n); + shares.push_back(y - triple[0]); + shares.push_back(x - triple[1]); + lengths.push_back(n); + } + else + prepare_mul(x, y, n); +} + +} /* namespace GC */ diff --git a/GC/Semi.h b/GC/Semi.h new file mode 100644 index 00000000..92f9139a --- /dev/null +++ b/GC/Semi.h @@ -0,0 +1,31 @@ +/* + * Semi.h + * + */ + +#ifndef GC_SEMI_H_ +#define GC_SEMI_H_ + +#include "Protocols/Beaver.h" +#include "SemiSecret.h" + +namespace GC +{ + +class Semi : public Beaver +{ + typedef Beaver super; + +public: + Semi(Player& P) : + super(P) + { + } + + void prepare_mult(const SemiSecret& x, const SemiSecret& y, int n, + bool repeat); +}; + +} /* namespace GC */ + +#endif /* GC_SEMI_H_ */ diff --git a/GC/SemiPrep.cpp b/GC/SemiPrep.cpp index 9eed3b31..3adc385d 100644 --- a/GC/SemiPrep.cpp +++ b/GC/SemiPrep.cpp @@ -4,6 +4,7 @@ */ #include "SemiPrep.h" +#include "Semi.h" #include "ThreadMaster.h" #include "OT/NPartyTripleGenerator.h" #include "OT/BitDiagonal.h" @@ -21,7 +22,7 @@ SemiPrep::SemiPrep(DataPositions& usage, bool) : { } -void SemiPrep::set_protocol(Beaver& protocol) +void SemiPrep::set_protocol(SemiSecret::Protocol& protocol) { if (triple_generator) { @@ -53,6 +54,9 @@ SemiPrep::~SemiPrep() { if (triple_generator) delete triple_generator; + this->print_left("mixed triples", mixed_triples.size(), + SemiSecret::type_string(), + this->usage.files.at(DATA_GF2N).at(DATA_MIXED)); } void SemiPrep::buffer_bits() @@ -64,4 +68,25 @@ void SemiPrep::buffer_bits() } } +array SemiPrep::get_mixed_triple(int n) +{ + assert(n < 128); + + if (mixed_triples.empty()) + { + assert(this->triple_generator); + this->triple_generator->generateMixedTriples(); + for (auto& x : this->triple_generator->mixedTriples) + { + this->mixed_triples.push_back({{x[0], x[1], x[2]}}); + } + this->triple_generator->unlock(); + } + + this->count(DATA_MIXED); + auto res = mixed_triples.back(); + mixed_triples.pop_back(); + return res; +} + } /* namespace GC */ diff --git a/GC/SemiPrep.h b/GC/SemiPrep.h index 737cfb98..ee4a7abe 100644 --- a/GC/SemiPrep.h +++ b/GC/SemiPrep.h @@ -25,11 +25,13 @@ class SemiPrep : public BufferPrep, ShiftableTripleBuffer> mixed_triples; + public: SemiPrep(DataPositions& usage, bool = true); ~SemiPrep(); - void set_protocol(Beaver& protocol); + void set_protocol(SemiSecret::Protocol& protocol); void buffer_triples(); void buffer_bits(); @@ -37,6 +39,8 @@ public: void buffer_squares() { throw not_implemented(); } void buffer_inverses() { throw not_implemented(); } + array get_mixed_triple(int n); + void get(Dtype type, SemiSecret* data) { BufferPrep::get(type, data); diff --git a/GC/SemiSecret.h b/GC/SemiSecret.h index e95554bf..dc9e0a34 100644 --- a/GC/SemiSecret.h +++ b/GC/SemiSecret.h @@ -19,6 +19,7 @@ namespace GC class SemiPrep; class DealerPrep; +class Semi; template class SemiSecretBase : public V, public ShareSecret @@ -88,9 +89,13 @@ public: typedef MC MAC_Check; typedef SemiInput Input; typedef SemiPrep LivePrep; + typedef Semi Protocol; static MC* new_mc(typename SemiShare::mac_key_type); + static void andrsvec(Processor& processor, + const vector& args); + SemiSecret() { } diff --git a/GC/SemiSecret.hpp b/GC/SemiSecret.hpp index f6a4d398..b147cce3 100644 --- a/GC/SemiSecret.hpp +++ b/GC/SemiSecret.hpp @@ -8,6 +8,7 @@ #include "Protocols/MAC_Check_Base.hpp" #include "Protocols/DealerMC.h" #include "SemiSecret.h" +#include "Semi.h" namespace GC { @@ -60,6 +61,60 @@ void SemiSecretBase::trans(Processor& processor, int n_outputs, } } +inline +void SemiSecret::andrsvec(Processor& processor, + const vector& args) +{ + int N_BITS = default_length; + auto protocol = ShareThread::s().protocol; + assert(protocol); + protocol->init_mul(); + auto it = args.begin(); + while (it < args.end()) + { + int n_args = (*it++ - 3) / 2; + int size = *it++; + it += n_args; + int base = *it++; + assert(n_args <= N_BITS); + for (int i = 0; i < size; i += N_BITS) + { + square64 square; + for (int j = 0; j < n_args; j++) + square.rows[j] = processor.S.at(*(it + j) + i / N_BITS).get(); + int n_ops = min(N_BITS, size - i); + square.transpose(n_args, n_ops); + for (int j = 0; j < n_ops; j++) + { + long bit = processor.S.at(base + i / N_BITS).get_bit(j); + auto y_ext = SemiSecret(bit).extend_bit(); + protocol->prepare_mult(square.rows[j], y_ext, n_args, true); + } + } + it += n_args; + } + + protocol->exchange(); + + it = args.begin(); + while (it < args.end()) + { + int n_args = (*it++ - 3) / 2; + int size = *it++; + for (int i = 0; i < size; i += N_BITS) + { + int n_ops = min(N_BITS, size - i); + square64 square; + for (int j = 0; j < n_ops; j++) + square.rows[j] = protocol->finalize_mul(n_args).get(); + square.transpose(n_ops, n_args); + for (int j = 0; j < n_args; j++) + processor.S.at(*(it + j) + i / N_BITS) = square.rows[j]; + } + it += 2 * n_args + 1; + } +} + template void SemiSecretBase::load_clear(int n, const Integer& x) { diff --git a/GC/ShareParty.h b/GC/ShareParty.h index 389efa33..ceda2f01 100644 --- a/GC/ShareParty.h +++ b/GC/ShareParty.h @@ -6,8 +6,6 @@ #ifndef GC_SHAREPARTY_H_ #define GC_SHAREPARTY_H_ -#include "Protocols/ReplicatedMC.h" -#include "Protocols/MaliciousRepMC.h" #include "ShareSecret.h" #include "Processor.h" #include "Program.h" diff --git a/GC/ShareParty.hpp b/GC/ShareParty.hpp index 28c28710..57beaec0 100644 --- a/GC/ShareParty.hpp +++ b/GC/ShareParty.hpp @@ -16,14 +16,12 @@ #include "Protocols/fake-stuff.h" #include "ShareThread.hpp" -#include "RepPrep.hpp" #include "ThreadMaster.hpp" #include "Thread.hpp" #include "ShareSecret.hpp" #include "Protocols/Replicated.hpp" #include "Protocols/ReplicatedPrep.hpp" -#include "Protocols/MaliciousRepMC.hpp" #include "Protocols/fake-stuff.hpp" namespace GC diff --git a/GC/ShareSecret.h b/GC/ShareSecret.h index fb254486..d8c0c18c 100644 --- a/GC/ShareSecret.h +++ b/GC/ShareSecret.h @@ -63,6 +63,7 @@ public: static void ands(Processor& processor, const vector& args) { and_(processor, args, false); } static void and_(Processor& processor, const vector& args, bool repeat); + static void andrsvec(Processor& processor, const vector& args); static void xors(Processor& processor, const vector& args); static void inputb(Processor& processor, const vector& args) { inputb(processor, processor, args); } diff --git a/GC/ShareSecret.hpp b/GC/ShareSecret.hpp index 12568ef8..db57e3dd 100644 --- a/GC/ShareSecret.hpp +++ b/GC/ShareSecret.hpp @@ -8,16 +8,12 @@ #include "ShareSecret.h" -#include "MaliciousRepSecret.h" -#include "Protocols/MaliciousRepMC.h" #include "ShareThread.h" #include "Thread.h" #include "square64.h" #include "Protocols/Share.h" -#include "Protocols/ReplicatedMC.hpp" -#include "Protocols/Beaver.hpp" #include "ShareParty.h" #include "ShareThread.hpp" #include "Thread.hpp" @@ -288,6 +284,12 @@ void ShareSecret::and_( ShareThread::s().and_(processor, args, repeat); } +template +void ShareSecret::andrsvec(Processor& processor, const vector& args) +{ + ShareThread::s().andrsvec(processor, args); +} + template void ShareSecret::xors(Processor& processor, const vector& args) { diff --git a/GC/ShareThread.h b/GC/ShareThread.h index 9c5f4ddb..70aae69b 100644 --- a/GC/ShareThread.h +++ b/GC/ShareThread.h @@ -7,11 +7,7 @@ #define GC_SHARETHREAD_H_ #include "Thread.h" -#include "MaliciousRepSecret.h" -#include "RepPrep.h" -#include "SemiHonestRepPrep.h" #include "Processor/Data_Files.h" -#include "Protocols/ReplicatedInput.h" #include @@ -45,6 +41,7 @@ public: void check(); void and_(Processor& processor, const vector& args, bool repeat); + void andrsvec(Processor& processor, const vector& args); void xors(Processor& processor, const vector& args); }; diff --git a/GC/ShareThread.hpp b/GC/ShareThread.hpp index 27eefda0..b0eea1b0 100644 --- a/GC/ShareThread.hpp +++ b/GC/ShareThread.hpp @@ -107,7 +107,7 @@ void ShareThread::and_(Processor& processor, else processor.S[right + j].mask(y_ext, n); processor.S[left + j].mask(x_ext, n); - protocol->prepare_mul(x_ext, y_ext, n); + protocol->prepare_mult(x_ext, y_ext, n, repeat); } } @@ -127,6 +127,53 @@ void ShareThread::and_(Processor& processor, } } +template +void ShareThread::andrsvec(Processor& processor, const vector& args) +{ + int N_BITS = T::default_length; + auto& protocol = this->protocol; + assert(protocol); + protocol->init_mul(); + auto it = args.begin(); + T x_ext, y_ext; + while (it < args.end()) + { + int n_args = (*it++ - 3) / 2; + int size = *it++; + it += n_args; + int base = *it++; + assert(n_args <= N_BITS); + for (int i = 0; i < size; i += N_BITS) + { + int n_ops = min(N_BITS, size - i); + for (int j = 0; j < n_args; j++) + { + processor.S.at(*(it + j) + i / N_BITS).mask(x_ext, n_ops); + processor.S.at(base + i / N_BITS).mask(y_ext, n_ops); + protocol->prepare_mul(x_ext, y_ext, n_ops); + } + } + it += n_args; + } + + protocol->exchange(); + + it = args.begin(); + while (it < args.end()) + { + int n_args = (*it++ - 3) / 2; + int size = *it++; + for (int i = 0; i < size; i += N_BITS) + { + int n_ops = min(N_BITS, size - i); + for (int j = 0; j < n_args; j++) + protocol->finalize_mul(n_ops).mask( + processor.S.at(*(it + j) + i / N_BITS), n_ops); + } + it += 2 * n_args + 1; + } +} + template void ShareThread::xors(Processor& processor, const vector& args) { diff --git a/GC/ThreadMaster.hpp b/GC/ThreadMaster.hpp index a754b2e7..03eea781 100644 --- a/GC/ThreadMaster.hpp +++ b/GC/ThreadMaster.hpp @@ -68,6 +68,7 @@ void ThreadMaster::run() P = new PlainPlayer(N, "main"); machine.load_schedule(progname); + machine.reset(machine.progs[0], memory); for (int i = 0; i < machine.nthreads; i++) threads.push_back(new_thread(i)); diff --git a/GC/TinierSharePrep.hpp b/GC/TinierSharePrep.hpp index e136ec44..d288d826 100644 --- a/GC/TinierSharePrep.hpp +++ b/GC/TinierSharePrep.hpp @@ -8,7 +8,7 @@ #include "TinierSharePrep.h" -#include "PersonalPrep.h" +#include "PersonalPrep.hpp" namespace GC { diff --git a/GC/TinyMC.h b/GC/TinyMC.h index c94677ff..8ef5e10f 100644 --- a/GC/TinyMC.h +++ b/GC/TinyMC.h @@ -46,7 +46,7 @@ public: sizes.reserve(n); } - void prepare_open(const T& secret) + void prepare_open(const T& secret, int = -1) { for (auto& part : secret.get_regs()) part_MC.prepare_open(part); diff --git a/GC/TinyPrep.hpp b/GC/TinyPrep.hpp index 897b3b48..d3efbb83 100644 --- a/GC/TinyPrep.hpp +++ b/GC/TinyPrep.hpp @@ -6,6 +6,8 @@ #include "TinierSharePrep.h" #include "Protocols/MascotPrep.hpp" +#include "Protocols/ShuffleSacrifice.hpp" +#include "Protocols/MalRepRingPrep.hpp" namespace GC { diff --git a/GC/instructions.h b/GC/instructions.h index 49443cc2..62a71603 100644 --- a/GC/instructions.h +++ b/GC/instructions.h @@ -45,6 +45,7 @@ X(NOTS, processor.nots(INST)) \ X(NOTCB, processor.notcb(INST)) \ X(ANDRS, T::andrs(PROC, EXTRA)) \ + X(ANDRSVEC, T::andrsvec(PROC, EXTRA)) \ X(ANDS, T::ands(PROC, EXTRA)) \ X(ANDM, T::andm(PROC, instruction)) \ X(ADDCB, C0 = PC1 + PC2) \ diff --git a/License.txt b/License.txt index ccaafe01..ab7ae3bb 100644 --- a/License.txt +++ b/License.txt @@ -1,19 +1,17 @@ -CSIRO Open Source Software Licence Agreement (variation of the BSD / MIT License) -Copyright (c) 2022, Commonwealth Scientific and Industrial Research Organisation (CSIRO) ABN 41 687 119 230. -All rights reserved. CSIRO is willing to grant you a licence to this MP-SPDZ sofware on the following terms, except where otherwise indicated for third party material. -Redistribution and use of this software in source and binary forms, with or without modification, are permitted provided that the following conditions are met: -* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. -* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. -* Neither the name of CSIRO nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission of CSIRO. -EXCEPT AS EXPRESSLY STATED IN THIS AGREEMENT AND TO THE FULL EXTENT PERMITTED BY APPLICABLE LAW, THE SOFTWARE IS PROVIDED "AS-IS". CSIRO MAKES NO REPRESENTATIONS, WARRANTIES OR CONDITIONS OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY REPRESENTATIONS, WARRANTIES OR CONDITIONS REGARDING THE CONTENTS OR ACCURACY OF THE SOFTWARE, OR OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, THE ABSENCE OF LATENT OR OTHER DEFECTS, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT DISCOVERABLE. -TO THE FULL EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL CSIRO BE LIABLE ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, IN AN ACTION FOR BREACH OF CONTRACT, NEGLIGENCE OR OTHERWISE) FOR ANY CLAIM, LOSS, DAMAGES OR OTHER LIABILITY HOWSOEVER INCURRED. WITHOUT LIMITING THE SCOPE OF THE PREVIOUS SENTENCE THE EXCLUSION OF LIABILITY SHALL INCLUDE: LOSS OF PRODUCTION OR OPERATION TIME, LOSS, DAMAGE OR CORRUPTION OF DATA OR RECORDS; OR LOSS OF ANTICIPATED SAVINGS, OPPORTUNITY, REVENUE, PROFIT OR GOODWILL, OR OTHER ECONOMIC LOSS; OR ANY SPECIAL, INCIDENTAL, INDIRECT, CONSEQUENTIAL, PUNITIVE OR EXEMPLARY DAMAGES, ARISING OUT OF OR IN CONNECTION WITH THIS AGREEMENT, ACCESS OF THE SOFTWARE OR ANY OTHER DEALINGS WITH THE SOFTWARE, EVEN IF CSIRO HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH CLAIM, LOSS, DAMAGES OR OTHER LIABILITY. -APPLICABLE LEGISLATION SUCH AS THE AUSTRALIAN CONSUMER LAW MAY APPLY REPRESENTATIONS, WARRANTIES, OR CONDITIONS, OR IMPOSES OBLIGATIONS OR LIABILITY ON CSIRO THAT CANNOT BE EXCLUDED, RESTRICTED OR MODIFIED TO THE FULL EXTENT SET OUT IN THE EXPRESS TERMS OF THIS CLAUSE ABOVE "CONSUMER GUARANTEES". TO THE EXTENT THAT SUCH CONSUMER GUARANTEES CONTINUE TO APPLY, THEN TO THE FULL EXTENT PERMITTED BY THE APPLICABLE LEGISLATION, THE LIABILITY OF CSIRO UNDER THE RELEVANT CONSUMER GUARANTEE IS LIMITED (WHERE PERMITTED AT CSIRO'S OPTION) TO ONE OF FOLLOWING REMEDIES OR SUBSTANTIALLY EQUIVALENT REMEDIES: -(a) THE REPLACEMENT OF THE SOFTWARE, THE SUPPLY OF EQUIVALENT SOFTWARE, OR SUPPLYING RELEVANT SERVICES AGAIN; -(b) THE REPAIR OF THE SOFTWARE; -(c) THE PAYMENT OF THE COST OF REPLACING THE SOFTWARE, OF ACQUIRING EQUIVALENT SOFTWARE, HAVING THE RELEVANT SERVICES SUPPLIED AGAIN, OR HAVING THE SOFTWARE REPAIRED. -IN THIS CLAUSE, CSIRO INCLUDES ANY THIRD PARTY AUTHOR OR OWNER OF ANY PART OF THE SOFTWARE OR MATERIAL DISTRIBUTED WITH IT. CSIRO MAY ENFORCE ANY RIGHTS ON BEHALF OF THE RELEVANT THIRD PARTY. -Third Party Components -The following third party components are distributed with the Software. You agree to comply with the licence terms for these components as part of accessing the Software. Other third party software may also be identified in separate files distributed with the Software. +The Software is copyright (c) 2022, Commonwealth Scientific and Industrial Research Organisation (CSIRO) ABN 41 687 119 230. + +CSIRO grants you a licence to the Software on the terms of the BSD 3-Clause Licence. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + + Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +The following third party components are distributed with the Software. ___________________________________________________________________ SPDZ-2 [https://github.com/bristolcrypto/SPDZ-2] Copyright (c) 2018, The University of Bristol diff --git a/Machines/MalRep.hpp b/Machines/MalRep.hpp index 020477fa..b68da12c 100644 --- a/Machines/MalRep.hpp +++ b/Machines/MalRep.hpp @@ -9,5 +9,7 @@ #include "Protocols/MalRepRingPrep.hpp" #include "Protocols/MaliciousRepPrep.hpp" #include "Protocols/MaliciousRepMC.hpp" +#include "Protocols/Beaver.hpp" +#include "Rep.hpp" #endif /* MACHINES_MALREP_HPP_ */ diff --git a/Machines/Rep.hpp b/Machines/Rep.hpp index a480860f..d684909b 100644 --- a/Machines/Rep.hpp +++ b/Machines/Rep.hpp @@ -4,7 +4,8 @@ */ #include "Protocols/MalRepRingPrep.h" -#include "Protocols/ReplicatedPrep2k.h" +#include "Protocols/SemiRep3Prep.h" +#include "GC/SemiHonestRepPrep.h" #include "Processor/Data_Files.hpp" #include "Processor/Instruction.hpp" @@ -12,6 +13,8 @@ #include "Protocols/MAC_Check_Base.hpp" #include "Protocols/Beaver.hpp" #include "Protocols/Spdz2kPrep.hpp" +#include "Protocols/ReplicatedMC.hpp" +#include "Protocols/Rep3Shuffler.hpp" #include "Math/Z2k.hpp" #include "GC/ShareSecret.hpp" #include "GC/RepPrep.hpp" diff --git a/Machines/dealer-ring-party.cpp b/Machines/dealer-ring-party.cpp index 890a24ab..e738cd86 100644 --- a/Machines/dealer-ring-party.cpp +++ b/Machines/dealer-ring-party.cpp @@ -15,8 +15,14 @@ #include "Protocols/DealerMC.hpp" #include "Protocols/DealerMatrixPrep.hpp" #include "Protocols/Beaver.hpp" -#include "Semi.hpp" +#include "Protocols/SemiInput.hpp" +#include "Protocols/MAC_Check_Base.hpp" +#include "Protocols/ReplicatedPrep.hpp" +#include "Protocols/MalRepRingPrep.hpp" +#include "Protocols/SemiMC.hpp" #include "GC/DealerPrep.h" +#include "GC/SemiPrep.h" +#include "GC/SemiSecret.hpp" int main(int argc, const char** argv) { diff --git a/Machines/emulate.cpp b/Machines/emulate.cpp index 5999050c..46911695 100644 --- a/Machines/emulate.cpp +++ b/Machines/emulate.cpp @@ -17,6 +17,7 @@ #include "Protocols/ReplicatedPrep.hpp" #include "Protocols/FakeShare.hpp" #include "Protocols/MalRepRingPrep.hpp" +#include "Protocols/MAC_Check_Base.hpp" int main(int argc, const char** argv) { diff --git a/Machines/malicious-rep-bin-party.cpp b/Machines/malicious-rep-bin-party.cpp index 2ae79671..d2747a0e 100644 --- a/Machines/malicious-rep-bin-party.cpp +++ b/Machines/malicious-rep-bin-party.cpp @@ -7,12 +7,14 @@ #include "GC/ShareParty.hpp" #include "GC/ShareSecret.hpp" #include "GC/MaliciousRepSecret.h" +#include "GC/RepPrep.h" #include "GC/Machine.hpp" #include "GC/Processor.hpp" #include "GC/Program.hpp" #include "GC/Thread.hpp" #include "GC/ThreadMaster.hpp" +#include "GC/RepPrep.hpp" #include "Processor/Instruction.hpp" #include "Protocols/MaliciousRepMC.hpp" diff --git a/Machines/mascot-offline.cpp b/Machines/mascot-offline.cpp index 975ae030..e24735b7 100644 --- a/Machines/mascot-offline.cpp +++ b/Machines/mascot-offline.cpp @@ -9,6 +9,7 @@ #include "Math/gfp.hpp" #include "Processor/FieldMachine.hpp" #include "Processor/OfflineMachine.hpp" +#include "Protocols/MascotPrep.hpp" int main(int argc, const char** argv) { diff --git a/Machines/no-party.cpp b/Machines/no-party.cpp index ce542de1..ceb35b08 100644 --- a/Machines/no-party.cpp +++ b/Machines/no-party.cpp @@ -9,6 +9,8 @@ #include "Processor/Machine.hpp" #include "Protocols/Replicated.hpp" #include "Protocols/MalRepRingPrep.hpp" +#include "Protocols/ReplicatedPrep.hpp" +#include "Protocols/MAC_Check_Base.hpp" #include "Math/gfp.hpp" #include "Math/Z2k.hpp" diff --git a/Machines/ps-rep-bin-party.cpp b/Machines/ps-rep-bin-party.cpp index 98ffb298..4ab36139 100644 --- a/Machines/ps-rep-bin-party.cpp +++ b/Machines/ps-rep-bin-party.cpp @@ -5,8 +5,11 @@ #include "GC/PostSacriBin.h" #include "GC/PostSacriSecret.h" +#include "GC/RepPrep.h" #include "GC/ShareParty.hpp" +#include "GC/RepPrep.hpp" +#include "Protocols/MaliciousRepMC.hpp" int main(int argc, const char** argv) { diff --git a/Machines/real-bmr-party.cpp b/Machines/real-bmr-party.cpp index 42000ddf..8f329971 100644 --- a/Machines/real-bmr-party.cpp +++ b/Machines/real-bmr-party.cpp @@ -7,6 +7,7 @@ #include "BMR/RealProgramParty.hpp" #include "Machines/SPDZ.hpp" +#include "Protocols/MascotPrep.hpp" int main(int argc, const char** argv) { diff --git a/Machines/replicated-bin-party.cpp b/Machines/replicated-bin-party.cpp index 763b1918..153d830e 100644 --- a/Machines/replicated-bin-party.cpp +++ b/Machines/replicated-bin-party.cpp @@ -4,6 +4,7 @@ */ #include "GC/ShareParty.h" +#include "GC/SemiHonestRepPrep.h" #include "GC/ShareParty.hpp" #include "GC/ShareSecret.hpp" @@ -12,6 +13,7 @@ #include "GC/Program.hpp" #include "GC/Thread.hpp" #include "GC/ThreadMaster.hpp" +#include "GC/RepPrep.hpp" #include "Processor/Instruction.hpp" #include "Protocols/MaliciousRepMC.hpp" diff --git a/Machines/replicated-ring-party.cpp b/Machines/replicated-ring-party.cpp index 2b3646fe..a295eafe 100644 --- a/Machines/replicated-ring-party.cpp +++ b/Machines/replicated-ring-party.cpp @@ -4,7 +4,6 @@ */ #include "Protocols/Rep3Share2k.h" -#include "Protocols/ReplicatedPrep2k.h" #include "Processor/RingOptions.h" #include "Math/Integer.h" #include "Machines/RepRing.hpp" diff --git a/Machines/sy-rep-field-party.cpp b/Machines/sy-rep-field-party.cpp index 1da85676..a457e3b0 100644 --- a/Machines/sy-rep-field-party.cpp +++ b/Machines/sy-rep-field-party.cpp @@ -13,10 +13,10 @@ #include "Math/gf2n.h" #include "Tools/ezOptionParser.h" #include "GC/MaliciousCcdSecret.h" +#include "GC/SemiHonestRepPrep.h" #include "Processor/FieldMachine.hpp" #include "Protocols/Replicated.hpp" -#include "Protocols/MaliciousRepMC.hpp" #include "Protocols/Share.hpp" #include "Protocols/fake-stuff.hpp" #include "Protocols/SpdzWise.hpp" @@ -30,6 +30,7 @@ #include "GC/RepPrep.hpp" #include "GC/ThreadMaster.hpp" #include "Math/gfp.hpp" +#include "MalRep.hpp" int main(int argc, const char** argv) { diff --git a/Machines/sy-rep-ring-party.cpp b/Machines/sy-rep-ring-party.cpp index 728466f7..45faca6f 100644 --- a/Machines/sy-rep-ring-party.cpp +++ b/Machines/sy-rep-ring-party.cpp @@ -11,10 +11,10 @@ #include "Protocols/MalRepRingPrep.h" #include "Processor/RingOptions.h" #include "GC/MaliciousCcdSecret.h" +#include "GC/SemiHonestRepPrep.h" #include "Processor/RingMachine.hpp" #include "Protocols/Replicated.hpp" -#include "Protocols/MaliciousRepMC.hpp" #include "Protocols/Share.hpp" #include "Protocols/fake-stuff.hpp" #include "Protocols/SpdzWise.hpp" @@ -32,6 +32,7 @@ #include "GC/ShareSecret.hpp" #include "GC/RepPrep.hpp" #include "GC/ThreadMaster.hpp" +#include "MalRep.hpp" int main(int argc, const char** argv) { diff --git a/Machines/sy-shamir-party.cpp b/Machines/sy-shamir-party.cpp index b009abb3..d251e7cd 100644 --- a/Machines/sy-shamir-party.cpp +++ b/Machines/sy-shamir-party.cpp @@ -12,6 +12,7 @@ #include "Math/gf2n.h" #include "GC/CcdSecret.h" #include "GC/MaliciousCcdSecret.h" +#include "GC/SemiHonestRepPrep.h" #include "Protocols/Share.hpp" #include "Protocols/SpdzWise.hpp" diff --git a/Machines/tinier-party.cpp b/Machines/tinier-party.cpp index 35aae3aa..1ea00ffe 100644 --- a/Machines/tinier-party.cpp +++ b/Machines/tinier-party.cpp @@ -25,6 +25,7 @@ #include "Protocols/MAC_Check_Base.hpp" #include "Protocols/Beaver.hpp" #include "Protocols/MascotPrep.hpp" +#include "Protocols/MalRepRingPrep.hpp" int main(int argc, const char** argv) { diff --git a/Makefile b/Makefile index 12fda5bd..467e6d8f 100644 --- a/Makefile +++ b/Makefile @@ -12,7 +12,7 @@ PROCESSOR = $(patsubst %.cpp,%.o,$(wildcard Processor/*.cpp)) FHEOBJS = $(patsubst %.cpp,%.o,$(wildcard FHEOffline/*.cpp FHE/*.cpp)) Protocols/CowGearOptions.o GC = $(patsubst %.cpp,%.o,$(wildcard GC/*.cpp)) $(PROCESSOR) -GC_SEMI = GC/SemiPrep.o GC/square64.o +GC_SEMI = GC/SemiPrep.o GC/square64.o GC/Semi.o OT = $(patsubst %.cpp,%.o,$(wildcard OT/*.cpp)) $(LIBSIMPLEOT) OT_EXE = ot.x ot-offline.x @@ -40,6 +40,17 @@ LIBSIMPLEOT_ASM = deps/SimpleOT/libsimpleot.a LIBSIMPLEOT += $(LIBSIMPLEOT_ASM) endif +STATIC_OTE = local/lib/liblibOTe.a +SHARED_OTE = local/lib/liblibOTe.so + +ifeq ($(USE_KOS), 0) +ifeq ($(USE_SHARED_OTE), 1) +OT += $(SHARED_OTE) local/lib/libcryptoTools.so +else +OT += $(STATIC_OTE) local/lib/libcryptoTools.a +endif +endif + # used for dependency generation OBJS = $(BMR) $(FHEOBJS) $(TINYOTOFFLINE) $(YAO) $(COMPLETE) $(patsubst %.cpp,%.o,$(wildcard Machines/*.cpp Utils/*.cpp)) DEPS := $(wildcard */*.d */*/*.d) @@ -106,6 +117,7 @@ endif tldr: libote $(MAKE) mascot-party.x + mkdir Player-Data 2> /dev/null; true ifeq ($(ARM), 1) Tools/intrinsics.h: deps/simde/simde @@ -130,8 +142,8 @@ $(SHAREDLIB): $(PROCESSOR) $(COMMONOBJS) GC/square64.o GC/Instruction.o $(FHEOFFLINE): $(FHEOBJS) $(SHAREDLIB) $(CXX) $(CFLAGS) -shared -o $@ $^ $(LDLIBS) -static/%.x: Machines/%.o $(LIBRELEASE) $(LIBSIMPLEOT) - $(CXX) $(CFLAGS) -o $@ $^ -Wl,-Map=$<.map -Wl,-Bstatic -static-libgcc -static-libstdc++ $(LIBRELEASE) $(LIBSIMPLEOT) $(BOOST) $(LDLIBS) -Wl,-Bdynamic -ldl +static/%.x: Machines/%.o $(LIBRELEASE) $(LIBSIMPLEOT) local/lib/libcryptoTools.a local/lib/liblibOTe.a + $(CXX) -o $@ $(CFLAGS) $^ -Wl,-Map=$<.map -Wl,-Bstatic -static-libgcc -static-libstdc++ $(LIBRELEASE) -llibOTe -lcryptoTools $(LIBSIMPLEOT) $(BOOST) $(LDLIBS) -Wl,-Bdynamic -ldl static/%.x: ECDSA/%.o ECDSA/P256Element.o $(VMOBJS) $(OT) $(LIBSIMPLEOT) $(CXX) $(CFLAGS) -o $@ $^ -Wl,-Map=$<.map -Wl,-Bstatic -static-libgcc -static-libstdc++ $(BOOST) $(LDLIBS) -Wl,-Bdynamic -ldl @@ -201,13 +213,13 @@ replicated-field-party.x: GC/square64.o brain-party.x: GC/square64.o malicious-rep-bin-party.x: GC/square64.o ps-rep-bin-party.x: GC/PostSacriBin.o -semi-bin-party.x: $(OT) GC/SemiPrep.o GC/square64.o +semi-bin-party.x: $(OT) $(GC_SEMI) tiny-party.x: $(OT) tinier-party.x: $(OT) spdz2k-party.x: $(TINIER) $(patsubst %.cpp,%.o,$(wildcard Machines/SPDZ2*.cpp)) static/spdz2k-party.x: $(patsubst %.cpp,%.o,$(wildcard Machines/SPDZ2*.cpp)) -semi-party.x: $(OT) GC/SemiPrep.o GC/square64.o -semi2k-party.x: $(OT) GC/SemiPrep.o GC/square64.o +semi-party.x: $(OT) $(GC_SEMI) +semi2k-party.x: $(OT) $(GC_SEMI) hemi-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) temi-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) soho-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) @@ -232,15 +244,15 @@ malicious-rep-ring-party.x: Protocols/MalRepRingOptions.o sy-rep-ring-party.x: Protocols/MalRepRingOptions.o rep4-ring-party.x: GC/Rep4Secret.o no-party.x: Protocols/ShareInterface.o -semi-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) GC/SemiPrep.o +semi-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) $(GC_SEMI) mascot-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) fake-spdz-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) emulate.x: GC/FakeSecret.o -semi-bmr-party.x: GC/SemiPrep.o $(OT) +semi-bmr-party.x: $(GC_SEMI) $(OT) real-bmr-party.x: $(OT) paper-example.x: $(VM) $(OT) $(FHEOFFLINE) -binary-example.x: $(VM) $(OT) GC/PostSacriBin.o GC/SemiPrep.o GC/AtlasSecret.o -mixed-example.x: $(VM) $(OT) GC/PostSacriBin.o GC/SemiPrep.o GC/AtlasSecret.o Machines/Tinier.o +binary-example.x: $(VM) $(OT) GC/PostSacriBin.o $(GC_SEMI) GC/AtlasSecret.o +mixed-example.x: $(VM) $(OT) GC/PostSacriBin.o $(GC_SEMI) GC/AtlasSecret.o Machines/Tinier.o l2h-example.x: $(VM) $(OT) Machines/Tinier.o he-example.x: $(FHEOFFLINE) mascot-offline.x: $(VM) $(TINIER) @@ -272,14 +284,15 @@ OT/BaseOT.o: deps/SimplestOT_C/ref10/Makefile deps/SimplestOT_C/ref10/Makefile: git submodule update --init deps/SimplestOT_C || git clone https://github.com/mkskeller/SimplestOT_C deps/SimplestOT_C - cd deps/SimplestOT_C/ref10; cmake . + cd deps/SimplestOT_C/ref10; PATH=$(CURDIR)/local/bin:$(PATH) cmake . .PHONY: Programs/Circuits Programs/Circuits: git submodule update --init Programs/Circuits -.PHONY: mpir-setup mpir-global mpir -mpir-setup: +.PHONY: mpir-setup mpir-global +mpir-setup: deps/mpir/Makefile +deps/mpir/Makefile: git submodule update --init deps/mpir || git clone https://github.com/wbhart/mpir deps/mpir cd deps/mpir; \ autoreconf -i; \ @@ -292,35 +305,45 @@ mpir-global: mpir-setup $(MAKE) -C deps/mpir sudo $(MAKE) -C deps/mpir install -mpir: mpir-setup +mpir: local/lib/libmpirxx.so +local/lib/libmpirxx.so: deps/mpir/Makefile cd deps/mpir; \ ./configure --enable-cxx --prefix=$(CURDIR)/local $(MAKE) -C deps/mpir install - -echo MY_CFLAGS += -I./local/include >> CONFIG.mine - -echo MY_LDLIBS += -Wl,-rpath -Wl,$(CURDIR)/local/lib -L$(CURDIR)/local/lib >> CONFIG.mine deps/libOTe/libOTe: - git submodule update --init --recursive deps/libOTe - -echo MY_CFLAGS += -I./local/include >> CONFIG.mine - -echo MY_LDLIBS += -Wl,-rpath -Wl,$(CURDIR)/local/lib -L$(CURDIR)/local/lib >> CONFIG.mine - + git submodule update --init --recursive deps/libOTe || git clone --recurse-submodules https://github.com/mkskeller/softspoken-implementation deps/libOTe boost: deps/libOTe/libOTe cd deps/libOTe; \ python3 build.py --setup --boost --install=$(CURDIR)/local OTE_OPTS = -DENABLE_SOFTSPOKEN_OT=ON -DCMAKE_CXX_COMPILER=$(CXX) -DCMAKE_INSTALL_LIBDIR=lib +ifeq ($(USE_SHARED_OTE), 1) +OTE = $(SHARED_OTE) +else +OTE = $(STATIC_OTE) +endif + +libote: + rm $(STATIC_OTE) $(SHARED_OTE)* 2>/dev/null; true + $(MAKE) $(OTE) + +local/lib/libcryptoTools.a: $(STATIC_OTE) +local/lib/libcryptoTools.so: $(SHARED_OTE) +OT/OTExtensionWithMatrix.o: $(OTE) + ifeq ($(ARM), 1) -libote: deps/libOTe/libOTe +local/lib/liblibOTe.a: deps/libOTe/libOTe cd deps/libOTe; \ PATH="$(CURDIR)/local/bin:$(PATH)" python3 build.py --install=$(CURDIR)/local -- -DBUILD_SHARED_LIBS=0 -DENABLE_AVX=OFF -DENABLE_SSE=OFF $(OTE_OPTS) else -libote: deps/libOTe/libOTe +local/lib/liblibOTe.a: deps/libOTe/libOTe cd deps/libOTe; \ PATH="$(CURDIR)/local/bin:$(PATH)" python3 build.py --install=$(CURDIR)/local -- -DBUILD_SHARED_LIBS=0 $(OTE_OPTS) endif -libote-shared: deps/libOTe/libOTe +$(SHARED_OTE): deps/libOTe/libOTe cd deps/libOTe; \ python3 build.py --install=$(CURDIR)/local -- -DBUILD_SHARED_LIBS=1 $(OTE_OPTS) diff --git a/Math/BitVec.h b/Math/BitVec.h index f0d60a1b..f4b0a1e2 100644 --- a/Math/BitVec.h +++ b/Math/BitVec.h @@ -69,6 +69,8 @@ public: { if (n == -1) pack(os); + else if (n == 1) + os.store_int<1>(this->a & 1); else os.store_int(super::mask(n).get(), DIV_CEIL(n, 8)); } @@ -77,6 +79,8 @@ public: { if (n == -1) unpack(os); + else if (n == 1) + this->a = os.get_int<1>(); else this->a = os.get_int(DIV_CEIL(n, 8)); } diff --git a/Math/Square.hpp b/Math/Square.hpp index 98b646ee..7ca997eb 100644 --- a/Math/Square.hpp +++ b/Math/Square.hpp @@ -4,6 +4,7 @@ */ #include "Math/Square.h" +#include "Math/BitVec.h" template Square& Square::sub(const Square& other) @@ -40,6 +41,16 @@ void Square::bit_sub(const BitVector& bits, int start) } } +template<> +inline +void Square::bit_sub(const BitVector& bits, int start) +{ + for (int i = 0; i < BitVec::length(); i++) + { + rows[i] -= bits.get_portion(start + i); + } +} + template void Square::conditional_add(BitVector& conditions, Square& other, int offset) diff --git a/Math/Zp_Data.h b/Math/Zp_Data.h index 60a132d6..3d3ecc20 100644 --- a/Math/Zp_Data.h +++ b/Math/Zp_Data.h @@ -20,10 +20,10 @@ using namespace std; #ifndef MAX_MOD_SZ - #if defined(GFP_MOD_SZ) and GFP_MOD_SZ > 10 + #if defined(GFP_MOD_SZ) and GFP_MOD_SZ > 11 #define MAX_MOD_SZ GFP_MOD_SZ #else - #define MAX_MOD_SZ 10 + #define MAX_MOD_SZ 11 #endif #endif diff --git a/Math/field_types.h b/Math/field_types.h index 9f54d3af..052cc40a 100644 --- a/Math/field_types.h +++ b/Math/field_types.h @@ -16,7 +16,8 @@ enum Dtype DATA_BIT, DATA_INVERSE, DATA_DABIT, - N_DTYPE + DATA_MIXED, + N_DTYPE, }; #endif /* MATH_FIELD_TYPES_H_ */ diff --git a/Math/mpn_fixed.h b/Math/mpn_fixed.h index b1c5642b..49bc8528 100644 --- a/Math/mpn_fixed.h +++ b/Math/mpn_fixed.h @@ -70,20 +70,6 @@ inline void mpn_add_fixed_n<2>(mp_limb_t* res, const mp_limb_t* x, const mp_limb ); } -template <> -inline void mpn_add_fixed_n<3>(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y) -{ - memcpy(res, y, 3 * sizeof(mp_limb_t)); - __asm__ ( - "add %3, %0 \n" - "adc %4, %1 \n" - "adc %5, %2 \n" - : "+&r"(res[0]), "+&r"(res[1]), "+r"(res[2]) - : "rm"(x[0]), "rm"(x[1]), "rm"(x[2]) - : "cc" - ); -} - template <> inline void mpn_add_fixed_n<4>(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y) { diff --git a/Networking/data.h b/Networking/data.h index 6d7fb728..e2bda042 100644 --- a/Networking/data.h +++ b/Networking/data.h @@ -26,7 +26,7 @@ inline void short_memcpy(void* out, void* in, size_t n_bytes) X(1) X(2) X(3) X(4) X(5) X(6) X(7) X(8) #undef X default: - throw invalid_length("length outside range"); + throw invalid_length("length outside range: " + to_string(n_bytes)); } } diff --git a/OT/BaseOT.h b/OT/BaseOT.h index 3069bd89..4faf9283 100644 --- a/OT/BaseOT.h +++ b/OT/BaseOT.h @@ -68,7 +68,7 @@ public: void set_receiver_inputs(const BitVector& new_inputs) { if ((int)new_inputs.size() != nOT) - throw invalid_length(); + throw invalid_length("BaseOT"); receiver_inputs = new_inputs; } diff --git a/OT/BitMatrix.h b/OT/BitMatrix.h index 445ebc3c..a797b979 100644 --- a/OT/BitMatrix.h +++ b/OT/BitMatrix.h @@ -127,6 +127,9 @@ public: vector< U, aligned_allocator > squares; + typename U::RowType& operator[](int i) + { return squares[i / U::n_rows()].rows[i % U::n_rows()]; } + size_t vertical_size(); void resize_vertical(int length) diff --git a/OT/BitMatrix.hpp b/OT/BitMatrix.hpp index 00ede633..74a682a7 100644 --- a/OT/BitMatrix.hpp +++ b/OT/BitMatrix.hpp @@ -19,7 +19,7 @@ template bool Matrix::operator==(Matrix& other) { if (squares.size() != other.squares.size()) - throw invalid_length(); + throw invalid_length("Matrix"); for (size_t i = 0; i < squares.size(); i++) if (not(squares[i] == other.squares[i])) return false; @@ -109,7 +109,7 @@ template Slice& Slice::rsub(Slice& other) { if (bm.squares.size() < other.end) - throw invalid_length(); + throw invalid_length("rsub"); for (size_t i = other.start; i < other.end; i++) bm.squares[i].rsub(other.bm.squares[i]); return *this; diff --git a/OT/MamaRectangle.h b/OT/MamaRectangle.h index 98da4d5a..a17e3064 100644 --- a/OT/MamaRectangle.h +++ b/OT/MamaRectangle.h @@ -18,6 +18,8 @@ class MamaRectangle typename T::Square squares[N]; public: + typedef GC::NoValue RowType; + static int n_rows() { return T::Square::n_rows(); } static int n_columns() { return T::Square::n_columns(); } static int n_row_bytes() { return T::Square::n_row_bytes(); } diff --git a/OT/NPartyTripleGenerator.h b/OT/NPartyTripleGenerator.h index 3c58e690..b212a480 100644 --- a/OT/NPartyTripleGenerator.h +++ b/OT/NPartyTripleGenerator.h @@ -6,6 +6,7 @@ #include "Tools/random.h" #include "Tools/time-func.h" #include "Processor/InputTuple.h" +#include "Protocols/dabit.h" #include "OT/OTTripleSetup.h" #include "OT/MascotParams.h" @@ -98,7 +99,8 @@ public: vector> preampTriples; vector> plainTriples; - vector plainBits; + vector> plainBits; + vector> mixedTriples; typename T::MAC_Check* MC; @@ -114,6 +116,7 @@ public: void plainTripleRound(int k = 0); void generatePlainBits(); + void generateMixedTriples(); void run_multipliers(MultJob job); diff --git a/OT/NPartyTripleGenerator.hpp b/OT/NPartyTripleGenerator.hpp index f2b981c1..47df8f49 100644 --- a/OT/NPartyTripleGenerator.hpp +++ b/OT/NPartyTripleGenerator.hpp @@ -489,7 +489,8 @@ void OTTripleGenerator::generatePlainBits() machine.set_passive(); machine.output = false; - int n = multiple_minimum(nPreampTriplesPerLoop, T::open_type::size_in_bits()); + int n = multiple_minimum(100 * nPreampTriplesPerLoop, + T::open_type::size_in_bits()); valueBits.resize(1); valueBits[0].resize(n); @@ -500,16 +501,52 @@ void OTTripleGenerator::generatePlainBits() wait_for_multipliers(); plainBits.clear(); + typename T::open_type two = 2; + for (int j = 0; j < n; j++) { if (j % T::open_type::size_in_bits() < T::open_type::length()) { - plainBits.push_back(valueBits[0].get_bit(j)); - plainBits.back() += ot_multipliers[0]->c_output[j] * 2; + bool b = valueBits[0].get_bit(j); + plainBits.push_back({b, b}); + plainBits.back().first += ot_multipliers[0]->c_output[j] * two; } } } +template +void OTTripleGenerator::generateMixedTriples() +{ + assert(ot_multipliers.size() == 1); + + machine.set_passive(); + machine.output = false; + + int n = multiple_minimum(100 * nPreampTriplesPerLoop, + T::open_type::size_in_bits()); + + valueBits.resize(2); + valueBits[0].resize(n); + valueBits[0].randomize(share_prg); + valueBits[1].resize(n * T::open_type::N_BITS); + valueBits[1].randomize(share_prg); + + signal_multipliers(DATA_MIXED); + + wait_for_multipliers(); + mixedTriples.clear(); + + for (int j = 0; j < n; j++) + { + auto a = valueBits[0].get_bit(j); + auto b = valueBits[1].template get_portion(j); + auto c = a ? b : typename T::open_type(); + for (auto& x : ot_multipliers) + c += x->c_output[j]; + mixedTriples.push_back({{a, b, c}}); + } +} + template void OTTripleGenerator::plainTripleRound(int k) { diff --git a/OT/OTCorrelator.hpp b/OT/OTCorrelator.hpp index 00561d3c..d6c19761 100644 --- a/OT/OTCorrelator.hpp +++ b/OT/OTCorrelator.hpp @@ -188,7 +188,7 @@ template void OTCorrelator::reduce_squares(unsigned int nTriples, vector& output, int start) { if (receiverOutputMatrix.squares.size() < nTriples + start) - throw invalid_length(); + throw invalid_length("reduce_squares"); output.resize(nTriples); for (unsigned int j = 0; j < nTriples; j++) { diff --git a/OT/OTExtensionWithMatrix.cpp b/OT/OTExtensionWithMatrix.cpp index 258e7430..409a4f99 100644 --- a/OT/OTExtensionWithMatrix.cpp +++ b/OT/OTExtensionWithMatrix.cpp @@ -9,7 +9,10 @@ #ifndef USE_KOS #include "Networking/PlayerCtSocket.h" -osuCrypto::IOService OTExtensionWithMatrix::ios; +#include +#include + +osuCrypto::IOService ot_extension_ios; #endif #include "OTCorrelator.hpp" @@ -112,7 +115,7 @@ void OTExtensionWithMatrix::extend(int nOTs_requested, const BitVector& newRecei resize(nOTs_requested); if (not channel) - channel = new osuCrypto::Channel(ios, new PlayerCtSocket(*player)); + channel = new osuCrypto::Channel(ot_extension_ios, new PlayerCtSocket(*player)); if (player->my_num()) { diff --git a/OT/OTExtensionWithMatrix.h b/OT/OTExtensionWithMatrix.h index e15ac953..e6eab6da 100644 --- a/OT/OTExtensionWithMatrix.h +++ b/OT/OTExtensionWithMatrix.h @@ -11,8 +11,9 @@ #include "Math/gf2n.h" #ifndef USE_KOS -#include -#include +namespace osuCrypto { +class Channel; +} #endif template @@ -57,7 +58,6 @@ class OTExtensionWithMatrix : public OTCorrelator int nsubloops; #ifndef USE_KOS - static osuCrypto::IOService ios; osuCrypto::Channel* channel; #endif diff --git a/OT/OTMultiplier.h b/OT/OTMultiplier.h index 21ec0622..0f86bc0c 100644 --- a/OT/OTMultiplier.h +++ b/OT/OTMultiplier.h @@ -59,6 +59,7 @@ protected: void multiplyForTriples(); virtual void multiplyForBits(); + virtual void multiplyForMixed(); virtual void multiplyForInputs(MultJob job) = 0; virtual void after_correlation() = 0; @@ -174,6 +175,7 @@ class SemiMultiplier : public OTMultiplier } void multiplyForBits(); + void multiplyForMixed(); void after_correlation(); diff --git a/OT/OTMultiplier.hpp b/OT/OTMultiplier.hpp index 24ad88a1..63f4dd08 100644 --- a/OT/OTMultiplier.hpp +++ b/OT/OTMultiplier.hpp @@ -128,6 +128,9 @@ void OTMultiplier::multiply() case DATA_TRIPLE: multiplyForTriples(); break; + case DATA_MIXED: + multiplyForMixed(); + break; default: throw not_implemented(); } @@ -188,6 +191,55 @@ void SemiMultiplier::multiplyForBits() this->outbox.push({}); } +template +void SemiMultiplier::multiplyForMixed() +{ + auto& rot_ext = this->rot_ext; + + typedef Square X; + OTCorrelator> otCorrelator( + this->generator.players[this->thread_num], BOTH, true); + + BitVector aBits = this->generator.valueBits[0]; + rot_ext.extend_correlated(aBits); + + auto& baseSenderOutputs = otCorrelator.matrices; + auto& baseReceiverOutput = otCorrelator.senderOutputMatrices[0]; + + rot_ext.hash_outputs(aBits.size(), baseSenderOutputs, baseReceiverOutput); + + if (this->generator.get_player().num_players() == 2) + { + c_output.clear(); + + for (unsigned j = 0; j < aBits.size(); j++) + { + this->generator.valueBits[1].set_portion(j, + BitVec(baseSenderOutputs[0][j] ^ baseSenderOutputs[1][j])); + c_output.push_back(baseReceiverOutput[j] ^ baseSenderOutputs[0][j]); + } + + this->outbox.push({}); + return; + } + + otCorrelator.setup_for_correlation(aBits, baseSenderOutputs, + baseReceiverOutput); + otCorrelator.correlate(0, otCorrelator.receiverOutputMatrix.squares.size(), + this->generator.valueBits[1], false, -1); + + c_output.clear(); + + for (unsigned j = 0; j < aBits.size(); j++) + { + c_output.push_back( + otCorrelator.receiverOutputMatrix[j] + ^ otCorrelator.senderOutputMatrices[0][j]); + } + + this->outbox.push({}); +} + template void OTMultiplier::multiplyForTriples() { @@ -592,3 +644,9 @@ void OTMultiplier::multiplyForBits() { throw runtime_error("bit generation not implemented in this case"); } + +template +void OTMultiplier::multiplyForMixed() +{ + throw runtime_error("mixed generation not implemented in this case"); +} diff --git a/Processor/BaseMachine.cpp b/Processor/BaseMachine.cpp index 3a4a63b4..d49d08c7 100644 --- a/Processor/BaseMachine.cpp +++ b/Processor/BaseMachine.cpp @@ -67,6 +67,14 @@ void BaseMachine::load_schedule(const string& progname, bool load_bytecode) string threadname; for (int i=0; i> threadname; + size_t split = threadname.find(":"); + long expected = -1; + if (split != string::npos) + { + expected = atoi(threadname.substr(split + 1).c_str()); + threadname = threadname.substr(0, split); + } + string filename = "Programs/Bytecode/" + threadname + ".bc"; bc_filenames.push_back(filename); if (load_bytecode) @@ -74,8 +82,11 @@ void BaseMachine::load_schedule(const string& progname, bool load_bytecode) #ifdef DEBUG_FILES cerr << "Loading program " << i << " from " << filename << endl; #endif - load_program(threadname, filename); + long size = load_program(threadname, filename); + if (expected >= 0 and expected != size) + throw runtime_error("broken bytecode file"); } + } for (auto i : {1, 0, 0}) @@ -99,7 +110,8 @@ void BaseMachine::print_compiler() cerr << "Compiler: " << compiler << endl; } -void BaseMachine::load_program(const string& threadname, const string& filename) +size_t BaseMachine::load_program(const string& threadname, + const string& filename) { (void)threadname; (void)filename; diff --git a/Processor/BaseMachine.h b/Processor/BaseMachine.h index 564affe0..6b5a029f 100644 --- a/Processor/BaseMachine.h +++ b/Processor/BaseMachine.h @@ -31,7 +31,8 @@ protected: string domain; string relevant_opts; - virtual void load_program(const string& threadname, const string& filename); + virtual size_t load_program(const string& threadname, + const string& filename); public: static thread_local int thread_num; diff --git a/Processor/Data_Files.hpp b/Processor/Data_Files.hpp index 3d40e2ca..46c84903 100644 --- a/Processor/Data_Files.hpp +++ b/Processor/Data_Files.hpp @@ -7,8 +7,7 @@ #include "Protocols/dabit.h" #include "Math/Setup.h" #include "GC/BitPrepFiles.h" - -#include "Protocols/MascotPrep.hpp" +#include "Tools/benchmarking.h" template Preprocessing* Preprocessing::get_live_prep(SubProcessor* proc, @@ -44,6 +43,20 @@ Preprocessing* Preprocessing::get_new( BaseMachine::thread_num); } +template +T Preprocessing::get_random_from_inputs(int nplayers) +{ + T res; + for (int j = 0; j < nplayers; j++) + { + T tmp; + typename T::open_type _; + this->get_input_no_count(tmp, _, j); + res += tmp; + } + return res; +} + template Sub_Data_Files::Sub_Data_Files(const Names& N, DataPositions& usage, int thread_num) : diff --git a/Processor/Instruction.h b/Processor/Instruction.h index 1de58c99..011dcb58 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -84,6 +84,7 @@ enum SUBSI = 0x2A, SUBCFI = 0x2B, SUBSFI = 0x2C, + PREFIXSUMS = 0x2D, // Multiplication/division/other arithmetic MULC = 0x30, MULM = 0x31, diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index e3761f5f..da4dd01e 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -130,6 +130,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case DABIT: case SHUFFLE: case ACCEPTCLIENTCONNECTION: + case PREFIXSUMS: get_ints(r, s, 2); break; // instructions with 1 register operand @@ -458,6 +459,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case STMSDCI: case XORS: case ANDRS: + case ANDRSVEC: case ANDS: case INPUTB: case INPUTBVEC: @@ -646,6 +648,7 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const int offset = 0; int size_offset = 0; int size = this->size; + bool n_prefix = 0; // special treatment for instructions writing to different types switch (opcode) @@ -731,25 +734,17 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const offset = 1; size_offset = -1; break; + case ANDRSVEC: + n_prefix = 2; + break; case INPUTB: skip = 4; offset = 3; size_offset = -2; break; case INPUTBVEC: - { - int res = 0; - auto it = start.begin(); - while (it < start.end()) - { - int n = *it - 3; - it += 3; - assert(it + n <= start.end()); - for (int i = 0; i < n; i++) - res = max(res, *it++); - } - return res + 1; - } + n_prefix = 3; + break; case ANDM: case NOTS: case NOTCB: @@ -795,6 +790,22 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const break; } + if (n_prefix > 0) + { + int res = 0; + auto it = start.begin(); + while (it < start.end()) + { + int n = *it - n_prefix; + int size = DIV_CEIL(*(it + 1), 64); + it += n_prefix; + assert(it + n <= start.end()); + for (int i = 0; i < n; i++) + res = max(res, *it++ + size); + } + return res; + } + if (skip > 0) { unsigned m = 0; @@ -1323,8 +1334,13 @@ void Program::execute(Processor& Proc) const (void) start; #ifdef COUNT_INSTRUCTIONS +#ifdef TIME_INSTRUCTIONS + RunningTimer timer; + int PC = Proc.PC; +#else Proc.stats[p[Proc.PC].get_opcode()]++; #endif +#endif #ifdef OUTPUT_INSTRUCTIONS cerr << instruction << endl; @@ -1352,6 +1368,10 @@ void Program::execute(Processor& Proc) const default: instruction.execute(Proc); } + +#if defined(COUNT_INSTRUCTIONS) and defined(TIME_INSTRUCTIONS) + Proc.stats[p[PC].get_opcode()] += timer.elapsed() * 1e9; +#endif } } diff --git a/Processor/Machine.h b/Processor/Machine.h index 8b3d018c..d3c1346b 100644 --- a/Processor/Machine.h +++ b/Processor/Machine.h @@ -44,7 +44,7 @@ class Machine : public BaseMachine Player* P; - void load_program(const string& threadname, const string& filename); + size_t load_program(const string& threadname, const string& filename); void prepare(const string& progname_str); diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index 6fb9b5fc..4ff52608 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -199,7 +199,7 @@ Machine::~Machine() } template -void Machine::load_program(const string& threadname, +size_t Machine::load_program(const string& threadname, const string& filename) { progs.push_back(N.num_players()); @@ -208,6 +208,7 @@ void Machine::load_program(const string& threadname, M2.minimum_size(SGF2N, CGF2N, progs[i], threadname); Mp.minimum_size(SINT, CINT, progs[i], threadname); Mi.minimum_size(NONE, INT, progs[i], threadname); + return progs.back().size(); } template diff --git a/Processor/Online-Thread.hpp b/Processor/Online-Thread.hpp index 038f28d2..0c845ccf 100644 --- a/Processor/Online-Thread.hpp +++ b/Processor/Online-Thread.hpp @@ -126,7 +126,8 @@ void thread_info::Sub_Main_Func() program = job.prognum; wait_timer.stop(); #ifdef DEBUG_THREADS - printf("\tRunning program %d\n",program); + printf("\tRunning program %d/job %d in thread %d\n", program, job.type, + num); #endif if (program==-1) @@ -208,6 +209,10 @@ void thread_info::Sub_Main_Func() *(vector>*) job.output, job.length, job.prognum, job.arg, Proc.Procp, job.begin, job.end, job.supply); +#ifdef DEBUG_THREADS + printf("\tSignalling I have finished with job %d in thread %d\n", + job.type, num); +#endif queues->finished(job); } else if (job.type == PERSONAL_TRIPLE_JOB) @@ -282,7 +287,8 @@ void thread_info::Sub_Main_Func() } #ifdef DEBUG_THREADS - printf("\tSignalling I have finished\n"); + printf("\tSignalling I have finished with program %d" + "in thread %d\n", program, num); #endif wait_timer.start(); queues->finished(job, P.total_comm()); diff --git a/Processor/Processor.h b/Processor/Processor.h index 3fedb3df..273b3a66 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -32,7 +32,7 @@ class SubProcessor DataPositions bit_usage; - SecureShuffle shuffler; + typename T::Protocol::Shuffler shuffler; void resize(size_t size) { C.resize(size); S.resize(size); } diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index 7bd800c6..8236071e 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -651,8 +651,9 @@ void SubProcessor::conv2ds(const Instruction& instruction) template void SubProcessor::secure_shuffle(const Instruction& instruction) { - SecureShuffle(S, instruction.get_size(), instruction.get_n(), - instruction.get_r(0), instruction.get_r(1), *this); + typename T::Protocol::Shuffler(S, instruction.get_size(), + instruction.get_n(), instruction.get_r(0), instruction.get_r(1), + *this); } template diff --git a/Processor/Program.h b/Processor/Program.h index 8fb3df14..96a70e5e 100644 --- a/Processor/Program.h +++ b/Processor/Program.h @@ -36,6 +36,8 @@ class Program unknown_usage(false), writes_persistence(false) { compute_constants(); } + size_t size() const { return p.size(); } + // Read in a program void parse(string filename); void parse(istream& s); diff --git a/Processor/ThreadQueues.cpp b/Processor/ThreadQueues.cpp index de013499..ecca7bbe 100644 --- a/Processor/ThreadQueues.cpp +++ b/Processor/ThreadQueues.cpp @@ -19,6 +19,9 @@ int ThreadQueues::distribute(ThreadJob job, int n_items, int base, int ThreadQueues::find_available() { +#ifdef VERBOSE_QUEUES + cerr << available.size() << " threads in use" << endl; +#endif if (not available.empty()) return 0; for (size_t i = 1; i < size(); i++) @@ -32,7 +35,7 @@ int ThreadQueues::find_available() int ThreadQueues::get_n_per_thread(int n_items, int granularity) { - int n_per_thread = ceil(n_items / (available.size() + 1.0)) / granularity + int n_per_thread = int(ceil(n_items / (available.size() + 1.0)) / granularity) * granularity; return n_per_thread; } @@ -40,11 +43,23 @@ int ThreadQueues::get_n_per_thread(int n_items, int granularity) int ThreadQueues::distribute_no_setup(ThreadJob job, int n_items, int base, int granularity, const vector* supplies) { +#ifdef VERBOSE_QUEUES + cerr << "Distribute " << job.type << " among " << available.size() << endl; +#endif + int n_per_thread = get_n_per_thread(n_items, granularity); + + if (n_items and (n_per_thread == 0 or base + n_per_thread > n_items)) + { + available.clear(); + return base; + } + for (size_t i = 0; i < available.size(); i++) { if (base + (i + 1) * n_per_thread > size_t(n_items)) { + assert(i); available.resize(i); return base + i * n_per_thread; } @@ -59,7 +74,14 @@ int ThreadQueues::distribute_no_setup(ThreadJob job, int n_items, int base, void ThreadQueues::wrap_up(ThreadJob job) { +#ifdef VERBOSE_QUEUES + cerr << "Wrap up " << available.size() << " threads" << endl; +#endif for (int i : available) - assert(at(i)->result().output == job.output); + { + auto result = at(i)->result(); + assert(result.output == job.output); + assert(result.type == job.type); + } available.clear(); } diff --git a/Processor/instructions.h b/Processor/instructions.h index bf443b0f..f22fde8e 100644 --- a/Processor/instructions.h +++ b/Processor/instructions.h @@ -62,6 +62,9 @@ X(SUBCFI, auto dest = &Procp.get_C()[r[0]]; auto op1 = &Procp.get_C()[r[1]]; \ typename sint::clear op2 = int(n), \ *dest++ = op2 - *op1++) \ + X(PREFIXSUMS, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1]]; \ + sint s, \ + s += *op1++; *dest++ = s) \ X(MULM, auto dest = &Procp.get_S()[r[0]]; auto op1 = &Procp.get_S()[r[1]]; \ auto op2 = &Procp.get_C()[r[2]], \ *dest++ = *op1++ * *op2++) \ @@ -380,6 +383,10 @@ X(PREP, throw not_implemented(),) \ X(GPREP, throw not_implemented(),) \ X(CISC, throw not_implemented(),) \ + X(SECSHUFFLE, throw not_implemented(),) \ + X(GENSECSHUFFLE, throw not_implemented(),) \ + X(APPLYSHUFFLE, throw not_implemented(),) \ + X(DELSHUFFLE, throw not_implemented(),) \ #define ALL_INSTRUCTIONS ARITHMETIC_INSTRUCTIONS REGINT_INSTRUCTIONS \ CLEAR_GF2N_INSTRUCTIONS REMAINING_INSTRUCTIONS diff --git a/Programs/Source/adult.mpc b/Programs/Source/adult.mpc new file mode 100644 index 00000000..373e332b --- /dev/null +++ b/Programs/Source/adult.mpc @@ -0,0 +1,54 @@ +m = 6 +n_train = 32561 +n_test = 16281 + +combo = 'combo' in program.args +binary = 'binary' in program.args +mixed = 'mixed' in program.args +nocap = 'nocap' in program.args + +try: + n_threads = int(program.args[2]) +except: + n_threads = None + +if combo: + n_train += n_test + +if binary: + m = 60 + attr_lengths = [1] * m +elif mixed or nocap: + cont = 6 if mixed else 3 + m = 60 + cont + attr_lengths = [0] * cont + [1] * 60 +else: + attr_lengths = None + +program.set_bit_length(32) +program.options_from_args() + +train = sint.Array(n_train), sint.Matrix(m, n_train) +test = sint.Array(n_test), sint.Matrix(m, n_test) + +for x in train + test: + x.input_from(0) + +import decision_tree, util + +#decision_tree.debug_layers = True +decision_tree.max_leaves = 3000 + +if 'nearest' in program.args: + sfix.round_nearest = True + +sfix.set_precision_from_args(program, True) + +trainer = decision_tree.TreeTrainer( + train[1], train[0], int(program.args[1]), attr_lengths=attr_lengths, + n_threads=n_threads) +trainer.debug_selection = 'debug_selection' in program.args +trainer.debug_gini = True +layers = trainer.train_with_testing(*test) + +#decision_tree.output_decision_tree(layers) diff --git a/Programs/Source/bench-dt.mpc b/Programs/Source/bench-dt.mpc new file mode 100644 index 00000000..4c8c64c9 --- /dev/null +++ b/Programs/Source/bench-dt.mpc @@ -0,0 +1,32 @@ +binary = 'binary' in program.args + +program.set_bit_length(32) + +n_train = int(program.args[1]) +m = int(program.args[2]) + +try: + n_levels = int(program.args[3]) +except: + n_levels = 1 + +try: + n_threads = int(program.args[4]) +except: + n_threads = None + +train = sint.Array(n_train), sint.Matrix(m, n_train) + +import decision_tree, util + +decision_tree.max_leaves = 2000 + +if 'nearest' in program.args: + sfix.round_nearest = True + +layers = decision_tree.TreeTrainer( + train[1], train[0], n_levels, binary=binary, n_threads=n_threads).train() + +#decision_tree.output_decision_tree(layers) + +#decision_tree.test_decision_tree('foo', layers, *train) diff --git a/Programs/Source/benchmark_secureNN.mpc b/Programs/Source/benchmark_secureNN.mpc index 7bba218b..9b2f9674 100644 --- a/Programs/Source/benchmark_secureNN.mpc +++ b/Programs/Source/benchmark_secureNN.mpc @@ -28,7 +28,12 @@ NetworkC = [ (500, 10, 'FC') ] -network = globals()['Network' + program.args[1]] +try: + network = globals()['Network' + program.args[1]] +except: + import sys + print('Usage: %s [A,B,C,D]' % ' '.join(sys.argv)) + sys.exit(1) # c5.9xlarge has 36 cores n_threads = 8 diff --git a/Programs/Source/gc_oram.mpc b/Programs/Source/gc_oram.mpc index fa7ac702..5ddcb5a7 100644 --- a/Programs/Source/gc_oram.mpc +++ b/Programs/Source/gc_oram.mpc @@ -7,9 +7,6 @@ from Compiler.GC.instructions import * bits.unit = 64 -program.to_merge = [ldmsdi, stmsdi, ldmsd, stmsd, stmsdci, xors, andrs] -program.stop_class = type(None) - from Compiler.circuit_oram import * from Compiler import circuit_oram diff --git a/Programs/Source/mnist_full_A.mpc b/Programs/Source/mnist_full_A.mpc index 37cd73d2..caca2214 100644 --- a/Programs/Source/mnist_full_A.mpc +++ b/Programs/Source/mnist_full_A.mpc @@ -10,6 +10,7 @@ import util program.options_from_args() sfix.set_precision_from_args(program, adapt_ring=True) +ml.use_mux = 'mux' in program.args MultiArray.disable_index_checks() if 'profile' in program.args: diff --git a/Programs/Source/spect.mpc b/Programs/Source/spect.mpc new file mode 100644 index 00000000..95fb7d12 --- /dev/null +++ b/Programs/Source/spect.mpc @@ -0,0 +1,49 @@ +m = 22 +n_train = 80 +n_test = 187 + +debug = 'debug' in program.args +combo = 'combo' in program.args + +if debug: + n_train = 7 + +if combo: + n_train += n_test + +Array.check_indices = False +MultiArray.disable_index_checks() + +train = sint.Array(n_train), sint.Matrix(m, n_train) +test = sint.Array(n_test), sint.Matrix(m, n_test) + +for x in train: + x.input_from(0) + +if not (debug or combo): + for x in test: + x.input_from(0) + +import decision_tree, util + +#decision_tree.debug = True + +if 'nearest' in program.args: + sfix.round_nearest = True + +sfix.set_precision_from_args(program, True) + +try: + n_threads = int(program.args[3]) +except: + n_threads = None + +trainer = decision_tree.TreeTrainer( + train[1], train[0], int(program.args[1]), binary=int(program.args[2]), + n_threads=n_threads) + +if not (debug or combo): + layers = trainer.train_with_testing(*test) +else: + layers = trainer.train() + test_decision_tree('train', layers, y, x) diff --git a/Programs/Source/test_gc.mpc b/Programs/Source/test_gc.mpc index cc6a5ea1..9792aa66 100644 --- a/Programs/Source/test_gc.mpc +++ b/Programs/Source/test_gc.mpc @@ -71,7 +71,7 @@ test(r * sbit(1) + sbit(1) * r, 0) test(sbits.get_type(64)(2**64 - 1).popcnt(), 64) a = [sbits.new(x, 2) for x in range(4)] -x, y = sbits.trans(a) +x, y, *z = sbits.trans(a) test(x, 0xa) test(y, 0xc) diff --git a/Protocols/Beaver.h b/Protocols/Beaver.h index 9b695d0d..e24cad3a 100644 --- a/Protocols/Beaver.h +++ b/Protocols/Beaver.h @@ -27,6 +27,7 @@ protected: vector shares; vector opened; vector> triples; + vector lengths; typename vector::iterator it; typename vector>::iterator triple; Preprocessing* prep; diff --git a/Protocols/Beaver.hpp b/Protocols/Beaver.hpp index dc981487..8c89f420 100644 --- a/Protocols/Beaver.hpp +++ b/Protocols/Beaver.hpp @@ -37,6 +37,7 @@ void Beaver::init_mul() shares.clear(); opened.clear(); triples.clear(); + lengths.clear(); } template @@ -48,12 +49,19 @@ void Beaver::prepare_mul(const T& x, const T& y, int n) triple = prep->get_triple(n); shares.push_back(x - triple[0]); shares.push_back(y - triple[1]); + lengths.push_back(n); } template void Beaver::exchange() { - MC->POpen(opened, shares, P); + assert(shares.size() == 2 * lengths.size()); + MC->init_open(P, shares.size()); + for (size_t i = 0; i < shares.size(); i++) + MC->prepare_open(shares[i], lengths[i / 2]); + MC->exchange(P); + for (size_t i = 0; i < shares.size(); i++) + opened.push_back(MC->finalize_raw()); it = opened.begin(); triple = triples.begin(); } diff --git a/Protocols/DabitSacrifice.hpp b/Protocols/DabitSacrifice.hpp index 74d9f026..d6f485cc 100644 --- a/Protocols/DabitSacrifice.hpp +++ b/Protocols/DabitSacrifice.hpp @@ -109,7 +109,8 @@ void DabitSacrifice::sacrifice_and_check_bits(vector >& dabits, ThreadJob job(&products, &multiplicands); int start = queues->distribute(job, multiplicands.size()); protocol.multiply(products, multiplicands, start, multiplicands.size(), proc); - queues->wrap_up(job); + if (start) + queues->wrap_up(job); } else protocol.multiply(products, multiplicands, 0, multiplicands.size(), proc); diff --git a/Protocols/DealerMC.h b/Protocols/DealerMC.h index 4e668136..db1ed813 100644 --- a/Protocols/DealerMC.h +++ b/Protocols/DealerMC.h @@ -22,7 +22,7 @@ public: ~DealerMC(); void init_open(const Player& P, int n = 0); - void prepare_open(const T& secret); + void prepare_open(const T& secret, int n_bits = -1); void exchange(const Player& P); typename T::open_type finalize_raw(); array finalize_several(int n); diff --git a/Protocols/DealerMC.hpp b/Protocols/DealerMC.hpp index 0f63b93d..08b4b458 100644 --- a/Protocols/DealerMC.hpp +++ b/Protocols/DealerMC.hpp @@ -46,10 +46,10 @@ void DealerMC::init_open(const Player& P, int n) } template -void DealerMC::prepare_open(const T& secret) +void DealerMC::prepare_open(const T& secret, int n_bits) { if (sub_player) - internal.prepare_open(secret); + internal.prepare_open(secret, n_bits); else { if (secret != T()) diff --git a/Protocols/DealerPrep.hpp b/Protocols/DealerPrep.hpp index cc010dd7..ea334257 100644 --- a/Protocols/DealerPrep.hpp +++ b/Protocols/DealerPrep.hpp @@ -7,6 +7,7 @@ #define PROTOCOLS_DEALERPREP_HPP_ #include "DealerPrep.h" +#include "GC/SemiSecret.h" template void DealerPrep::buffer_triples() diff --git a/Protocols/FakeProtocol.h b/Protocols/FakeProtocol.h index 018ac338..c40308c5 100644 --- a/Protocols/FakeProtocol.h +++ b/Protocols/FakeProtocol.h @@ -13,6 +13,51 @@ #include +template +class FakeShuffle +{ +public: + FakeShuffle(SubProcessor&) + { + } + + FakeShuffle(vector& a, size_t n, int unit_size, size_t output_base, + size_t input_base, SubProcessor&) + { + apply(a, n, unit_size, output_base, input_base, 0, 0); + } + + size_t generate(size_t) + { + return 0; + } + + void apply(vector& a, size_t n, int unit_size, size_t output_base, + size_t input_base, int, bool) + { + auto source = a.begin() + input_base; + auto dest = a.begin() + output_base; + for (size_t i = 0; i < n; i++) + // just copy + *dest++ = *source++; + + if (n > 1) + { + // swap first two to pass check + for (int i = 0; i < unit_size; i++) + swap(a[output_base + i], a[output_base + i + unit_size]); + } + } + + void del(size_t) + { + } + + void inverse_permutation(vector&, size_t, size_t, size_t) + { + } +}; + template class FakeProtocol : public ProtocolBase { @@ -31,6 +76,8 @@ class FakeProtocol : public ProtocolBase map ltz_stats; public: + typedef FakeShuffle Shuffler; + Player& P; FakeProtocol(Player& P) : diff --git a/Protocols/HemiMatrixPrep.hpp b/Protocols/HemiMatrixPrep.hpp index 3446733e..062c2239 100644 --- a/Protocols/HemiMatrixPrep.hpp +++ b/Protocols/HemiMatrixPrep.hpp @@ -4,6 +4,7 @@ */ #include "HemiMatrixPrep.h" +#include "MAC_Check.h" #include "FHE/Diagonalizer.h" #include "Tools/Bundle.h" @@ -113,7 +114,8 @@ void HemiMatrixPrep::buffer_triples() job.begin = start; job.end = n_matrices; matrix_rand_mult(job); - queues.wrap_up(job); + if (start) + queues.wrap_up(job); } else { @@ -177,7 +179,8 @@ void HemiMatrixPrep::buffer_triples() #endif for (int i = start; i < n_inner; i++) products[i] = multiplicands.at(i) * multiplicands2.at(i); - queues.wrap_up(job); + if (start) + queues.wrap_up(job); #ifdef VERBOSE_HE fprintf(stderr, "adding at %f\n", timer.elapsed()); fflush(stderr); diff --git a/Protocols/HemiPrep.h b/Protocols/HemiPrep.h index b2b510aa..6db5bf43 100644 --- a/Protocols/HemiPrep.h +++ b/Protocols/HemiPrep.h @@ -30,6 +30,10 @@ class HemiPrep : public SemiHonestRingPrep map timers; + SemiPrep* two_party_prep; + + SemiPrep& get_two_party_prep(); + public: static void basic_setup(Player& P); static void teardown(); @@ -40,7 +44,7 @@ public: HemiPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), RingPrep(proc, usage), - SemiHonestRingPrep(proc, usage) + SemiHonestRingPrep(proc, usage), two_party_prep(0) { } ~HemiPrep(); @@ -48,6 +52,9 @@ public: vector*>& get_multipliers(); void buffer_triples(); + + void buffer_bits(); + void buffer_dabits(ThreadQueues* queues); }; #endif /* PROTOCOLS_HEMIPREP_H_ */ diff --git a/Protocols/HemiPrep.hpp b/Protocols/HemiPrep.hpp index ce55bce7..099466de 100644 --- a/Protocols/HemiPrep.hpp +++ b/Protocols/HemiPrep.hpp @@ -56,6 +56,13 @@ HemiPrep::~HemiPrep() { for (auto& x : multipliers) delete x; + + if (two_party_prep) + { + auto& usage = two_party_prep->usage; + delete two_party_prep; + delete &usage; + } } template @@ -110,4 +117,51 @@ void HemiPrep::buffer_triples() {{ a.element(i), b.element(i), c.element(i) }}); } +template +SemiPrep& HemiPrep::get_two_party_prep() +{ + assert(this->proc); + assert(this->proc->P.num_players() == 2); + + if (not two_party_prep) + { + two_party_prep = new SemiPrep(this->proc, + *new DataPositions(this->proc->P.num_players())); + two_party_prep->set_protocol(this->proc->protocol); + } + + return *two_party_prep; +} + +template +void HemiPrep::buffer_bits() +{ + assert(this->proc); + if (this->proc->P.num_players() == 2) + { + auto& prep = get_two_party_prep(); + prep.buffer_dabits(0); + for (auto& x : prep.dabits) + this->bits.push_back(x.first); + prep.dabits.clear(); + } + else + SemiHonestRingPrep::buffer_bits(); +} + +template +void HemiPrep::buffer_dabits(ThreadQueues* queues) +{ + assert(this->proc); + if (this->proc->P.num_players() == 2) + { + auto& prep = get_two_party_prep(); + prep.buffer_dabits(queues); + this->dabits = prep.dabits; + prep.dabits.clear(); + } + else + SemiHonestRingPrep::buffer_dabits(queues); +} + #endif diff --git a/Protocols/HighGearKeyGen.hpp b/Protocols/HighGearKeyGen.hpp index 49fa6702..4c405472 100644 --- a/Protocols/HighGearKeyGen.hpp +++ b/Protocols/HighGearKeyGen.hpp @@ -168,7 +168,7 @@ void HighGearKeyGen::run(PartSetup& setup, MachineBase& machine) timer.reset(); map timers; - SimpleEncCommit_ EC(P, setup.pk, setup.FieldD, timers, machine, 0, true); + SummingEncCommit EC(P, setup.pk, setup.FieldD, timers, machine, 0, true); Plaintext_ alpha(setup.FieldD); EC.next(alpha, setup.calpha); assert(alpha.is_diagonal()); diff --git a/Protocols/LowGearKeyGen.hpp b/Protocols/LowGearKeyGen.hpp index 4b9d0d05..5056c3a8 100644 --- a/Protocols/LowGearKeyGen.hpp +++ b/Protocols/LowGearKeyGen.hpp @@ -10,6 +10,7 @@ #include "Machines/SPDZ.hpp" #include "ShareVector.hpp" +#include "MascotPrep.hpp" template LowGearKeyGen::LowGearKeyGen(Player& P, PairwiseMachine& machine, diff --git a/Protocols/MAC_Check.h b/Protocols/MAC_Check.h index d0b062c4..311de4d9 100644 --- a/Protocols/MAC_Check.h +++ b/Protocols/MAC_Check.h @@ -92,7 +92,7 @@ class Tree_MAC_Check : public TreeSum, public MAC_Check_B virtual ~Tree_MAC_Check(); virtual void init_open(const Player& P, int n = 0); - virtual void prepare_open(const U& secret); + virtual void prepare_open(const U& secret, int = -1); virtual void exchange(const Player& P); virtual void AddToCheck(const U& share, const T& value, const Player& P); @@ -143,7 +143,7 @@ public: MAC_Check_Z2k(const T& ai, int opening_sum=10, int max_broadcast=10, int send_player=0); MAC_Check_Z2k(const T& ai, Names& Nms, int thread_num); - void prepare_open(const W& secret); + void prepare_open(const W& secret, int = -1); void prepare_open_no_mask(const W& secret); virtual void Check(const Player& P); @@ -184,7 +184,7 @@ public: ~Direct_MAC_Check(); void init_open(const Player& P, int n = 0); - void prepare_open(const T& secret); + void prepare_open(const T& secret, int = -1); void exchange(const Player& P); }; diff --git a/Protocols/MAC_Check.hpp b/Protocols/MAC_Check.hpp index 5798d9a4..fe6a0108 100644 --- a/Protocols/MAC_Check.hpp +++ b/Protocols/MAC_Check.hpp @@ -96,7 +96,7 @@ void Tree_MAC_Check::init_open(const Player&, int n) } template -void Tree_MAC_Check::prepare_open(const U& secret) +void Tree_MAC_Check::prepare_open(const U& secret, int) { this->values.push_back(secret.get_share()); macs.push_back(secret.get_mac()); @@ -242,7 +242,7 @@ MAC_Check_Z2k::MAC_Check_Z2k(const T& ai, Names& Nms, } template -void MAC_Check_Z2k::prepare_open(const W& secret) +void MAC_Check_Z2k::prepare_open(const W& secret, int) { prepare_open_no_mask(secret + (get_random_element() << W::clear::N_BITS)); } @@ -402,7 +402,7 @@ void Direct_MAC_Check::init_open(const Player& P, int n) } template -void Direct_MAC_Check::prepare_open(const T& secret) +void Direct_MAC_Check::prepare_open(const T& secret, int) { this->values.push_back(secret.get_share()); this->macs.push_back(secret.get_mac()); diff --git a/Protocols/MAC_Check_Base.h b/Protocols/MAC_Check_Base.h index b4f684bc..fed190ef 100644 --- a/Protocols/MAC_Check_Base.h +++ b/Protocols/MAC_Check_Base.h @@ -59,7 +59,7 @@ public: /// Initialize opening round virtual void init_open(const Player& P, int n = 0); /// Add value to be opened - virtual void prepare_open(const T& secret); + virtual void prepare_open(const T& secret, int n_bits = -1); /// Run opening protocol virtual void exchange(const Player& P) = 0; /// Get next opened value diff --git a/Protocols/MAC_Check_Base.hpp b/Protocols/MAC_Check_Base.hpp index 47528e00..01096fa9 100644 --- a/Protocols/MAC_Check_Base.hpp +++ b/Protocols/MAC_Check_Base.hpp @@ -53,7 +53,7 @@ void MAC_Check_Base::init_open(const Player&, int n) } template -void MAC_Check_Base::prepare_open(const T& secret) +void MAC_Check_Base::prepare_open(const T& secret, int) { secrets.push_back(secret); } diff --git a/Protocols/MalRepRingPrep.hpp b/Protocols/MalRepRingPrep.hpp index 6ce2e244..1be5bb9a 100644 --- a/Protocols/MalRepRingPrep.hpp +++ b/Protocols/MalRepRingPrep.hpp @@ -165,7 +165,8 @@ void TripleShuffleSacrifice::triple_sacrifice(vector>& triples, TripleSacrificeJob job(&triples, &check_triples); int start = queues->distribute(job, N); triple_sacrifice(triples, check_triples, P, MC, start, N); - queues->wrap_up(job); + if (start) + queues->wrap_up(job); } else triple_sacrifice(triples, check_triples, P, MC, 0, N); diff --git a/Protocols/MaliciousRepPrep.hpp b/Protocols/MaliciousRepPrep.hpp index b7296748..ec26bc84 100644 --- a/Protocols/MaliciousRepPrep.hpp +++ b/Protocols/MaliciousRepPrep.hpp @@ -3,6 +3,9 @@ * */ +#ifndef PROTOCOLS_MALICIOUSREPPREP_HPP_ +#define PROTOCOLS_MALICIOUSREPPREP_HPP_ + #include "MaliciousRepPrep.h" #include "Tools/Subroutines.h" #include "Processor/OnlineOptions.h" @@ -232,3 +235,5 @@ void MaliciousRepPrep::buffer_inputs(int player) assert(proc); this->buffer_inputs_as_usual(player, proc); } + +#endif diff --git a/Protocols/MascotPrep.hpp b/Protocols/MascotPrep.hpp index 1393bb46..f5c09941 100644 --- a/Protocols/MascotPrep.hpp +++ b/Protocols/MascotPrep.hpp @@ -108,18 +108,4 @@ void MascotInputPrep::buffer_inputs(int player) this->inputs[player].push_back(input); } -template -T Preprocessing::get_random_from_inputs(int nplayers) -{ - T res; - for (int j = 0; j < nplayers; j++) - { - T tmp; - typename T::open_type _; - this->get_input_no_count(tmp, _, j); - res += tmp; - } - return res; -} - #endif diff --git a/Protocols/PostSacriRepRingShare.h b/Protocols/PostSacriRepRingShare.h index 7cbd483c..fccb9e0c 100644 --- a/Protocols/PostSacriRepRingShare.h +++ b/Protocols/PostSacriRepRingShare.h @@ -9,6 +9,7 @@ #include "Protocols/MaliciousRep3Share.h" #include "Protocols/MalRepRingShare.h" #include "Protocols/Rep3Share2k.h" +#include "GC/MaliciousRepSecret.h" template class MalRepRingPrepWithBits; template class PostSacrifice; diff --git a/Protocols/ProtocolSetup.h b/Protocols/ProtocolSetup.h index 2f417c42..63c47172 100644 --- a/Protocols/ProtocolSetup.h +++ b/Protocols/ProtocolSetup.h @@ -65,6 +65,14 @@ public: { return mac_key; } + + /** + * Set how much preprocessing is produced at once. + */ + static void set_batch_size(size_t batch_size) + { + OnlineOptions::singleton.batch_size = batch_size; + } }; /** diff --git a/Protocols/Rep3Share.h b/Protocols/Rep3Share.h index fb02d26f..cd321b26 100644 --- a/Protocols/Rep3Share.h +++ b/Protocols/Rep3Share.h @@ -9,11 +9,13 @@ #include "Math/FixedVec.h" #include "Math/Integer.h" #include "Protocols/Replicated.h" +#include "Protocols/Rep3Shuffler.h" #include "GC/ShareSecret.h" #include "ShareInterface.h" #include "Processor/Instruction.h" template class ReplicatedPrep; +template class SemiRep3Prep; template class ReplicatedRingPrep; template class ReplicatedPO; template class SpecificPrivateOutput; @@ -109,7 +111,8 @@ public: typedef ReplicatedInput Input; typedef ReplicatedPO PO; typedef SpecificPrivateOutput PrivateOutput; - typedef ReplicatedPrep LivePrep; + typedef typename conditional, SemiRep3Prep>::type LivePrep; typedef ReplicatedRingPrep TriplePrep; typedef Rep3Share Honest; diff --git a/Protocols/Rep3Share2k.h b/Protocols/Rep3Share2k.h index e52d160b..0fc2e50e 100644 --- a/Protocols/Rep3Share2k.h +++ b/Protocols/Rep3Share2k.h @@ -11,7 +11,7 @@ #include "Math/Z2k.h" #include "GC/square64.h" -template class ReplicatedPrep2k; +template class SemiRep3Prep; template class Rep3Share2 : public Rep3Share> @@ -26,7 +26,7 @@ public: typedef ReplicatedInput Input; typedef ReplicatedPO PO; typedef SpecificPrivateOutput PrivateOutput; - typedef ReplicatedPrep2k LivePrep; + typedef SemiRep3Prep LivePrep; typedef Rep3Share2 Honest; typedef SignedZ2 clear; diff --git a/Protocols/Rep3Shuffler.h b/Protocols/Rep3Shuffler.h new file mode 100644 index 00000000..ec80a48e --- /dev/null +++ b/Protocols/Rep3Shuffler.h @@ -0,0 +1,33 @@ +/* + * Rep3Shuffler.h + * + */ + +#ifndef PROTOCOLS_REP3SHUFFLER_H_ +#define PROTOCOLS_REP3SHUFFLER_H_ + +template +class Rep3Shuffler +{ + SubProcessor& proc; + + vector, 2>> shuffles; + +public: + Rep3Shuffler(vector& a, size_t n, int unit_size, size_t output_base, + size_t input_base, SubProcessor& proc); + + Rep3Shuffler(SubProcessor& proc); + + int generate(int n_shuffle); + + void apply(vector& a, size_t n, int unit_size, size_t output_base, + size_t input_base, int handle, bool reverse); + + void inverse_permutation(vector& stack, size_t n, size_t output_base, + size_t input_base); + + void del(int handle); +}; + +#endif /* PROTOCOLS_REP3SHUFFLER_H_ */ diff --git a/Protocols/Rep3Shuffler.hpp b/Protocols/Rep3Shuffler.hpp new file mode 100644 index 00000000..a2edfb76 --- /dev/null +++ b/Protocols/Rep3Shuffler.hpp @@ -0,0 +1,131 @@ +/* + * Rep3Shuffler.cpp + * + */ + +#ifndef PROTOCOLS_REP3SHUFFLER_HPP_ +#define PROTOCOLS_REP3SHUFFLER_HPP_ + +#include "Rep3Shuffler.h" + +template +Rep3Shuffler::Rep3Shuffler(vector& a, size_t n, int unit_size, + size_t output_base, size_t input_base, SubProcessor& proc) : + proc(proc) +{ + apply(a, n, unit_size, output_base, input_base, generate(n / unit_size), + false); + shuffles.pop_back(); +} + +template +Rep3Shuffler::Rep3Shuffler(SubProcessor& proc) : + proc(proc) +{ +} + +template +int Rep3Shuffler::generate(int n_shuffle) +{ + shuffles.push_back({}); + auto& shuffle = shuffles.back(); + for (int i = 0; i < 2; i++) + { + auto& perm = shuffle[i]; + for (int j = 0; j < n_shuffle; j++) + perm.push_back(j); + for (int j = 0; j < n_shuffle; j++) + { + int k = proc.protocol.shared_prngs[i].get_uint(n_shuffle - j); + swap(perm[k], perm[k + j]); + } + } + return shuffles.size() - 1; +} + +template +void Rep3Shuffler::apply(vector& a, size_t n, int unit_size, + size_t output_base, size_t input_base, int handle, bool reverse) +{ + assert(proc.P.num_players() == 3); + assert(not T::malicious); + assert(not T::dishonest_majority); + assert(n % unit_size == 0); + + auto& shuffle = shuffles.at(handle); + vector to_shuffle; + for (size_t i = 0; i < n; i++) + to_shuffle.push_back(a[input_base + i]); + + typename T::Input input(proc); + + vector to_share(n); + + for (int ii = 0; ii < 3; ii++) + { + int i; + if (reverse) + i = 2 - ii; + else + i = ii; + + if (proc.P.get_player(i) == 0) + { + for (size_t j = 0; j < n / unit_size; j++) + for (int k = 0; k < unit_size; k++) + if (reverse) + to_share.at(j * unit_size + k) = to_shuffle.at( + shuffle[0].at(j) * unit_size + k).sum(); + else + to_share.at(shuffle[0].at(j) * unit_size + k) = + to_shuffle.at(j * unit_size + k).sum(); + } + else if (proc.P.get_player(i) == 1) + { + for (size_t j = 0; j < n / unit_size; j++) + for (int k = 0; k < unit_size; k++) + if (reverse) + to_share[j * unit_size + k] = to_shuffle[shuffle[1][j] + * unit_size + k][0]; + else + to_share[shuffle[1][j] * unit_size + k] = to_shuffle[j + * unit_size + k][0]; + } + + input.reset_all(proc.P); + + if (proc.P.get_player(i) < 2) + for (auto& x : to_share) + input.add_mine(x); + + for (int k = 0; k < 2; k++) + input.add_other((-i + 3 + k) % 3); + + input.exchange(); + to_shuffle.clear(); + + for (size_t j = 0; j < n; j++) + { + T x = input.finalize((-i + 3) % 3) + input.finalize((-i + 4) % 3); + to_shuffle.push_back(x); + } + } + + for (size_t i = 0; i < n; i++) + a[output_base + i] = to_shuffle[i]; +} + +template +void Rep3Shuffler::del(int handle) +{ + for (int i = 0; i < 2; i++) + shuffles.at(handle)[i].clear(); +} + +template +void Rep3Shuffler::inverse_permutation(vector&, size_t, size_t, size_t) +{ + throw runtime_error("inverse permutation not implemented"); +} + +#endif diff --git a/Protocols/Replicated.h b/Protocols/Replicated.h index 48b01440..05f132ce 100644 --- a/Protocols/Replicated.h +++ b/Protocols/Replicated.h @@ -20,6 +20,8 @@ template class SubProcessor; template class ReplicatedMC; template class ReplicatedInput; template class Preprocessing; +template class SecureShuffle; +template class Rep3Shuffler; class Instruction; /** @@ -59,6 +61,8 @@ protected: public: typedef T share_type; + typedef SecureShuffle Shuffler; + int counter; ProtocolBase(); @@ -81,6 +85,7 @@ public: virtual void init_mul() = 0; /// Schedule multiplication of operand pair virtual void prepare_mul(const T& x, const T& y, int n = -1) = 0; + virtual void prepare_mult(const T& x, const T& y, int n, bool repeat); /// Run multiplication protocol virtual void exchange() = 0; /// Get next multiplication result @@ -143,6 +148,8 @@ class Replicated : public ReplicatedBase, public ProtocolBase public: static const bool uses_triples = false; + typedef Rep3Shuffler Shuffler; + Replicated(Player& P); Replicated(const ReplicatedBase& other); diff --git a/Protocols/Replicated.hpp b/Protocols/Replicated.hpp index f398da7f..494a7e0d 100644 --- a/Protocols/Replicated.hpp +++ b/Protocols/Replicated.hpp @@ -41,7 +41,7 @@ inline ReplicatedBase::ReplicatedBase(Player& P) : P(P) { assert(P.num_players() == 3); if (not P.is_encrypted()) - insecure("unencrypted communication"); + insecure("unencrypted communication", false); shared_prngs[0].ReSeed(); octetStream os; @@ -121,6 +121,13 @@ T ProtocolBase::mul(const T& x, const T& y) return finalize_mul(); } +template +void ProtocolBase::prepare_mult(const T& x, const T& y, int n, + bool) +{ + prepare_mul(x, y, n); +} + template void ProtocolBase::finalize_mult(T& res, int n) { diff --git a/Protocols/ReplicatedInput.h b/Protocols/ReplicatedInput.h index 9e1498df..26451f17 100644 --- a/Protocols/ReplicatedInput.h +++ b/Protocols/ReplicatedInput.h @@ -17,19 +17,17 @@ template class PrepLessInput : public InputBase { protected: - vector shares; - size_t i_share; + PointerVector shares; public: PrepLessInput(SubProcessor* proc) : - InputBase(proc ? proc->Proc : 0), i_share(0) {} + InputBase(proc ? proc->Proc : 0) {} virtual ~PrepLessInput() {} virtual void reset(int player) = 0; virtual void add_mine(const typename T::open_type& input, int n_bits = -1) = 0; virtual void add_other(int player, int n_bits = - 1) = 0; - virtual void send_mine() = 0; virtual void finalize_other(int player, T& target, octetStream& o, int n_bits = -1) = 0; diff --git a/Protocols/ReplicatedInput.hpp b/Protocols/ReplicatedInput.hpp index 1cfac4a1..ffc34d6f 100644 --- a/Protocols/ReplicatedInput.hpp +++ b/Protocols/ReplicatedInput.hpp @@ -19,7 +19,6 @@ void ReplicatedInput::reset(int player) if (player == P.my_num()) { this->shares.clear(); - this->i_share = 0; os.resize(2); for (auto& o : os) o.reset_write_head(); @@ -89,7 +88,7 @@ inline void ReplicatedInput::finalize_other(int player, T& target, template T PrepLessInput::finalize_mine() { - return this->shares[this->i_share++]; + return this->shares.next(); } #endif diff --git a/Protocols/ReplicatedPrep.h b/Protocols/ReplicatedPrep.h index 8a30749c..e73d9cc2 100644 --- a/Protocols/ReplicatedPrep.h +++ b/Protocols/ReplicatedPrep.h @@ -38,6 +38,8 @@ class BufferPrep : public Preprocessing template void buffer_inverses(false_type) { throw runtime_error("no inverses"); } + virtual bool bits_from_dabits() { return false; } + protected: vector> triples; vector> squares; diff --git a/Protocols/ReplicatedPrep.hpp b/Protocols/ReplicatedPrep.hpp index a172b05b..867b844d 100644 --- a/Protocols/ReplicatedPrep.hpp +++ b/Protocols/ReplicatedPrep.hpp @@ -7,6 +7,7 @@ #define PROTOCOlS_REPLICATEDPREP_HPP_ #include "ReplicatedPrep.h" +#include "SemiRep3Prep.h" #include "DabitSacrifice.h" #include "Spdz2kPrep.h" @@ -64,17 +65,24 @@ BufferPrep::~BufferPrep() * T::default_length); size_t used_bits = my_usage.at(DATA_BIT); - if (not T::clear::invertible and field_type == DATA_INT and not T::has_mac) - // add dabits with computation modulo power of two but without MAC - used_bits += my_usage.at(DATA_DABIT); + size_t used_dabits = my_usage.at(DATA_DABIT); + if (bits_from_dabits()) + { + if (not T::clear::invertible and field_type == DATA_INT and not T::has_mac) + // add dabits with computation modulo power of two but without MAC + used_bits += my_usage.at(DATA_DABIT); + } + else + used_dabits += used_bits; + this->print_left("bits", bits.size(), type_string, used_bits); + this->print_left("dabits", dabits.size(), type_string, used_dabits); #define X(KIND, TYPE) \ this->print_left(#KIND, KIND.size(), type_string, \ this->usage.files.at(T::clear::field_type()).at(TYPE)); X(squares, DATA_SQUARE) X(inverses, DATA_INVERSE) - X(dabits, DATA_DABIT) #undef X for (auto& x : this->edabits) @@ -549,7 +557,8 @@ void MaliciousRingPrep::buffer_personal_edabits(int n_bits, vector& wholes int start = queues->distribute(job, buffer_size, 0, BT::default_length); this->template buffer_personal_edabits_without_check<0>(n_bits, sums, bits, proc, input_player, start, buffer_size); - queues->wrap_up(job); + if (start) + queues->wrap_up(job); } else this->template buffer_personal_edabits_without_check<0>(n_bits, sums, @@ -651,12 +660,78 @@ void RingPrep::buffer_dabits_without_check(vector>& dabits, int start = queues->distribute(job, buffer_size, old_size); this->buffer_dabits_without_check(dabits, start, dabits.size()); - queues->wrap_up(job); + if (start > old_size) + queues->wrap_up(job); } else buffer_dabits_without_check(dabits, old_size, dabits.size()); } +template +void SemiRep3Prep::buffer_dabits(ThreadQueues*) +{ + assert(this->protocol); + assert(this->proc); + + typedef typename T::bit_type BT; + int n_blocks = DIV_CEIL(this->buffer_size, BT::default_length); + int n_bits = n_blocks * BT::default_length; + + vector b(n_blocks); + + vector> a(n_bits); + Player& P = this->proc->P; + + for (int i = 0; i < 2; i++) + { + for (auto& x : b) + x[i].randomize(this->protocol->shared_prngs[i]); + + int j = P.get_offset(i); + + for (int k = 0; k < n_bits; k++) + a[k][j][i] = b[k / BT::default_length][i].get_bit( + k % BT::default_length); + } + + // the first multiplication + vector first(n_bits), second(n_bits); + typename T::Input input(P); + + if (P.my_num() == 0) + { + for (auto& x : a) + input.add_mine(x[0][0] * x[1][1]); + } + else + input.add_other(0); + + input.exchange(); + + for (int k = 0; k < n_bits; k++) + first[k] = a[k][0] + a[k][1] - 2 * input.finalize(0); + + input.reset_all(P); + + if (P.my_num() != 0) + { + for (int k = 0; k < n_bits; k++) + input.add_mine(first[k].local_mul(a[k][2])); + } + + input.add_other(1); + input.add_other(2); + input.exchange(); + + for (int k = 0; k < n_bits; k++) + { + second[k] = first[k] + a[k][2] + - 2 * (input.finalize(1) + input.finalize(2)); + this->dabits.push_back({second[k], + b[k / BT::default_length].get_bit(k % BT::default_length)}); + } +} + template void RingPrep::buffer_dabits_without_check(vector>& dabits, size_t begin, size_t end) @@ -718,7 +793,8 @@ void RingPrep::buffer_edabits_without_check(int n_bits, vector& sums, ThreadJob job(n_bits, &sums, &bits); int start = queues->distribute(job, rounded, 0, dl); buffer_edabits_without_check<0>(n_bits, sums, bits, start, rounded); - queues->wrap_up(job); + if (start) + queues->wrap_up(job); } else buffer_edabits_without_check<0>(n_bits, sums, bits, 0, rounded); @@ -844,7 +920,8 @@ void RingPrep::sanitize(vector>& edabits, int n_bits, SanitizeJob job(&edabits, n_bits, player); int start = queues->distribute(job, edabits.size()); sanitize<0>(edabits, n_bits, player, start, edabits.size()); - queues->wrap_up(job); + if (start) + queues->wrap_up(job); } else sanitize<0>(edabits, n_bits, player, 0, edabits.size()); @@ -1027,6 +1104,7 @@ void BufferPrep::get_dabit_no_count(T& a, typename T::bit_type& b) InScope in_scope(this->do_count, false); ThreadQueues* queues = 0; buffer_dabits(queues); + assert(not dabits.empty()); } a = dabits.back().first; b = dabits.back().second; @@ -1085,7 +1163,7 @@ template void BufferPrep::buffer_edabits_with_queues(bool strict, int n_bits) { ThreadQueues* queues = 0; - if (BaseMachine::thread_num == 0) + if (BaseMachine::thread_num == 0 and BaseMachine::has_singleton()) queues = &BaseMachine::s().queues; buffer_edabits(strict, n_bits, queues); } diff --git a/Protocols/SecureShuffle.hpp b/Protocols/SecureShuffle.hpp index d8c3d8e6..752798b2 100644 --- a/Protocols/SecureShuffle.hpp +++ b/Protocols/SecureShuffle.hpp @@ -21,7 +21,7 @@ SecureShuffle::SecureShuffle(SubProcessor& proc) : template SecureShuffle::SecureShuffle(vector& a, size_t n, int unit_size, size_t output_base, size_t input_base, SubProcessor& proc) : - proc(proc), unit_size(unit_size) + proc(proc), unit_size(unit_size), n_shuffle(0), exact(false) { pre(a, n, input_base); diff --git a/Protocols/Semi.h b/Protocols/Semi.h index 5f63a9d6..903aca6d 100644 --- a/Protocols/Semi.h +++ b/Protocols/Semi.h @@ -59,7 +59,20 @@ public: for (auto& info : infos) { if (not info.big_gap()) - throw runtime_error("bit length too large"); + { + if (not T::clear::invertible) + { + int min_size = 64 * DIV_CEIL( + info.k + OnlineOptions::singleton.trunc_error, 64); + throw runtime_error( + "Bit length too large for trunc_pr. " + "Disable it or increase the ring size " + "during compilation using '-R " + + to_string(min_size) + "'."); + } + else + throw runtime_error("bit length too large"); + } if (this->P.my_num()) for (int i = 0; i < size; i++) proc.get_S_ref(info.dest_base + i) = -open_type( diff --git a/Protocols/SemiInput.h b/Protocols/SemiInput.h index c40d0c17..d4c864f0 100644 --- a/Protocols/SemiInput.h +++ b/Protocols/SemiInput.h @@ -6,20 +6,28 @@ #ifndef PROTOCOLS_SEMIINPUT_H_ #define PROTOCOLS_SEMIINPUT_H_ -#include "ShamirInput.h" +#include "ReplicatedInput.h" template class SemiMC; +template +class PairwiseKeyInput : public PrepLessInput +{ +protected: + vector send_prngs; + vector recv_prngs; + +public: + PairwiseKeyInput(SubProcessor* proc, PlayerBase& P); +}; + /** * Additive secret sharing input protocol */ template -class SemiInput : public InputBase +class SemiInput : public PairwiseKeyInput { - vector send_prngs; - vector recv_prngs; PlayerBase& P; - vector> shares; public: SemiInput(SubProcessor& proc, SemiMC&) : diff --git a/Protocols/SemiInput.hpp b/Protocols/SemiInput.hpp index f0fefe13..7ab4a855 100644 --- a/Protocols/SemiInput.hpp +++ b/Protocols/SemiInput.hpp @@ -12,9 +12,15 @@ template SemiInput::SemiInput(SubProcessor* proc, PlayerBase& P) : - InputBase(proc), P(P) + PairwiseKeyInput(proc, P), P(P) +{ + this->reset_all(P); +} + +template +PairwiseKeyInput::PairwiseKeyInput(SubProcessor* proc, PlayerBase& P) : + PrepLessInput(proc) { - shares.resize(P.num_players()); vector to_send(P.num_players()), to_receive; for (int i = 0; i < P.num_players(); i++) { @@ -26,13 +32,13 @@ SemiInput::SemiInput(SubProcessor* proc, PlayerBase& P) : for (int i = 0; i < P.num_players(); i++) if (i != P.my_num()) recv_prngs[i].SetSeed(to_receive[i].consume(SEED_SIZE)); - this->reset_all(P); } template void SemiInput::reset(int player) { - shares[player].clear(); + if (player == P.my_num()) + this->shares.clear(); } template @@ -43,9 +49,9 @@ void SemiInput::add_mine(const typename T::clear& input, int) for (int i = 0; i < P.num_players(); i++) { if (i != P.my_num()) - sum += send_prngs[i].template get(); + sum += this->send_prngs[i].template get(); } - shares[P.my_num()].push_back(input - sum); + this->shares.push_back(input - sum); } template @@ -62,13 +68,13 @@ template void SemiInput::finalize_other(int player, T& target, octetStream&, int) { - target = recv_prngs[player].template get(); + target = this->recv_prngs[player].template get(); } template T SemiInput::finalize_mine() { - return shares[P.my_num()].next(); + return this->shares.next(); } #endif diff --git a/Protocols/SemiMC.h b/Protocols/SemiMC.h index fe4d9db6..27fd3b71 100644 --- a/Protocols/SemiMC.h +++ b/Protocols/SemiMC.h @@ -15,13 +15,17 @@ template class SemiMC : public TreeSum, public MAC_Check_Base { +protected: + vector lengths; + public: // emulate MAC_Check SemiMC(const typename T::mac_key_type& _ = {}, int __ = 0, int ___ = 0) { (void)_; (void)__; (void)___; } virtual ~SemiMC() {} - virtual void prepare_open(const T& secret); + virtual void init_open(const Player& P, int n = 0); + virtual void prepare_open(const T& secret, int n_bits = -1); virtual void exchange(const Player& P); void Check(const Player& P) { (void)P; } diff --git a/Protocols/SemiMC.hpp b/Protocols/SemiMC.hpp index b5487857..75aa0c6e 100644 --- a/Protocols/SemiMC.hpp +++ b/Protocols/SemiMC.hpp @@ -11,9 +11,18 @@ #include "MAC_Check.hpp" template -void SemiMC::prepare_open(const T& secret) +void SemiMC::init_open(const Player& P, int n) +{ + MAC_Check_Base::init_open(P, n); + lengths.clear(); + lengths.reserve(n); +} + +template +void SemiMC::prepare_open(const T& secret, int n_bits) { this->values.push_back(secret); + lengths.push_back(n_bits); } template @@ -28,6 +37,8 @@ void DirectSemiMC::POpen_(vector& values, { this->values.clear(); this->values.reserve(S.size()); + this->lengths.clear(); + this->lengths.reserve(S.size()); for (auto& secret : S) this->prepare_open(secret); this->exchange_(P); @@ -39,10 +50,20 @@ void DirectSemiMC::exchange_(const PlayerBase& P) { Bundle oss(P); oss.mine.reserve(this->values.size()); - for (auto& x : this->values) - x.pack(oss.mine); + assert(this->values.size() == this->lengths.size()); + for (size_t i = 0; i < this->lengths.size(); i++) + this->values[i].pack(oss.mine, this->lengths[i]); P.unchecked_broadcast(oss); - direct_add_openings(this->values, P, oss); + size_t n = P.num_players(); + size_t me = P.my_num(); + for (size_t i = 0; i < this->lengths.size(); i++) + for (size_t j = 0; j < n; j++) + if (j != me) + { + T tmp; + tmp.unpack(oss[j], this->lengths[i]); + this->values[i] += tmp; + } } template diff --git a/Protocols/SemiPrep.h b/Protocols/SemiPrep.h index 3580a73b..9646e945 100644 --- a/Protocols/SemiPrep.h +++ b/Protocols/SemiPrep.h @@ -8,18 +8,26 @@ #include "MascotPrep.h" +template class HemiPrep; + /** * Semi-honest triple generation based on oblivious transfer */ template class SemiPrep : public virtual OTPrep, public virtual SemiHonestRingPrep { + friend class HemiPrep; + public: SemiPrep(SubProcessor* proc, DataPositions& usage); void buffer_triples(); - void buffer_bits(); + void buffer_dabits(ThreadQueues* queues); + + void get_one_no_count(Dtype dtype, T& a); + + bool bits_from_dabits(); }; #endif /* PROTOCOLS_SEMIPREP_H_ */ diff --git a/Protocols/SemiPrep.hpp b/Protocols/SemiPrep.hpp index bc61787d..f1ec6efd 100644 --- a/Protocols/SemiPrep.hpp +++ b/Protocols/SemiPrep.hpp @@ -6,6 +6,8 @@ #include "SemiPrep.h" #include "ReplicatedPrep.hpp" +#include "MascotPrep.hpp" +#include "OT/NPartyTripleGenerator.hpp" template SemiPrep::SemiPrep(SubProcessor* proc, DataPositions& usage) : @@ -31,16 +33,37 @@ void SemiPrep::buffer_triples() } template -void SemiPrep::buffer_bits() +bool SemiPrep::bits_from_dabits() { assert(this->proc); - if (this->proc->P.num_players() == 2 and not T::clear::characteristic_two) + return this->proc->P.num_players() == 2 and not T::clear::characteristic_two; +} + +template +void SemiPrep::buffer_dabits(ThreadQueues* queues) +{ + if (bits_from_dabits()) { assert(this->triple_generator); this->triple_generator->generatePlainBits(); for (auto& x : this->triple_generator->plainBits) - this->bits.push_back(x); + this->dabits.push_back({x.first, x.second}); } else - SemiHonestRingPrep::buffer_bits(); + SemiHonestRingPrep::buffer_dabits(queues); +} + +template +void SemiPrep::get_one_no_count(Dtype dtype, T& a) +{ + if (bits_from_dabits()) + { + if (dtype != DATA_BIT) + throw not_implemented(); + + typename T::bit_type b; + this->get_dabit_no_count(a, b); + } + else + SemiHonestRingPrep::get_one_no_count(dtype, a); } diff --git a/Protocols/SemiPrep2k.h b/Protocols/SemiPrep2k.h index 50311c59..49ccca47 100644 --- a/Protocols/SemiPrep2k.h +++ b/Protocols/SemiPrep2k.h @@ -49,6 +49,12 @@ public: void get_dabit_no_count(T& a, typename T::bit_type& b) { + if (this->bits_from_dabits()) + { + SemiPrep::get_dabit_no_count(a, b); + return; + } + this->get_one_no_count(DATA_BIT, a); b = a & 1; } diff --git a/Protocols/ReplicatedPrep2k.h b/Protocols/SemiRep3Prep.h similarity index 51% rename from Protocols/ReplicatedPrep2k.h rename to Protocols/SemiRep3Prep.h index da35865e..5d68f03e 100644 --- a/Protocols/ReplicatedPrep2k.h +++ b/Protocols/SemiRep3Prep.h @@ -3,8 +3,8 @@ * */ -#ifndef PROTOCOLS_REPLICATEDPREP2K_H_ -#define PROTOCOLS_REPLICATEDPREP2K_H_ +#ifndef PROTOCOLS_SEMIREP3PREP_H_ +#define PROTOCOLS_SEMIREP3PREP_H_ #include "ReplicatedPrep.h" @@ -12,11 +12,13 @@ * Preprocessing for three-party replicated secret sharing modulo a power of two */ template -class ReplicatedPrep2k : public virtual SemiHonestRingPrep, +class SemiRep3Prep : public virtual SemiHonestRingPrep, public virtual ReplicatedRingPrep { + void buffer_dabits(ThreadQueues*); + public: - ReplicatedPrep2k(SubProcessor* proc, DataPositions& usage) : + SemiRep3Prep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), RingPrep(proc, usage), SemiHonestRingPrep(proc, usage), ReplicatedRingPrep(proc, usage) @@ -25,11 +27,14 @@ public: void buffer_bits() { this->buffer_bits_without_check(); } - void get_dabit_no_count(T& a, typename T::bit_type& b) + void get_one_no_count(Dtype dtype, T& a) { - this->get_one_no_count(DATA_BIT, a); - b = a & 1; + if (dtype != DATA_BIT) + throw not_implemented(); + + typename T::bit_type b; + this->get_dabit_no_count(a, b); } }; -#endif /* PROTOCOLS_REPLICATEDPREP2K_H_ */ +#endif /* PROTOCOLS_SEMIREP3PREP_H_ */ diff --git a/Protocols/Shamir.h b/Protocols/Shamir.h index 402173e9..db056ae4 100644 --- a/Protocols/Shamir.h +++ b/Protocols/Shamir.h @@ -49,7 +49,8 @@ public: Player& P; static U get_rec_factor(int i, int n); - static U get_rec_factor(int i, int n_total, int start, int threshold); + static U get_rec_factor(int i, int n_total, int start, int threshold, + int target = -1); Shamir(Player& P, int threshold = 0); ~Shamir(); diff --git a/Protocols/Shamir.hpp b/Protocols/Shamir.hpp index 8bfdf70e..89fa6853 100644 --- a/Protocols/Shamir.hpp +++ b/Protocols/Shamir.hpp @@ -20,14 +20,24 @@ typename T::open_type::Scalar Shamir::get_rec_factor(int i, int n) template typename T::open_type::Scalar Shamir::get_rec_factor(int i, int n_total, - int start, int n_points) + int start, int n_points, int target) { U res = 1; for (int j = 0; j < n_points; j++) { - int other = positive_modulo(start + j, n_total); + int other; + if (n_total > 0) + other = positive_modulo(start + j, n_total); + else + other = start + j; if (i != other) - res *= U(other + 1) / (U(other + 1) - U(i + 1)); + { + res *= (U(other + 1) - U(target + 1)) / (U(other + 1) - U(i + 1)); +#ifdef DEBUG_SHAMIR + cout << "res=" << res << " other+1=" << (other + 1) << " target=" + << target << " i+1=" << (i + 1) << endl; +#endif + } } return res; } @@ -43,6 +53,7 @@ Shamir::Shamir(Player& P, int t) : else threshold = ShamirMachine::s().threshold; n_mul_players = 2 * threshold + 1; + resharing = new ShamirInput(0, P); } template @@ -69,11 +80,6 @@ int Shamir::get_n_relevant_players() template void Shamir::reset() { - if (resharing == 0) - { - resharing = new ShamirInput(0, P); - } - for (int i = 0; i < P.num_players(); i++) resharing->reset(i); diff --git a/Protocols/ShamirInput.h b/Protocols/ShamirInput.h index 91e09309..eaa72f2d 100644 --- a/Protocols/ShamirInput.h +++ b/Protocols/ShamirInput.h @@ -8,7 +8,7 @@ #include "Processor/Input.h" #include "Shamir.h" -#include "ReplicatedInput.h" +#include "SemiInput.h" #include "Machines/ShamirMachine.h" /** @@ -16,7 +16,7 @@ * to every other player */ template -class IndividualInput : public PrepLessInput +class IndividualInput : public PairwiseKeyInput { protected: Player& P; @@ -25,7 +25,7 @@ protected: public: IndividualInput(SubProcessor* proc, Player& P) : - PrepLessInput(proc), P(P), senders(P.num_players()) + PairwiseKeyInput(proc, P), P(P), senders(P.num_players()) { this->reset_all(P); } @@ -53,14 +53,14 @@ class ShamirInput : public IndividualInput { friend class Shamir; - vector> vandermonde; - - SeededPRNG secure_prng; + vector> reconstruction; vector randomness; int threshold; + void init(); + public: static vector> get_vandermonde(size_t t, size_t n); @@ -79,6 +79,7 @@ public: else threshold = ShamirMachine::s().threshold; + init(); } ShamirInput(ShamirMC&, Preprocessing&, Player& P) : @@ -87,6 +88,7 @@ public: } void add_mine(const typename T::open_type& input, int n_bits = -1); + void finalize_other(int player, T& target, octetStream& o, int n_bits = -1); }; #endif /* PROTOCOLS_SHAMIRINPUT_H_ */ diff --git a/Protocols/ShamirInput.hpp b/Protocols/ShamirInput.hpp index 6d9992ad..41c88012 100644 --- a/Protocols/ShamirInput.hpp +++ b/Protocols/ShamirInput.hpp @@ -10,6 +10,7 @@ #include "Machines/ShamirMachine.h" #include "Protocols/ReplicatedInput.hpp" +#include "Protocols/SemiInput.hpp" template void IndividualInput::reset(int player) @@ -17,7 +18,6 @@ void IndividualInput::reset(int player) if (player == P.my_num()) { this->shares.clear(); - this->i_share = 0; os.reset(P); } @@ -45,6 +45,20 @@ vector> ShamirInput::get_vandermonde( return vandermonde; } +template +void ShamirInput::init() +{ + reconstruction.resize(this->P.num_players() - threshold); + for (size_t i = 0; i < reconstruction.size(); i++) + { + auto& x = reconstruction[i]; + for (int j = 0; j <= threshold; j++) + x.push_back( + Shamir::get_rec_factor(j - 1, 0, -1, threshold + 1, + i + threshold)); + } +} + template void ShamirInput::add_mine(const typename T::open_type& input, int n_bits) { @@ -53,18 +67,20 @@ void ShamirInput::add_mine(const typename T::open_type& input, int n_bits) int n = P.num_players(); int t = threshold; - if (vandermonde.empty()) - vandermonde = get_vandermonde(t, n); - randomness.resize(t); - for (auto& x : randomness) - x.randomize(secure_prng); - - for (int i = 0; i < n; i++) + for (int i = 0; i < t; i++) { - typename T::open_type x = input; + randomness[i].randomize(this->send_prngs[i]); + if (i == P.my_num()) + this->shares.push_back(randomness[i]); + } + + for (int i = threshold; i < n; i++) + { + typename T::open_type x = input + * reconstruction.at(i - threshold).at(0); for (int j = 0; j < t; j++) - x += randomness[j] * vandermonde[i][j]; + x += randomness[j] * reconstruction.at(i - threshold).at(j + 1); if (i == P.my_num()) this->shares.push_back(x); else @@ -74,6 +90,16 @@ void ShamirInput::add_mine(const typename T::open_type& input, int n_bits) this->senders[P.my_num()] = true; } +template +void ShamirInput::finalize_other(int player, T& target, + octetStream& o, int n_bits) +{ + if (this->P.my_num() < threshold) + target.randomize(this->recv_prngs.at(player)); + else + IndividualInput::finalize_other(player, target, o, n_bits); +} + template void IndividualInput::add_sender(int player) { diff --git a/Protocols/ShamirMC.h b/Protocols/ShamirMC.h index c6a88f0a..bd0cc317 100644 --- a/Protocols/ShamirMC.h +++ b/Protocols/ShamirMC.h @@ -67,7 +67,7 @@ public: void POpen_End(vector& values,const vector& S,const Player& P); virtual void init_open(const Player& P, int n = 0); - virtual void prepare_open(const T& secret); + virtual void prepare_open(const T& secret, int = -1); virtual void exchange(const Player& P); virtual typename T::open_type finalize_raw(); diff --git a/Protocols/ShamirMC.hpp b/Protocols/ShamirMC.hpp index 7238aa5e..585a6896 100644 --- a/Protocols/ShamirMC.hpp +++ b/Protocols/ShamirMC.hpp @@ -72,7 +72,7 @@ void ShamirMC::prepare(const vector& S, const Player& P) } template -void ShamirMC::prepare_open(const T& share) +void ShamirMC::prepare_open(const T& share, int) { share.pack(os->mine); } diff --git a/Protocols/ShamirShare.h b/Protocols/ShamirShare.h index bf40cb28..318f050d 100644 --- a/Protocols/ShamirShare.h +++ b/Protocols/ShamirShare.h @@ -9,6 +9,7 @@ #include "Protocols/Shamir.h" #include "Protocols/ShamirInput.h" #include "Machines/ShamirMachine.h" +#include "GC/NoShare.h" #include "ShareInterface.h" template class ReplicatedPrep; diff --git a/Protocols/ShuffleSacrifice.hpp b/Protocols/ShuffleSacrifice.hpp index 4d03dd67..150cdb61 100644 --- a/Protocols/ShuffleSacrifice.hpp +++ b/Protocols/ShuffleSacrifice.hpp @@ -141,7 +141,8 @@ void DabitShuffleSacrifice::dabit_sacrifice(vector >& output, int start = queues->distribute(job, products.size()); protocol.multiply(products, multiplicands, start, products.size(), proc); - queues->wrap_up(job); + if (start) + queues->wrap_up(job); } else protocol.multiply(products, multiplicands, 0, products.size(), proc); @@ -311,7 +312,8 @@ void EdabitShuffleSacrifice::edabit_sacrifice(vector >& output, &supplies); edabit_sacrifice_buckets(to_check, n_bits, strict, player, proc, start, N, personal_prep); - queues->wrap_up(job); + if (start) + queues->wrap_up(job); } else edabit_sacrifice_buckets(to_check, n_bits, strict, player, proc, 0, N, diff --git a/Protocols/SpdzWiseMC.h b/Protocols/SpdzWiseMC.h index 9ad76198..f48adfc9 100644 --- a/Protocols/SpdzWiseMC.h +++ b/Protocols/SpdzWiseMC.h @@ -36,7 +36,7 @@ public: { inner_MC.init_open(P, n); } - void prepare_open(const T& secret) + void prepare_open(const T& secret, int = -1) { inner_MC.prepare_open(secret.get_share()); } diff --git a/README.md b/README.md index 5f3fa2a4..c14f41ce 100644 --- a/README.md +++ b/README.md @@ -270,8 +270,8 @@ compute the preprocessing time for a particular computation. #### Requirements - - GCC 5 or later (tested with up to 11) or LLVM/clang 5 or later - (tested with up to 12). We recommend clang because it performs + - GCC 5 or later (tested with up to 11) or LLVM/clang 6 or later + (tested with up to 14). We recommend clang because it performs better. Note that GCC 5/6 and clang 9 don't support libOTe, so you need to deactivate its use for these compilers (see the next section). @@ -694,7 +694,7 @@ Compile the virtual machine: and the high-level program: -`./compile.py -B ` +`./compile.py -G -B ` Then run as follows: @@ -874,7 +874,7 @@ three parties, change the definition of `MAX_N_PARTIES` in In order to compile a high-level program, use `./compile.py -B`: -`./compile.py -B 32 tutorial` +`./compile.py -G -B 32 tutorial` Finally, run the two parties as follows: @@ -1004,7 +1004,7 @@ you entirely delete the definition, it will be able to run for any number of parties albeit slower. Compile the virtual machine: -`make -j 8 libote` + `make -j 8 bmr` After compiling the mpc file: @@ -1020,7 +1020,7 @@ You can benchmark the ORAM implementation as follows: 1) Edit `Program/Source/gc_oram.mpc` to change size and to choose Circuit ORAM or linear scan without ORAM. -2) Run `./compile.py -D gc_oram`. The `-D` argument instructs the +2) Run `./compile.py -G -D gc_oram`. The `-D` argument instructs the compiler to remove dead code. This is useful for more complex programs such as this one. 3) Run `gc_oram` in the virtual machines as explained above. diff --git a/Scripts/build.sh b/Scripts/build.sh index c541152a..1c3f7286 100755 --- a/Scripts/build.sh +++ b/Scripts/build.sh @@ -6,7 +6,8 @@ function build echo GDEBUG = >> CONFIG.mine root=`pwd` cd deps/libOTe - python3 build.py --install=$root/local -- -DENABLE_SOFTSPOKEN_OT=ON -DBUILD_SHARED_LIBS=0 $3 + rm -R out + python3 build.py --install=$root/local -- -DENABLE_SOFTSPOKEN_OT=ON -DBUILD_SHARED_LIBS=0 -DCMAKE_INSTALL_LIBDIR=lib $3 cd $root make clean rm -R static diff --git a/Scripts/compile-for-emulation.sh b/Scripts/compile-for-emulation.sh new file mode 100755 index 00000000..b808ef0c --- /dev/null +++ b/Scripts/compile-for-emulation.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +pypy3 ./compile.py -CDR 64 -K '' $* diff --git a/Scripts/emulate-append.sh b/Scripts/emulate-append.sh new file mode 100755 index 00000000..55475210 --- /dev/null +++ b/Scripts/emulate-append.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +. $(dirname $0)/run-common.sh +prog=${1%.sch} +prog=${prog##*/} +shift +$prefix ./emulate.x $prog $* 2>&1 | tee -a logs/emulate-append-$prog diff --git a/Scripts/run-common.sh b/Scripts/run-common.sh index c6835069..fe3c54e7 100644 --- a/Scripts/run-common.sh +++ b/Scripts/run-common.sh @@ -41,12 +41,21 @@ run_player() { if test "$prog"; then log_prefix=$prog- fi + if test "$BENCH"; then + log_prefix=$log_prefix$bin-$(echo "$*" | sed 's/ /-/g')-N$players- + fi set -o pipefail for i in $(seq 0 $[players-1]); do >&2 echo Running $prefix $SPDZROOT/$bin $i $params log=$SPDZROOT/logs/$log_prefix$i $prefix $SPDZROOT/$bin $i $params 2>&1 | - { if test $i = 0; then tee $log; else cat > $log; fi; } & + { + if test "$BENCH"; then + if test $i = 0; then tee -a $log; else cat >> $log; fi; + else + if test $i = 0; then tee $log; else cat > $log; fi; + fi + } & codes[$i]=$! done for i in $(seq 0 $[players-1]); do diff --git a/Scripts/test_tutorial.sh b/Scripts/test_tutorial.sh index e58edef8..60157c93 100755 --- a/Scripts/test_tutorial.sh +++ b/Scripts/test_tutorial.sh @@ -41,6 +41,7 @@ function test_vm run_opts="$run_opts -B 5" export PORT=$((RANDOM%10000+10000)) +export BENCH= for dabit in ${dabit:-0 1 2}; do if [[ $dabit = 1 ]]; then diff --git a/Scripts/tldr.sh b/Scripts/tldr.sh index ce906a3a..bd5b396a 100755 --- a/Scripts/tldr.sh +++ b/Scripts/tldr.sh @@ -24,6 +24,11 @@ if test "$flags"; then if $flags | grep -q avx2; then cpu=avx2 else + if test `uname -m` != x86_64; then + echo Binaries are not available for `uname -m` + echo Use the source distribution: https://github.com/data61/MP-SPDZ/#tldr-source-distribution + exit 1 + fi cpu=amd64 fi diff --git a/Tools/ExecutionStats.cpp b/Tools/ExecutionStats.cpp index bbc36dca..daa2309a 100644 --- a/Tools/ExecutionStats.cpp +++ b/Tools/ExecutionStats.cpp @@ -26,6 +26,7 @@ void ExecutionStats::print() { sorted_stats.insert({x.second, x.first}); } + size_t total = 0; for (auto& x : sorted_stats) { auto opcode = x.second; @@ -35,7 +36,7 @@ void ExecutionStats::print() switch (opcode) { #define X(NAME, PRE, CODE) case NAME: cerr << #NAME; n_fill -= strlen(#NAME); break; - ARITHMETIC_INSTRUCTIONS + ALL_INSTRUCTIONS #undef X #define X(NAME, CODE) case NAME: cerr << #NAME; n_fill -= strlen(#NAME); break; COMBI_INSTRUCTIONS @@ -48,5 +49,7 @@ void ExecutionStats::print() for (int i = 0; i < n_fill; i++) cerr << " "; cerr << dec << calls << endl; + total += calls; } + cerr << "\tTotal:" << string(9, ' ') << total << endl; } diff --git a/Tools/names.cpp b/Tools/names.cpp index 062beb02..220263c8 100644 --- a/Tools/names.cpp +++ b/Tools/names.cpp @@ -2,4 +2,4 @@ const char* DataPositions::dtype_names[N_DTYPE + 1] = { "Triples", "Squares", "Bits", "Inverses", - "daBits", "None" }; + "daBits", "Mixed triples", "None" }; diff --git a/Utils/Check-Offline.cpp b/Utils/Check-Offline.cpp index 203328f2..aec3c595 100644 --- a/Utils/Check-Offline.cpp +++ b/Utils/Check-Offline.cpp @@ -15,6 +15,7 @@ #include "GC/TinierSecret.h" #include "GC/TinyMC.h" #include "GC/SemiSecret.h" +#include "GC/RepPrep.h" #include "Math/Setup.h" #include "Processor/Data_Files.h" diff --git a/Utils/binary-example.cpp b/Utils/binary-example.cpp index d13f79d3..a07d7427 100644 --- a/Utils/binary-example.cpp +++ b/Utils/binary-example.cpp @@ -29,6 +29,7 @@ #include "Protocols/fake-stuff.hpp" #include "Machines/ShamirMachine.hpp" #include "Machines/Rep4.hpp" +#include "Machines/Rep.hpp" template void run(int argc, char** argv); diff --git a/Utils/l2h-example.cpp b/Utils/l2h-example.cpp index 475bcb8a..91ce7f0b 100644 --- a/Utils/l2h-example.cpp +++ b/Utils/l2h-example.cpp @@ -7,6 +7,7 @@ #include "Math/gfp.hpp" #include "Machines/SPDZ.hpp" +#include "Protocols/MascotPrep.hpp" int main(int argc, char** argv) { diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 5dbb1fe7..f026f300 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -22,7 +22,7 @@ steps: - script: echo MY_CFLAGS += -DFEWER_RINGS >> CONFIG.mine - script: - echo MY_CFLAGS += -DCOMP_SEC=64 >> CONFIG.mine + echo MY_CFLAGS += -DCOMP_SEC=10 >> CONFIG.mine - script: echo CXX = clang++ >> CONFIG.mine - script: diff --git a/doc/Compiler.rst b/doc/Compiler.rst index db5c1e9c..34343c51 100644 --- a/doc/Compiler.rst +++ b/doc/Compiler.rst @@ -77,6 +77,13 @@ Compiler.ml module :show-inheritance: .. autofunction:: approx_sigmoid +Compiler.decision_tree module +----------------------------- + +.. automodule:: Compiler.decision_tree + :members: + :no-undoc-members: + Compiler.circuit module ----------------------- @@ -112,3 +119,13 @@ Compiler.oram module TrivialORAMIndexStructure, ValueTuple, demux, get_log_value_size, get_parallel, get_value_size, gf2nBlock, intBlock + + +Compiler.sqrt_oram module +------------------------- + +.. automodule:: Compiler.sqrt_oram + :members: + :no-undoc-members: + :exclude-members: LinearPositionMap, PositionMap, RecursivePositionMap, + refresh, shuffle_the_shuffle diff --git a/doc/Doxyfile b/doc/Doxyfile index 8420157a..f82046eb 100644 --- a/doc/Doxyfile +++ b/doc/Doxyfile @@ -933,7 +933,7 @@ EXCLUDE_SYMLINKS = NO # Note that the wildcards are matched against the file with absolute path, so to # exclude all test directories for example use the pattern */test/* -EXCLUDE_PATTERNS = +EXCLUDE_PATTERNS = *.d # The EXCLUDE_SYMBOLS tag can be used to specify one or more symbol names # (namespaces, classes, functions, etc.) that should be excluded from the diff --git a/doc/compilation.rst b/doc/compilation.rst index f01581ff..01753edd 100644 --- a/doc/compilation.rst +++ b/doc/compilation.rst @@ -54,6 +54,11 @@ The following options influence the computation domain: Compile for binary computation using *integer length* as default. +.. cmdoption:: -G + --garbled-circuit + + Compile for garbled circuits (does not replace :option:`-B`). + For arithmetic computation (:option:`-F`, :option:`-P`, and :option:`-R`) you can set the bit length during execution using ``program.set_bit_length(length)``. For diff --git a/doc/index.rst b/doc/index.rst index 0abbac7c..648546c8 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -6,6 +6,7 @@ If you're new to MP-SPDZ, consider the following: 1. `Quickstart tutorial `_ 2. `Implemented protocols `_ 3. :ref:`troubleshooting` +4. :ref:`io` lists all the ways of getting data in and out. .. toctree:: :maxdepth: 4 diff --git a/doc/io.rst b/doc/io.rst index a4d00cee..50128d94 100644 --- a/doc/io.rst +++ b/doc/io.rst @@ -1,3 +1,5 @@ +.. _io: + Input/Output ------------ diff --git a/doc/machine-learning.rst b/doc/machine-learning.rst index 084bc1c8..54764e37 100644 --- a/doc/machine-learning.rst +++ b/doc/machine-learning.rst @@ -5,6 +5,9 @@ MP-SPDZ supports a limited subset of the Keras interface for machine learning. This includes the SGD and Adam optimizers and the following layer types: dense, 2D convolution, 2D max-pooling, and dropout. +The machine learning code only works in with arithmetic machines, that +is, you cannot compile it with ``-B``. + In the following we will walk through the example code in ``keras_mnist_dense.mpc``, which trains a dense neural network for MNIST. It starts by defining tensors to hold data:: diff --git a/doc/non-linear.rst b/doc/non-linear.rst index 969e6d6c..4687cc63 100644 --- a/doc/non-linear.rst +++ b/doc/non-linear.rst @@ -1,3 +1,5 @@ +.. _nonlinear: + Non-linear Computation ---------------------- @@ -8,14 +10,14 @@ throughout MP-SPDZ: Unknown prime modulus This approach goes back to `Catrina and de Hoogh - `_. It crucially relies on + `_. It crucially relies on the use of secret random bits in the arithmetic domain. Enough such bits allow to mask a secret value so that it is secure to reveal the masked value. This can then be split in bits as it is public. The public bits and the secret mask bits are then used to compute a number of non-linear functions. The same idea has been used to implement `fixed-point - `_ and + `_ and `floating-point `_ computation. We call this method "unknown prime modulus" because it only mandates a minimum modulus size for a given cleartext range, which diff --git a/doc/troubleshooting.rst b/doc/troubleshooting.rst index f76d2d5b..6a32bd37 100644 --- a/doc/troubleshooting.rst +++ b/doc/troubleshooting.rst @@ -140,6 +140,16 @@ This indicates an error in the internal accounting of preprocessing. Please file a bug report. +Required prime bit length is not the same as ``-F`` parameter during compilation +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This is related to statistical masking that requires the prime to be a +fair bit larger than the actual "payload". The technique goes to back +to `Catrina and de Hoogh +`_. +See also the paragraph on unknown prime moduli in :ref:`nonlinear`. + + Windows/VirtualBox performance ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~