From c597554af9cba0c891c8e4cdbd01cf290acbb937 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Fri, 6 Aug 2021 18:24:58 +1000 Subject: [PATCH] ATLAS. --- BMR/Register.h | 2 + BMR/Register.hpp | 2 +- CHANGELOG.md | 15 + CONFIG | 1 + Compiler/GC/types.py | 31 +- Compiler/allocator.py | 5 +- Compiler/comparison.py | 41 +- Compiler/compilerLib.py | 4 + Compiler/floatingpoint.py | 9 +- Compiler/instructions.py | 64 ++ Compiler/instructions_base.py | 16 +- Compiler/library.py | 53 +- Compiler/ml.py | 341 +++++++- Compiler/mpc_math.py | 2 +- Compiler/non_linear.py | 23 + Compiler/program.py | 29 +- Compiler/types.py | 769 ++++++++++++++---- FHE/NTL-Subs.cpp | 32 +- FHE/NoiseBounds.cpp | 2 + FHE/PPData.cpp | 6 +- FHE/Plaintext.cpp | 4 +- FHE/Random_Coins.h | 2 +- FHE/Ring_Element.h | 2 +- GC/AtlasSecret.cpp | 33 + GC/AtlasSecret.h | 47 ++ GC/AtlasShare.h | 72 ++ GC/CcdPrep.h | 6 +- GC/FakeSecret.h | 2 + GC/NoShare.h | 3 + GC/Secret.h | 2 + GC/ShareSecret.h | 5 + GC/ShareSecret.hpp | 6 +- GC/ShareThread.hpp | 18 +- GC/TinySecret.h | 20 +- GC/TinyShare.h | 9 +- GC/VectorProtocol.h | 1 + GC/VectorProtocol.hpp | 8 +- GC/instructions.h | 3 + Machines/TripleMachine.cpp | 18 +- Machines/atlas-party.cpp | 16 + Machines/ccd-party.cpp | 5 +- Makefile | 15 +- Math/Bit.cpp | 6 +- Math/Bit.h | 5 +- Math/BitVec.h | 3 + Math/FixedVec.h | 25 +- Math/Integer.h | 4 +- Math/Square.h | 14 +- Math/Square.hpp | 14 +- Math/ValueInterface.h | 2 +- Math/Z2k.h | 3 + Math/Z2k.hpp | 14 + Math/Zp_Data.cpp | 5 + Math/Zp_Data.h | 21 +- Math/bigint.cpp | 5 - Math/bigint.h | 12 +- Math/fixint.h | 6 +- Math/gf2n.cpp | 358 +++++--- Math/gf2n.h | 212 +++-- Math/gf2nlong.cpp | 249 +----- Math/gf2nlong.h | 205 ++--- Math/gfp.h | 2 +- Math/gfpvar.cpp | 243 ++++-- Math/gfpvar.h | 116 ++- Math/modp.h | 15 +- Math/modp.hpp | 12 + Math/square128.cpp | 2 +- Networking/CryptoPlayer.cpp | 1 + Networking/Player.cpp | 1 + OT/BitMatrix.h | 11 +- OT/BitMatrix.hpp | 6 +- OT/MamaRectangle.h | 9 +- OT/NPartyTripleGenerator.hpp | 47 ++ OT/OTCorrelator.hpp | 16 +- OT/OTMultiplier.hpp | 10 +- OT/Rectangle.h | 20 +- OT/Rectangle.hpp | 4 - OT/TripleMachine.h | 5 +- Processor/BaseMachine.cpp | 1 + Processor/BaseMachine.h | 1 + Processor/Instruction.h | 4 + Processor/Instruction.hpp | 42 + Processor/Machine.h | 3 + Processor/Machine.hpp | 34 +- Processor/Memory.h | 25 +- Processor/Online-Thread.hpp | 1 + Processor/OnlineMachine.hpp | 2 + Processor/OnlineOptions.cpp | 6 +- Processor/Processor.h | 3 + Processor/Processor.hpp | 27 +- Programs/Source/benchmark_net.mpc | 15 +- Programs/Source/idash_train.mpc | 10 + Programs/Source/keras_mnist_dense.mpc | 48 ++ Programs/Source/keras_mnist_dense_predict.mpc | 39 + Programs/Source/keras_mnist_lenet.mpc | 44 + Programs/Source/mnist_49.mpc | 5 +- Programs/Source/mnist_A.mpc | 4 + Programs/Source/mnist_B.mpc | 73 -- Programs/Source/mnist_D.mpc | 60 -- Programs/Source/mnist_full_A.mpc | 3 + Programs/Source/mnist_full_B.mpc | 3 + Programs/Source/mnist_full_C.mpc | 6 +- Programs/Source/mnist_full_D.mpc | 3 + Programs/Source/mnist_logreg.mpc | 3 + Protocols/Atlas.h | 71 ++ Protocols/Atlas.hpp | 132 +++ Protocols/AtlasPrep.h | 39 + Protocols/AtlasShare.h | 46 ++ Protocols/CowGearOptions.cpp | 15 +- Protocols/FakeProtocol.h | 31 +- Protocols/HemiPrep.h | 1 + Protocols/HemiPrep.hpp | 8 + Protocols/HighGearKeyGen.cpp | 21 +- Protocols/HighGearKeyGen.h | 4 +- Protocols/HighGearKeyGen.hpp | 25 + Protocols/LowGearKeyGen.cpp | 13 +- Protocols/LowGearKeyGen.h | 9 +- Protocols/LowGearKeyGen.hpp | 8 +- Protocols/MaliciousRepPrep.hpp | 9 + Protocols/MascotPrep.h | 5 + Protocols/MascotPrep.hpp | 14 + Protocols/Rep3Share2k.h | 3 + Protocols/Rep4Share2k.h | 3 + Protocols/Replicated.h | 1 + Protocols/Replicated.hpp | 6 + Protocols/ReplicatedPrep.h | 5 + Protocols/ReplicatedPrep.hpp | 66 +- Protocols/Semi2kShare.h | 2 + Protocols/Shamir.h | 11 +- Protocols/Shamir.hpp | 103 ++- Protocols/ShamirInput.h | 20 +- Protocols/ShamirInput.hpp | 14 +- Protocols/ShamirMC.h | 20 +- Protocols/ShamirMC.hpp | 66 +- Protocols/ShamirShare.h | 6 +- Protocols/ShareInterface.h | 3 + Protocols/ShareVector.hpp | 7 +- Protocols/SpdzWiseRingShare.h | 2 + Protocols/fake-stuff.hpp | 16 +- README.md | 36 +- Scripts/atlas.sh | 14 + Scripts/build.sh | 3 + Scripts/run-common.sh | 2 +- Scripts/test_tutorial.sh | 6 +- Tools/Buffer.h | 7 + Tools/Bundle.h | 6 + Tools/MMO.h | 4 + Tools/MMO.hpp | 31 +- Tools/random.cpp | 22 + Tools/random.h | 13 + Utils/Fake-Offline.cpp | 15 +- Utils/hyper.cpp | 22 + Yao/YaoWire.h | 2 + doc/Compiler.rst | 11 +- doc/index.rst | 3 +- doc/io.rst | 22 +- doc/machine-learning.rst | 79 ++ doc/networking.rst | 12 +- doc/troubleshooting.rst | 38 +- 159 files changed, 3667 insertions(+), 1378 deletions(-) create mode 100644 GC/AtlasSecret.cpp create mode 100644 GC/AtlasSecret.h create mode 100644 GC/AtlasShare.h create mode 100644 Machines/atlas-party.cpp create mode 100644 Programs/Source/keras_mnist_dense.mpc create mode 100644 Programs/Source/keras_mnist_dense_predict.mpc create mode 100644 Programs/Source/keras_mnist_lenet.mpc delete mode 100644 Programs/Source/mnist_B.mpc delete mode 100644 Programs/Source/mnist_D.mpc create mode 100644 Protocols/Atlas.h create mode 100644 Protocols/Atlas.hpp create mode 100644 Protocols/AtlasPrep.h create mode 100644 Protocols/AtlasShare.h create mode 100755 Scripts/atlas.sh create mode 100644 Utils/hyper.cpp create mode 100644 doc/machine-learning.rst diff --git a/BMR/Register.h b/BMR/Register.h index 50a4cb67..886155d7 100644 --- a/BMR/Register.h +++ b/BMR/Register.h @@ -272,6 +272,8 @@ public: // only true for evaluation static const bool actual_inputs = false; + static int threshold(int) { throw not_implemented(); } + static Register new_reg(); static Register tmp_reg() { return new_reg(); } static Register and_reg() { return new_reg(); } diff --git a/BMR/Register.hpp b/BMR/Register.hpp index e4f743cb..bd214a85 100644 --- a/BMR/Register.hpp +++ b/BMR/Register.hpp @@ -107,7 +107,7 @@ void EvalRegister::store(GC::Memory& mem, //cout << "ext:" << ext << "/" << (int)reg.get_external() << " " << endl; tmp = spdz_wire.mask + U::constant(ext, (int)party.get_id() - 1, party.get_mac_key()); S.push_back(tmp); - tmp *= gf2n_long(1) << i; + tmp <<= i; dest += tmp; const Key& key = reg.external_key(party.get_id()); Key& expected_key = spdz_wire.my_keys[(int)reg.get_external()]; diff --git a/CHANGELOG.md b/CHANGELOG.md index a71b926f..7a97b0f6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,20 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. +## 0.2.6 (Aug 6, 2021) + +- [ATLAS](https://eprint.iacr.org/2021/833) +- Keras-like interface +- Iterative linear solution approximation +- Binary output +- HighGear/LowGear key generation for wider range of parameters by default +- Dabit generation for smaller primes and malicious security +- More consistent type model +- Improved local computation +- Optimized GF(2^8) for CCD +- NTL only needed for computation with GF(2^40) +- Virtual machines suggest compile-time optimizations +- Improved documentation of types + ## 0.2.5 (Jul 2, 2021) - Training of convolutional neural networks diff --git a/CONFIG b/CONFIG index 54370858..5f12d2c3 100644 --- a/CONFIG +++ b/CONFIG @@ -69,6 +69,7 @@ LDLIBS = -lmpirxx -lmpir -lsodium $(MY_LDLIBS) LDLIBS += -lboost_system -lssl -lcrypto ifeq ($(USE_NTL),1) +CFLAGS += -DUSE_NTL LDLIBS := -lntl $(LDLIBS) endif diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 7ff1e8b5..77b162bd 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -118,6 +118,8 @@ class bits(Tape.Register, _structure, _bit): self.store_inst[isinstance(address, int)](self, address) @classmethod def new(cls, value=None, n=None): + if util.is_constant(value): + n = value.bit_length() return cls.get_type(n)(value) def __init__(self, value=None, n=None, size=None): assert n == self.n or n is None @@ -152,7 +154,7 @@ class bits(Tape.Register, _structure, _bit): and self.n == other.n: for i in range(math.ceil(self.n / self.unit)): self.mov(self[i], other[i]) - elif isinstance(other, sint): + elif isinstance(other, sint) and isinstance(self, sbits): self.mov(self, sbitvec(other, self.n).elements()[0]) else: try: @@ -214,7 +216,13 @@ class cbits(bits): cls.conv_cint_vec(cint(other, size=other.size), res) types = {} def load_int(self, value): - self.load_other(regint(value)) + if self.n <= 64: + tmp = regint(value) + elif value == self.long_one(): + tmp = cint(1, size=self.n) + else: + raise CompilerError('loading long integers to cbits not supported') + self.load_other(tmp) def store_in_dynamic_mem(self, address): inst.stmsdci(self, cbits.conv(address)) def clear_op(self, other, c_inst, ci_inst, op): @@ -227,7 +235,7 @@ class cbits(bits): else: if util.is_constant(other): if other >= 2**31 or other < -2**31: - return op(self, cbits(other)) + return op(self, cbits.new(other)) res = cbits.get_type(max(self.n, len(bin(other)) - 2))() ci_inst(res, self, other) return res @@ -269,6 +277,8 @@ class cbits(bits): res = cbits.get_type(self.n+other)() inst.shlcbi(res, self, other) return res + def __invert__(self): + return self ^ self.long_one() def print_reg(self, desc=''): inst.print_regb(self, desc) def print_reg_plain(self): @@ -527,6 +537,14 @@ class sbits(bits): return res @staticmethod def bit_adder(*args, **kwargs): + """ Binary adder in binary circuits. + + :param a: summand (list of 0/1 in compatible type) + :param b: summand (list of 0/1 in compatible type) + :param carry_in: input carry (default 0) + :param get_carry: add final carry to output + :returns: list of 0/1 in relevant type + """ return sbitint.bit_adder(*args, **kwargs) @staticmethod def ripple_carry_adder(*args, **kwargs): @@ -889,7 +907,7 @@ sbits.dynamic_array = DynamicArray cbits.dynamic_array = Array def _complement_two_extend(bits, k): - return bits + [bits[-1]] * (k - len(bits)) + return bits[:k] + [bits[-1]] * (k - len(bits)) class _sbitintbase: def extend(self, n): @@ -1096,7 +1114,6 @@ class sbitintvec(sbitvec, _number, _bitint, _sbitintbase): raise CompilerError('round to nearest not implemented') if not isinstance(other, sbitintvec): other = sbitintvec(other) - assert len(self.v) + len(other.v) <= k a = self.get_type(k).from_vec(_complement_two_extend(self.v, k)) b = self.get_type(k).from_vec(_complement_two_extend(other.v, k)) tmp = a * b @@ -1148,6 +1165,10 @@ class sbitfix(_fix): sub: 0.199997 lt: 0 + Note that the default precision (16 bits after the dot, 31 bits in + total) only allows numbers up to :math:`2^{31-16-1} \\approx + 16000`. You can increase this using :py:func:`set_precision`. + """ float_type = type(None) clear_type = cbitfix diff --git a/Compiler/allocator.py b/Compiler/allocator.py index 674ecff1..45c60e3d 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -167,8 +167,9 @@ class StraightlineAllocator: def finalize(self, options): for reg in self.alloc: - for x in reg.vector: - if x not in self.dealloc and reg not in self.dealloc: + for x in reg.get_all(): + if x not in self.dealloc and reg not in self.dealloc \ + and len(x.duplicates) == 1: print('Warning: read before write at register', x) print('\tregister trace: %s' % format_trace(x.caller, '\t\t')) diff --git a/Compiler/comparison.py b/Compiler/comparison.py index 0fe62101..f4cf89ad 100644 --- a/Compiler/comparison.py +++ b/Compiler/comparison.py @@ -7,7 +7,7 @@ representing the integer bit length, and kappa the statistical security parameter. Most of these routines were implemented before the cint/sint classes, so use -the old-fasioned Register class and assembly instructions instead of operator +the old-fashioned Register class and assembly instructions instead of operator overloading. The PreMulC function has a few variants, depending on whether @@ -61,18 +61,13 @@ def ld2i(c, n): t1 = t2 movc(c, t1) -inverse_of_two = {} - -def divide_by_two(res, x, m=1): - """ Faster clear division by two using a cached value of 2^-1 mod p """ - tmp = program.curr_block.new_reg('c') - inv2m(tmp, m) - mulc(res, x, tmp) - def require_ring_size(k, op): if int(program.options.ring) < k: - raise CompilerError('ring size too small for %s, compile ' - 'with \'-R %d\' or more' % (op, k)) + msg = 'ring size too small for %s, compile ' \ + 'with \'-R %d\' or more' % (op, k) + if k > 64 and k < 128: + msg += ' (maybe \'-R 128\' as it is supported by default)' + raise CompilerError(msg) program.curr_tape.require_bit_length(k) @instructions_base.cisc @@ -122,20 +117,11 @@ def Trunc(d, a, k, m, kappa, signed): m: compile-time integer signed: True/False, describes a """ - t = program.curr_block.new_reg('s') - c = [program.curr_block.new_reg('c') for i in range(3)] - c2m = program.curr_block.new_reg('c') if m == 0: movs(d, a) return - elif program.options.ring: - return TruncRing(d, a, k, m, signed) else: - a_prime = program.non_linear.mod2m(a, k, m, signed) - subs(t, a, a_prime) - ldi(c[1], 1) - divide_by_two(c[2], c[1], m) - mulm(d, t, c[2]) + movs(d, program.non_linear.trunc(a, k, m, kappa, signed)) def TruncRing(d, a, k, m, signed): program.curr_tape.require_bit_length(1) @@ -489,13 +475,12 @@ def BitLTL(res, a, b, kappa): """ k = len(b) a_bits = b[0].bit_decompose_clear(a, k) - s = [[program.curr_block.new_reg('s') for i in range(k)] for j in range(2)] - t = [program.curr_block.new_reg('s') for i in range(1)] - for i in range(len(b)): - s[0][i] = b[0].long_one() - b[i] - CarryOut(t[0], a_bits[::-1], s[0][::-1], b[0].long_one(), kappa) - subsfi(res, t[0], 1) - return a_bits, s[0] + from .types import sint + movs(res, sint.conv(BitLTL_raw(a_bits, b))) + +def BitLTL_raw(a_bits, b): + s = [x.bit_not() for x in b] + return CarryOutRaw(a_bits[::-1], s[::-1], b[0].long_one()).bit_not() def PreMulC_with_inverses_and_vectors(p, a): """ diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index 79f3b2a8..1e2bd351 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -15,6 +15,10 @@ def run(args, options): if options.binary: VARS['sint'] = GC_types.sbitintvec.get_type(int(options.binary)) VARS['sfix'] = GC_types.sbitfixvec + for i in 'cint', 'cfix', 'cgf2n', 'sintbit', 'sgf2n', 'sgf2nint', \ + 'sgf2nuint', 'sgf2nuint32', 'sgf2nfloat', 'sfloat', 'cfloat', \ + 'squant': + del VARS[i] print('Compiling file', prog.infile) diff --git a/Compiler/floatingpoint.py b/Compiler/floatingpoint.py index 16858acb..b6b5ef83 100644 --- a/Compiler/floatingpoint.py +++ b/Compiler/floatingpoint.py @@ -264,8 +264,8 @@ def BitDec(a, k, m, kappa, bits_to_compute=None): return program.Program.prog.non_linear.bit_dec(a, k, m) def BitDecRingRaw(a, k, m): + comparison.require_ring_size(m, 'bit decomposition') n_shift = int(program.Program.prog.options.ring) - m - assert(n_shift >= 0) if program.Program.prog.use_split(): x = a.split_to_two_summands(m) bits = types._bitint.carry_lookahead_adder(x[0], x[1], fewer_inv=False) @@ -504,7 +504,8 @@ def TruncPrRing(a, k, m, signed=True): return comparison.TruncLeakyInRing(a, k, m, signed=signed) else: from .types import sint - if signed: + prog = program.Program.prog + if signed and prog.use_trunc_pr != -1: a += (1 << (k - 1)) if program.Program.prog.use_trunc_pr: res = sint() @@ -530,7 +531,7 @@ def TruncPrRing(a, k, m, signed=True): overflow = msb.bit_xor(masked >> (n_ring - 1)) res = shifted - upper + \ (overflow << (k - m)) - if signed: + if signed and prog.use_trunc_pr != -1: res -= (1 << (k - m - 1)) return res @@ -672,7 +673,7 @@ def BitDecFull(a, maybe_mixed=False): t = bbits[0].bit_decompose_clear(p - c, bit_length) c = longint(c, bit_length) czero = (c==0) - q = bbits[0].long_one() - BITLT(bbits, t, bit_length) + q = bbits[0].long_one() - comparison.BitLTL_raw(bbits, t) fbar = [bbits[0].clear_type.conv(cint(x)) for x in ((1< 64: - bits = 127 - else: - # assume 64-bit machine - bits = 63 - if self.args[2] > bits: - raise CompilerError('Shifting by more than %d bits ' - 'not implemented' % bits) - elif self.args[2] < 0: + if self.args[2] < 0: raise CompilerError('negative shift') ### diff --git a/Compiler/library.py b/Compiler/library.py index 2a186f56..f169c796 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -875,6 +875,11 @@ def for_range(start, stop=None, step=None): a[i] = i global x x += 1 + + Note that you cannot overwrite data structures such as + :py:class:`~Compiler.types.Array` in a loop even when using + :py:obj:`global`. Use :py:func:`~Compiler.types.Array.assign` + instead. """ def decorator(loop_body): range_loop(loop_body, start, stop, step) @@ -889,7 +894,7 @@ def for_range_parallel(n_parallel, n_loops): the optimization. :param n_parallel: compile-time (int) - :param n_loops: regint/cint/int + :param n_loops: regint/cint/int or list of int Example: @@ -898,7 +903,18 @@ def for_range_parallel(n_parallel, n_loops): @for_range_parallel(n_parallel, n_loops) def _(i): a[i] = a[i] * a[i] + + Multidimensional ranges are supported as well. The following + executes ``f(0, 0)`` to ``f(4, 2)``, two calls in parallel. + + .. code:: + + @for_range_parallel(2, [5, 3]) + def f(i, j): + ... """ + if isinstance(n_loops, (list, tuple)): + return for_range_multithread(None, n_parallel, n_loops) return map_reduce_single(n_parallel, n_loops) def for_range_opt(n_loops, budget=None): @@ -922,7 +938,18 @@ def for_range_opt(n_loops, budget=None): def _(i): ... + Multidimensional ranges are supported as well. The following + executes ``f(0, 0)`` to ``f(4, 2)`` in parallel according to + the budget. + + .. code:: + + @for_range_opt([5, 3]) + def f(i, j): + ... """ + if isinstance(n_loops, (list, tuple)): + return for_range_opt_multithread(None, n_loops) return map_reduce_single(None, n_loops, budget=budget) def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [], @@ -961,9 +988,13 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [], @for_range(loop_rounds) def f(i): state = tuplify(initializer()) + start_block = get_block() for k in range(n_parallel): j = i * n_parallel + k state = reducer(tuplify(loop_body(j)), state) + if n_parallel > 1 and start_block != get_block(): + print('WARNING: parallelization broken ' + 'by control flow instruction') r = reducer(mem_state, state) write_state_to_memory(r) else: @@ -1222,6 +1253,24 @@ def map_sum(n_threads, n_parallel, n_loops, n_items, value_types): return tuple(a + b for a,b in zip(x,y)) return map_reduce(n_threads, n_parallel, n_loops, initializer, summer) +def map_sum_opt(n_threads, n_loops, types): + """ Multi-threaded sum reduction. The following computes a sum of + ten squares in three threads:: + + @map_sum_opt(3, 10, [sint]) + def summer(i): + return sint(i) ** 2 + + result = summer() + + :param n_threads: number of threads (int) + :param n_loops: number of loop runs (regint/cint/int) + :param types: return type, must match the return statement + in the loop + + """ + return map_sum(n_threads, None, n_loops, len(types), types) + def tree_reduce_multithread(n_threads, function, vector): inputs = vector.Array(len(vector)) inputs.assign_vector(vector) @@ -1326,6 +1375,8 @@ def _run_and_link(function, g=None): for name, var in pre.items(): if isinstance(var, (program.Tape.Register, _single, _vec)): new_var = g[name] + if util.is_constant_float(new_var): + raise CompilerError('cannot reassign constants in blocks') if id(new_var) != id(var): new_var.link(var) return res diff --git a/Compiler/ml.py b/Compiler/ml.py index d2c3d6e1..9e147a2c 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -1,7 +1,7 @@ """ This module contains machine learning functionality. It is work in progress, so you must expect things to change. The only tested -functionality for training is using consective layers. +functionality for training is using consecutive layers. This includes logistic regression. It can be run as follows:: @@ -15,7 +15,7 @@ follows:: This loads measurements from party 0 and labels (0/1) from party 1. After running, the model is stored in :py:obj:`sgd.layers[0].W` and -:py:obj:`sgd.layers[1].b`. The :py:obj:`approx` parameter determines +:py:obj:`sgd.layers[0].b`. The :py:obj:`approx` parameter determines whether to use an approximate sigmoid function. Setting it to 5 uses a five-piece approximation instead of a three-piece one. @@ -172,6 +172,7 @@ def _no_mem_warnings(function): res = function(*args, **kwargs) get_program().warn_about_mem.pop() return res + copy_doc(wrapper, function) return wrapper class Tensor(MultiArray): @@ -265,6 +266,7 @@ class Output(NoVariableLayer): self.weights = None self.approx = approx self.compute_loss = True + self.d_out = 1 def divisor(self, divisor, size): return cfix(1.0 / divisor, size=size) @@ -303,7 +305,6 @@ class Output(NoVariableLayer): def _(base, size): diff = self.eval(size, base) - \ self.Y.get(batch.get_vector(base, size)) - assert sfix.f == cfix.f if self.weights is not None: assert N == len(self.weights) diff *= self.weights.get_vector(base, size) @@ -318,6 +319,7 @@ class Output(NoVariableLayer): print_ln('batch %s', batch.reveal_nested()) def set_weights(self, weights): + assert sfix.f == cfix.f self.weights = cfix.Array(len(weights)) self.weights.assign(weights) self.weight_total = sum(weights) @@ -1366,7 +1368,7 @@ class Conv2d(ConvBase): padding_h, padding_w = self.padding if self.use_conv2ds: - n_parts = max(1, round(self.n_threads / n_channels_out)) + n_parts = max(1, round((self.n_threads or 1) / n_channels_out)) while len(batch) % n_parts != 0: n_parts -= 1 print('Convolution in %d parts' % n_parts) @@ -1763,17 +1765,20 @@ class Optimizer: always_shuffle = True time_layers = False revealing_correctness = False + early_division = False @staticmethod def from_args(program, layers): if 'adam' in program.args or 'adamapprox' in program.args: - return Adam(layers, 1, approx='adamapprox' in program.args) + res = Adam(layers, 1, approx='adamapprox' in program.args) elif 'amsgrad' in program.args: - return Adam(layers, approx=True, amsgrad=True) + res = Adam(layers, approx=True, amsgrad=True) elif 'quotient' in program.args: - return Adam(layers, approx=True, amsgrad=True, normalize=True) + res = Adam(layers, approx=True, amsgrad=True, normalize=True) else: - return SGD(layers, 1) + res = SGD(layers, 1) + res.early_division = 'early_div' in program.args + return res def __init__(self, report_loss=None): self.tol = 0.000 @@ -1794,11 +1799,13 @@ class Optimizer: def layers(self, layers): """ Construct linear graph from list of layers. """ self._layers = layers + self.thetas = [] prev = None for layer in layers: if not layer.inputs and prev is not None: layer.inputs = [prev] prev = layer + self.thetas.extend(layer.thetas()) def set_layers_with_inputs(self, layers): """ Construct graph from :py:obj:`inputs` members of list of layers. """ @@ -1855,12 +1862,22 @@ class Optimizer: theta.delete() @_no_mem_warnings - def eval(self, data): - """ Compute evaluation after training. """ - N = len(data) - self.layers[0].X.assign(data) - self.forward(N) - return self.layers[-1].eval(N) + def eval(self, data, batch_size=None): + """ Compute evaluation after training. + + :param data: sample data (:py:class:`Compiler.types.Matrix` with one row per sample) + """ + if isinstance(self.layers[-1].Y, Array): + res = sfix.Array(len(data)) + else: + res = sfix.Matrix(len(data), self.layers[-1].d_out) + def f(start, batch_size, batch): + batch.assign_vector(regint.inc(batch_size, start)) + self.forward(batch=batch) + part = self.layers[-1].eval(batch_size) + res.assign_part_vector(part.get_vector(), start) + self.run_in_batches(f, data, batch_size or len(self.layers[1].X)) + return res @_no_mem_warnings def backward(self, batch): @@ -1877,6 +1894,9 @@ class Optimizer: if len(layer.inputs) == 1: layer.inputs[0].nabla_Y.address = \ layer.nabla_X.address + if i == len(self.layers) - 1 and self.early_division: + layer.nabla_X.assign_vector( + layer.nabla_X.get_vector() / len(batch)) if self.time_layers: stop_timer(200 + i) @@ -1963,36 +1983,41 @@ class Optimizer: return res def reveal_correctness(self, data, truth, batch_size): - training_data = self.layers[0].X.address - training_truth = self.layers[-1].Y.address - self.layers[0].X.address = data.address - self.layers[-1].Y.address = truth.address N = data.sizes[0] - batch = regint.Array(batch_size) n_correct = MemValue(0) loss = MemValue(sfix(0)) - def f(start, batch_size): + def f(start, batch_size, batch): batch.assign_vector(regint.inc(batch_size, start)) self.forward(batch=batch) part_truth = truth.get_part(start, batch_size) n_correct.iadd( self.layers[-1].reveal_correctness(batch_size, part_truth)) loss.iadd(self.layers[-1].l * batch_size) - @for_range(N // batch_size) - def _(i): - start = i * batch_size - f(start, batch_size) - batch_size = N % batch_size - if batch_size: - start = N - batch_size - f(start, batch_size) - self.layers[0].X.address = training_data - self.layers[-1].Y.address = training_truth + self.run_in_batches(f, data, batch_size) loss = loss.reveal() if cfix.f < 31: loss = cfix._new(loss.v << (31 - cfix.f), k=63, f=31) return n_correct, loss / N + def run_in_batches(self, f, data, batch_size, truth=None): + training_data = self.layers[0].X.address + training_truth = self.layers[-1].Y.address + self.layers[0].X.address = data.address + if truth: + self.layers[-1].Y.address = truth.address + N = data.sizes[0] + batch = regint.Array(batch_size) + @for_range(N // batch_size) + def _(i): + start = i * batch_size + f(start, batch_size, batch) + batch_size = N % batch_size + if batch_size: + start = N - batch_size + f(start, batch_size, batch) + self.layers[0].X.address = training_data + self.layers[-1].Y.address = training_truth + @_no_mem_warnings def run_by_args(self, program, n_runs, batch_size, test_X, test_Y, acc_batch_size=None): @@ -2048,11 +2073,14 @@ class Optimizer: print_ln('train_acc: %s (%s/%s)', cfix(self.n_correct, k=63, f=31) / n_trained, self.n_correct, n_trained) - n_test = len(test_Y) - n_correct, loss = self.reveal_correctness(test_X, test_Y, acc_batch_size) - print_ln('test loss: %s', loss) - print_ln('acc: %s (%s/%s)', cfix(n_correct, k=63, f=31) / n_test, - n_correct, n_test) + if test_X and test_Y: + n_test = len(test_Y) + n_correct, loss = self.reveal_correctness(test_X, test_Y, + acc_batch_size) + print_ln('test loss: %s', loss) + print_ln('acc: %s (%s/%s)', + cfix(n_correct, k=63, f=31) / n_test, + n_correct, n_test) if acc_first: start_timer(1) self.run(batch_size) @@ -2062,6 +2090,10 @@ class Optimizer: int(n_test // self.layers[-1].n_outputs * 1.2))) def _(): self.gamma.imul(.5) + if 'crash' in program.args: + @if_(self.gamma == 0) + def _(): + runtime_error('diverging') self.reset() print_ln('reset after reducing learning rate to %s', self.gamma) @@ -2110,7 +2142,6 @@ class Adam(Optimizer): self.ms = [] self.vs = [] self.gs = [] - self.thetas = [] self.vhats = [] for layer in layers: for nabla in layer.nablas(): @@ -2119,8 +2150,6 @@ class Adam(Optimizer): x.append(nabla.same_shape()) if amsgrad: self.vhats.append(nabla.same_shape()) - for theta in layer.thetas(): - self.thetas.append(theta) super(Adam, self).__init__() @@ -2177,12 +2206,10 @@ class SGD(Optimizer): self.momentum = 0.9 self.layers = layers self.n_epochs = n_epochs - self.thetas = [] self.nablas = [] self.delta_thetas = [] for layer in layers: self.nablas.extend(layer.nablas()) - self.thetas.extend(layer.thetas()) for theta in layer.thetas(): self.delta_thetas.append(theta.same_shape()) self.gamma = MemValue(cfix(0.01)) @@ -2220,10 +2247,13 @@ class SGD(Optimizer): # divide by len(batch) by truncation # increased rate if len(batch) is not a power of two pre_trunc = nabla_vector.v * rate.v - k = nabla_vector.k + rate.k + k = max(nabla_vector.k, rate.k) + rate.f m = rate.f + int(log_batch_size) - v = pre_trunc.round(k, m, signed=True, - nearest=sfix.round_nearest) + if self.early_division: + v = pre_trunc + else: + v = pre_trunc.round(k, m, signed=True, + nearest=sfix.round_nearest) new = nabla_vector._new(v) diff = red_old - new delta_theta.assign_vector(diff, base) @@ -2265,3 +2295,228 @@ class SGD(Optimizer): print_ln('%s at %s: nabla=%s update=%s theta=%s', str(theta), index, aa[1][index], aa[0][index], aa[2][index]) self.gamma.imul(1 - 10 ** - 6) + +def apply_padding(input_shape, kernel_size, strides, padding): + if padding == 'valid': + return (input_shape[0] - kernel_size[0] + 1) // strides[0], \ + (input_shape[1] - kernel_size[1] + 1) // strides[1], + elif padding == 'same': + return (input_shape[1]) // strides[0], \ + (input_shape[2]) // strides[1], + else: + raise Exception('invalid padding: ' + padding) + +class keras: + class layers: + Flatten = lambda *args, **kwargs: ('flatten', args, kwargs) + Dense = lambda *args, **kwargs: ('dense', args, kwargs) + + def Conv2D(filters, kernel_size, strides=(1, 1), padding='valid', + activation=None): + return 'conv2d', {'filters': filters, 'kernel_size': kernel_size, + 'strides': strides, 'padding': padding, + 'activation': activation} + + def MaxPooling2D(pool_size=2, strides=None, padding='valid'): + return 'maxpool', {'pool_size': pool_size, 'strides': strides, + 'padding': padding} + + def Dropout(rate): + l = math.log(rate, 2) + if int(l) != l: + raise Exception('rate needs to be a power of two') + return 'dropout', rate + + class optimizers: + SGD = lambda *args, **kwargs: ('sgd', args, kwargs) + Adam = lambda *args, **kwargs: ('adam', args, kwargs) + + class models: + class Sequential: + def __init__(self, layers): + self.layers = layers + self.optimizer = None + self.opt = None + + def compile(self, optimizer): + self.optimizer = optimizer + + @property + def trainable_variables(self): + if self.opt == None: + raise Exception('need to run build() or fit() first') + return list(self.opt.thetas) + + def build(self, input_shape, batch_size=128): + if self.opt != None and \ + input_shape == self.opt.layers[0].X.sizes and \ + batch_size <= self.batch_size and \ + type(self.opt).__name__.lower() == self.optimizer[0]: + return + if self.optimizer == None: + self.optimizer = 'inference', [], {} + if input_shape == None: + raise Exception('must specify number of samples') + Layer.back_batch_size = batch_size + layers = [] + for i, layer in enumerate(self.layers): + name = layer[0] + if name == 'dense': + if len(layers) == 0: + N = input_shape[0] + n_units = reduce(operator.mul, input_shape[1:]) + else: + N = batch_size + n_units = reduce(operator.mul, + layers[-1].Y.sizes[1:]) + if i == len(self.layers) - 1: + if layer[2].get('activation', 'softmax') in \ + ('softmax', 'sigmoid'): + del layer[2]['activation'] + layers.append(Dense(N, n_units, layer[1][0], + **layer[2])) + elif name == 'conv2d': + if len(layers) != 0: + input_shape = layers[-1].Y.sizes + input_shape = list(input_shape) + \ + [1] * (4 - len(input_shape)) + print (layer[1]) + kernel_size = layer[1]['kernel_size'] + filters = layer[1]['filters'] + strides = layer[1]['strides'] + padding = layer[1]['padding'] + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + if isinstance(strides, int): + strides = (strides, strides) + weight_shape = [filters] + list(kernel_size) + \ + [input_shape[-1]] + output_shape = [batch_size] + list( + apply_padding(input_shape[1:3], kernel_size, + strides, padding)) + [filters] + layers.append(FixConv2d(input_shape, weight_shape, + (filters,), output_shape, + strides, padding.upper())) + elif name == 'maxpool': + pool_size = layer[1]['pool_size'] + strides = layer[1]['strides'] + padding = layer[1]['padding'] + if isinstance(pool_size, int): + pool_size = (pool_size, pool_size) + if isinstance(strides, int): + strides = (strides, strides) + if strides == None: + strides = pool_size + layers.append(MaxPool(layers[-1].Y.sizes, + [1] + list(strides) + [1], + [1] + list(pool_size) + [1], + padding.upper())) + elif name == 'dropout': + layers.append(Dropout(batch_size, reduce( + operator.mul, layers[-1].Y.sizes[1:]), + alpha=layer[1])) + elif name == 'flatten': + pass + else: + raise Exception(layer[0] + ' not supported') + if layers[-1].d_out == 1: + layers.append(Output(input_shape[0])) + else: + layers.append(MultiOutput(input_shape[0], layers[-1].d_out)) + if self.optimizer[1]: + raise Exception('use keyword arguments for optimizer') + opt = self.optimizer[0] + opts = self.optimizer[2] + if opt == 'sgd': + opt = SGD(layers, 1) + momentum = opts.pop('momentum', None) + if momentum != None: + opt.momentum = momentum + elif opt == 'adam': + opt = Adam(layers, amsgrad=opts.pop('amsgrad', None), + approx=True) + beta1 = opts.pop('beta_1', None) + beta2 = opts.pop('beta_2', None) + epsilon = opts.pop('epsilon', None) + if beta1 != None: + opt.beta1 = beta1 + if beta2: + opt.beta2 = beta2 + if epsilon: + if epsilon < opt.epsilon: + print('WARNING: epsilon smaller than default might ' + 'cause overflows') + opt.epsilon = epsilon + elif opt == 'inference': + opt = Optimizer() + opt.layers = layers + else: + raise Exception(opt + ' not supported') + lr = opts.pop('learning_rate', None) + if lr != None: + opt.gamma = MemValue(cfix(lr)) + if opts: + raise Exception(opts + ' not supported') + self.batch_size = batch_size + self.opt = opt + + def fit(self, x, y, batch_size, epochs=1, validation_data=None): + assert len(x) == len(y) + self.build(x.sizes, batch_size) + if x.total_size() != self.opt.layers[0].X.total_size(): + raise Exception('sample data size mismatch') + if y.total_size() != self.opt.layers[-1].Y.total_size(): + print (y, layers[-1].Y) + raise Exception('label size mismatch') + if validation_data == None: + validation_data = None, None + else: + if len(validation_data[0]) != len(validation_data[1]): + raise Exception('test set size mismatch') + self.opt.layers[0].X.address = x.address + self.opt.layers[-1].Y.address = y.address + self.opt.run_by_args(get_program(), epochs, batch_size, + validation_data[0], validation_data[1], + batch_size) + return self.opt + + def predict(self, x, batch_size=None): + if self.opt == None: + raise Exception('need to run fit() or build() first') + if batch_size != None: + batch_size = min(batch_size, self.batch_size) + return self.opt.eval(x, batch_size=batch_size) + +def solve_linear(A, b, n_iterations, debug=False): + """ Iterative linear solution approximation. """ + assert len(b) == A.sizes[0] + x = sfix.Array(A.sizes[1]) + x.assign_vector(sfix.get_random(-1, 1, size=len(x))) + At = A.transpose() + @for_range(n_iterations) + def _(i): + r = At * (b - A * x) + tmp = A * r + tmp = sfix.dot_product(tmp, tmp) + alpha = (tmp == 0).if_else(0, sfix.dot_product(r, r) / tmp) + x.assign(x + alpha * r) + if debug: + print_ln('%s r=%s tmp=%s r*r=%s tmp*tmp=%s alpha=%s x=%s alpha*r=%s', i, + list(r.reveal()), list(tmp.reveal()), + sfix.dot_product(r, r).reveal(), sfix.dot_product(tmp, tmp).reveal(), + alpha.reveal(), x.reveal_list(), list((alpha * r).reveal())) + return x + +def mr(A, n_iterations): + """ Iterative matrix inverse approximation. """ + assert len(A.sizes) == 2 + assert A.sizes[0] == A.sizes[1] + M = A.same_shape() + n = A.sizes[0] + @for_range(n) + def _(i): + e = sfix.Array(n) + e.assign_all(0) + e[i] = 1 + M[i] = solve_linear(A, e, n_iterations) + return M.transpose() diff --git a/Compiler/mpc_math.py b/Compiler/mpc_math.py index 87d11def..2c3ec450 100644 --- a/Compiler/mpc_math.py +++ b/Compiler/mpc_math.py @@ -3,7 +3,7 @@ Module for math operations. Implements trigonometric and logarithmic functions. -This has to imported explicitely. +This has to imported explicitly. """ diff --git a/Compiler/non_linear.py b/Compiler/non_linear.py index 3ef709ca..43e10c2e 100644 --- a/Compiler/non_linear.py +++ b/Compiler/non_linear.py @@ -38,6 +38,12 @@ class NonLinear: signed) return res + def trunc(self, a, k, m, kappa, signed): + self.check_security(kappa) + if m == 0: + return a + return self._trunc(a, k, m, signed) + class Masking(NonLinear): def eqz(self, a, k): c, r = self._mask(a, k) @@ -71,6 +77,12 @@ class Prime(Masking): def _trunc_pr(self, a, k, m, signed=None): return TruncPrField(a, k, m, self.kappa) + def _trunc(self, a, k, m, signed=None): + a_prime = self.mod2m(a, k, m, signed) + tmp = cint() + inv2m(tmp, m) + return (a - a_prime) * tmp + def bit_dec(self, a, k, m, maybe_mixed=False): if maybe_mixed: return BitDecFieldRaw(a, k, m, self.kappa) @@ -94,6 +106,14 @@ class KnownPrime(NonLinear): # nearest truncation return self.trunc_round_nearest(a, k, m, signed) + def _trunc(self, a, k, m, signed=None): + if signed: + a += cint(1) << (k - 1) + res = sint.bit_compose(self.bit_dec(a, k, k, True)[m:]) + if signed: + res -= cint(1) << (k - 1 - m) + return res + def trunc_round_nearest(self, a, k, m, signed): a += cint(1) << (m - 1) if signed: @@ -133,6 +153,9 @@ class Ring(Masking): def _trunc_pr(self, a, k, m, signed): return TruncPrRing(a, k, m, signed=signed) + def _trunc(self, a, k, m, signed=None): + return comparison.TruncRing(None, a, k, m, signed=signed) + def bit_dec(self, a, k, m, maybe_mixed=False): if maybe_mixed: return BitDecRingRaw(a, k, m) diff --git a/Compiler/program.py b/Compiler/program.py index b004c48c..4fe3673a 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -149,6 +149,7 @@ class Program(object): self._always_raw = False self._linear_rounds = False self.warn_about_mem = [True] + self.relevant_opts = set() Program.prog = self from . import instructions_base, instructions, types, comparison instructions.program = self @@ -329,11 +330,13 @@ class Program(object): sch_file.write(' '.join(sys.argv) + '\n') req = max(x.req_bit_length['p'] for x in self.tapes) if self.options.ring: - sch_file.write('R:%s' % (self.options.ring if req else 0)) + sch_file.write('R:%s' % self.options.ring) elif self.options.prime: sch_file.write('p:%s' % self.options.prime) else: sch_file.write('lgp:%s' % req) + sch_file.write('\n') + sch_file.write('opts: %s\n' % ' '.join(self.relevant_opts)) for tape in self.tapes: tape.write_bytes() @@ -469,6 +472,16 @@ class Program(object): self.tape_counter += 1 return res + @property + def use_trunc_pr(self): + if not self._use_trunc_pr: + self.relevant_opts.add('trunc_pr') + return self._use_trunc_pr + + @use_trunc_pr.setter + def use_trunc_pr(self, change): + self._use_trunc_pr = change + def use_edabit(self, change=None): """ Setting whether to use edaBits for non-linear functionality (default: false). @@ -477,6 +490,8 @@ class Program(object): :returns: setting if :py:obj:`change` is :py:obj:`None` """ if change is None: + if not self._edabit: + self.relevant_opts.add('edabit') return self._edabit else: self._edabit = change @@ -492,6 +507,8 @@ class Program(object): :returns: setting if :py:obj:`change` is :py:obj:`None` """ if change is None: + if not self._split: + self.relevant_opts.add('split') return self._split else: if change and not self.options.ring: @@ -527,12 +544,14 @@ class Program(object): """ Set a number of options from the command-line arguments. """ if 'trunc_pr' in self.args: self.use_trunc_pr = True + if 'signed_trunc_pr' in self.args: + self.use_trunc_pr = -1 if 'split' in self.args or 'split3' in self.args: self.use_split(3) - if 'split4' in self.args: - self.use_split(4) - if 'split2' in self.args: - self.use_split(2) + for arg in self.args: + m = re.match('split([0-9]+)', arg) + if m: + self.use_split(int(m.group(1))) if 'raw' in self.args: self.always_raw(True) if 'edabit' in self.args: diff --git a/Compiler/types.py b/Compiler/types.py index e0453179..8b398f61 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -1,5 +1,5 @@ """ -This module defines all types availabe in high-level programs. These +This module defines all types available in high-level programs. These include basic types such as secret integers or floating-point numbers and container types. A single instance of the former uses one or more so-called registers in the virtual machine while the latter use the @@ -20,30 +20,12 @@ correct signature. Basic types ----------- -Basic types contain many special methods such as :py:func:`__add__`. This is -used for operator overloading in Python. It is not recommend to use -them, use the plain operators instead, such as ``+`` instead of -:py:func:`__add__`. See -https://docs.python.org/3/reference/datamodel.html#special-method-names -for a translation to operators. +All basic can be used as vectors, that is one instance representing +several values, with all operations being executed element-wise. For +example, the following computes ten multiplications of integers input +by party 0 and 1:: -In some operations such as -secure comparison, the secure computation protocols allows for more -parameters than just the operands which influence the performance. In -this case, we provide an alias for better code readability. For -example, :meth:`sint.greater_than` is an alias of -:py:meth:`sint.__gt__`. When using operator overloading, the -parameters default to the globally defined ones. - -Methods of basic types generally return instances of the respective type. - -Note that the data model of Python operates with reverse operators -such as :py:func:`__radd__`. This means that if for the usual operator of the -first operand does not support the second operand, the reverse -operator of the second operand is used. For example, -:py:meth:`_clear.__sub__` does not support secret values as second -operand but :py:meth:`_secret.__rsub__` does support clear values, so -``cint(3) - sint(2)`` will result in a secret integer of value 1. + sint.get_input_from(0, size=10) * sint.get_input_from(1, size=10) .. autosummary:: :nosignatures: @@ -193,7 +175,7 @@ def vectorized_classmethod(function): def vectorize_init(function): def vectorized_init(*args, **kwargs): size = None - if len(args) > 1 and (isinstance(args[1], Tape.Register) or \ + if len(args) > 1 and (isinstance(args[1], _register) or \ isinstance(args[1], sfloat)): size = args[1].size if 'size' in kwargs and kwargs['size'] is not None \ @@ -343,6 +325,14 @@ class _int(object): @staticmethod def bit_adder(*args, **kwargs): + """ Binary adder in arithmetic circuits. + + :param a: summand (list of 0/1 in compatible type) + :param b: summand (list of 0/1 in compatible type) + :param carry_in: input carry (default 0) + :param get_carry: add final carry to output + :returns: list of 0/1 in relevant type + """ return intbitint.bit_adder(*args, **kwargs) @staticmethod @@ -414,6 +404,10 @@ class _int(object): carry = self * other return self + other - 2 * carry, carry + @staticmethod + def long_one(): + return 1 + class _bit(object): """ Binary functionality. """ @@ -506,6 +500,18 @@ class _structure(object): """ return Matrix(rows, columns, cls, *args, **kwargs) + @classmethod + def Tensor(cls, shape): + """ + Type-dependent tensor of any dimension:: + + a = sfix.Tensor([10, 10]) + """ + if len(shape) == 1: + return Array(size[0], cls) + else: + return MultiArray(shape, cls) + @classmethod def row_matrix_mul(cls, row, matrix, res_params=None): return sum(row[k].mul_no_reduce(matrix[k].get_vector(), @@ -614,8 +620,11 @@ class _register(Tape.Register, _number, _structure): @set_instruction_type def __init__(self, reg_type, val, size): + from .GC.types import sbits if isinstance(val, (tuple, list)): size = len(val) + elif isinstance(val, sbits): + size = val.n super(_register, self).__init__(reg_type, program.curr_tape, size=size) if isinstance(val, int): self.load_int(val) @@ -687,14 +696,28 @@ class _clear(_register): def raw_output(self): raw_output(self) + @vectorize + def binary_output(self, player=None): + """ Write 64-bit signed integer to + ``Player-Data/Binary-Output-P-``. + + :param player: only output on given player (default all) + """ + regint(self).binary_output(player) + @set_instruction_type @read_mem_value @vectorize def clear_op(self, other, c_inst, ci_inst, reverse=False): cls = self.__class__ res = self.prep_res(other) + if isinstance(other, regint): + other = cls(other) if isinstance(other, cls): - c_inst(res, self, other) + if reverse: + c_inst(res, other, self) + else: + c_inst(res, self, other) elif isinstance(other, int): if self.in_immediate_range(other): ci_inst(res, self, other) @@ -713,7 +736,7 @@ class _clear(_register): def coerce_op(self, other, inst, reverse=False): cls = self.__class__ res = cls() - if isinstance(other, int): + if isinstance(other, (int, regint)): other = cls(other) elif not isinstance(other, cls): return NotImplemented @@ -756,20 +779,6 @@ class _clear(_register): return self.coerce_op(other, divc, True) __rtruediv__.__doc__ = __truediv__.__doc__ - def __eq__(self, other): - """ Equality check of public values. - - :param other: convertible type (at least same as :py:obj:`self` and regint/int) - :return: 0/1 (regint) """ - if isinstance(other, (_clear,int)): - return regint(self) == other - else: - return NotImplemented - - def __ne__(self, other): - return 1 - (self == other) - __ne__.__doc__ = __eq__.__doc__ - def __and__(self, other): """ Bit-wise AND of public values. @@ -799,8 +808,22 @@ class _clear(_register): class cint(_clear, _int): """ - Clear integer in same domain as secure computation - (depends on protocol). + Clear integer in same domain as secure computation (depends on + protocol). A number operators are supported (``+, -, *, /, //, **, + %, ^, &, |, ~, ==, !=, <<, >>``), returning either + :py:class:`cint` if the other operand is public (cint/regint/int) + or :py:class:`sint` if the other operand is + :py:class:`sint`. Comparison operators (``==, !=, <, <=, >, >=``) + are also supported, returning :py:func:`regint`. Comparisons and + ``~`` require that the value is within the global bit length. The + same holds for :py:func:`abs`. ``/`` runs field division if the + modulus is a prime while ``//`` runs integer floor + division. ``**`` requires the exponent to be compile-time integer + or the base to be two. + + :param val: initialization (cint/regint/int/cgf2n or list thereof) + :param size: vector size (int), defaults to 1 or size of list + """ __slots__ = [] instruction_type = 'modp' @@ -845,10 +868,6 @@ class cint(_clear, _int): @vectorize_init def __init__(self, val=None, size=None): - """ - :param val: initialization (cint/regint/int/cgf2n or list thereof) - :param size: vector size (int), defaults to 1 or size of list - """ super(cint, self).__init__('c', val=val, size=size) @vectorize @@ -901,14 +920,23 @@ class cint(_clear, _int): :param other: cint/regint/int """ return self.coerce_op(other, modc, True) + def __floordiv__(self, other): + return self.coerce_op(other, floordivc) + + def __rfloordiv__(self, other): + return self.coerce_op(other, floordivc, True) + + @vectorize def less_than(self, other, bit_length): """ Clear comparison for particular bit length. :param other: cint/regint/int :param bit_length: signed bit length of inputs :return: 0/1 (regint), undefined if inputs outside range """ + if not isinstance(other, (cint, regint, int)): + return NotImplemented if bit_length <= 64: - return self < other + return regint(self) < regint(other) else: diff = self - other shifted = diff >> (bit_length - 1) @@ -916,18 +944,16 @@ class cint(_clear, _int): return res def __lt__(self, other): - """ Clear 64-bit comparison. + """ Clear comparison. :param other: cint/regint/int :return: 0/1 (regint) """ - if isinstance(other, (type(self),int)): - return regint(self) < other - else: - return NotImplemented + return self.less_than(other, program.bit_length) + @vectorize def __gt__(self, other): - if isinstance(other, (type(self),int)): - return regint(self) > other + if isinstance(other, (cint, regint, int)): + return self.conv(other) < self else: return NotImplemented @@ -947,7 +973,7 @@ class cint(_clear, _int): :param other: cint/regint/int :return: 0/1 (regint) """ - if not isinstance(other, (_clear, int)): + if not isinstance(other, (_clear, regint, int)): return NotImplemented res = 1 remaining = program.bit_length @@ -962,6 +988,9 @@ class cint(_clear, _int): remaining -= 64 return res + def __ne__(self, other): + return 1 - (self == other) + def __lshift__(self, other): """ Clear left shift. @@ -1065,8 +1094,17 @@ class cint(_clear, _int): class cgf2n(_clear, _gf2n): """ - Clear :math:`\mathrm{GF}(2^n)` value. n is 40 or 128, - depending on USE_GF2N_LONG compile-time variable. + Clear :math:`\mathrm{GF}(2^n)` value. n is chosen at runtime. A + number operators are supported (``+, -, *, /, **, ^, &, |, ~, ==, + !=, <<, >>``), returning either :py:class:`cgf2n` if the other + operand is public (cgf2n/regint/int) or :py:class:`sgf2n` if the + other operand is secret. The following operators require the other + operand to be a compile-time integer: ``**, <<, >>``. ``*, /, **`` refer + to field multiplication and division. + + :param val: initialization (cgf2n/cint/regint/int or list thereof) + :param size: vector size (int), defaults to 1 or size of list + """ __slots__ = [] instruction_type = 'gf2n' @@ -1097,10 +1135,6 @@ class cgf2n(_clear, _gf2n): return value < 2**32 and value >= 0 def __init__(self, val=None, size=None): - """ - :param val: initialization (cgf2n/cint/regint/int or list thereof) - :param size: vector size (int), defaults to 1 or size of list - """ super(cgf2n, self).__init__('cg', val=val, size=size) @vectorize @@ -1158,6 +1192,16 @@ class cgf2n(_clear, _gf2n): else: return NotImplemented + def __eq__(self, other): + if isinstance(other, (cgf2n, int)): + return (regint(self) == regint(other)) * \ + (regint(self >> 64) == regint(other >> 64)) + else: + return NotImplemented + + def __ne__(self, other): + return 1 - (self == other) + @vectorize def bit_decompose(self, bit_length=None, step=None): """ Clear bit decomposition. @@ -1173,7 +1217,16 @@ class cgf2n(_clear, _gf2n): class regint(_register, _int): """ Clear 64-bit integer. - Unlike :py:class:`cint` this is always a 64-bit integer. + Unlike :py:class:`cint` this is always a 64-bit integer. The type + supports the following operations with :py:class:`regint` or + Python integers, always returning :py:class:`regint`: ``+, -, *, %, + /, //, **, ^, &, |, <<, >>, ==, !=, <, <=, >, >=``. For operations + with other types, see the respective descriptions. Both ``/`` and + ``//`` stand for floor division. + + :param val: initialization (cint/cgf2n/regint/int or list thereof) + :param size: vector size (int), defaults to 1 or size of list + """ __slots__ = [] reg_type = 'ci' @@ -1265,10 +1318,6 @@ class regint(_register, _int): @vectorize_init def __init__(self, val=None, size=None): - """ - :param val: initialization (cint/cgf2n/regint/int or list thereof) - :param size: vector size (int), defaults to 1 or size of list - """ super(regint, self).__init__(self.reg_type, val=val, size=size) def load_int(self, val): @@ -1298,9 +1347,9 @@ class regint(_register, _int): @vectorize @read_mem_value def int_op(self, other, inst, reverse=False): - try: + if isinstance(other, (int, regint)): other = self.conv(other) - except: + else: return NotImplemented res = regint() if reverse: @@ -1352,6 +1401,8 @@ class regint(_register, _int): """ Clear modulo computation. :param other: regint/cint/int """ + if util.is_constant(other) and other >= 2 ** 64: + return self return self - (self / other) * other def __rmod__(self, other): @@ -1372,16 +1423,16 @@ class regint(_register, _int): :param other: regint/cint/int :return: 0/1 """ - return self.int_op(other, eqc) + return self.int_op(other, eqc, False) def __ne__(self, other): return 1 - (self == other) def __lt__(self, other): - return self.int_op(other, ltc) + return self.int_op(other, ltc, False) def __gt__(self, other): - return self.int_op(other, gtc) + return self.int_op(other, gtc, False) def __le__(self, other): return 1 - (self > other) @@ -1393,6 +1444,12 @@ class regint(_register, _int): op.__doc__ = __eq__.__doc__ del op + def cint_op(self, other, op): + if isinstance(other, regint): + return regint(op(cint(self), other)) + else: + return NotImplemented + def __lshift__(self, other): """ Clear shift. @@ -1400,13 +1457,13 @@ class regint(_register, _int): if isinstance(other, int): return self * 2**other else: - return regint(cint(self) << other) + return self.cint_op(other, operator.lshift) def __rshift__(self, other): if isinstance(other, int): return self / 2**other else: - return regint(cint(self) >> other) + return self.cint_op(other, operator.rshift) def __rlshift__(self, other): return regint(other << cint(self)) @@ -1422,19 +1479,19 @@ class regint(_register, _int): """ Clear bit-wise AND. :param other: regint/cint/int """ - return regint(other & cint(self)) + return self.cint_op(other, operator.and_) def __or__(self, other): """ Clear bit-wise OR. :param other: regint/cint/int """ - return regint(other | cint(self)) + return self.cint_op(other, operator.or_) def __xor__(self, other): """ Clear bit-wise XOR. :param other: regint/cint/int """ - return regint(other ^ cint(self)) + return self.cint_op(other, operator.xor) __rand__ = __and__ __ror__ = __or__ @@ -1494,12 +1551,24 @@ class regint(_register, _int): def output_if(self, cond): cint(self).output_if(cond) + def binary_output(self, player=None): + """ Write 64-bit signed integer to + ``Player-Data/Binary-Output-P-``. + + :param player: only output on given player (default all) + """ + if player == None: + player = -1 + intoutput(player, self) + class localint(object): """ Local integer that must prevented from leaking into the secure - computation. Uses regint internally. """ + computation. Uses regint internally. + + :param value: initialization, convertible to regint + """ def __init__(self, value=None): - """ :param value: initialization, convertible to regint """ self._v = regint(value) self.size = 1 @@ -1517,9 +1586,59 @@ class localint(object): class personal(object): def __init__(self, player, value): + assert value is not NotImplemented + assert not isinstance(value, _secret) + while isinstance(value, personal): + assert player == value.player + value = value._v self.player = player self._v = value + def binary_output(self): + self._v.binary_output(self.player) + + def _san(self, other): + if isinstance(other, personal): + assert self.player == other.player + return self._v + + def _div_san(self): + return self._v.conv((library.get_player_id() == self.player)._v).if_else(self._v, 1) + + __add__ = lambda self, other: personal(self.player, self._san(other) + other) + __sub__ = lambda self, other: personal(self.player, self._san(other) - other) + __mul__ = lambda self, other: personal(self.player, self._san(other) * other) + __pow__ = lambda self, other: personal(self.player, self._san(other) ** other) + __truediv__ = lambda self, other: personal(self.player, self._san(other) / other) + __floordiv__ = lambda self, other: personal(self.player, self._san(other) // other) + __mod__ = lambda self, other: personal(self.player, self._san(other) % other) + __lt__ = lambda self, other: personal(self.player, self._san(other) < other) + __gt__ = lambda self, other: personal(self.player, self._san(other) > other) + __le__ = lambda self, other: personal(self.player, self._san(other) <= other) + __ge__ = lambda self, other: personal(self.player, self._san(other) >= other) + __eq__ = lambda self, other: personal(self.player, self._san(other) == other) + __ne__ = lambda self, other: personal(self.player, self._san(other) != other) + __and__ = lambda self, other: personal(self.player, self._san(other) & other) + __xor__ = lambda self, other: personal(self.player, self._san(other) ^ other) + __or__ = lambda self, other: personal(self.player, self._san(other) | other) + __lshift__ = lambda self, other: personal(self.player, self._san(other) << other) + __rshift__ = lambda self, other: personal(self.player, self._san(other) >> other) + + __neg__ = lambda self: personal(self.player, -self._v) + + __radd__ = lambda self, other: personal(self.player, other + self._v) + __rsub__ = lambda self, other: personal(self.player, other - self._v) + __rmul__ = lambda self, other: personal(self.player, other * self._v) + __rand__ = lambda self, other: personal(self.player, other & self._v) + __rxor__ = lambda self, other: personal(self.player, other ^ self._v) + __ror__ = lambda self, other: personal(self.player, other | self._v) + __rlshift__ = lambda self, other: personal(self.player, other << self._v) + __rrshift__ = lambda self, other: personal(self.player, other >> self._v) + + __rtruediv__ = lambda self, other: personal(self.player, other / self._div_san()) + __rfloordiv__ = lambda self, other: personal(self.player, other // self._div_san()) + __rmod__ = lambda self, other: personal(self.player, other % self._div_san()) + class longint: def __init__(self, value, length=None, n_limbs=None): assert length is None or n_limbs is None @@ -1870,7 +1989,24 @@ class _secret(_register): class sint(_secret, _int): """ - Secret integer in the protocol-specific domain. + Secret integer in the protocol-specific domain. It supports + operations with :py:class:`sint`, :py:class:`cint`, + :py:class:`regint`, and Python integers. Operations where one of + the operands is an :py:class:`sint` either result in an + :py:class:`sint` or an :py:class:`sintbit`, the latter for + comparisons. + + The following operations work as expected in the computation + domain (modulo a prime or a power of two): ``+, -, *``. ``/`` + denotes the field division modulo a prime. It will reveal if the + divisor is zero. Comparisons operators (``==, !=, <, <=, >, >=``) + assume that the element in the computation domain represents a + signed integer in a restricted range, see below. The same holds + for ``abs()``, shift operators (``<<, >>``), modulo (``%``), and + exponentation (``**``). Modulo only works if the right-hand + operator is a compile-time power of two, and exponentiation only + works if the base is two or if the exponent is a compile-time + integer. Most non-linear operations require compile-time parameters for bit length and statistical security. They default to the global @@ -1884,6 +2020,11 @@ class sint(_secret, _int): parameter does not matter. Modulo prime, the behaviour is undefined and potentially insecure if the operands are longer than the bit length. + + :param val: initialization (sint/cint/regint/int/cgf2n or list + thereof or sbits/sbitvec) + :param size: vector size (int), defaults to 1 or size of list + """ __slots__ = [] instruction_type = 'modp' @@ -1978,10 +2119,6 @@ class sint(_secret, _int): edabit(whole, *bits) return whole, bits - @staticmethod - def long_one(): - return 1 - @staticmethod @vectorize def bit_decompose_clear(a, n_bits): @@ -2084,12 +2221,14 @@ class sint(_secret, _int): len(indices[3]), *(list(indices) + [m, l])) return res + @vectorize_init def __init__(self, val=None, size=None): - """ - :param val: initialization (sint/cint/regint/int/cgf2n or list thereof) - :param size: vector size (int), defaults to 1 or size of list - """ - super(sint, self).__init__('s', val=val, size=size) + if isinstance(val, personal): + size = val._v.size + super(sint, self).__init__('s', size=size) + inputpersonal(size, val.player, self, self.clear_type.conv(val._v)) + else: + super(sint, self).__init__('s', val=val, size=size) @vectorize def __neg__(self): @@ -2107,7 +2246,8 @@ class sint(_secret, _int): """ Secret comparison (signed). :param other: sint/cint/regint/int - :return: 0/1 (sint) """ + :param bit_length: bit length of input (default: global bit length) + :return: 0/1 (sintbit) """ res = sintbit() comparison.LTZ(res, self - other, (bit_length or program.bit_length) + 1, @@ -2166,7 +2306,9 @@ class sint(_secret, _int): def mod2m(self, m, bit_length=None, security=None, signed=True): """ Secret modulo power of two. - :param m: secret or public integer (sint/cint/regint/int) """ + :param m: secret or public integer (sint/cint/regint/int) + :param bit_length: bit length of input (default: global bit length) + """ bit_length = bit_length or program.bit_length security = security or program.security if isinstance(m, int): @@ -2191,14 +2333,19 @@ class sint(_secret, _int): @vectorize def pow2(self, bit_length=None, security=None): - """ Secret power of two. """ + """ Secret power of two. + + :param bit_length: bit length of input (default: global bit length) + """ return floatingpoint.Pow2(self, bit_length or program.bit_length, \ security or program.security) def __lshift__(self, other, bit_length=None, security=None): """ Secret left shift. - :param other: secret or public integer (sint/cint/regint/int) """ + :param other: secret or public integer (sint/cint/regint/int) + :param bit_length: bit length of input (default: global bit length) + """ return self * util.pow2_value(other, bit_length, security) @vectorize @@ -2206,7 +2353,9 @@ class sint(_secret, _int): def __rshift__(self, other, bit_length=None, security=None, signed=True): """ Secret right shift. - :param other: secret or public integer (sint/cint/regint/int) """ + :param other: secret or public integer (sint/cint/regint/int) + :param bit_length: bit length of input (default: global bit length) + """ bit_length = bit_length or program.bit_length security = security or program.security if isinstance(other, int): @@ -2286,7 +2435,9 @@ class sint(_secret, _int): def int_div(self, other, bit_length=None, security=None): """ Secret integer division. - :param other: sint """ + :param other: sint + :param bit_length: bit length of input (default: global bit length) + """ k = bit_length or program.bit_length kappa = security or program.security tmp = library.IntDiv(self, other, k, kappa) @@ -2339,13 +2490,45 @@ class sint(_secret, _int): if not util.is_constant(player) or self.size > 1: secret_mask = sint() player_mask = cint() - inputmaskreg(secret_mask, player_mask, player) + inputmaskreg(secret_mask, player_mask, regint.conv(player)) return personal(player, (self + secret_mask).reveal() - player_mask) else: return super(sint, self).reveal_to(player) + def private_division(self, divisor, active=True, dividend_length=None, + divisor_length=None): + assert active == False + + d = divisor + l = divisor_length or program.bit_length + m = dividend_length or program.bit_length + sigma = program.security + + min_length = m + l + 2 * sigma + 1 + if program.options.ring: + comparison.require_ring_size(min_length, 'private division') + else: + program.curr_tape.require_bit_length(min_length) + + r = sint.get_random_int(l + sigma) + r_prime = sint.get_random_int(m + sigma) + r_pprime = sint.get_random_int(l + sigma) + + h = (r + (r_prime << (l + sigma))) * sint(d) + z = ((self << (l + sigma)) + h + r_pprime).reveal_to(0) + + y = sint(z // (d << (l + sigma))) + y_prime = sint((z // d) % (2 ** (l + sigma))) + + b = r.greater_than(y_prime, l + sigma) + w = y - b - r_prime + + return w + class sintbit(sint): + """ :py:class:`sint` holding a bit, supporting binary operations + (``&, |, ^``). """ @classmethod def prep_res(cls, other): return sint() @@ -2405,7 +2588,18 @@ class sintbit(sint): return super(sintbit, self).__rsub__(other) class sgf2n(_secret, _gf2n): - """ Secret :math:`\mathrm{GF}(2^n)` value. """ + """ + Secret :math:`\mathrm{GF}(2^n)` value. n is chosen at runtime. A + number operators are supported (``+, -, *, /, **, ^, ~, ==, !=, + <<``), :py:class:`sgf2n`. Operators generally work with + cgf2n/regint/cint/int, except ``**, <<``, which require a + compile-time integer. ``/`` refers to field division. ``*, /, + **`` refer to field multiplication and division. + + :param val: initialization (sgf2n/cgf2n/regint/int/cint or list thereof) + :param size: vector size (int), defaults to 1 or size of list + + """ __slots__ = [] instruction_type = 'gf2n' clear_type = cgf2n @@ -2449,10 +2643,6 @@ class sgf2n(_secret, _gf2n): self._store_in_mem(address, gstms, gstmsi) def __init__(self, val=None, size=None): - """ - :param val: initialization (sgf2n/cgf2n/regint/int/cint or list thereof) - :param size: vector size (int), defaults to 1 or size of list - """ super(sgf2n, self).__init__('sg', val=val, size=size) def __neg__(self): @@ -3082,7 +3272,17 @@ def parse_type(other, k=None, f=None): return other class cfix(_number, _structure): - """ Clear fixed-point number represented as clear integer. """ + """ + Clear fixed-point number represented as clear integer. It supports + basic arithmetic (``+, -, *, /``), returning either + :py:class:`cfix` if the other operand is public + (cfix/regint/cint/int) or :py:class:`sfix` if the other operand is + an sfix. It also support comparisons (``==, !=, <, <=, >, >=``), + returning either :py:class:`regint` or :py:class:`sbitint`. + + :param v: cfix/float/int + + """ __slots__ = ['value', 'f', 'k'] reg_type = 'c' scalars = (int, float, regint, cint) @@ -3181,7 +3381,6 @@ class cfix(_number, _structure): @vectorize_init @read_mem_value def __init__(self, v=None, k=None, f=None, size=None): - """ :param v: cfix/float/int """ f = self.f if f is None else f k = self.k if k is None else k self.f = f @@ -3404,6 +3603,17 @@ class cfix(_number, _structure): def output_if(self, cond): cond_print_plain(cond, self.v, cint(-self.f)) + @vectorize + def binary_output(self, player=None): + """ Write double-precision floating-point number to + ``Player-Data/Binary-Output-P-``. + + :param player: only output on given player (default all) + """ + if player == None: + player = -1 + floatoutput(player, self.v, cint(-self.f), cint(0), cint(0)) + class _single(_number, _structure): """ Representation as single integer preserving the order """ """ E.g. fixed-point numbers """ @@ -3494,6 +3704,30 @@ class _single(_number, _structure): res = A.unreduced(CC, B, res_params, n).reduce_after_mul() return res + @classmethod + def read_from_file(cls, *args, **kwargs): + """ Read shares from ``Persistence/Transactions-P.data``. + Precision must be the same as when storing. + + :param start: starting position in number of shares from beginning + (int/regint/cint) + :param n_items: number of items (int) + :returns: destination for final position, -1 for eof reached, + or -2 for file not found (regint) + :returns: list of shares + """ + stop, shares = cls.int_type.read_from_file(*args, **kwargs) + return stop, [cls._new(x) for x in shares] + + @classmethod + def write_to_file(cls, shares): + """ Write shares of integer representation to + ``Persistence/Transactions-P.data`` (appending at the end). + + :param: shares (list or iterable of sfix) + """ + cls.int_type.write_to_file([x.v for x in shares]) + def store_in_mem(self, address): """ Store in memory by public address. """ self.v.store_in_mem(address) @@ -3581,6 +3815,9 @@ class _single(_number, _structure): def link(self, other): self.v.link(other.v) + def get_vector(self): + return self + class _fix(_single): """ Secret fixed point type. """ __slots__ = ['v', 'f', 'k'] @@ -3638,6 +3875,13 @@ class _fix(_single): res.load_int(cls.int_type.conv(other)) return res + @classmethod + def conv(cls, other): + if isinstance(other, _fix) and (cls.k, cls.f) == (other.k, other.f): + return other + else: + return cls(other) + @classmethod def _new(cls, other, k=None, f=None): res = cls(k=k, f=f) @@ -3646,7 +3890,6 @@ class _fix(_single): @vectorize_init def __init__(self, _v=None, k=None, f=None, size=None): - """ :params _v: int/float/regint/cint/sint/sfloat """ if k is None: k = self.k else: @@ -3721,13 +3964,16 @@ class _fix(_single): except: return NotImplemented if isinstance(other, (_fix, self.clear_type)): - val = self.v.TruncMul(other.v, self.k + other.k, other.f, + k = max(self.k, other.k) + max_f = max(self.f, other.f) + min_f = min(self.f, other.f) + val = self.v.TruncMul(other.v, k + min_f, min_f, self.kappa, self.round_nearest) if 'vec' not in self.__dict__: - return self._new(val, k=self.k, f=self.f) + return self._new(val, k=k, f=max_f) else: - return self.vec._new(val, k=self.k, f=self.f) + return self.vec._new(val, k=k, f=max_f) elif isinstance(other, cfix.scalars): scalar_fix = cfix(other) return self * scalar_fix @@ -3792,7 +4038,20 @@ class _fix(_single): class sfix(_fix): """ Secret fixed-point number represented as secret integer. This uses integer operations internally, see :py:class:`sint` for security - considerations. """ + considerations. + + It supports basic arithmetic (``+, -, *, /``), returning + :py:class:`sfix`, and comparisons (``==, !=, <, <=, >, >=``), + returning :py:class:`sbitint`. The other operand can be any of + sfix/sint/cfix/regint/cint/int/float. It also supports ``abs()`` + and ``**``, the latter for integer exponents. + + Note that the default precision (16 bits after the dot, 31 bits in + total) only allows numbers up to :math:`2^{31-16-1} \\approx + 16000`. You can increase this using :py:func:`set_precision`. + + :params _v: int/float/regint/cint/sint/sfloat + """ int_type = sint clear_type = cfix @@ -3866,7 +4125,7 @@ class sfix(_fix): return self.v def unreduced(self, v, other=None, res_params=None, n_summands=1): - return unreduced_sfix(v, self.k * 2, self.f, self.kappa) + return unreduced_sfix(v, self.k + self.f, self.f, self.kappa) @staticmethod def multipliable(v, k, f, size): @@ -3889,7 +4148,7 @@ class unreduced_sfix(_single): @classmethod def _new(cls, v): - return cls(v, 2 * sfix.k, sfix.f, sfix.kappa) + return cls(v, sfix.k + sfix.f, sfix.f, sfix.kappa) def __init__(self, v, k, m, kappa): self.v = v @@ -3913,7 +4172,7 @@ class unreduced_sfix(_single): def reduce_after_mul(self): v = sfix.int_type.round(self.v, self.k, self.m, self.kappa, nearest=sfix.round_nearest, signed=True) - return sfix._new(v, k=self.k // 2, f=self.m) + return sfix._new(v, k=self.k - self.m, f=self.m) sfix.unreduced_type = unreduced_sfix @@ -3934,6 +4193,13 @@ class squant(_single): def from_sint(cls, other): raise CompilerError('sint to squant conversion not implemented') + @classmethod + def conv(cls, other): + if isinstance(other, squant): + return other + else: + return cls(other) + @classmethod def _new(cls, value, params=None): res = cls(params=params) @@ -4126,7 +4392,15 @@ class sfloat(_number, _structure): s: sign bit This uses integer operations internally, see :py:class:`sint` for security - considerations. """ + considerations. + + The type supports basic arithmetic (``+, -, *, /``), returning + :py:class:`sfloat`, and comparisons (``==, !=, <, <=, >, >=``), + returning :py:class:`sint`. The other operand can be any of + sint/cfix/regint/cint/int/float. + + :param v: initialization (sfloat/sfix/float/int/sint/cint/regint) + """ __slots__ = ['v', 'p', 'z', 's', 'size'] # single precision @@ -4223,9 +4497,6 @@ class sfloat(_number, _structure): @vectorize_init @read_mem_value def __init__(self, v, p=None, z=None, s=None, size=None): - """ - :param v: initialization (sfloat/sfix/float/int/sint/cint/regint) - """ self.size = get_global_vector_size() if p is None: if isinstance(v, sfloat): @@ -4547,6 +4818,16 @@ class cfloat(object): """ Output. """ print_float_plain(self.v, self.p, self.z, self.s, self.nan) + def binary_output(self, player=None): + """ Write double-precision floating-point number to + ``Player-Data/Binary-Output-P-``. + + :param player: only output on given player (default all) + """ + if player == None: + player = -1 + floatoutput(player, self.v, self.p, self.z, self.s) + sfix.float_type = sfloat _types = { @@ -4564,24 +4845,39 @@ def _get_type(t): return t class Array(object): - """ Array accessible by public index. """ + """ + Array accessible by public index. That is, ``a[i]`` works for an + array ``a`` and ``i`` being a :py:class:`regint`, + :py:class:`cint`, or a Python integer. ``a[start:stop:step]`` + works as well, and so does iteration over an array. + + Arrays support a number of element-wise operations if the + underlying basic type does so. These are ``+, -, *, **, /``. The + return type of these is a vector, which can be assigned to an + array of a compatible type using :py:func:`assign`. + + :param length: compile-time integer (int) or :py:obj:`None` for unknown length + :param value_type: basic type + :param address: if given (regint/int), the array will not be allocated + """ @classmethod def create_from(cls, l): - """ Convert Python iterator to array. Basic type will be taken + """ Convert Python iterator or vector to array. Basic type will be taken from first element, further elements must to be convertible to that. """ if isinstance(l, cls): return l - tmp = list(l) - res = cls(len(tmp), type(tmp[0])) + if isinstance(l, _number): + tmp = l + t = type(l) + else: + tmp = list(l) + t = type(tmp[0]) + res = cls(len(tmp), t) res.assign(tmp) return res def __init__(self, length, value_type, address=None, debug=None, alloc=True): - """ - :param length: compile-time integer (int) or :py:obj:`None` for unknown length - :param value_type: basic type - :param address: if given (regint/int), the array will not be allocated """ value_type = _get_type(value_type) self.address = address self.length = length @@ -4662,9 +4958,20 @@ class Array(object): self._store(value, self.get_address(index)) def maybe_get(self, condition, index): + """ Return entry if condition is true. + + :param condition: 0/1 (regint/cint/int) + :param index: regint/cint/int + """ return condition * self[condition * index] def maybe_set(self, condition, index, value): + """ Change entry if condition is true. + + :param condition: 0/1 (regint/cint/int) + :param index: regint/cint/int + :param value: updated value + """ if self.sink is None: self.sink = self.value_type.Array( 1, address=self.value_type.malloc(1, creator_tape=program.tapes[0])) @@ -4699,9 +5006,16 @@ class Array(object): yield self[i] def same_shape(self): + """ Array of same length and type. """ return Array(self.length, self.value_type) def assign(self, other, base=0): + """ Assignment. + + :param other: vector/Array/Matrix/MultiArray/iterable of + compatible type and smaller size + :param base: index to start assignment at + """ try: other = other.get_vector() except: @@ -4714,10 +5028,10 @@ class Array(object): if isinstance(other, Array): @library.for_range_opt(len(other)) def _(i): - self[i] = other[i] + self[base + i] = other[i] else: for i,j in enumerate(other): - self[i] = j + self[base + i] = j return self assign_vector = assign @@ -4750,9 +5064,19 @@ class Array(object): get_part_vector = get_vector def get_part(self, base, size): + """ Part array. + + :param base: start index (regint/cint/int) + :param size: integer + :returns: Array of same type + """ return Array(size, self.value_type, self.get_address(base)) def get(self, indices): + """ Vector from arbitrary indices. + + :param indices: regint vector or array + """ return self.value_type.load_mem( regint.inc(len(indices), self.address, 0) + indices, size=len(indices)) @@ -4766,6 +5090,11 @@ class Array(object): return self.value_type.load_mem(self.address + addresses) def expand_to_vector(self, index, size): + """ Create vector from single entry. + + :param index: regint/cint/int + :param size: int + """ assert self.value_type.n_elements() == 1 address = regint(size=size) incint(address, regint(self.get_address(index), size=1), 0) @@ -4789,6 +5118,25 @@ class Array(object): def _(i): self[i] = input_from(player) + def read_from_file(self, start): + """ Read content from ``Persistence/Transactions-P.data``. + Precision must be the same as when storing if applicable. + + :param start: starting position in number of shares from beginning + (int/regint/cint) + :returns: destination for final position, -1 for eof reached, + or -2 for file not found (regint) + """ + stop, shares = self.value_type.read_from_file(start, len(self)) + self.assign(shares) + return stop + + def write_to_file(self): + """ Write shares of integer representation to + ``Persistence/Transactions-P.data`` (appending at the end). + """ + self.value_type.write_to_file(list(self)) + def __add__(self, other): """ Vector addition. @@ -4826,6 +5174,9 @@ class Array(object): __radd__ = __add__ __rmul__ = __mul__ + def __neg__(self): + return -self.get_vector() + def shuffle(self): """ Insecure shuffle in place. """ if self.value_type == regint: @@ -4850,6 +5201,23 @@ class Array(object): reveal_nested = reveal_list + def reveal_to_binary_output(self, player=None): + """ Reveal to binary output if supported by type. + + :param: player to reveal to (default all) + """ + if player == None: + self.get_vector().reveal().binary_output() + else: + self.get_vector().reveal_to(player).binary_output() + + def binary_output(self, player=None): + """ Binary output if supported by type. + + :param: player (default all) + """ + self.get_vector().binary_output(player) + def sort(self, n_threads=None): """ Sort in place using Batchers' odd-even merge mergesort @@ -4869,9 +5237,9 @@ sgf2n.dynamic_array = Array class SubMultiArray(object): - """ Multidimensional array functionality. """ + """ Multidimensional array functionality. Don't construct this + directly, use :py:class:`MultiArray` instead. """ def __init__(self, sizes, value_type, address, index, debug=None): - """ Do not call this, use :py:class:`MultiArray` instead. """ self.sizes = tuple(sizes) self.value_type = _get_type(value_type) if address is not None: @@ -4946,7 +5314,7 @@ class SubMultiArray(object): :param base: compile-time (int) """ assert self.value_type.n_elements() == 1 assert vector.size <= self.total_size() - vector.store_in_mem(self.address + base) + self.value_type.conv(vector).store_in_mem(self.address + base) def assign(self, other): """ Assign container to content. Not implemented for floating-point. @@ -4957,6 +5325,12 @@ class SubMultiArray(object): self.assign_vector(other.get_vector()) def get_part_vector(self, base=0, size=None): + """ Vector from range of the first dimension, including all + entries in further dimensions. + + :param base: index in first dimension (regint/cint/int) + :param size: size in first dimension (int) + """ assert self.value_type.n_elements() == 1 part_size = reduce(operator.mul, self.sizes[1:]) size = (size or 1) * part_size @@ -4965,12 +5339,23 @@ class SubMultiArray(object): size=size) def assign_part_vector(self, vector, base=0): + """ Assign vector from range of the first dimension, including all + entries in further dimensions. + + :param vector: updated entries + :param base: index in first dimension (regint/cint/int) + """ assert self.value_type.n_elements() == 1 part_size = reduce(operator.mul, self.sizes[1:]) assert vector.size <= self.total_size() vector.store_in_mem(self.address + base * part_size) def get_slice_vector(self, slice): + """ Vector from range of indicies of the first dimension, including + all entries in further dimensions. + + :param slice: regint array + """ assert self.value_type.n_elements() == 1 part_size = reduce(operator.mul, self.sizes[1:]) assert len(slice) * part_size <= self.total_size() @@ -5001,10 +5386,22 @@ class SubMultiArray(object): return res def get_vector_by_indices(self, *indices): + """ + Vector with potential asterisks. The potential retrieves + all entry where the first dimension index is 0, and the third + dimension index is 1:: + + a.get_vector_by_indices(0, None, 1) + + """ addresses = self.get_addresses(*indices) return self.value_type.load_mem(addresses) def assign_vector_by_indices(self, vector, *indices): + """ + Assign vector to entries with potential asterisks. See + :py:func:`get_vector_by_indices` for an example. + """ addresses = self.get_addresses(*indices) vector.store_in_mem(addresses) @@ -5013,6 +5410,12 @@ class SubMultiArray(object): return MultiArray(self.sizes, self.value_type) def get_part(self, start, size): + """ Part multi-array. + + :param start: first-dimension index (regint/cint/int) + :param size: int + + """ return MultiArray([size] + list(self.sizes[1:]), self.value_type, address=self[start].address) @@ -5034,6 +5437,29 @@ class SubMultiArray(object): def _(i): self[i].input_from(player, budget=budget, raw=raw) + def write_to_file(self): + """ Write shares of integer representation to + ``Persistence/Transactions-P.data`` (appending at the end). + """ + @library.for_range(len(self)) + def _(i): + self[i].write_to_file() + + def read_from_file(self, start): + """ Read content from ``Persistence/Transactions-P.data``. + Precision must be the same as when storing if applicable. + + :param start: starting position in number of shares from beginning + (int/regint/cint) + :returns: destination for final position, -1 for eof reached, + or -2 for file not found (regint) + """ + start = MemValue(start) + @library.for_range(len(self)) + def _(i): + start.write(self[i].read_from_file(start)) + return start + def schur(self, other): """ Element-wise product. @@ -5064,6 +5490,21 @@ class SubMultiArray(object): __radd__ = __add__ + def __sub__(self, other): + """ Element-wise subtraction. + + :param other: container of matching size and type + :return: container of same shape and type as :py:obj:`self` """ + if is_zero(other): + return self + assert self.sizes == other.sizes + if len(self.sizes) == 2: + res = Matrix(self.sizes[0], self.sizes[1], self.value_type) + else: + res = MultiArray(self.sizes, self.value_type) + res.assign_vector(self.get_vector() - other.get_vector()) + return res + def iadd(self, other): """ Element-wise addition in place. @@ -5131,6 +5572,8 @@ class SubMultiArray(object): def _(k): res_matrix[i][j] += self[i][k] * other[k][j] return res_matrix + elif isinstance(other, self.value_type): + return self * Array.create_from(other) else: raise NotImplementedError @@ -5154,10 +5597,14 @@ class SubMultiArray(object): regint.inc(5))) """ assert len(self.sizes) == 2 - assert len(other.sizes) == 2 - assert self.sizes[1] == other.sizes[0] + if isinstance(other, Array): + other_sizes = [len(other), 1] + else: + other_sizes = other.sizes + assert len(other.sizes) == 2 + assert self.sizes[1] == other_sizes[0] return self.value_type.direct_matrix_mul(self.address, other.address, - self.sizes[0], *other.sizes, + self.sizes[0], *other_sizes, reduce=reduce, indices=indices) def direct_mul_trans(self, other, reduce=True, indices=None): @@ -5310,6 +5757,12 @@ class SubMultiArray(object): library.break_point() return res + def trace(self): + """ Matrix trace. """ + assert len(self.sizes) == 2 + assert self.sizes[0] == self.sizes[1] + return sum(self[i][i] for i in range(self.sizes[0])) + def reveal_list(self): """ Reveal as list. """ return list(self.get_vector().reveal()) @@ -5325,17 +5778,34 @@ class SubMultiArray(object): return [f(sizes[1:]) for i in range(sizes[0])] return f(self.sizes) + def reveal_to_binary_output(self, player=None): + """ Reveal to binary output if supported by type. + + :param: player to reveal to (default all) + """ + if player == None: + self.get_vector().reveal().binary_output() + else: + self.get_vector().reveal_to(player).binary_output() + def __str__(self): return '%s multi-array of lengths %s at %s' % (self.value_type, self.sizes, self.address) class MultiArray(SubMultiArray): - """ Multidimensional array. """ + """ + Multidimensional array. The access operator (``a[i]``) allows to a + multi-dimensional array of dimension one less or a simple array + for a two-dimensional array. Element-wise addition and subtraction + is supported, returning a vector, which can be assigned using + :py:func:`assign`. Matrix-vector and matrix-matrix multiplication + is supported as well. + + :param sizes: shape (compile-time list of integers) + :param value_type: basic type of entries + + """ def __init__(self, sizes, value_type, debug=None, address=None, alloc=True): - """ - :param sizes: shape (compile-time list of integers) - :param value_type: basic type of entries - """ if isinstance(address, Array): self.array = address else: @@ -5361,21 +5831,27 @@ class MultiArray(SubMultiArray): self.array.delete() class Matrix(MultiArray): - """ Matrix. """ + """ Matrix. + + :param rows: compile-time (int) + :param columns: compile-time (int) + :param value_type: basic type of entries + + """ def __init__(self, rows, columns, value_type, debug=None, address=None): - """ - :param rows: compile-time (int) - :param columns: compile-time (int) - :param value_type: basic type of entries - """ MultiArray.__init__(self, [rows, columns], value_type, debug=debug, \ address=address) def set_column(self, index, vector): + """ Change column. + + :param index: regint/cint/int + :param vector: short enought vector of compatible type + """ assert self.value_type.n_elements() == 1 addresses = regint.inc(self.sizes[0], self.address + index, self.sizes[1]) - vector.store_in_mem(addresses) + self.value_type.conv(vector).store_in_mem(addresses) class VectorArray(object): def __init__(self, length, value_type, vector_size, address=None): @@ -5456,7 +5932,11 @@ class MemValue(_mem): """ Single value in memory. This is useful to transfer information between threads. Operations are automatically read from memory if required, this means you can use any operation with - :py:class:`MemValue` objects as if they were a basic type. """ + :py:class:`MemValue` objects as if they were a basic type. + + :param value: basic type or int (will be converted to regint) + + """ __slots__ = ['last_write_block', 'reg_type', 'register', 'address', 'deleted'] @classmethod @@ -5467,7 +5947,6 @@ class MemValue(_mem): return cls(value) def __init__(self, value, address=None): - """ :param value: basic type or int (will be converted to regint) """ self.last_write_block = None if isinstance(value, int): self.value_type = regint diff --git a/FHE/NTL-Subs.cpp b/FHE/NTL-Subs.cpp index e9a29269..cb5daa38 100644 --- a/FHE/NTL-Subs.cpp +++ b/FHE/NTL-Subs.cpp @@ -16,11 +16,13 @@ #include using namespace std; +#ifdef USE_NTL #include #include #include #include NTL_CLIENT +#endif #include "FHEOffline/DataSetup.h" @@ -288,7 +290,7 @@ void Parameters::SPDZ_Data_Setup(FHE_Params& params, FFT_Data& FTD) FTD.init(R,Zp); } - +#ifdef USE_NTL /* Compute Phi(N) */ int phi_N(int N) { @@ -342,6 +344,17 @@ ZZX Cyclotomic(int N) F=Num/Den; return F; } +#else +int phi_N(int N) +{ + if (((N - 1) & N) != 0) + throw runtime_error("compile with NTL support"); + else if (N == 1) + return 1; + else + return N / 2; +} +#endif void init(Ring& Rg, int m, bool generate_poly) @@ -370,6 +383,7 @@ void init(Ring& Rg, int m, bool generate_poly) } else { +#ifdef USE_NTL int k=0; for (int i=1; i& elem) const void PPData::reset_iteration() { - pow=1; theta = (root); thetaPow=theta; + pow = 1; + theta = {root, prData}; + thetaPow = theta; } void PPData::next_iteration() diff --git a/FHE/Plaintext.cpp b/FHE/Plaintext.cpp index b22b5b94..84cbb9d1 100644 --- a/FHE/Plaintext.cpp +++ b/FHE/Plaintext.cpp @@ -33,7 +33,7 @@ void Plaintext::from_poly() const e.change_rep(evaluation); a.resize(n_slots); for (unsigned int i=0; i::from_poly() const (*Field_Data).to_eval(aa); a.resize(n_slots); for (unsigned int i=0; iget_prD()}; type=Both; } diff --git a/FHE/Random_Coins.h b/FHE/Random_Coins.h index ad0d9fdc..2c0eedb5 100644 --- a/FHE/Random_Coins.h +++ b/FHE/Random_Coins.h @@ -10,7 +10,7 @@ class FHE_PK; #ifndef N_LIMBS_RAND -#define N_LIMBS_RAND 0 +#define N_LIMBS_RAND 1 #endif class Int_Random_Coins : public AddableMatrix> diff --git a/FHE/Ring_Element.h b/FHE/Ring_Element.h index f221d0d2..5cc93ca9 100644 --- a/FHE/Ring_Element.h +++ b/FHE/Ring_Element.h @@ -75,7 +75,7 @@ class Ring_Element FFTD = &prd; rep = r; for (auto& x : other) - element.push_back(x); + element.push_back({x, FFTD->get_prD()}); } /* Functional Operators */ diff --git a/GC/AtlasSecret.cpp b/GC/AtlasSecret.cpp new file mode 100644 index 00000000..13e6e6eb --- /dev/null +++ b/GC/AtlasSecret.cpp @@ -0,0 +1,33 @@ +/* + * AtlasSecret.cpp + * + */ + +#include "AtlasSecret.h" +#include "TinyMC.h" + +#include "Protocols/ShamirMC.hpp" +#include "Protocols/MAC_Check_Base.hpp" +#include "Secret.hpp" + +namespace GC +{ + +typename AtlasSecret::MC* AtlasSecret::new_mc(typename AtlasSecret::mac_key_type mac_key) +{ + return new MC(mac_key); +} + +AtlasShare::AtlasShare(const AtlasSecret& other) : + AtlasShare(other.get_bit(0)) +{ +} + +void AtlasShare::random() +{ + AtlasSecret tmp; + this->get_party().DataF.get_one(DATA_BIT, tmp); + *this = tmp.get_reg(0); +} + +} diff --git a/GC/AtlasSecret.h b/GC/AtlasSecret.h new file mode 100644 index 00000000..8f6372ed --- /dev/null +++ b/GC/AtlasSecret.h @@ -0,0 +1,47 @@ +/* + * AtlasSecret.h + * + */ + +#ifndef GC_ATLASSECRET_H_ +#define GC_ATLASSECRET_H_ + +#include "TinySecret.h" +#include "AtlasShare.h" + +namespace GC +{ + +class AtlasSecret : public VectorSecret +{ + typedef AtlasSecret This; + typedef VectorSecret super; + +public: + typedef TinyMC MC; + typedef MC MAC_Check; + typedef VectorProtocol Protocol; + typedef VectorInput Input; + typedef CcdPrep LivePrep; + + static string type_short() + { + return "atlas"; + } + + static MC* new_mc(typename super::mac_key_type); + + AtlasSecret() + { + } + + template + AtlasSecret(const T& other) : + super(other) + { + } +}; + +} + +#endif /* GC_ATLASSECRET_H_ */ diff --git a/GC/AtlasShare.h b/GC/AtlasShare.h new file mode 100644 index 00000000..bad9e10e --- /dev/null +++ b/GC/AtlasShare.h @@ -0,0 +1,72 @@ +/* + * AtlasShare.h + * + */ + +#ifndef GC_ATLASSHARE_H_ +#define GC_ATLASSHARE_H_ + +#include "Protocols/AtlasShare.h" +#include "Protocols/ShamirMC.h" +#include "Math/Bit.h" + +namespace GC +{ + +class AtlasSecret; + +class AtlasShare : public ::AtlasShare>, public ShareSecret +{ + typedef AtlasShare This; + +public: + typedef ::AtlasShare> super; + + typedef Atlas Protocol; + typedef ShamirMC MAC_Check; + typedef MAC_Check Direct_MC; + typedef ShamirInput Input; + typedef ReplicatedPrep LivePrep; + + typedef This small_type; + + typedef Bit clear; + + static MAC_Check* new_mc(mac_key_type) + { + return new MAC_Check; + } + + static This new_reg() + { + return {}; + } + + AtlasShare() + { + } + + template + AtlasShare(const U& other) : + super(other) + { + } + + AtlasShare(const AtlasSecret& other); + + void XOR(const This& a, const This& b) + { + *this = a + b; + } + + void public_input(bool input) + { + *this = input; + } + + void random(); +}; + +} + +#endif /* GC_ATLASSHARE_H_ */ diff --git a/GC/CcdPrep.h b/GC/CcdPrep.h index f3da07ca..00387375 100644 --- a/GC/CcdPrep.h +++ b/GC/CcdPrep.h @@ -61,7 +61,11 @@ public: { assert(part_proc); for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) - this->bits.push_back(part_prep.get_bit()); + { + typename T::part_type tmp; + part_prep.get_one_no_count(DATA_BIT, tmp); + this->bits.push_back(tmp); + } } void buffer_squares() diff --git a/GC/FakeSecret.h b/GC/FakeSecret.h index a23c303b..f650d4f6 100644 --- a/GC/FakeSecret.h +++ b/GC/FakeSecret.h @@ -70,6 +70,8 @@ public: static const true_type invertible; static const true_type characteristic_two; + static int threshold(int) { return 0; } + static MC* new_mc(mac_key_type key) { return new MC(key); } static void store_clear_in_dynamic(Memory& mem, diff --git a/GC/NoShare.h b/GC/NoShare.h index 9cea3fa0..e027e95c 100644 --- a/GC/NoShare.h +++ b/GC/NoShare.h @@ -76,6 +76,8 @@ public: bool operator!=(NoValue) const { fail(); return 0; } + bool operator==(int) { fail(); return false; } + bool get_bit(int) { fail(); return 0; } void randomize(PRNG&) { fail(); } @@ -173,6 +175,7 @@ public: void invert(int, NoShare) { fail(); } NoShare mask(int) const { fail(); return {}; } + void mask(NoShare, int) const { fail(); } void input(istream&, bool) { fail(); } void output(ostream&, bool) { fail(); } diff --git a/GC/Secret.h b/GC/Secret.h index 4e687d23..f97ad7b3 100644 --- a/GC/Secret.h +++ b/GC/Secret.h @@ -75,6 +75,8 @@ public: static const bool actual_inputs = T::actual_inputs; + static int threshold(int nplayers) { return T::threshold(nplayers); } + static Secret input(party_id_t from, const int128& input, int n_bits = -1); static Secret input(Processor>& processor, const InputArgs& args); void random(int n_bits, int128 share); diff --git a/GC/ShareSecret.h b/GC/ShareSecret.h index a2bb5958..498b620c 100644 --- a/GC/ShareSecret.h +++ b/GC/ShareSecret.h @@ -46,6 +46,11 @@ public: static const bool is_real = true; static const bool actual_inputs = true; + static ShareThread& get_party() + { + return ShareThread::s(); + } + static void store_clear_in_dynamic(Memory& mem, const vector& accesses); diff --git a/GC/ShareSecret.hpp b/GC/ShareSecret.hpp index c8b5a327..1a508828 100644 --- a/GC/ShareSecret.hpp +++ b/GC/ShareSecret.hpp @@ -243,9 +243,9 @@ void ShareSecret::reveal_inst(Processor& processor, assert(U::default_length == Clear::N_BITS); for (int j = 0; j < DIV_CEIL(n, U::default_length); j++) { - shares.push_back( - processor.S[r1 + j].mask( - min(U::default_length, n - j * U::default_length))); + shares.push_back({}); + processor.S[r1 + j].mask(shares.back(), + min(U::default_length, n - j * U::default_length)); } } assert(party.MC); diff --git a/GC/ShareThread.hpp b/GC/ShareThread.hpp index 0280fb50..c484e9e2 100644 --- a/GC/ShareThread.hpp +++ b/GC/ShareThread.hpp @@ -91,21 +91,21 @@ void ShareThread::and_(Processor& processor, auto& protocol = this->protocol; processor.check_args(args, 4); protocol->init_mul(DataF, *this->MC); + T x_ext, y_ext; for (size_t i = 0; i < args.size(); i += 4) { int n_bits = args[i]; int left = args[i + 2]; int right = args[i + 3]; - T y_ext; for (int j = 0; j < DIV_CEIL(n_bits, T::default_length); j++) { - if (repeat) - y_ext = processor.S[right].extend_bit(); - else - y_ext = processor.S[right + j]; int n = min(T::default_length, n_bits - j * T::default_length); - protocol->prepare_mul(processor.S[left + j].mask(n), - y_ext.mask(n), n); + if (repeat) + processor.S[right].extend_bit(y_ext, n); + else + processor.S[right + j].mask(y_ext, n); + processor.S[left + j].mask(x_ext, n); + protocol->prepare_mul(x_ext, y_ext, n); } } @@ -118,7 +118,9 @@ void ShareThread::and_(Processor& processor, for (int j = 0; j < DIV_CEIL(n_bits, T::default_length); j++) { int n = min(T::default_length, n_bits - j * T::default_length); - processor.S[out + j] = protocol->finalize_mul(n).mask(n); + auto& res = processor.S[out + j]; + protocol->finalize_mult(res, n); + res.mask(res, n); } } } diff --git a/GC/TinySecret.h b/GC/TinySecret.h index 9f48474d..632630f1 100644 --- a/GC/TinySecret.h +++ b/GC/TinySecret.h @@ -126,18 +126,20 @@ public: return *this * BitVec(other); } - This extend_bit() const + void extend_bit(This& res, int n_bits) const { - This res; - res.get_regs().resize(BitVec::N_BITS, this->get_reg(0)); - return res; + auto& regs = res.get_regs(); + regs.assign(n_bits, this->get_reg(0)); } - This mask(int n_bits) const + void mask(This& res, int n_bits) const { - This res = *this; - res.get_regs().resize(n_bits); - return res; + if (this != &res) + res.get_regs().assign(this->get_regs().begin(), + this->get_regs().begin() + + max(size_t(n_bits), this->get_regs().size())); + + res.resize_regs(n_bits); } T get_bit(int i) const @@ -172,7 +174,7 @@ public: template void finalize_input(U& inputter, int from, int n_bits) { - *this = inputter.finalize(from, n_bits).mask(n_bits); + inputter.finalize(from, n_bits).mask(*this, n_bits); } }; diff --git a/GC/TinyShare.h b/GC/TinyShare.h index 4fbb6092..0562e7f1 100644 --- a/GC/TinyShare.h +++ b/GC/TinyShare.h @@ -46,11 +46,6 @@ public: return "tiny share"; } - static ShareThread>& get_party() - { - return ShareThread>::s(); - } - static This new_reg() { return {}; @@ -75,7 +70,7 @@ public: void public_input(bool input) { - auto& party = get_party(); + auto& party = this->get_party(); *this = super::constant(input, party.P->my_num(), party.MC->get_alphai()); } @@ -83,7 +78,7 @@ public: void random() { TinySecret tmp; - get_party().DataF.get_one(DATA_BIT, tmp); + this->get_party().DataF.get_one(DATA_BIT, tmp); *this = tmp.get_reg(0); } }; diff --git a/GC/VectorProtocol.h b/GC/VectorProtocol.h index 1292c221..3f7e203c 100644 --- a/GC/VectorProtocol.h +++ b/GC/VectorProtocol.h @@ -25,6 +25,7 @@ public: void init_mul(Preprocessing& prep, typename T::MAC_Check& MC); typename T::clear prepare_mul(const T& x, const T& y, int n = -1); void exchange(); + void finalize_mult(T& res, int n = -1); T finalize_mul(int n = -1); typename T::part_type::Protocol& get_part() diff --git a/GC/VectorProtocol.hpp b/GC/VectorProtocol.hpp index 072cb71f..cae46181 100644 --- a/GC/VectorProtocol.hpp +++ b/GC/VectorProtocol.hpp @@ -50,10 +50,16 @@ template T VectorProtocol::finalize_mul(int n) { T res; + finalize_mult(res, n); + return res; +} + +template +void VectorProtocol::finalize_mult(T& res, int n) +{ res.resize_regs(n); for (int i = 0; i < n; i++) res.get_reg(i) = part_protocol.finalize_mul(1); - return res; } } /* namespace GC */ diff --git a/GC/instructions.h b/GC/instructions.h index 31dc0592..ae395e91 100644 --- a/GC/instructions.h +++ b/GC/instructions.h @@ -137,6 +137,9 @@ X(JOIN_TAPE, MACH->join_tape(R0)) \ X(USE, ) \ X(USE_INP, ) \ + X(NPLAYERS, I0 = Thread::s().P->num_players()) \ + X(THRESHOLD, I0 = T::threshold(Thread::s().P->num_players())) \ + X(PLAYERID, I0 = Thread::s().P->my_num()) \ #define INSTRUCTIONS BIT_INSTRUCTIONS GC_INSTRUCTIONS diff --git a/Machines/TripleMachine.cpp b/Machines/TripleMachine.cpp index f3020be8..5bb4a6e9 100644 --- a/Machines/TripleMachine.cpp +++ b/Machines/TripleMachine.cpp @@ -73,9 +73,9 @@ TripleMachine::TripleMachine(int argc, const char** argv) : opt.add( "", // Default. 0, // Required? - 0, // Number of args expected. + 1, // Number of args expected. 0, // Delimiter if expecting multiple args. - "GF(p) items", // Help description. + "GF(p) items for chosen prime", // Help description. "-P", // Flag token. "--prime-field" // Flag token. ); @@ -124,7 +124,12 @@ TripleMachine::TripleMachine(int argc, const char** argv) : correlation_check = opt.get("-c")->isSet; generateMACs = opt.get("-m")->isSet || check; amplify = opt.get("-a")->isSet || generateMACs; - primeField = opt.get("-P")->isSet; + if (opt.isSet("-P")) + { + string tmp; + opt.get("-P")->getString(tmp); + prime = tmp; + } bonding = opt.get("-b")->isSet; opt.get("-Z")->getInt(z2k); check |= z2k; @@ -133,8 +138,7 @@ TripleMachine::TripleMachine(int argc, const char** argv) : opt.get("-S")->getInt(z2s); // doesn't work with Montgomery multiplication - gfp1::init_default(gfp0::MAX_N_BITS, false); - gfp0::init_default(gfp0::MAX_N_BITS, true); + gfpvar1::init_field(prime, false); gf2n_long::init_field(128); gf2n_short::init_field(40); @@ -179,8 +183,8 @@ void TripleMachine::run() for (int i = 0; i < nthreads; i++) { - if (primeField) - generators[i] = new_generator>(setup, i, mac_keyp); + if (prime) + generators[i] = new_generator>(setup, i, mac_keyp); else if (z2k) { if (z2k == 32 and z2s == 32) diff --git a/Machines/atlas-party.cpp b/Machines/atlas-party.cpp new file mode 100644 index 00000000..6e754c7f --- /dev/null +++ b/Machines/atlas-party.cpp @@ -0,0 +1,16 @@ +/* + * atlas-party.cpp + * + */ + +#include "Protocols/AtlasShare.h" +#include "Protocols/AtlasPrep.h" +#include "GC/AtlasSecret.h" + +#include "ShamirMachine.hpp" +#include "Protocols/Atlas.hpp" + +int main(int argc, const char** argv) +{ + ShamirMachineSpec(argc, argv); +} diff --git a/Machines/ccd-party.cpp b/Machines/ccd-party.cpp index 9945f40b..433aaf26 100644 --- a/Machines/ccd-party.cpp +++ b/Machines/ccd-party.cpp @@ -17,8 +17,9 @@ int main(int argc, const char** argv) { - gf2n_short::init_field(40); + gf2n_::init_field(8); ez::ezOptionParser opt; ShamirOptions::singleton = {opt, argc, argv}; - GC::ShareParty>(argc, argv, opt); + assert(ShamirOptions::singleton.nparties < (1 << gf2n_::length())); + GC::ShareParty>>(argc, argv, opt); } diff --git a/Makefile b/Makefile index c44f6291..930897ac 100644 --- a/Makefile +++ b/Makefile @@ -28,7 +28,8 @@ LIB = libSPDZ.a LIBRELEASE = librelease.a ifeq ($(AVX_OT), 0) -LIBSIMPLEOT = ECDSA/P256Element.o +VM += ECDSA/P256Element.o +OT += ECDSA/P256Element.o else LIBSIMPLEOT = SimpleOT/libsimpleot.a endif @@ -42,6 +43,7 @@ DEPS := $(wildcard */*.d */*/*.d) all: arithmetic binary gen_input online offline externalIO bmr ecdsa doc +vm: arithmetic binary .PHONY: doc doc: @@ -50,10 +52,8 @@ doc: arithmetic: rep-ring rep-field shamir semi2k-party.x semi-party.x mascot sy binary: rep-bin yao semi-bin-party.x tinier-party.x tiny-party.x ccd-party.x malicious-ccd-party.x real-bmr -ifeq ($(USE_NTL),1) all: overdrive she-offline arithmetic: hemi-party.x soho-party.x gear -endif -include $(DEPS) include $(wildcard *.d static/*.d) @@ -105,7 +105,7 @@ ifeq ($(MACHINE), aarch64) tldr: simde/simde endif -shamir: shamir-party.x malicious-shamir-party.x galois-degree.x +shamir: shamir-party.x malicious-shamir-party.x atlas-party.x galois-degree.x sy: sy-rep-field-party.x sy-rep-ring-party.x sy-shamir-party.x @@ -179,12 +179,6 @@ secure.x: Utils/secure.o %gear-party.x: Machines/%gear-party.o $(VM) OT/OTTripleSetup.o OT/BaseOT.o $(LIBSIMPLEOT) $(CXX) -o $@ $(CFLAGS) $^ $(LDLIBS) -lntl -hemi-party.x: Machines/hemi-party.o $(VM) - $(CXX) -o $@ $(CFLAGS) $^ $(LDLIBS) -lntl - -soho-party.x: Machines/soho-party.o $(VM) - $(CXX) -o $@ $(CFLAGS) $^ $(LDLIBS) -lntl - %-ecdsa-party.x: ECDSA/%-ecdsa-party.o ECDSA/P256Element.o $(VM) $(CXX) -o $@ $(CFLAGS) $^ $(LDLIBS) @@ -207,6 +201,7 @@ cowgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(OT) chaigear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(OT) lowgear-party.x: $(FHEOFFLINE) $(OT) Protocols/CowGearOptions.o Protocols/LowGearKeyGen.o highgear-party.x: $(FHEOFFLINE) $(OT) Protocols/CowGearOptions.o Protocols/HighGearKeyGen.o +atlas-party.x: GC/AtlasSecret.o static/hemi-party.x: $(FHEOFFLINE) static/soho-party.x: $(FHEOFFLINE) static/cowgear-party.x: $(FHEOFFLINE) diff --git a/Math/Bit.cpp b/Math/Bit.cpp index a0e90a32..6d47b9be 100644 --- a/Math/Bit.cpp +++ b/Math/Bit.cpp @@ -6,8 +6,12 @@ #include "Bit.h" #include "gf2n.h" -Bit::Bit(const gf2n_short& other) : +template +Bit::Bit(const gf2n_& other) : super(other.get_bit(0)) { assert(other.is_one() or other.is_zero()); } + +template Bit::Bit(const gf2n_& other); +template Bit::Bit(const gf2n_& other); diff --git a/Math/Bit.h b/Math/Bit.h index c62b26f2..10c4e018 100644 --- a/Math/Bit.h +++ b/Math/Bit.h @@ -8,7 +8,7 @@ #include "BitVec.h" -class gf2n_short; +template class gf2n_; class Bit : public BitVec_ { @@ -37,7 +37,8 @@ public: throw runtime_error("never call this"); } - Bit(const gf2n_short& other); + template + Bit(const gf2n_& other); Bit operator*(const Bit& other) const { diff --git a/Math/BitVec.h b/Math/BitVec.h index c7f35fca..fd25e134 100644 --- a/Math/BitVec.h +++ b/Math/BitVec.h @@ -54,6 +54,9 @@ public: BitVec_ extend_bit() const { return -(this->a & 1); } BitVec_ mask(int n) const { return n < n_bits ? *this & ((1L << n) - 1) : *this; } + void extend_bit(BitVec_& res, int) const { res = extend_bit(); } + void mask(BitVec_& res, int n) const { res = mask(n); } + void add(octetStream& os) { *this += os.get(); } void mul(const BitVec_& a, const BitVec_& b) { *this = a * b; } diff --git a/Math/FixedVec.h b/Math/FixedVec.h index e2dfb38a..de936066 100644 --- a/Math/FixedVec.h +++ b/Math/FixedVec.h @@ -22,6 +22,7 @@ template class Replicated; template class FixedVec { + typedef FixedVec This; array v; public: @@ -302,19 +303,29 @@ public: return res; } - FixedVec extend_bit() const + void extend_bit(This& res, int n_bits) const { - FixedVec res; for (int i = 0; i < L; i++) - res[i] = v[i].extend_bit(); + v[i].extend_bit(res[i], n_bits); + } + + void mask(This& res, int n_bits) const + { + for (int i = 0; i < L; i++) + v[i].mask(res[i], n_bits); + } + + This extend_bit() const + { + This res; + extend_bit(res, T::N_BITS); return res; } - FixedVec mask(int n_bits) const + This mask(int n_bits) const { - FixedVec res; - for (int i = 0; i < L; i++) - res[i] = v[i].mask(n_bits); + This res; + mask(res, n_bits); return res; } diff --git a/Math/Integer.h b/Math/Integer.h index 006d20e1..59395fff 100644 --- a/Math/Integer.h +++ b/Math/Integer.h @@ -15,6 +15,7 @@ using namespace std; #include "field_types.h" #include "Z2k.h" #include "ValueInterface.h" +#include "gf2nlong.h" // Functionality shared between integers and bit vectors @@ -124,6 +125,7 @@ class Integer : public IntBase Integer(const Z2& x) : Integer(x.get_limb(0)) {} template Integer(const gfp_& x); + Integer(int128 x) : Integer(x.get_lower()) {} Integer(const Integer& x, int n_bits); @@ -186,7 +188,7 @@ Integer Integer::convert_unsigned(const gfp_& other) template Integer Integer::convert_unsigned(const Z2& other) { - return bigint::tmp = other; + return other; } // slight misnomer diff --git a/Math/Square.h b/Math/Square.h index 5804aec9..86dee80b 100644 --- a/Math/Square.h +++ b/Math/Square.h @@ -11,15 +11,18 @@ template class Square { +protected: + static const int N_ROWS = U::MAX_N_BITS; + public: typedef U RowType; - static const int N_ROWS = U::MAX_N_BITS; - static const int N_ROWS_ALLOCATED = N_ROWS; - static const int N_COLUMNS = N_ROWS; - static const int N_ROW_BYTES = N_ROWS / 8; + static int n_rows() { return U::size_in_bits(); } + static int n_rows_allocated() { return n_rows(); } + static int n_columns() { return n_rows(); } + static int n_row_bytes() { return U::size(); } - static size_t size() { return N_ROWS * U::size(); } + static size_t size() { return U::length() * U::size(); } U rows[N_ROWS]; @@ -32,6 +35,7 @@ public: int offset); void to(U& result); void to(U& result, false_type); + void to(U& result, true_type); template void to(gfp_& result, true_type); diff --git a/Math/Square.hpp b/Math/Square.hpp index 7541f66d..e62b09a7 100644 --- a/Math/Square.hpp +++ b/Math/Square.hpp @@ -36,7 +36,7 @@ void Square::conditional_add(BitVector& conditions, Square& other, int offset) { for (int i = 0; i < U::length(); i++) - if (conditions.get_bit(N_ROWS * offset + i)) + if (conditions.get_bit(n_rows() * offset + i)) rows[i] += other.rows[i]; } @@ -61,22 +61,22 @@ void Square::to(U& result) } template -template -void Square::to(gfp_& result, true_type) +void Square::to(U& result, true_type) { + int L = U::get_ZpD().get_t(); mp_limb_t product[2 * L], sum[2 * L], tmp[L][2 * L]; memset(tmp, 0, sizeof(tmp)); memset(sum, 0, sizeof(sum)); - for (int i = 0; i < gfp_::length(); i++) + for (int i = 0; i < U::length(); i++) { - memcpy(&(tmp[i/64][i/64]), &(rows[i]), sizeof(rows[i])); + memcpy(&(tmp[i/64][i/64]), &(rows[i]), U::size()); if (i % 64 == 0) memcpy(product, tmp[i/64], sizeof(product)); else mpn_lshift(product, tmp[i/64], 2 * L, i % 64); - mpn_add_fixed_n<2 * L>(sum, product, sum); + mpn_add_n(sum, product, sum, 2 * L); } mp_limb_t q[2 * L], ans[2 * L]; - mpn_tdiv_qr(q, ans, 0, sum, 2 * L, gfp_::get_ZpD().get_prA(), L); + mpn_tdiv_qr(q, ans, 0, sum, 2 * L, U::get_ZpD().get_prA(), L); result.assign((void*) ans); } diff --git a/Math/ValueInterface.h b/Math/ValueInterface.h index d56a94f8..d15af24c 100644 --- a/Math/ValueInterface.h +++ b/Math/ValueInterface.h @@ -24,7 +24,7 @@ public: template static void init(bool mont = true) { (void) mont; } static void init_default(int, bool = true) {} - static void init_field(const bigint& = {}) {} + static void init_field(const bigint& = {}, bool = true) {} static void read_or_generate_setup(const string&, const OnlineOptions&) {} template diff --git a/Math/Z2k.h b/Math/Z2k.h index f70ef0ad..3e653044 100644 --- a/Math/Z2k.h +++ b/Math/Z2k.h @@ -39,6 +39,7 @@ public: static const int N_BITS = K; static const int MAX_EDABITS = K; + static const int MAX_N_BITS = K; static const int N_BYTES = (K + 7) / 8; static const mp_limb_t UPPER_MASK = mp_limb_t(-1LL) >> (N_LIMB_BITS - 1 - (K - 1) % N_LIMB_BITS); @@ -97,6 +98,8 @@ public: void convert_destroy(bigint& a) { *this = a; } + int bit_length() const; + Z2 operator+(const Z2& other) const; Z2 operator-(const Z2& other) const; diff --git a/Math/Z2k.hpp b/Math/Z2k.hpp index 5df91806..6734d2cc 100644 --- a/Math/Z2k.hpp +++ b/Math/Z2k.hpp @@ -9,6 +9,8 @@ #include #include "Math/Integer.h" +#include + template const int Z2::N_BITS; template @@ -67,6 +69,18 @@ bool Z2::get_bit(int i) const return 1 & (a[i / N_LIMB_BITS] >> (i % N_LIMB_BITS)); } +template +int Z2::bit_length() const +{ + if (is_zero()) + return 1; + size_t max_limb = 0; + for (int i = 1; i < N_WORDS; i++) + if (a[i] != 0) + max_limb = i; + return log2(mp_limb_t(a[max_limb])) + 1 + 64 * max_limb; +} + template Z2 Z2::operator&(const Z2& other) const { diff --git a/Math/Zp_Data.cpp b/Math/Zp_Data.cpp index bb95dc05..63c279a2 100644 --- a/Math/Zp_Data.cpp +++ b/Math/Zp_Data.cpp @@ -148,3 +148,8 @@ bool Zp_Data::operator!=(const Zp_Data& other) const else return false; } + +bool Zp_Data::operator==(const Zp_Data& other) const +{ + return not (*this != other); +} diff --git a/Math/Zp_Data.h b/Math/Zp_Data.h index 6dca5f02..335ac9d7 100644 --- a/Math/Zp_Data.h +++ b/Math/Zp_Data.h @@ -84,6 +84,7 @@ class Zp_Data void Sub(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) const; bool operator!=(const Zp_Data& other) const; + bool operator==(const Zp_Data& other) const; template friend void to_modp(modp_& ans,int x,const Zp_Data& ZpD); template friend void to_modp(modp_& ans,const mpz_class& x,const Zp_Data& ZpD); @@ -169,12 +170,9 @@ inline void Zp_Data::Add(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) c { switch (t) { - case 4: - return Add<4>(ans, x, y); - case 2: - return Add<2>(ans, x, y); - case 1: - return Add<1>(ans, x, y); +#define X(L) case L: Add(ans, x, y); break; + X(1) X(2) X(3) X(4) X(5) +#undef X default: return Add<0>(ans, x, y); } @@ -203,14 +201,9 @@ inline void Zp_Data::Sub(mp_limb_t* ans,const mp_limb_t* x,const mp_limb_t* y) c { switch (t) { - /* - case 2: - Sub<2>(ans, x, y); - break; - case 1: - Sub<1>(ans, x, y); - break; - */ +#define X(L) case L: Sub(ans, x, y); break; + X(1) X(2) X(3) X(4) X(5) +#undef X default: Sub<0>(ans, x, y); break; diff --git a/Math/bigint.cpp b/Math/bigint.cpp index 1107c969..1952859d 100644 --- a/Math/bigint.cpp +++ b/Math/bigint.cpp @@ -148,11 +148,6 @@ bigint::bigint(const GC::Clear& x) : bigint(SignedZ2<64>(x)) { } -bigint::bigint(const gfpvar& other) -{ - to_bigint(*this, other.get(), other.get_ZpD()); -} - bigint::bigint(const mp_limb_t* data, size_t n_limbs) { mpz_import(get_mpz_t(), n_limbs, -1, 8, -1, 0, data); diff --git a/Math/bigint.h b/Math/bigint.h index e62b39f3..5cd31981 100644 --- a/Math/bigint.h +++ b/Math/bigint.h @@ -23,7 +23,8 @@ enum ReportType template class gfp_; -class gfpvar; +template +class gfpvar_; class gmp_random; class Integer; template class Z2; @@ -54,7 +55,8 @@ public: bigint(const T& x) : mpz_class(x) {} template bigint(const gfp_& x); - bigint(const gfpvar& x); + template + bigint(const gfpvar_& x); template bigint(const Z2& x); template @@ -189,6 +191,12 @@ bigint::bigint(const gfp_& x) *this = x; } +template +bigint::bigint(const gfpvar_& other) +{ + to_bigint(*this, other.get(), other.get_ZpD()); +} + template bigint& bigint::operator=(const gfp_& x) { diff --git a/Math/fixint.h b/Math/fixint.h index c10d3c26..33a8d80b 100644 --- a/Math/fixint.h +++ b/Math/fixint.h @@ -72,9 +72,9 @@ public: int n_bits = this->size_in_bits(); if (numBits(limit) - N_OVERFLOW > n_bits) { - cerr << "maybe change N_LIMBS_RAND to at least " - << ((numBits(limit) - N_OVERFLOW) / 64) << endl; - throw runtime_error("fixed-length integer too small"); + throw runtime_error("Fixed-length integer too small. " + "Maybe change N_LIMBS_RAND to at least " + + to_string((numBits(limit) - N_OVERFLOW) / 64)); } } }; diff --git a/Math/gf2n.cpp b/Math/gf2n.cpp index 987b8a4a..797d60db 100644 --- a/Math/gf2n.cpp +++ b/Math/gf2n.cpp @@ -11,37 +11,25 @@ const false_type ValueInterface::characteristic_two; const false_type ValueInterface::prime_field; const false_type ValueInterface::invertible; -const true_type gf2n_short::characteristic_two; -const true_type gf2n_long::characteristic_two; - -const true_type gf2n_short::invertible; -const true_type gf2n_long::invertible; - -int gf2n_short::n = 0; -int gf2n_short::t1; -int gf2n_short::t2; -int gf2n_short::t3; -int gf2n_short::l0; -int gf2n_short::l1; -int gf2n_short::l2; -int gf2n_short::l3; -int gf2n_short::nterms; -word gf2n_short::mask; -bool gf2n_short::useC; +template +int gf2n_::l[4]; +template +bool gf2n_::useC; word gf2n_short_table[256][256]; -#define num_2_fields 4 +#define num_2_fields 6 /* Require * 2*(n-1)-64+t1<64 */ int fields_2[num_2_fields][4] = { - {4,1,0,0},{8,4,3,1},{28,1,0,0},{40,20,15,10} + {4,1,0,0},{8,4,3,1},{28,1,0,0},{40,20,15,10},{63,1,0,0},{128,7,2,1}, }; -void gf2n_short::init_tables() +template +void gf2n_::init_tables() { if (sizeof(word)!=8) { cout << "Word size is wrong" << endl; @@ -61,12 +49,17 @@ void gf2n_short::init_tables() } } - void gf2n_short::init_field(int nn) +{ + super::init_field(nn == 0 ? DEFAULT_LENGTH : nn); +} + +template +void gf2n_::init_field(int nn) { if (nn == 0) { - nn = DEFAULT_LENGTH; + nn = MAX_N_BITS; #ifdef VERBOSE cerr << "Using GF(2^" << nn << ")" << endl; #endif @@ -77,34 +70,38 @@ void gf2n_short::init_field(int nn) assert(n == 0); - gf2n_short::init_tables(); + init_tables(); int i,j=-1; for (i=0; i MAX_N_BITS) + throw runtime_error("Bit length not supported.\n" + "You might need to compile with USE_GF2N_LONG = 1.\n" + "Remember to run 'make clean'."); + if (j==-1) { - if (nn == 128) - throw runtime_error("need to compile with USE_GF2N_LONG = 1; " - "remember to make clean"); - else - throw runtime_error("field size not supported"); + throw runtime_error("field size not supported"); } n=nn; - nterms=1; - l0=64-n; - t1=fields_2[j][1]; - l1=64+t1-n; - if (fields_2[j][2]!=0) - { nterms=3; - t2=fields_2[j][2]; - l2=64+t2-n; - t3=fields_2[j][3]; - l3=64+t3-n; + l[0] = MAX_N_BITS - n; + for (int i = 1; i < 4; i++) + { + if (fields_2[j][i] == 0) + break; + nterms = i; + t[i] = fields_2[j][i]; + l[i] = MAX_N_BITS + t[i] - n; } - if (2*(n-1)-64+t1>=64) { throw invalid_params(); } + assert(nterms > 0); - mask=(1ULL< +void gf2n_::init_multiplication() +{ + if (n <= 8) + { + word red = 1; + for (int i = 1; i <= nterms; i++) + red ^= (1 << t[i]); + memset(mult_table, 0, sizeof(mult_table)); + for (int i = 1; i < 1 << n; i++) + { + for (int j = 1; j <= i; j++) + { + word tmp = mult_table[i / 2][j]; + tmp <<= 1; + if (i & 1) + tmp ^= j; + if (tmp >> n) + tmp ^= red; + tmp &= Integer(mask).get(); + mult_table[i][j] = tmp; + mult_table[j][i] = tmp; + } + } + } +} + + +/* Takes 8bit x and y and returns the 16 bit product in c1 and c0 + ans = (c1<<8)^c0 + where c1 and c0 are 8 bit +*/ +void mul(octet x, octet y, octet& c0, octet& c1) +{ + auto full = gf2n_short_table[octet(x)][octet(y)]; + c0 = full; + c1 = full >> 8; +} + /* Takes 16bit x and y and returns the 32 bit product in c1 and c0 ans = (c1<<16)^c0 where c1 and c0 are 16 bit @@ -147,44 +182,47 @@ inline word mul16(word x,word y) -void gf2n_short::reduce_trinomial(word xh,word xl) +template +void gf2n_::reduce(U xh, U xl) { - // Deal with xh first - a=xl; - a^=(xh<>n; - while (hi!=0) - { a&=mask; + if (2 * (n - 1) - MAX_N_BITS + t[1] < MAX_N_BITS) + { + // Deal with xh first + a = xl; + for (int i = 0; i < nterms + 1; i++) + a ^= (xh << l[i]); - a^=hi; - a^=(hi<>n; + // Now deal with last word + U hi = a >> n; + while (hi != 0) + { + a &= mask; + + a ^= hi; + for (int i = 1; i < nterms + 1; i++) + a ^= (hi << t[i]); + + hi = a >> n; + } } -} - -void gf2n_short::reduce_pentanomial(word xh,word xl) -{ - // Deal with xh first - a=xl; - a^=(xh<>n; - while (hi!=0) - { a&=mask; - - a^=hi; - a^=(hi<>n; + else + { + a = xl; + U upper, lower; + upper = xh & uppermask; + lower = xh & lowermask; + // Upper part + U tmp = 0; + for (int i = 0; i < nterms + 1; i++) + tmp ^= (upper >> (n - t[1] - l[i])); + lower ^= (tmp >> (l[1])); + a ^= (tmp << (n - l[1])); + // Lower part + for (int i = 0; i < nterms + 1; i++) + a ^= (lower << l[i]); } } @@ -209,7 +247,7 @@ void mul32(word x,word y,word& ans) } -void mul64(word x, word y, word& lo, word& hi) +void mul(word x, word y, word& lo, word& hi) { word c,d,e,t; word xl=x&0xFFFFFFFF,yl=y&0xFFFFFFFF; @@ -223,52 +261,90 @@ void mul64(word x, word y, word& lo, word& hi) } -void gf2n_short::mul(const gf2n_short& x,const gf2n_short& y) +word to_word(word x) { - word hi,lo; - - if (gf2n_short::useC) - { /* Uses Karatsuba */ - mul64(x.a, y.a, lo, hi); + return x; +} + +word to_word(int128 x) +{ + return x.get_lower(); +} + +template +gf2n_& gf2n_::mul(const gf2n_& x,const gf2n_& y) +{ + U hi,lo; + + if (n <= 8) + { + *this = mult_table[octet(to_word(x.a))][octet(to_word(y.a))]; + return *this; + } + else if (useC or n > 64) + { + ::mul(x.a, y.a, lo, hi); } else - { /* Use Intel Instructions */ -#ifdef __PCLMUL__ - __m128i xx,yy,zz; - uint64_t c[] __attribute__((aligned (16))) = { 0,0 }; - xx=_mm_set1_epi64x(x.a); - yy=_mm_set1_epi64x(y.a); - zz=_mm_clmulepi64_si128(xx,yy,0); - _mm_store_si128((__m128i*)c,zz); - lo=c[0]; - hi=c[1]; -#else - throw runtime_error("need to compile with PCLMUL support"); -#endif + { + int128 res = clmul<0>(int128(x.a).a, int128(y.a).a); + + if (MAX_N_BITS <= 64) + { + hi = res.get_upper(); + lo = res.get_lower(); + } + else + { + res.to(lo); + hi = 0; + } } reduce(hi,lo); + return *this; } -gf2n_short gf2n_short::operator*(const Bit& x) const +template +gf2n_ gf2n_::operator*(const Bit& x) const { - return x.get() * a; + return x.get() ? a : 0; } -gf2n_short gf2n_short::invert() const +template +gf2n_ gf2n_::invert() const { - if (is_one()) { return *this; } + if (n < 64) + return U(invert(a)); + else + return invert>(a).get_lower(); +} + +template<> +gf2n_ gf2n_::invert() const +{ + if (n < 64) + return int128(invert(a.get_lower())); + if (n < 128) + return invert(a); + else + return invert>(a).get_lower(); +} + +template +template +T gf2n_::invert(T a) const +{ + if (is_one()) { return a; } if (is_zero()) { throw division_by_zero(); } - word u,v=a,B=0,D=1,mod=1; + T u,v=a,B=0,D=1,mod=1; + + mod ^= (T(1) << n); + for (int i = 1; i <= nterms; i++) + mod ^= (1ULL << t[i]); - mod^=(1ULL< +gf2n_ gf2n_::operator <<(int i) const +{ + if (i < 0) + throw runtime_error("cannot shift by negative"); + else if (i >= n) + return 0; + else + return a << i; +} + +template +gf2n_ gf2n_::operator >>(int i) const +{ + if (i < 0) + throw runtime_error("cannot shift by negative"); + else if (i >= n) + return 0; + else + return a >> i; +} + + +template +void gf2n_::randomize(PRNG& G, int n) { (void) n; - a=G.get_uint(); - a=(a<<32)^G.get_uint(); + a=G.get(); a&=mask; } - -void gf2n_short::output(ostream& s,bool human) const +template<> +void gf2n_::output(ostream& s,bool human) const { if (human) - { s << hex << showbase << a << dec << " "; } + s << hex << showbase << word(a) << dec; else - { s.write((char*) &a,sizeof(word)); } + s.write((char*) &a, sizeof(octet)); } -void gf2n_short::input(istream& s,bool human) +template +void gf2n_::output(ostream& s,bool human) const +{ + if (human) + { s << hex << showbase << a << dec; } + else + { s.write((char*) &a, (sizeof(U))); } +} + +template +void gf2n_::input(istream& s,bool human) { if (s.peek() == EOF) { if (s.tellg() == 0) { cout << "IO problem. Empty file?" << endl; - throw file_error("gf2n_short input"); + throw file_error("gf2n input"); } - throw end_of_file("gf2n_short"); + throw end_of_file("gf2n"); } if (human) { s >> hex >> a >> dec; } else - { s.read((char*) &a,sizeof(word)); } + { s.read((char*) &a, sizeof(U)); } a &= mask; } +gf2n_short gf2n_short::cut(int128 x) +{ + return x.get_lower(); +} + +gf2n_short::gf2n_short(const int128& a) +{ + reduce(a.get_upper(), a.get_lower()); +} + + // Expansion is by x=y^5+1 (as we embed GF(256) into GF(2^40) void expand_byte(gf2n_short& a,int b) { @@ -371,3 +491,7 @@ void collapse_byte(int& b,const gf2n_short& aa) b+=a[i]; } } + +template class gf2n_; +template class gf2n_; +template class gf2n_ ; diff --git a/Math/gf2n.h b/Math/gf2n.h index 8eba20ae..13c3e3e7 100644 --- a/Math/gf2n.h +++ b/Math/gf2n.h @@ -9,12 +9,14 @@ using namespace std; #include "Tools/random.h" -#include "Math/gf2nlong.h" #include "Math/field_types.h" +#include "Math/bigint.h" +#include "Math/ValueInterface.h" class gf2n_short; class P2Data; class Bit; +class int128; template class Square; typedef Square gf2n_short_square; @@ -31,88 +33,77 @@ void collapse_byte(int& b,const gf2n_short& a); Arithmetic in Gf_{2^n} with n<64 */ -class gf2n_short : public ValueInterface +template +class gf2n_ : public ValueInterface { +protected: friend class gf2n_long; - word a; + U a; - static int n,t1,t2,t3,nterms; - static int l0,l1,l2,l3; - static word mask; + static int n, nterms; + static int t[4], l[4]; + static U mask; + static U uppermask, lowermask; static bool useC; - /* Assign x[0..2*nwords] to a and reduce it... */ - void reduce_trinomial(word xh,word xl); - void reduce_pentanomial(word xh,word xl); - - void reduce(word xh,word xl) - { if (nterms==3) - { reduce_pentanomial(xh,xl); } - else - { reduce_trinomial(xh,xl); } - } + static octet mult_table[256][256]; static void init_tables(); + static void init_multiplication(); + + template + T invert(T a) const; public: - typedef gf2n_short value_type; - typedef word internal_type; - typedef gf2n_short next; - typedef ::Square Square; - typedef P2Data FD; - typedef gf2n_short Scalar; + typedef U internal_type; + typedef gf2n_ Scalar; - static const int MAX_N_BITS = 64; static const int N_BYTES = sizeof(a); - static const int DEFAULT_LENGTH = 40; + static const int MAX_N_BITS = 8 * N_BYTES; static void init_field(int nn = 0); + static void init_default(int, bool = false) { init_field(); } + static void reset() { n = 0; } static int degree() { return n; } - static int default_degree() { return 40; } static int get_nterms() { return nterms; } - static int get_t(int i) - { if (i==0) { return t1; } - else if (i==1) { return t2; } - else if (i==2) { return t3; } - return -1; - } + static int get_t(int i) { return (i < 3) ? t[i + 1] : -1; } static DataFieldType field_type() { return DATA_GF2N; } static char type_char() { return '2'; } static string type_short() { return "2"; } - static string type_string() { return "gf2n"; } + static string type_string() { return "gf2n_"; } static int size() { return sizeof(a); } static int size_in_bits() { return sizeof(a) * 8; } - static int length() { return n == 0 ? DEFAULT_LENGTH : n; } + static int length() { return n == 0 ? MAX_N_BITS : n; } static bool allows(Dtype type) { (void) type; return true; } static const true_type invertible; static const true_type characteristic_two; - static gf2n_short cut(int128 x) { return x.get_lower(); } + static gf2n_ Mul(gf2n_ a, gf2n_ b) { return a * b; } - static gf2n_short Mul(gf2n_short a, gf2n_short b) { return a * b; } - - word get() const { return a; } - word get_word() const { return a; } + U get() const { return a; } const void* get_ptr() const { return &a; } void assign_zero() { a=0; } void assign_one() { a=1; } void assign_x() { a=2; } - void assign(const void* aa) { a = *(word*) aa & mask; } + void assign(const void* aa) { memcpy(&a, aa, sizeof(a)); } void normalize() { a &= mask; } + /* Assign x[0..2*nwords] to a and reduce it... */ + void reduce(U xh, U xl); + int get_bit(int i) const - { return (a>>i)&1; } + { return ((a>>i)&1) != 0; } void set_bit(int i,unsigned int b) { if (b==1) { a |= (1UL< - gf2n_short(IntBase a) : a(a.get()) {} + gf2n_(IntBase a) : a(a.get()) {} int is_zero() const { return (a==0); } int is_one() const { return (a==1); } - bool operator==(const gf2n_short& y) const { return a==y.a; } - bool operator!=(const gf2n_short& y) const { return a!=y.a; } + bool operator==(const gf2n_& y) const { return a==y.a; } + bool operator!=(const gf2n_& y) const { return a!=y.a; } // x+y - void add(const gf2n_short& x,const gf2n_short& y) + void add(const gf2n_& x,const gf2n_& y) { a=x.a^y.a; } void add(octet* x) - { a^=*(word*)(x); } + { a^=*(U*)(x); } void add(octetStream& os) { add(os.consume(size())); } - void sub(const gf2n_short& x,const gf2n_short& y) + void sub(const gf2n_& x,const gf2n_& y) { a=x.a^y.a; } // = x * y - void mul(const gf2n_short& x,const gf2n_short& y); + gf2n_& mul(const gf2n_& x,const gf2n_& y); - gf2n_short lazy_add(const gf2n_short& x) const { return *this + x; } - gf2n_short lazy_mul(const gf2n_short& x) const { return *this * x; } + gf2n_ lazy_add(const gf2n_& x) const { return *this + x; } + gf2n_ lazy_mul(const gf2n_& x) const { return *this * x; } - gf2n_short operator+(const gf2n_short& x) const { gf2n_short res; res.add(*this, x); return res; } - gf2n_short operator*(const gf2n_short& x) const { gf2n_short res; res.mul(*this, x); return res; } - gf2n_short& operator+=(const gf2n_short& x) { add(*this, x); return *this; } - gf2n_short& operator*=(const gf2n_short& x) { mul(*this, x); return *this; } - gf2n_short operator-(const gf2n_short& x) const { gf2n_short res; res.add(*this, x); return res; } - gf2n_short& operator-=(const gf2n_short& x) { sub(*this, x); return *this; } - gf2n_short operator/(const gf2n_short& x) const { return *this * x.invert(); } + gf2n_ operator+(const gf2n_& x) const { gf2n_ res; res.add(*this, x); return res; } + gf2n_ operator*(const gf2n_& x) const { gf2n_ res; res.mul(*this, x); return res; } + gf2n_& operator+=(const gf2n_& x) { add(*this, x); return *this; } + gf2n_& operator*=(const gf2n_& x) { mul(*this, x); return *this; } + gf2n_ operator-(const gf2n_& x) const { gf2n_ res; res.add(*this, x); return res; } + gf2n_& operator-=(const gf2n_& x) { sub(*this, x); return *this; } + gf2n_ operator/(const gf2n_& x) const { return *this * x.invert(); } - gf2n_short operator*(const Bit& x) const; + gf2n_ operator*(const Bit& x) const; + gf2n_ operator*(int x) const { return *this * gf2n_(x); } - gf2n_short invert() const; + gf2n_ invert() const; void negate() { return; } /* Bitwise Ops */ - gf2n_short operator&(const gf2n_short& x) const { return a & x.a; } - gf2n_short operator^(const gf2n_short& x) const { return a ^ x.a; } - gf2n_short operator|(const gf2n_short& x) const { return a | x.a; } - gf2n_short operator~() const { return ~a; } - gf2n_short operator<<(int i) const { return a << i; } - gf2n_short operator>>(int i) const { return a >> i; } + gf2n_ operator&(const gf2n_& x) const { return a & x.a; } + gf2n_ operator^(const gf2n_& x) const { return a ^ x.a; } + gf2n_ operator|(const gf2n_& x) const { return a | x.a; } + gf2n_ operator~() const { return ~a; } + gf2n_ operator<<(int i) const; + gf2n_ operator>>(int i) const; - gf2n_short& operator&=(const gf2n_short& x) { *this = *this & x; return *this; } - gf2n_short& operator>>=(int i) { *this = *this >> i; return *this; } - gf2n_short& operator<<=(int i) { *this = *this << i; return *this; } + gf2n_& operator&=(const gf2n_& x) { *this = *this & x; return *this; } + gf2n_& operator^=(const gf2n_& x) { *this = *this ^ x; return *this; } + gf2n_& operator>>=(int i) { *this = *this >> i; return *this; } + gf2n_& operator<<=(int i) { *this = *this << i; return *this; } /* Crap RNG */ void randomize(PRNG& G, int n = -1); @@ -183,12 +175,16 @@ class gf2n_short : public ValueInterface void output(ostream& s,bool human) const; void input(istream& s,bool human); - friend ostream& operator<<(ostream& s,const gf2n_short& x) - { s << hex << showbase << x.a << dec; + friend ostream& operator<<(ostream& s,const gf2n_& x) + { + x.output(s, true); return s; } - friend istream& operator>>(istream& s,gf2n_short& x) - { s >> hex >> x.a >> dec; + friend istream& operator>>(istream& s,gf2n_& x) + { + word tmp; + s >> hex >> tmp >> dec; + x = tmp; return s; } @@ -196,15 +192,75 @@ class gf2n_short : public ValueInterface // Pack and unpack in native format // i.e. Dont care about conversion to human readable form void pack(octetStream& o, int n = -1) const - { (void) n; o.append((octet*) &a,sizeof(word)); } + { (void) n; o.append((octet*) &a,sizeof(U)); } void unpack(octetStream& o, int n = -1) - { (void) n; o.consume((octet*) &a,sizeof(word)); } + { (void) n; o.consume((octet*) &a,sizeof(U)); } }; +class gf2n_short : public gf2n_ +{ + typedef gf2n_ super; + +public: + typedef gf2n_short value_type; + typedef gf2n_short next; + typedef ::Square Square; + typedef P2Data FD; + typedef gf2n_short Scalar; + + static const int DEFAULT_LENGTH = 40; + + static int length() { return n == 0 ? DEFAULT_LENGTH : n; } + static int default_degree() { return 40; } + + static void init_field(int nn = 0); + + static gf2n_short cut(int128 x); + + gf2n_short() {} + template + gf2n_short(const T& other) : super(other) {} + gf2n_short(const int128& a); + + word get_word() const { return a; } +}; + +#include "gf2nlong.h" + +class gf2n_long; + #ifdef USE_GF2N_LONG typedef gf2n_long gf2n; #else typedef gf2n_short gf2n; #endif +template +const true_type gf2n_::characteristic_two; +template +const true_type gf2n_::invertible; + +template +int gf2n_::n = 0; +template +U gf2n_::mask; +template +int gf2n_::nterms; +template +int gf2n_::t[4]; + +template +U gf2n_::uppermask; +template +U gf2n_::lowermask; + +template +octet gf2n_::mult_table[256][256]; + +template<> +inline gf2n_& gf2n_::mul(const gf2n_& x, const gf2n_& y) +{ + return *this = mult_table[octet(x.a)][octet(y.a)]; +} + #endif diff --git a/Math/gf2nlong.cpp b/Math/gf2nlong.cpp index 44369d9d..c2555681 100644 --- a/Math/gf2nlong.cpp +++ b/Math/gf2nlong.cpp @@ -27,241 +27,26 @@ ostream& operator<<(ostream& s, const int128& a) { word* tmp = (word*)&a.a; s << hex; - s << noshowbase; - s.width(16); - s.fill('0'); - s << tmp[1]; - s.width(16); + + if (tmp[1] != 0) + { + s << noshowbase; + s.width(16); + s.fill('0'); + s << tmp[1]; + s.width(16); + } + else + s << showbase; + s << tmp[0] << dec; return s; } - -int gf2n_long::n; -int gf2n_long::t1; -int gf2n_long::t2; -int gf2n_long::t3; -int gf2n_long::l0; -int gf2n_long::l1; -int gf2n_long::l2; -int gf2n_long::l3; -int gf2n_long::nterms; -int128 gf2n_long::mask; -int128 gf2n_long::lowermask; -int128 gf2n_long::uppermask; - -#define num_2_fields 1 - -/* Require - * 2*(n-1)-64+t1<64 - */ -int long_fields_2[num_2_fields][4] = { - {128,7,2,1}, - }; - - -void gf2n_long::init_field(int nn) +istream& operator>>(istream& s, int128& a) { - if (nn == 0) - { - nn = MAX_N_BITS; -#ifdef VERBOSE - cerr << "Using GF(2^" << nn << ")" << endl; -#endif - } - - if (nn!=128) { - throw runtime_error("Compiled for GF(2^128) only. Change parameters or compile " - "without USE_GF2N_LONG"); - } - - int i,j=-1; - for (i=0; i=128) { throw not_implemented(); } - // if (nterms==3 && n!=128) { throw not_implemented(); } - - mask=_mm_set_epi64x(-1,-1); - lowermask=_mm_set_epi64x((1LL<<(64-7))-1,-1); - uppermask=_mm_set_epi64x(((word)-1)<<(64-7),0); - - // for CPUs without PCLMUL - gf2n_short::init_tables(); -} - - - -void gf2n_long::reduce_trinomial(int128 xh,int128 xl) -{ - // Deal with xh first - a=xl; - a^=(xh<>n; - while (hi==0) - { a&=mask; - - a^=hi; - a^=(hi<>n; - } -} - -void gf2n_long::reduce_pentanomial(int128 xh, int128 xl) -{ - // Deal with xh first - a=xl; - int128 upper, lower; - upper=xh&uppermask; - lower=xh&lowermask; - // Upper part - int128 tmp = 0; - tmp^=(upper>>(n-t1-l0)); - tmp^=(upper>>(n-t1-l1)); - tmp^=(upper>>(n-t1-l2)); - tmp^=(upper>>(n-t1-l3)); - lower^=(tmp>>(l1)); - a^=(tmp<<(n-l1)); - // Lower part - a^=(lower<>n; - while (hi!=0) - { a&=mask; - - a^=hi; - a^=(hi<>n; - } -*/ -} - - -class int129 -{ - int128 lower; - bool msb; - -public: - int129() : lower(_mm_setzero_si128()), msb(false) { } - int129(int128 lower, bool msb) : lower(lower), msb(msb) { } - int129(int128 a) : lower(a), msb(false) { } - int129(word a) - { *this = a; } - int128 get_lower() { return lower; } - int129& operator=(const __m128i& other) - { lower = other; msb = false; return *this; } - int129& operator=(const word& other) - { lower = _mm_set_epi64x(0, other); msb = false; return *this; } - bool operator==(const int129& other) - { return (lower == other.lower) && (msb == other.msb); } - bool operator!=(const int129& other) - { return !(*this == other); } - bool operator>=(const int129& other) - { //cout << ">= " << msb << other.msb << (msb > other.msb) << is_ge(lower.a, other.lower.a) << endl; - return msb == other.msb ? is_ge(lower.a, other.lower.a) : msb > other.msb; } - int129 operator<<(int other) - { return int129(lower << other, _mm_cvtsi128_si32(((lower >> (128-other)) & 1).a)); } - int129& operator>>=(int other) - { lower >>= other; lower |= (int128(msb) << (128-other)); msb = !other; return *this; } - int129 operator^(const int129& other) - { return int129(lower ^ other.lower, msb ^ other.msb); } - int129& operator^=(const int129& other) - { lower ^= other.lower; msb ^= other.msb; return *this; } - int129 operator&(const word& other) - { return int129(lower & other, false); } - friend ostream& operator<<(ostream& s, const int129& a) - { s << a.msb << a.lower; return s; } -}; - -gf2n_long gf2n_long::invert() const -{ - if (is_one()) { return *this; } - if (is_zero()) { throw division_by_zero(); } - - int129 u,v=a,B=0,D=1,mod=1; - - mod^=(int129(1)<>=1; - if ((B&1)!=0) { B^=mod; } - B>>=1; - } - while ((v&1)==0 && v!=0) - { v>>=1; - if ((D&1)!=0) { D^=mod; } - D>>=1; - } - - if (u>=v) { u=u^v; B=B^D; } - else { v=v^u; D=D^B; } - } - - return D.get_lower(); -} - - -void gf2n_long::randomize(PRNG& G, int n) -{ - (void) n; - a=G.get_doubleword(); - a&=mask; -} - - -void gf2n_long::output(ostream& s,bool human) const -{ - if (human) - { s << *this; } - else - { s.write((char*) &a,sizeof(__m128i)); } -} - -void gf2n_long::input(istream& s,bool human) -{ - if (s.peek() == EOF) - { if (s.tellg() == 0) - { cout << "IO problem. Empty file?" << endl; - throw file_error("gf2n_long input"); - } - throw end_of_file("gf2n_long"); - } - - if (human) - { s >> *this; } - else - { s.read((char*) &a,sizeof(__m128i)); } + gf2n_long tmp; + s >> tmp; + a = tmp.get(); + return s; } diff --git a/Math/gf2nlong.h b/Math/gf2nlong.h index a4d73432..9848a862 100644 --- a/Math/gf2nlong.h +++ b/Math/gf2nlong.h @@ -16,8 +16,11 @@ using namespace std; #include "Tools/intrinsics.h" #include "Math/field_types.h" #include "Math/bigint.h" +#include "Math/gf2n.h" +bool is_ge(__m128i a, __m128i b); + class int128 { public: @@ -39,6 +42,8 @@ public: #endif bool operator!=(const int128& other) const { return !(*this == other); } + bool operator>=(const int128& other) const { return is_ge(a, other.a); } + int128 operator<<(const int& other) const; int128 operator>>(const int& other) const; @@ -56,10 +61,52 @@ public: int128& operator&=(const int128& other) { a &= other.a; return *this; } friend ostream& operator<<(ostream& s, const int128& a); + friend istream& operator>>(istream& s, int128& a); bool get_bit(int i) const; + + void randomize(PRNG& G) { *this = G.get_doubleword(); } + + void to(int128& other) { other = *this; } + void to(word& other) { other = get_lower(); } }; +template +class bit_plus +{ + static const int N_BITS = 8 * sizeof(T); + + T lower; + bool msb; + +public: + bit_plus() : msb(false) { } + bit_plus(T lower, bool msb) : lower(lower), msb(msb) { } + template + bit_plus(U a) : lower(a), msb(false) { } + T get_lower() { return lower; } + bool operator==(const bit_plus& other) + { return (lower == other.lower) && (msb == other.msb); } + bool operator!=(const bit_plus& other) + { return !(*this == other); } + bool operator>=(const bit_plus& other) + { return msb == other.msb ? lower >= other.lower : msb > other.msb; } + bit_plus operator<<(int other) + { return bit_plus(lower << other, ((lower >> (N_BITS-other)) & 1) != 0); } + bit_plus& operator>>=(int other) + { lower >>= other; lower |= (T(msb) << (N_BITS-other)); msb = !other; return *this; } + bit_plus operator^(const bit_plus& other) + { return bit_plus(lower ^ other.lower, msb ^ other.msb); } + bit_plus& operator^=(const bit_plus& other) + { lower ^= other.lower; msb ^= other.msb; return *this; } + bit_plus operator&(const word& other) + { return bit_plus(lower & other, false); } + friend ostream& operator<<(ostream& s, const bit_plus& a) + { s << a.msb << a.lower; return s; } +}; + +typedef bit_plus int129; + template class Input; template class PrivateOutput; @@ -82,167 +129,48 @@ class NoValue; Arithmetic in Gf_{2^n} with n<=128 */ -class gf2n_long : public ValueInterface +class gf2n_long : public gf2n_ { - int128 a; - - static int n,t1,t2,t3,nterms; - static int l0,l1,l2,l3; - static int128 mask,lowermask,uppermask; - - /* Assign x[0..2*nwords] to a and reduce it... */ - void reduce_trinomial(int128 xh,int128 xl); - void reduce_pentanomial(int128 xh,int128 xl); + typedef gf2n_ super; public: typedef gf2n_long value_type; - typedef int128 internal_type; - typedef gf2n_long next; typedef ::Square Square; - const static int MAX_N_BITS = 128; - const static int N_BYTES = sizeof(a); - typedef gf2n_long Scalar; - void reduce(int128 xh,int128 xl) - { - if (nterms==3) - { reduce_pentanomial(xh,xl); } - else - { reduce_trinomial(xh,xl); } - } - - static void init_field(int nn); - static int degree() { return n; } - static int length() { return n; } static int default_degree() { return 128; } - static int get_nterms() { return nterms; } - static int get_t(int i) - { if (i==0) { return t1; } - else if (i==1) { return t2; } - else if (i==2) { return t3; } - return -1; - } - static DataFieldType field_type() { return DATA_GF2N; } - static char type_char() { return '2'; } - static string type_short() { return "2"; } static string type_string() { return "gf2n_long"; } - - static int size() { return sizeof(a); } - static int size_in_bits() { return sizeof(a) * 8; } - - static bool allows(Dtype type) { (void) type; return true; } - - static const true_type invertible; - static const true_type characteristic_two; + word get_word() const { return this->a.get_lower(); } static gf2n_long cut(int128 x) { return x; } - static gf2n_long Mul(gf2n_long a, gf2n_long b) { return a * b; } - - int128 get() const { return a; } - word get_word() const { return _mm_cvtsi128_si64(a.a); } - - const void* get_ptr() const { return &a.a; } - - void assign_zero() { a=_mm_setzero_si128(); } - void assign_one() { a=int128(0,1); } - void assign_x() { a=int128(0,2); } - void assign(const void* buffer) { a = _mm_loadu_si128((__m128i*)buffer); } - - int get_bit(int i) const - { return ((a>>i)&1).get_lower(); } - gf2n_long() { assign_zero(); } - gf2n_long(const int128& g) : a(g & mask) {} + gf2n_long(const super& g) : super(g) {} + gf2n_long(const int128& g) : super(g) {} gf2n_long(int g) : gf2n_long(int128(unsigned(g))) {} template - gf2n_long(IntBase g) : a(g.get()) {} - - int is_zero() const { return a==int128(0); } - int is_one() const { return a==int128(1); } - int equal(const gf2n_long& y) const { return (a==y.a); } - bool operator==(const gf2n_long& y) const { return a==y.a; } - bool operator!=(const gf2n_long& y) const { return a!=y.a; } - - // x+y - void add(const gf2n_long& x,const gf2n_long& y) - { a=x.a^y.a; } - void add(octet* x) - { a^=int128(_mm_loadu_si128((__m128i*)x)); } - void add(octetStream& os) - { add(os.consume(size())); } - void sub(const gf2n_long& x,const gf2n_long& y) - { a=x.a^y.a; } - // = x * y - gf2n_long& mul(const gf2n_long& x,const gf2n_long& y); - - gf2n_long lazy_add(const gf2n_long& x) const { return *this + x; } - gf2n_long lazy_mul(const gf2n_long& x) const { return *this * x; } - - gf2n_long operator+(const gf2n_long& x) const { gf2n_long res; res.add(*this, x); return res; } - gf2n_long operator*(const gf2n_long& x) const { gf2n_long res; res.mul(*this, x); return res; } - gf2n_long& operator+=(const gf2n_long& x) { add(*this, x); return *this; } - gf2n_long& operator*=(const gf2n_long& x) { mul(*this, x); return *this; } - gf2n_long operator-(const gf2n_long& x) const { gf2n_long res; res.add(*this, x); return res; } - gf2n_long& operator-=(const gf2n_long& x) { sub(*this, x); return *this; } - gf2n_long operator/(const gf2n_long& x) const { return *this * x.invert(); } - - gf2n_long invert() const; - void negate() { return; } - - /* Bitwise Ops */ - gf2n_long operator&(const gf2n_long& x) const { return a & x.a; } - gf2n_long operator^(const gf2n_long& x) const { return a ^ x.a; } - gf2n_long operator|(const gf2n_long& x) const { return a | x.a; } - gf2n_long operator~() const { return ~a; } - gf2n_long operator<<(int i) const { return a << i; } - gf2n_long operator>>(int i) const { return a >> i; } - - gf2n_long& operator&=(const gf2n_long& x) { *this = *this & x; return *this; } - gf2n_long& operator^=(const gf2n_long& x) { *this = *this ^ x; return *this; } - gf2n_long& operator>>=(int i) { *this = *this >> i; return *this; } - gf2n_long& operator<<=(int i) { *this = *this << i; return *this; } - - /* Crap RNG */ - void randomize(PRNG& G, int n = -1); - // compatibility with gfp - void almost_randomize(PRNG& G) { randomize(G); } - - void force_to_bit() { a &= 1; } - - void output(ostream& s,bool human) const; - void input(istream& s,bool human); + gf2n_long(IntBase g) : super(g.get()) {} friend ostream& operator<<(ostream& s,const gf2n_long& x) - { s << hex << x.a << dec; + { s << hex << x.get() << dec; return s; } friend istream& operator>>(istream& s,gf2n_long& x) { bigint tmp; s >> hex >> tmp >> dec; - x.a = 0; + x = 0; auto size = tmp.get_mpz_t()->_mp_size; assert(size >= 0); assert(size <= 2); - mpn_copyi((mp_limb_t*)&x.a.a, tmp.get_mpz_t()->_mp_d, size); + mpn_copyi((mp_limb_t*)x.get_ptr(), tmp.get_mpz_t()->_mp_d, size); return s; } - - - // Pack and unpack in native format - // i.e. Dont care about conversion to human readable form - void pack(octetStream& o, int n = -1) const - { (void) n; o.append((octet*) &a,sizeof(__m128i)); } - void unpack(octetStream& o, int n = -1) - { (void) n; o.consume((octet*) &a,sizeof(__m128i)); } }; - inline int128 int128::operator<<(const int& other) const { int128 res(_mm_slli_epi64(a, other)); @@ -267,12 +195,12 @@ inline int128 int128::operator>>(const int& other) const return res; } -void mul64(word x, word y, word& lo, word& hi); +void mul(word x, word y, word& lo, word& hi); inline __m128i software_clmul(__m128i a, __m128i b, int choice) { word lo, hi; - mul64(int128(a).get_half(choice & 1), + mul(int128(a).get_half(choice & 1), int128(b).get_half((choice & 0x10) >> 4), lo, hi); return int128(hi, lo).a; } @@ -309,6 +237,11 @@ inline void mul128(__m128i a, __m128i b, __m128i *res1, __m128i *res2) *res2 = tmp6; } +inline void mul(int128 a, int128 b, int128& lo, int128& hi) +{ + mul128(a.a, b.a, &lo.a, &hi.a); +} + inline bool int128::get_bit(int i) const { if (i < 64) @@ -317,16 +250,4 @@ inline bool int128::get_bit(int i) const return (get_upper() >> (i - 64)) & 1; } -inline gf2n_long& gf2n_long::mul(const gf2n_long& x,const gf2n_long& y) -{ - __m128i res[2]; - memset(res,0,sizeof(res)); - - mul128(x.a.a,y.a.a,res,res+1); - - reduce(res[1],res[0]); - - return *this; -} - #endif /* MATH_GF2NLONG_H_ */ diff --git a/Math/gfp.h b/Math/gfp.h index 5b493ce7..7b257b5f 100644 --- a/Math/gfp.h +++ b/Math/gfp.h @@ -159,7 +159,7 @@ class gfp_ : public ValueInterface void add(void* x) { ZpD.Add(a.x,a.x,(mp_limb_t*)x); } void sub(const gfp_& x,const gfp_& y) - { Sub(a,x.a,y.a,ZpD); } + { ZpD.Sub(a.x,x.a.x,y.a.x); } // = x * y void mul(const gfp_& x,const gfp_& y) { a.template mul(x.a,y.a,ZpD); } diff --git a/Math/gfpvar.cpp b/Math/gfpvar.cpp index 8b065195..45db4e5d 100644 --- a/Math/gfpvar.cpp +++ b/Math/gfpvar.cpp @@ -1,5 +1,5 @@ /* - * gfpvar.cpp + * gfpvar_.cpp * */ @@ -9,207 +9,302 @@ #include "gfp.hpp" -const true_type gfpvar::invertible; -const true_type gfpvar::prime_field; -const false_type gfpvar::characteristic_two; +template +Zp_Data gfpvar_::ZpD; -Zp_Data gfpvar::ZpD; - -string gfpvar::type_string() +template +string gfpvar_::type_string() { return "gfpvar"; } -string gfpvar::type_short() +template +string gfpvar_::type_short() { return "p"; } -char gfpvar::type_char() +template +char gfpvar_::type_char() { return 'p'; } -int gfpvar::length() +template +int gfpvar_::length() { return ZpD.pr_bit_length; } -int gfpvar::size() +template +int gfpvar_::size() { - return ZpD.pr_byte_length; + return ZpD.get_t() * sizeof(mp_limb_t); } -bool gfpvar::allows(Dtype dtype) +template +int gfpvar_::size_in_bits() +{ + return size() * 8; +} + +template +bool gfpvar_::allows(Dtype dtype) { return gfp_<0, 0>::allows(dtype); } -DataFieldType gfpvar::field_type() +template +DataFieldType gfpvar_::field_type() { return gfp_<0, 0>::field_type(); } -void gfpvar::init_field(bigint prime, bool montgomery) +template +void gfpvar_::init_field(bigint prime, bool montgomery) { ZpD.init(prime, montgomery); if (ZpD.get_t() > N_LIMBS) - throw wrong_gfp_size("gfpvar", prime, "MAX_MOD_SZ", ZpD.get_t() * 2); + throw wrong_gfp_size("gfpvar_", prime, "MAX_MOD_SZ", ZpD.get_t() * 2); } -void gfpvar::init_default(int lgp, bool montgomery) +template +void gfpvar_::init_default(int lgp, bool montgomery) { init_field(SPDZ_Data_Setup_Primes(lgp), montgomery); } -const Zp_Data& gfpvar::get_ZpD() +template +const Zp_Data& gfpvar_::get_ZpD() { return ZpD; } -const bigint& gfpvar::pr() +template +const bigint& gfpvar_::pr() { return ZpD.pr; } -template<> -void gfpvar::generate_setup>(string prep_data_prefix, - int nplayers, int lgp) -{ - generate_prime_setup>(prep_data_prefix, nplayers, lgp); -} - -void gfpvar::check_setup(string dir) +template +void gfpvar_::check_setup(string dir) { ::check_setup(dir, pr()); } -void gfpvar::write_setup(string dir) +template +void gfpvar_::write_setup(string dir) { write_online_setup(dir, pr()); } -gfpvar::gfpvar() +template +gfpvar_::gfpvar_() { } -gfpvar::gfpvar(int other) +template +gfpvar_::gfpvar_(int other) { to_modp(a, other, ZpD); } -gfpvar::gfpvar(const bigint& other) +template +gfpvar_::gfpvar_(const bigint& other) { to_modp(a, other, ZpD); } -gfpvar::gfpvar(const modp& other) : - a(other) +template +gfpvar_::gfpvar_(int128 other) : + gfpvar_( + (bigint::tmp = other.get_lower() + + ((bigint::tmp2 = other.get_upper()) << 64))) { } -void gfpvar::assign(const char* buffer) +template +gfpvar_::gfpvar_(BitVec_ other) : + gfpvar_(bigint::tmp = other.get()) +{ +} + +template +void gfpvar_::assign(const void* buffer) { a.assign(buffer, ZpD.get_t()); } -void gfpvar::assign_zero() +template +void gfpvar_::assign_zero() { *this = {}; } -void gfpvar::assign_one() +template +void gfpvar_::assign_one() { assignOne(a, ZpD); } -bool gfpvar::is_zero() +template +bool gfpvar_::is_zero() { return isZero(a, ZpD); } -bool gfpvar::is_one() +template +bool gfpvar_::is_one() { return isOne(a, ZpD); } -gfpvar::modp_type gfpvar::get() const +template +bool gfpvar_::is_bit() +{ + return is_zero() or is_one(); +} + +template +typename gfpvar_::modp_type gfpvar_::get() const { return a; } -gfpvar gfpvar::operator +(const gfpvar& other) const +template +const void* gfpvar_::get_ptr() const { - gfpvar res; + return a.get(); +} + +template +void* gfpvar_::get_ptr() +{ + return &a; +} + +template +void gfpvar_::zero_overhang() +{ + a.zero_overhang(ZpD); +} + +template +void gfpvar_::check() +{ + assert(mpn_cmp(a.get(), ZpD.get_prA(), ZpD.get_t()) < 0); +} + +template +gfpvar_ gfpvar_::operator +(const gfpvar_& other) const +{ + gfpvar_ res; Add(res.a, a, other.a, ZpD); return res; } -gfpvar gfpvar::operator -(const gfpvar& other) const +template +gfpvar_ gfpvar_::operator -(const gfpvar_& other) const { - gfpvar res; + gfpvar_ res; Sub(res.a, a, other.a, ZpD); return res; } -gfpvar gfpvar::operator *(const gfpvar& other) const +template +gfpvar_ gfpvar_::operator *(const gfpvar_& other) const { - gfpvar res; + gfpvar_ res; Mul(res.a, a, other.a, ZpD); return res; } -gfpvar gfpvar::operator /(const gfpvar& other) const +template +gfpvar_ gfpvar_::operator /(const gfpvar_& other) const { return *this * other.invert(); } -gfpvar& gfpvar::operator +=(const gfpvar& other) +template +gfpvar_ gfpvar_::operator <<(int other) const +{ + return bigint::tmp = (bigint::tmp = *this) << other; +} + +template +gfpvar_ gfpvar_::operator >>(int other) const +{ + return bigint::tmp = (bigint::tmp = *this) >> other; +} + +template +gfpvar_& gfpvar_::operator +=(const gfpvar_& other) { Add(a, a, other.a, ZpD); return *this; } -gfpvar& gfpvar::operator -=(const gfpvar& other) +template +gfpvar_& gfpvar_::operator -=(const gfpvar_& other) { Sub(a, a, other.a, ZpD); return *this; } -gfpvar& gfpvar::operator *=(const gfpvar& other) +template +gfpvar_& gfpvar_::operator *=(const gfpvar_& other) { Mul(a, a, other.a, ZpD); return *this; } -bool gfpvar::operator ==(const gfpvar& other) const +template +gfpvar_& gfpvar_::operator &=(const gfpvar_& other) +{ + *this = bigint::tmp = (bigint::tmp = *this) & (bigint::tmp2 = other); + return *this; +} + +template +gfpvar_& gfpvar_::operator >>=(int other) +{ + return *this = *this >> other; +} + +template +bool gfpvar_::operator ==(const gfpvar_& other) const { return areEqual(a, other.a, ZpD); } -bool gfpvar::operator !=(const gfpvar& other) const +template +bool gfpvar_::operator !=(const gfpvar_& other) const { return not (*this == other); } -void gfpvar::add(octetStream& other) +template +void gfpvar_::add(octetStream& other) { - *this += other.get(); + *this += other.get>(); } -void gfpvar::negate() +template +void gfpvar_::negate() { - *this = gfpvar() - *this; + *this = gfpvar_() - *this; } -gfpvar gfpvar::invert() const +template +gfpvar_ gfpvar_::invert() const { - gfpvar res; + gfpvar_ res; Inv(res.a, a, ZpD); return res; } -gfpvar gfpvar::sqrRoot() const +template +gfpvar_ gfpvar_::sqrRoot() const { bigint ti = *this; ti = sqrRootMod(ti, ZpD.pr); @@ -218,44 +313,42 @@ gfpvar gfpvar::sqrRoot() const return ti; } -void gfpvar::randomize(PRNG& G, int) +template +void gfpvar_::randomize(PRNG& G, int) { a.randomize(G, ZpD); } -void gfpvar::almost_randomize(PRNG& G) +template +void gfpvar_::almost_randomize(PRNG& G) { randomize(G); } -void gfpvar::pack(octetStream& os, int) const +template +void gfpvar_::pack(octetStream& os, int) const { a.pack(os, ZpD); } -void gfpvar::unpack(octetStream& os, int) +template +void gfpvar_::unpack(octetStream& os, int) { a.unpack(os, ZpD); } -void gfpvar::output(ostream& o, bool human) const +template +void gfpvar_::output(ostream& o, bool human) const { a.output(o, ZpD, human); } -void gfpvar::input(istream& i, bool human) +template +void gfpvar_::input(istream& i, bool human) { a.input(i, ZpD, human); } -ostream& operator <<(ostream& o, const gfpvar& x) -{ - x.output(o, true); - return o; -} - -istream& operator >>(istream& i, gfpvar& x) -{ - x.input(i, true); - return i; -} +template class gfpvar_<0, MAX_MOD_SZ / 2>; +template class gfpvar_<1, MAX_MOD_SZ>; +template class gfpvar_<2, MAX_MOD_SZ>; diff --git a/Math/gfpvar.h b/Math/gfpvar.h index 319ee55e..41d1ce2b 100644 --- a/Math/gfpvar.h +++ b/Math/gfpvar.h @@ -9,27 +9,32 @@ #include "modp.h" #include "Zp_Data.h" #include "Setup.h" +#include "Square.h" class FFT_Data; +template class BitVec_; -class gfpvar +template +class gfpvar_ { - typedef modp_ modp_type; + typedef modp_ modp_type; static Zp_Data ZpD; modp_type a; public: - typedef gfpvar Scalar; + typedef gfpvar_ Scalar; typedef FFT_Data FD; - typedef void Square; - typedef void next; + typedef ::Square Square; + typedef gfpvar_ next; + typedef gfpvar_ value_type; static const int MAX_N_BITS = modp_type::MAX_N_BITS; static const int MAX_EDABITS = modp_type::MAX_N_BITS; static const int N_LIMBS = modp_type::N_LIMBS; + static const int N_BITS = -1; static const true_type invertible; static const true_type prime_field; @@ -41,12 +46,18 @@ public: static int length(); static int size(); + static int size_in_bits(); static bool allows(Dtype dtype); static DataFieldType field_type(); static void init_field(bigint prime, bool montgomery = true); static void init_default(int lgp, bool montgomery = true); + template + static void init(bool montgomery) + { + init_field(T::pr(), montgomery); + } static const Zp_Data& get_ZpD(); static const bigint& pr(); @@ -61,47 +72,72 @@ public: write_setup(get_prep_sub_dir(nplayers)); } - gfpvar(); - gfpvar(int other); - gfpvar(const bigint& other); - gfpvar(const modp& other); + gfpvar_(); + gfpvar_(int other); + gfpvar_(int128 other); + gfpvar_(BitVec_ other); + gfpvar_(const bigint& other); - template - gfpvar(const gfp_& other) + template + gfpvar_(const modp_& other, const Zp_Data& ZpD) + { + if (get_ZpD() == ZpD) + a = other; + else + { + to_bigint(bigint::tmp, other, ZpD); + *this = bigint::tmp; + } + } + + template + gfpvar_(const gfp_& other) { assert(pr() == other.pr()); a = other.get(); } - void assign(const char* buffer); + void assign(const void* buffer); void assign_zero(); void assign_one(); bool is_zero(); bool is_one(); + bool is_bit(); modp_type get() const; + const void* get_ptr() const; + void* get_ptr(); - gfpvar operator+(const gfpvar& other) const; - gfpvar operator-(const gfpvar& other) const; - gfpvar operator*(const gfpvar& other) const; - gfpvar operator/(const gfpvar& other) const; + void zero_overhang(); + void check(); - gfpvar& operator+=(const gfpvar& other); - gfpvar& operator-=(const gfpvar& other); - gfpvar& operator*=(const gfpvar& other); + gfpvar_ operator+(const gfpvar_& other) const; + gfpvar_ operator-(const gfpvar_& other) const; + gfpvar_ operator*(const gfpvar_& other) const; + gfpvar_ operator/(const gfpvar_& other) const; - bool operator==(const gfpvar& other) const; - bool operator!=(const gfpvar& other) const; + gfpvar_ operator<<(int other) const; + gfpvar_ operator>>(int other) const; + + gfpvar_& operator+=(const gfpvar_& other); + gfpvar_& operator-=(const gfpvar_& other); + gfpvar_& operator*=(const gfpvar_& other); + gfpvar_& operator&=(const gfpvar_& other); + + gfpvar_& operator>>=(int other); + + bool operator==(const gfpvar_& other) const; + bool operator!=(const gfpvar_& other) const; void add(octetStream& other); void negate(); - gfpvar invert() const; + gfpvar_ invert() const; - gfpvar sqrRoot() const; + gfpvar_ sqrRoot() const; void randomize(PRNG& G, int n_bits = -1); void almost_randomize(PRNG& G); @@ -113,9 +149,39 @@ public: void input(istream& o, bool human); }; -ostream& operator<<(ostream& o, const gfpvar& x); -istream& operator>>(istream& i, gfpvar& x); +typedef gfpvar_<0, MAX_MOD_SZ / 2> gfpvar; +typedef gfpvar_<1, MAX_MOD_SZ> gfpvar1; +typedef gfpvar_<2, MAX_MOD_SZ> gfpvar2; typedef gfpvar gfp; +template +const true_type gfpvar_::invertible; +template +const true_type gfpvar_::prime_field; +template +const false_type gfpvar_::characteristic_two; + +template +template +void gfpvar_::generate_setup(string prep_data_prefix, + int nplayers, int lgp) +{ + generate_prime_setup(prep_data_prefix, nplayers, lgp); +} + +template +ostream& operator <<(ostream& o, const gfpvar_& x) +{ + x.output(o, true); + return o; +} + +template +istream& operator >>(istream& i, gfpvar_& x) +{ + x.input(i, true); + return i; +} + #endif /* MATH_GFPVAR_H_ */ diff --git a/Math/modp.h b/Math/modp.h index 9bf94205..f84da1d6 100644 --- a/Math/modp.h +++ b/Math/modp.h @@ -46,13 +46,24 @@ class modp_ } template - modp_(const gfp_& other) : + modp_(const gfp_& other, const Zp_Data& ZpD) : modp_() { + assert(other.get_ZpD() == ZpD); assert(M <= L); inline_mpn_copyi(x, other.get().get(), M); } + template + modp_(const gfpvar_& other, const Zp_Data& ZpD) : + modp_() + { + if (other.get_ZpD() == ZpD) + *this = other.get(); + else + to_modp(*this, bigint(other), ZpD); + } + const mp_limb_t* get() const { return x; } void assign(const void* buffer, int t) { memcpy(x, buffer, t * sizeof(mp_limb_t)); } @@ -64,6 +75,8 @@ class modp_ template void convert_destroy(const fixint& source, const Zp_Data& ZpD); + void zero_overhang(const Zp_Data& ZpD); + void randomize(PRNG& G, const Zp_Data& ZpD); // Pack and unpack in native format diff --git a/Math/modp.hpp b/Math/modp.hpp index e4570f94..50f93cae 100644 --- a/Math/modp.hpp +++ b/Math/modp.hpp @@ -1,6 +1,10 @@ +#ifndef MATH_MODP_HPP_ +#define MATH_MODP_HPP_ + #include "Zp_Data.h" #include "modp.h" #include "Z2k.hpp" +#include "gfpvar.h" #include "Tools/Exceptions.h" @@ -246,6 +250,12 @@ void modp_::convert(const mp_limb_t* source, mp_size_t size, const Zp_Data& Z ZpD.Mont_Mult(x, x, ZpD.R2); } +template +void modp_::zero_overhang(const Zp_Data& ZpD) +{ + x[ZpD.get_t() - 1] &= ZpD.overhang_mask(); +} + template @@ -350,3 +360,5 @@ void modp_::input(istream& s,const Zp_Data& ZpD,bool human) else { s.read((char*) x,ZpD.t*sizeof(mp_limb_t)); } } + +#endif diff --git a/Math/square128.cpp b/Math/square128.cpp index fc1941fe..fadbe21f 100644 --- a/Math/square128.cpp +++ b/Math/square128.cpp @@ -279,7 +279,7 @@ template <> void Square::to(gf2n_long& result, false_type) { int128 high, low; - for (int i = 0; i < 128; i++) + for (int i = 0; i < gf2n_long::degree(); i++) { low ^= rows[i].get() << i; high ^= rows[i].get() >> (128 - i); diff --git a/Networking/CryptoPlayer.cpp b/Networking/CryptoPlayer.cpp index 6f947cc6..aee390bb 100644 --- a/Networking/CryptoPlayer.cpp +++ b/Networking/CryptoPlayer.cpp @@ -21,6 +21,7 @@ void ssl_error(string side, string pronoun, string other, string server) << " have the necessary certificate (" << PREP_DIR << server << ".pem in the default configuration)," << " and run `c_rehash ` on its location." << endl + << "The certificates should be the same on every host. " << "Also make sure that it's still valid. Certificates generated " << "with `Scripts/setup-ssl.sh` expire after a month." << endl; } diff --git a/Networking/Player.cpp b/Networking/Player.cpp index 37807273..51a14312 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -347,6 +347,7 @@ void Player::send_all(const octetStream& o) const void Player::receive_all(vector& os) const { + os.resize(num_players()); for (int j = 0; j < num_players(); j++) if (j != my_num()) receive_player(j, os[j]); diff --git a/OT/BitMatrix.h b/OT/BitMatrix.h index 311f069e..87f99ec9 100644 --- a/OT/BitMatrix.h +++ b/OT/BitMatrix.h @@ -24,9 +24,11 @@ union square128 { typedef gf2n_long RowType; const static int N_ROWS = 128; - const static int N_ROWS_ALLOCATED = 128; - const static int N_COLUMNS = 128; - const static int N_ROW_BYTES = 128 / 8; + + static int n_rows() { return 128; } + static int n_rows_allocated() { return 128; } + static int n_columns() { return 128; } + static int n_row_bytes() { return 128 / 8; } static size_t size() { return N_ROWS * sizeof(__m128i); } @@ -124,7 +126,8 @@ public: size_t vertical_size(); - void resize_vertical(int length) { squares.resize(DIV_CEIL(length, U::N_ROWS)); } + void resize_vertical(int length) + { squares.resize(DIV_CEIL(length, U::n_rows())); } bool operator==(Matrix& other); bool operator!=(Matrix& other); diff --git a/OT/BitMatrix.hpp b/OT/BitMatrix.hpp index 7e20f800..23f0d84d 100644 --- a/OT/BitMatrix.hpp +++ b/OT/BitMatrix.hpp @@ -118,12 +118,12 @@ Slice& Slice::rsub(Slice& other) template Slice& Slice::sub(BitVector& other, int repeat) { - if (end * U::PartType::N_COLUMNS > other.size() * repeat) - throw invalid_length(to_string(U::PartType::N_COLUMNS)); + if (end * U::PartType::n_columns() > other.size() * repeat) + throw invalid_length(to_string(U::PartType::n_columns())); for (size_t i = start; i < end; i++) { bm.squares[i].sub(other.get_ptr_to_byte(i / repeat, - U::PartType::N_ROW_BYTES)); + U::PartType::n_row_bytes())); } return *this; } diff --git a/OT/MamaRectangle.h b/OT/MamaRectangle.h index 6c7f4b00..44fb35e5 100644 --- a/OT/MamaRectangle.h +++ b/OT/MamaRectangle.h @@ -18,9 +18,9 @@ class MamaRectangle typename T::Square squares[N]; public: - static const int N_ROWS = T::Square::N_ROWS; - static const int N_COLUMNS = T::Square::N_COLUMNS; - static const int N_ROW_BYTES = T::Square::N_ROW_BYTES; + static int n_rows() { return T::Square::n_rows(); } + static int n_columns() { return T::Square::n_columns(); } + static int n_row_bytes() { return T::Square::n_row_bytes(); } static int size() { @@ -58,7 +58,8 @@ public: void randomize(int row, PRNG& G) { - squares[row / T::Square::N_ROWS].randomize(row % T::Square::N_ROWS, G); + squares[row / T::Square::n_rows()].randomize( + row % T::Square::n_rows(), G); } void pack(octetStream& os) const diff --git a/OT/NPartyTripleGenerator.hpp b/OT/NPartyTripleGenerator.hpp index 59d303ce..6ccf7da3 100644 --- a/OT/NPartyTripleGenerator.hpp +++ b/OT/NPartyTripleGenerator.hpp @@ -14,6 +14,7 @@ #include "OT/OTMultiplier.hpp" #include "Protocols/MAC_Check.hpp" #include "Protocols/SemiInput.hpp" +#include "Protocols/SemiMC.hpp" #include #include @@ -523,7 +524,32 @@ void OTTripleGenerator::plainTripleRound(int k) { plainTriples[j] = {{a, b, c}}; } + +#ifdef DEBUG_MASCOT + cout << "lengths "; + for (int i = 0; i < 3; i++) + cout << valueBits[i].size() << " "; + cout << endl; + + auto& P = globalPlayer; + SemiMC> MC; + + auto aa = MC.open(a, P); + auto bb = MC.open(b, P); + auto cc = MC.open(c, P); + if (cc != aa * bb) + { + cout << j << " " << cc << " != " << aa << " * " << bb << ", diff " << + (cc - aa * bb) << endl; + cout << "OT output " << ot_multipliers[0]->c_output[j] << endl; + assert(cc == aa * bb); + } +#endif } + +#ifdef DEBUG_MASCOT + cout << "plain triple round done" << endl; +#endif } template @@ -655,6 +681,27 @@ void MascotTripleGenerator::sacrifice(typename T::MAC_Check& MC, PRNG& G) MC.POpen_Begin(openedAs, maskedAs, globalPlayer); MC.POpen_End(openedAs, maskedAs, globalPlayer); +#ifdef DEBUG_MASCOT + MC.Check(globalPlayer); + auto& P = globalPlayer; + + for (int j = 0; j < nTriplesPerLoop; j++) + for (int i = 0; i < 2; i++) + { + auto a = MC.open(uncheckedTriples[j].a[i], P); + auto b = MC.open(uncheckedTriples[j].b, P); + auto c = MC.open(uncheckedTriples[j].c[i], P); + if (c != a * b) + { + cout << c << " != " << a << " * " << b << ", diff " << hex << + (c - a * b) << endl; + assert(c == a * b); + } + } + + MC.Check(globalPlayer); +#endif + for (int j = 0; j < nTriplesPerLoop; j++) { MC.AddToCheck(maskedTriples[j].computeCheckShare(openedAs[j]), 0, globalPlayer); diff --git a/OT/OTCorrelator.hpp b/OT/OTCorrelator.hpp index ab22a9e4..00561d3c 100644 --- a/OT/OTCorrelator.hpp +++ b/OT/OTCorrelator.hpp @@ -135,8 +135,8 @@ void OTExtensionWithMatrix::hash_outputs(int nOTs, vector& senderOutput, gettimeofday(&startv, NULL); #endif - int n_rows = V::PartType::N_ROWS_ALLOCATED; - int n = (nOTs + n_rows - 1) / n_rows * V::PartType::N_ROWS; + int n_rows = V::PartType::n_rows_allocated(); + int n = (nOTs + n_rows - 1) / n_rows * V::PartType::n_rows(); for (int i = 0; i < 2; i++) senderOutput[i].resize_vertical(n); receiverOutput.resize_vertical(n); @@ -192,8 +192,20 @@ void OTCorrelator::reduce_squares(unsigned int nTriples, vector& output, i output.resize(nTriples); for (unsigned int j = 0; j < nTriples; j++) { +#ifdef DEBUG_MASCOT + T a, b; + receiverOutputMatrix.squares[j + start].to(a); + senderOutputMatrices[0].squares[j + start].to(b); +#endif + receiverOutputMatrix.squares[j + start].sub( senderOutputMatrices[0].squares[j + start]).to(output[j]); + +#ifdef DEBUG_MASCOT + cout << output[j] << " ?= " << a << " - " << b << endl; + cout << "first row " << receiverOutputMatrix.squares[j + start].rows[0] << endl; + assert(output[j] == a - b); +#endif } } diff --git a/OT/OTMultiplier.hpp b/OT/OTMultiplier.hpp index 9c878e59..a6842cb8 100644 --- a/OT/OTMultiplier.hpp +++ b/OT/OTMultiplier.hpp @@ -141,9 +141,9 @@ void OTMultiplier::multiplyForTriples() { typedef typename W::Rectangle X; - otCorrelator.resize(X::N_COLUMNS * generator.nPreampTriplesPerLoop); + otCorrelator.resize(X::n_columns() * generator.nPreampTriplesPerLoop); - rot_ext.resize(X::N_ROWS * generator.nPreampTriplesPerLoop + 2 * 128); + rot_ext.resize(X::n_rows() * generator.nPreampTriplesPerLoop + 2 * 128); vector >& baseSenderOutputs = otCorrelator.matrices; Matrix& baseReceiverOutput = otCorrelator.senderOutputMatrices[0]; @@ -283,7 +283,7 @@ void MascotMultiplier::after_correlation() for (int j = 0; j < 3; j++) { bits.append(generator.valueBits[j], - n_vals[j] * T::Square::N_COLUMNS); + n_vals[j] * T::Square::n_columns()); total += n_vals[j]; } this->auth_ot_ext.resize(bits.size()); @@ -298,7 +298,7 @@ void MascotMultiplier::after_correlation() } else { - this->auth_ot_ext.resize(n_vals[0] * T::Square::N_COLUMNS); + this->auth_ot_ext.resize(n_vals[0] * T::Square::n_columns()); for (int j = 0; j < 3; j++) { int nValues = n_vals[j]; @@ -459,7 +459,7 @@ void MascotMultiplier::multiplyForBits(U) { int128 r = auth_ot_ext.receiverOutputMatrix.squares[j/128].rows[j%128]; int128 s = auth_ot_ext.senderOutputMatrices[0].squares[j/128].rows[j%128]; - macs[0][j] = T::clear::cut(r ^ s); + macs[0][j] = typename T::clear(r ^ s); } outbox.push(job); diff --git a/OT/Rectangle.h b/OT/Rectangle.h index a5db5c5d..ab0bd5de 100644 --- a/OT/Rectangle.h +++ b/OT/Rectangle.h @@ -15,16 +15,19 @@ template class Rectangle { -public: - typedef V RowType; - static const int N_ROWS = U::N_BITS; - static const int N_COLUMNS = V::N_BITS; - static const int N_ROW_BYTES = V::N_BYTES; // make sure number of allocated rows matches the number of bytes static const int N_ROWS_ALLOCATED = 8 * U::N_BYTES; +public: + typedef V RowType; + + static int n_rows() { return U::N_BITS; } + static int n_columns() { return V::N_BITS; } + static int n_row_bytes() { return V::N_BYTES; } + static int n_rows_allocated() { return N_ROWS_ALLOCATED; } + V rows[N_ROWS_ALLOCATED]; static size_t size() { return N_ROWS * RowType::size(); } @@ -68,4 +71,11 @@ using Z2kRectangle = Rectangle, Z2 >; template using Z2kSquare = Rectangle, Z2>; +template +ostream& operator<<(ostream& o, const Rectangle&) +{ + throw not_implemented(); + return o; +} + #endif /* OT_RECTANGLE_H_ */ diff --git a/OT/Rectangle.hpp b/OT/Rectangle.hpp index eae2a880..f43bb5eb 100644 --- a/OT/Rectangle.hpp +++ b/OT/Rectangle.hpp @@ -15,10 +15,6 @@ template const int Rectangle::N_ROWS; template -const int Rectangle::N_COLUMNS; -template -const int Rectangle::N_ROW_BYTES; -template const int Rectangle::N_ROWS_ALLOCATED; template diff --git a/OT/TripleMachine.h b/OT/TripleMachine.h index 2b95ff4d..22c30639 100644 --- a/OT/TripleMachine.h +++ b/OT/TripleMachine.h @@ -20,12 +20,13 @@ class TripleMachine : public OfflineMachineBase, public MascotParams int nConnections; gf2n mac_key2; - gfp1 mac_keyp; + gfpvar1 mac_keyp; Z2<128> mac_keyz; + bigint prime; + public: int nloops; - bool primeField; bool bonding; int z2k, z2s; diff --git a/Processor/BaseMachine.cpp b/Processor/BaseMachine.cpp index 1eea739a..7a0385b9 100644 --- a/Processor/BaseMachine.cpp +++ b/Processor/BaseMachine.cpp @@ -84,6 +84,7 @@ void BaseMachine::load_schedule(const string& progname, bool load_bytecode) inpf.get(); getline(inpf, compiler); getline(inpf, domain); + getline(inpf, relevant_opts); inpf.close(); } diff --git a/Processor/BaseMachine.h b/Processor/BaseMachine.h index 9b80a001..0e08549e 100644 --- a/Processor/BaseMachine.h +++ b/Processor/BaseMachine.h @@ -26,6 +26,7 @@ protected: string compiler; string domain; + string relevant_opts; void print_timers(); diff --git a/Processor/Instruction.h b/Processor/Instruction.h index cfcac3b3..573429b0 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -93,6 +93,7 @@ enum LEGENDREC = 0x38, DIGESTC = 0x39, INV2M = 0x3a, + FLOORDIVC = 0x3b, // Open OPEN = 0xA5, MULS = 0xA6, @@ -123,6 +124,7 @@ enum INPUTMIXED = 0xF2, INPUTMIXEDREG = 0xF3, RAWINPUT = 0xF4, + INPUTPERSONAL = 0xF5, STARTINPUT = 0x61, STOPINPUT = 0x62, READSOCKETC = 0x63, @@ -193,6 +195,8 @@ enum CONDPRINTSTR = 0xBF, PRINTFLOATPREC = 0xE0, CONDPRINTPLAIN = 0xE1, + INTOUTPUT = 0xE6, + FLOATOUTPUT = 0xE7, // GF(2^n) versions diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index 8949f28e..1936edd9 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -62,6 +62,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case MULM: case DIVC: case MODC: + case FLOORDIVC: case TRIPLE: case ANDC: case XORC: @@ -287,6 +288,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case INPUTMIXEDREG: case RAWINPUT: case GRAWINPUT: + case INPUTPERSONAL: case TRUNC_PR: case RUN_TAPE: num_var_args = get_int(s); @@ -446,9 +448,14 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) get_vector(get_int(s), start, s); break; case PRINTREGSIGNED: + case INTOUTPUT: n = get_int(s); get_ints(r, s, 1); break; + case FLOATOUTPUT: + n = get_int(s); + get_vector(4, start, s); + break; case TRANS: num_var_args = get_int(s) - 1; n = get_int(s); @@ -537,6 +544,7 @@ int BaseInstruction::get_reg_type() const case PLAYERID: case CONVCBIT: case CONVCBITVEC: + case INTOUTPUT: return INT; case PREP: case GPREP: @@ -567,6 +575,7 @@ int BaseInstruction::get_reg_type() const case LEGENDREC: case DIGESTC: case INV2M: + case FLOORDIVC: case OPEN: case ANDC: case XORC: @@ -581,6 +590,7 @@ int BaseInstruction::get_reg_type() const case SHRCI: case CONVINT: case PUBINPUT: + case FLOATOUTPUT: return CINT; default: if (is_gf2n_instruction()) @@ -718,6 +728,11 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const size = DIV_CEIL(this->size, 64); skip = 1; break; + case INPUTPERSONAL: + size_offset = -2; + offset = 2; + skip = 4; + break; } if (skip > 0) @@ -848,7 +863,16 @@ inline void Instruction::execute(Processor& Proc) const throw Processor_Error("Division by zero from register"); Proc.write_C2(r[0], Proc.read_C2(r[1]) / Proc.read_C2(r[2])); break; + case FLOORDIVC: + if (Proc.read_Cp(r[2]).is_zero()) + throw Processor_Error("Division by zero from register"); + Proc.temp.aa.from_signed(Proc.read_Cp(r[1])); + Proc.temp.aa2.from_signed(Proc.read_Cp(r[2])); + Proc.write_Cp(r[0], bigint(Proc.temp.aa / Proc.temp.aa2)); + break; case MODC: + if (Proc.read_Cp(r[2]).is_zero()) + throw Processor_Error("Modulo by zero from register"); to_bigint(Proc.temp.aa, Proc.read_Cp(r[1])); to_bigint(Proc.temp.aa2, Proc.read_Cp(r[2])); mpz_fdiv_r(Proc.temp.aa.get_mpz_t(), Proc.temp.aa.get_mpz_t(), Proc.temp.aa2.get_mpz_t()); @@ -887,6 +911,8 @@ inline void Instruction::execute(Processor& Proc) const Proc.write_Cp(r[0], Proc.get_inverse2(n)); break; case MODCI: + if (n == 0) + throw Processor_Error("Modulo by immediate zero"); to_bigint(Proc.temp.aa, Proc.read_Cp(r[1])); to_gfp(Proc.temp.ansp, Proc.temp.aa2 = mpz_fdiv_ui(Proc.temp.aa.get_mpz_t(), n)); Proc.write_Cp(r[0],Proc.temp.ansp); @@ -944,6 +970,9 @@ inline void Instruction::execute(Processor& Proc) const case GRAWINPUT: Proc.Proc2.input.raw_input(Proc.Proc2, start, size); return; + case INPUTPERSONAL: + Proc.Procp.input_personal(start); + return; // Note: Fp version has different semantics for NOTC than GNOTC case NOTC: to_bigint(Proc.temp.aa, Proc.read_Cp(r[1])); @@ -1164,6 +1193,19 @@ inline void Instruction::execute(Processor& Proc) const case RAWOUTPUT: Proc.read_Cp(r[0]).output(Proc.public_output, false); break; + case INTOUTPUT: + if (n == -1 or n == Proc.P.my_num()) + Integer(Proc.read_Ci(r[0])).output(Proc.binary_output, false); + break; + case FLOATOUTPUT: + if (n == -1 or n == Proc.P.my_num()) + { + double tmp = bigint::get_float(Proc.read_Cp(start[0] + i), + Proc.read_Cp(start[1] + i), Proc.read_Cp(start[2] + i), + Proc.read_Cp(start[3] + i)).get_d(); + Proc.binary_output.write((char*) &tmp, sizeof(double)); + } + break; case STARTPRIVATEOUTPUT: Proc.privateOutputp.start(n,r[0],r[1]); break; diff --git a/Processor/Machine.h b/Processor/Machine.h index 2cd0c664..5ce56eeb 100644 --- a/Processor/Machine.h +++ b/Processor/Machine.h @@ -49,6 +49,8 @@ class Machine : public BaseMachine void load_program(const string& threadname, const string& filename); + void suggest_optimizations(); + public: vector progs; @@ -71,6 +73,7 @@ class Machine : public BaseMachine OnlineOptions opts; atomic data_sent; + NamedCommStats comm_stats; ExecutionStats stats; Machine(int my_number, Names& playerNames, const string& progname, diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index d7a1fb10..64bfae69 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -79,6 +79,11 @@ Machine::Machine(int my_number, Names& playerNames, inpf.open(memory_filename(), ios::in | ios::binary); if (inpf.fail()) { throw file_error(memory_filename()); } inpf >> M2 >> Mp >> Mi; + if (inpf.get() != 'M') + { + cerr << "Invalid memory file. Run with '-m empty'." << endl; + exit(1); + } inpf.close(); } else if (!(memtype.compare("empty")==0)) @@ -319,7 +324,12 @@ void Machine::run() #endif print_timers(); - cerr << "Data sent = " << data_sent / 1e6 << " MB" << endl; + + size_t rounds = 0; + for (auto& x : comm_stats) + rounds += x.second.rounds; + cerr << "Data sent = " << data_sent / 1e6 << " MB in ~" << rounds + << " rounds (party " << my_number << ")" << endl; auto& P = *this->P; Bundle bundle(P); @@ -328,7 +338,7 @@ void Machine::run() size_t global = 0; for (auto& os : bundle) global += os.get_int(8); - cerr << "Global data sent = " << global / 1e6 << " MB" << endl; + cerr << "Global data sent = " << global / 1e6 << " MB (all parties)" << endl; #ifdef VERBOSE_OPTIONS if (opening_sum < N.num_players() && !direct) @@ -352,6 +362,7 @@ void Machine::run() // Write out the memory to use next time ofstream outf(memory_filename(), ios::out | ios::binary); outf << M2 << Mp << Mi; + outf << 'M'; outf.close(); bit_memories.write_memory(N.my_num()); @@ -396,6 +407,8 @@ void Machine::run() sint::LivePrep::teardown(); sgf2n::LivePrep::teardown(); + suggest_optimizations(); + #ifdef VERBOSE cerr << "End of prog" << endl; #endif @@ -420,4 +433,21 @@ void Machine::reqbl(int n) sint::clear::reqbl(n); } +template +void Machine::suggest_optimizations() +{ + string optimizations; + if (relevant_opts.find("trunc_pr") != string::npos and sint::has_trunc_pr) + optimizations.append("\tprogram.use_trunc_pr = True\n"); + if (relevant_opts.find("split") != string::npos and sint::has_split) + optimizations.append( + "\tprogram.use_split(" + to_string(N.num_players()) + ")\n"); + if (relevant_opts.find("edabit") != string::npos and not sint::has_split) + optimizations.append("\tprogram.use_edabit(True)\n"); + if (not optimizations.empty()) + cerr << "This program might benefit from some protocol options." << endl + << "Consider adding the following at the beginning of '" << progname + << ".mpc':" << endl << optimizations; +} + #endif diff --git a/Processor/Memory.h b/Processor/Memory.h index c628e8e5..3c509395 100644 --- a/Processor/Memory.h +++ b/Processor/Memory.h @@ -33,16 +33,33 @@ class Memory unsigned size_c() { return MC.size(); } + template + static void check_index(const vector& M, size_t i) + { + if (i >= M.size()) + throw overflow("memory", i, M.size()); + } + const typename T::clear& read_C(int i) const - { return MC[i]; } + { + check_index(MC, i); + return MC[i]; + } const T& read_S(int i) const - { return MS[i]; } + { + check_index(MS, i); + return MS[i]; + } void write_C(unsigned int i,const typename T::clear& x) - { MC[i]=x; + { + check_index(MC, i); + MC[i]=x; } void write_S(unsigned int i,const T& x) - { MS[i]=x; + { + check_index(MS, i); + MS[i]=x; } void minimum_size(RegType secret_type, RegType clear_type, diff --git a/Processor/Online-Thread.hpp b/Processor/Online-Thread.hpp index 74fc0a2d..7b9a711a 100644 --- a/Processor/Online-Thread.hpp +++ b/Processor/Online-Thread.hpp @@ -312,6 +312,7 @@ void thread_info::Sub_Main_Func() delete processor; machine.data_sent += P.sent + prep_sent; + machine.comm_stats += P.comm_stats; queues->finished(actual_usage); delete MC2; diff --git a/Processor/OnlineMachine.hpp b/Processor/OnlineMachine.hpp index 049388e8..1f1b4e53 100644 --- a/Processor/OnlineMachine.hpp +++ b/Processor/OnlineMachine.hpp @@ -103,7 +103,9 @@ OnlineMachine::OnlineMachine(int argc, const char** argv, ez::ezOptionParser& op "--external-server" // Flag token. ); + opt.parse(argc, argv); opt.get("--lg2")->getInt(lg2); + opt.resetArgs(); } inline diff --git a/Processor/OnlineOptions.cpp b/Processor/OnlineOptions.cpp index e71185e5..c980b51c 100644 --- a/Processor/OnlineOptions.cpp +++ b/Processor/OnlineOptions.cpp @@ -23,7 +23,7 @@ OnlineOptions::OnlineOptions() : playerno(-1) memtype = "empty"; bits_from_squares = false; direct = false; - bucket_size = 3; + bucket_size = 4; cmd_private_input_file = "Player-Data/Input"; cmd_private_output_file = ""; } @@ -163,11 +163,11 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, "--direct" // Flag token. ); opt.add( - "3", // Default. + "4", // Default. 0, // Required? 1, // Number of args expected. 0, // Delimiter if expecting multiple args. - "Batch size for sacrifice (3-5, default: 3)", // Help description. + "Batch size for sacrifice (3-5, default: 4)", // Help description. "-B", // Flag token. "--bucket-size" // Flag token. ); diff --git a/Processor/Processor.h b/Processor/Processor.h index fa0a5fe1..5183c61b 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -70,6 +70,8 @@ public: int b); void conv2ds(const Instruction& instruction); + void input_personal(const vector& args); + CheckVector& get_S() { return S; @@ -109,6 +111,7 @@ public: ifstream public_input; ofstream public_output; ofstream private_output; + ofstream binary_output; int sent, rounds; diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index e88fe879..cddb6505 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -85,11 +85,14 @@ Processor::Processor(int thread_num,Player& P, private_input.open(private_input_filename.c_str()); public_output.open(get_filename(PREP_DIR "Public-Output-",true).c_str(), ios_base::out); private_output.open(get_filename(PREP_DIR "Private-Output-",true).c_str(), ios_base::out); + binary_output.open( + get_parameterized_filename(P.my_num(), thread_num, + PREP_DIR "Binary-Output"), ios_base::out); open_input_file(P.my_num(), thread_num, machine.opts.cmd_private_input_file); secure_prng.ReSeed(); - shared_prng.SeedGlobally(P); + shared_prng.SeedGlobally(P, false); // only output on party 0 if not interactive bool output = P.my_num() == 0 or machine.opts.interactive; @@ -217,8 +220,8 @@ void Processor::convcintvec(const Instruction& instruction) int n_cols = min(n_bits - j * unit, unit); for (int k = 0; k < n_rows; k++) square.rows[k] = - Integer(Procp.C[instruction.get_r(0) + i * unit + k] - >> (j * unit)).get(); + Integer::convert_unsigned( + Procp.C[instruction.get_r(0) + i * unit + k] >> (j * unit)).get(); square.transpose(n_rows, n_cols); for (int k = 0; k < n_cols; k++) Procb.C[instruction.get_start()[k + j * unit] + i] = square.rows[k]; @@ -646,6 +649,24 @@ void SubProcessor::conv2ds(const Instruction& instruction) } } +template +void SubProcessor::input_personal(const vector& args) +{ + input.reset_all(P); + for (size_t i = 0; i < args.size(); i += 4) + for (int j = 0; j < args[i]; j++) + { + if (args[i + 1] == P.my_num()) + input.add_mine(C[args[i + 3] + j]); + else + input.add_other(args[i + 1]); + } + input.exchange(); + for (size_t i = 0; i < args.size(); i += 4) + for (int j = 0; j < args[i]; j++) + S[args[i + 2] + j] = input.finalize(args[i + 1]); +} + template typename sint::clear Processor::get_inverse2(unsigned m) { diff --git a/Programs/Source/benchmark_net.mpc b/Programs/Source/benchmark_net.mpc index ff9c94ff..18396df5 100644 --- a/Programs/Source/benchmark_net.mpc +++ b/Programs/Source/benchmark_net.mpc @@ -1,6 +1,13 @@ import ml import util import math +import sys + +if len(program.args) < 2: + print('Usage: %s ' % program.args[0], + file=sys.stderr) + print(' refers to the letter naming in SecureNN.', file=sys.stderr) + exit(1) program.options_from_args() program.options.cisc = True @@ -29,10 +36,10 @@ if program.args[1] == 'A': ] elif program.args[1] == 'B': layers = [ - ml.FixConv2d([1, 28, 28, 1], (16, 5, 5, 1), (16,), [1, 24, 24, 16], (1, 1)), + ml.FixConv2d([1, 28, 28, 1], (16, 5, 5, 1), (16,), [1, 24, 24, 16], (1, 1), 'VALID'), ml.MaxPool([1, 24, 24, 16]), ml.Relu([1, 12, 12, 16]), - ml.FixConv2d([1, 12, 12, 16], (16, 5, 5, 16), (16,), [1, 8, 8, 16], (1, 1)), + ml.FixConv2d([1, 12, 12, 16], (16, 5, 5, 16), (16,), [1, 8, 8, 16], (1, 1), 'VALID'), ml.MaxPool([1, 8, 8, 16]), ml.Relu([1, 4, 4, 16]), ml.Dense(1, 256, 100), @@ -42,10 +49,10 @@ elif program.args[1] == 'B': ] elif program.args[1] == 'C': layers = [ - ml.FixConv2d([1, 28, 28, 1], (20, 5, 5, 1), (20,), [1, 24, 24, 20], (1, 1)), + ml.FixConv2d([1, 28, 28, 1], (20, 5, 5, 1), (20,), [1, 24, 24, 20], (1, 1), 'VALID'), ml.MaxPool([1, 24, 24, 20]), ml.Relu([1, 12, 12, 20]), - ml.FixConv2d([1, 12, 12, 20], (50, 5, 5, 20), (50,), [1, 8, 8, 50], (1, 1)), + ml.FixConv2d([1, 12, 12, 20], (50, 5, 5, 20), (50,), [1, 8, 8, 50], (1, 1), 'VALID'), ml.MaxPool([1, 8, 8, 50]), ml.Relu([1, 4, 4, 50]), ml.Dense(1, 800, 500), diff --git a/Programs/Source/idash_train.mpc b/Programs/Source/idash_train.mpc index ccef647f..b57034d9 100644 --- a/Programs/Source/idash_train.mpc +++ b/Programs/Source/idash_train.mpc @@ -1,6 +1,14 @@ import ml import random import re +import sys + +if len(program.args) < 4: + print('Usage: %s ' % program.args[0], + file=sys.stderr) + print('Refer to https://github.com/mkskeller/idash-submission for ' + 'scripts to run this benchmark.', file=sys.stderr) + exit(1) program.use_trunc_pr = True @@ -43,6 +51,8 @@ if 'mini' in program.args: else: batch_size = N +ml.Layer.back_batch_size = batch_size + X_normal = sfix.Matrix(n_normal, n_features) X_pos = sfix.Matrix(n_pos, n_features) diff --git a/Programs/Source/keras_mnist_dense.mpc b/Programs/Source/keras_mnist_dense.mpc new file mode 100644 index 00000000..a525c065 --- /dev/null +++ b/Programs/Source/keras_mnist_dense.mpc @@ -0,0 +1,48 @@ +# this trains a dense neural network on MNIST +# see https://github.com/csiro-mlai/mnist-mpc for data preparation + +program.options_from_args() + +training_samples = sfix.Tensor([60000, 28, 28]) +training_labels = sint.Tensor([60000, 10]) + +test_samples = sfix.Tensor([10000, 28, 28]) +test_labels = sint.Tensor([10000, 10]) + +training_labels.input_from(0) +training_samples.input_from(0) + +test_labels.input_from(0) +test_samples.input_from(0) + +from Compiler import ml +tf = ml + +layers = [ + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(128, activation='relu'), + tf.keras.layers.Dense(128, activation='relu'), + tf.keras.layers.Dense(10, activation='softmax') +] + +model = tf.keras.models.Sequential(layers) + +optim = tf.keras.optimizers.SGD(momentum=0.9, learning_rate=0.01) + +model.compile(optimizer=optim) + +opt = model.fit( + training_samples, + training_labels, + epochs=1, + batch_size=128, + validation_data=(test_samples, test_labels) +) + +guesses = model.predict(test_samples) + +print_ln('guess %s', guesses.reveal_nested()[:3]) +print_ln('truth %s', test_labels.reveal_nested()[:3]) + +for var in model.trainable_variables: + var.write_to_file() diff --git a/Programs/Source/keras_mnist_dense_predict.mpc b/Programs/Source/keras_mnist_dense_predict.mpc new file mode 100644 index 00000000..84d3cdea --- /dev/null +++ b/Programs/Source/keras_mnist_dense_predict.mpc @@ -0,0 +1,39 @@ +# this trains a dense neural network on MNIST +# see https://github.com/csiro-mlai/mnist-mpc for data preparation + +program.options_from_args() + +training_samples = sfix.Tensor([60000, 28, 28]) +training_labels = sint.Tensor([60000, 10]) + +test_samples = sfix.Tensor([10000, 28, 28]) +test_labels = sint.Tensor([10000, 10]) + +training_labels.input_from(0) +training_samples.input_from(0) + +test_labels.input_from(0) +test_samples.input_from(0) + +from Compiler import ml +tf = ml + +layers = [ + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(128, activation='relu'), + tf.keras.layers.Dense(128, activation='relu'), + tf.keras.layers.Dense(10, activation='softmax') +] + +model = tf.keras.models.Sequential(layers) + +model.build(test_samples.sizes) + +start = 0 +for var in model.trainable_variables: + start = var.read_from_file(start) + +guesses = model.predict(test_samples) + +print_ln('guess %s', guesses.reveal_nested()[:3]) +print_ln('truth %s', test_labels.reveal_nested()[:3]) diff --git a/Programs/Source/keras_mnist_lenet.mpc b/Programs/Source/keras_mnist_lenet.mpc new file mode 100644 index 00000000..9fdac27f --- /dev/null +++ b/Programs/Source/keras_mnist_lenet.mpc @@ -0,0 +1,44 @@ +# this trains LeNet on MNIST with a dropout layer +# see https://github.com/csiro-mlai/mnist-mpc for data preparation + +program.options_from_args() + +training_samples = MultiArray([60000, 28, 28], sfix) +training_labels = MultiArray([60000, 10], sint) + +test_samples = MultiArray([10000, 28, 28], sfix) +test_labels = MultiArray([10000, 10], sint) + +training_labels.input_from(0) +training_samples.input_from(0) + +test_labels.input_from(0) +test_samples.input_from(0) + +from Compiler import ml +tf = ml + +layers = [ + tf.keras.layers.Conv2D(20, 5, 1, 'valid', activation='relu'), + tf.keras.layers.MaxPooling2D(2), + tf.keras.layers.Conv2D(50, 5, 1, 'valid', activation='relu'), + tf.keras.layers.MaxPooling2D(2), + tf.keras.layers.Flatten(), + tf.keras.layers.Dropout(0.5), + tf.keras.layers.Dense(500, activation='relu'), + tf.keras.layers.Dense(10, activation='softmax') +] + +model = tf.keras.models.Sequential(layers) + +optim = tf.keras.optimizers.Adam(amsgrad=True) + +model.compile(optimizer=optim) + +opt = model.fit( + training_samples, + training_labels, + epochs=10, + batch_size=128, + validation_data=(test_samples, test_labels) +) diff --git a/Programs/Source/mnist_49.mpc b/Programs/Source/mnist_49.mpc index 1ad2eadd..05218130 100644 --- a/Programs/Source/mnist_49.mpc +++ b/Programs/Source/mnist_49.mpc @@ -1,3 +1,6 @@ +# this trains network with dense layers in 4/9 distinction +# see https://github.com/csiro-mlai/mnist-mpc for data preparation + import ml import math import re @@ -54,7 +57,7 @@ if not ('no_acc' in program.args and 'no_loss' in program.args): Y.input_from(0) X.input_from(0) -sgd = ml.SGD(layers, 1) +sgd = ml.Optimizer.from_args(program, layers) if 'no_out' in program.args: del sgd.layers[-1] diff --git a/Programs/Source/mnist_A.mpc b/Programs/Source/mnist_A.mpc index 0bd1c0a9..18d4f369 100644 --- a/Programs/Source/mnist_A.mpc +++ b/Programs/Source/mnist_A.mpc @@ -1,3 +1,6 @@ +# this trains network with dense layers in 0/1 distinction +# see https://github.com/csiro-mlai/mnist-mpc for data preparation + import ml import math @@ -36,6 +39,7 @@ except: batch_size = N assert batch_size <= N +ml.Layer.back_batch_size = batch_size try: ml.set_n_threads(int(program.args[3])) diff --git a/Programs/Source/mnist_B.mpc b/Programs/Source/mnist_B.mpc deleted file mode 100644 index a62be674..00000000 --- a/Programs/Source/mnist_B.mpc +++ /dev/null @@ -1,73 +0,0 @@ -import ml -import math - -#ml.report_progress = True - -program.options_from_args() - -approx = 3 - -if 'profile' in program.args: - print('Compiling for profiling') - N = 1000 - n_test = 1000 -elif 'debug' in program.args: - N = 10 - n_test = 10 -elif 'debug20' in program.args: - N = 20 - n_test = 20 -elif 'debug100' in program.args: - N = 100 - n_test = 100 -elif 'gisette' in program.args: - print('Compiling for 4/9') - N = 11791 - n_test = 1991 -else: - N = 12665 - n_test = 2115 - -n_examples = N -n_features = 28 ** 2 - -try: - n_epochs = int(program.args[1]) -except: - n_epochs = 100 - -try: - batch_size = int(program.args[2]) -except: - batch_size = N - -assert batch_size <= N - -try: - ml.set_n_threads(int(program.args[3])) -except: - pass - -layers = [ - ml.FixConv2d([N, 28, 28, 1], (16, 5, 5, 1), (16,), [N, 24, 24, 16], (1, 1)), - ml.MaxPool([N, 24, 24, 16]), - ml.Relu([N, 12, 12, 16]), - ml.FixConv2d([N, 12, 12, 16], (16, 5, 5, 16), (16,), [N, 8, 8, 16], (1, 1)), - ml.MaxPool([N, 8, 8, 16]), - ml.Relu([N, 4, 4, 16]), - ml.Dense(N, 256, 100), - ml.Relu([N, 100]), - ml.Dense(N, 100, 1), - ml.Output(N) -] - -layers[-1].Y.input_from(0) -layers[0].X.input_from(0) - -Y = sint.Array(n_test) -X = sfix.Matrix(n_test, n_features) -Y.input_from(0) -X.input_from(0) - -sgd = ml.SGD(layers, 1, report_loss=True) -sgd.run_by_args(program, n_epochs, batch_size, X, Y) diff --git a/Programs/Source/mnist_D.mpc b/Programs/Source/mnist_D.mpc deleted file mode 100644 index 49f8f06f..00000000 --- a/Programs/Source/mnist_D.mpc +++ /dev/null @@ -1,60 +0,0 @@ -import ml -import math - -#ml.report_progress = True - -program.options_from_args() - -approx = 3 - -if 'profile' in program.args: - print('Compiling for profiling') - N = 1000 - n_test = 1000 -elif 'debug' in program.args: - N = 10 - n_test = 10 -elif 'gisette' in program.args: - print('Compiling for 4/9') - N = 11791 - n_test = 1991 -else: - N = 12665 - n_test = 2115 - -n_examples = N -n_features = 28 ** 2 - -try: - n_epochs = int(program.args[1]) -except: - n_epochs = 100 - -try: - batch_size = int(program.args[2]) -except: - batch_size = N - -assert batch_size <= N - -try: - ml.set_n_threads(int(program.args[3])) -except: - pass - -layers = [ - ml.FixConv2d([N, 28, 28, 1], (5, 5, 5, 1), (5,), [N, 14, 14, 5], (2, 2)), - ml.Relu([N, 14, 14, 5]), - ml.Dense(N, 980, 1), - ml.Output(N, approx=approx)] - -layers[-1].Y.input_from(0) -layers[0].X.input_from(0) - -Y = sint.Array(n_test) -X = sfix.Matrix(n_test, n_features) -Y.input_from(0) -X.input_from(0) - -sgd = ml.SGD(layers, 1, report_loss=True) -sgd.run_by_args(program, n_epochs, batch_size, X, Y) diff --git a/Programs/Source/mnist_full_A.mpc b/Programs/Source/mnist_full_A.mpc index 4a8065df..a1250b5d 100644 --- a/Programs/Source/mnist_full_A.mpc +++ b/Programs/Source/mnist_full_A.mpc @@ -1,3 +1,6 @@ +# this trains network A from SecureNN +# see https://github.com/csiro-mlai/mnist-mpc for data preparation + import ml import math import re diff --git a/Programs/Source/mnist_full_B.mpc b/Programs/Source/mnist_full_B.mpc index 334072b4..84cfe615 100644 --- a/Programs/Source/mnist_full_B.mpc +++ b/Programs/Source/mnist_full_B.mpc @@ -1,3 +1,6 @@ +# this trains network B from SecureNN +# see https://github.com/csiro-mlai/mnist-mpc for data preparation + import ml import math import re diff --git a/Programs/Source/mnist_full_C.mpc b/Programs/Source/mnist_full_C.mpc index fb6ad55f..e0388a3f 100644 --- a/Programs/Source/mnist_full_C.mpc +++ b/Programs/Source/mnist_full_C.mpc @@ -1,3 +1,6 @@ +# this trains network C (LeNet) from SecureNN +# see https://github.com/csiro-mlai/mnist-mpc for data preparation + import ml import math import re @@ -58,9 +61,10 @@ layers = [ ml.Dense(N, 800, 500), ml.Relu([N, 500]), ml.Dense(N, 500, 10), - ml.MultiOutput(n_examples, 10) ] +layers += [ml.MultiOutput.from_args(program, n_examples, 10)] + if 'dropout' in program.args or 'dropout2' in program.args: layers.insert(8, ml.Dropout(N, 500)) elif 'dropout.25' in program.args: diff --git a/Programs/Source/mnist_full_D.mpc b/Programs/Source/mnist_full_D.mpc index f250de3a..7ebe1904 100644 --- a/Programs/Source/mnist_full_D.mpc +++ b/Programs/Source/mnist_full_D.mpc @@ -1,3 +1,6 @@ +# this trains network D from SecureNN +# see https://github.com/csiro-mlai/mnist-mpc for data preparation + import ml import math import re diff --git a/Programs/Source/mnist_logreg.mpc b/Programs/Source/mnist_logreg.mpc index f7d77bd5..5fe18d9e 100644 --- a/Programs/Source/mnist_logreg.mpc +++ b/Programs/Source/mnist_logreg.mpc @@ -1,3 +1,6 @@ +# this trains logistic regression in 0/1 distinction +# see https://github.com/csiro-mlai/mnist-mpc for data preparation + import ml program.options_from_args() diff --git a/Protocols/Atlas.h b/Protocols/Atlas.h new file mode 100644 index 00000000..1a6d66d4 --- /dev/null +++ b/Protocols/Atlas.h @@ -0,0 +1,71 @@ +/* + * Atla.h + * + */ + +#ifndef PROTOCOLS_ATLAS_H_ +#define PROTOCOLS_ATLAS_H_ + +#include "Replicated.h" + +template +class Atlas : public ProtocolBase +{ + Shamir shamir, shamir2; + + Bundle oss, oss2; + PointerVector masks; + + vector> double_sharings; + + vector reconstruction; + + int next_king, base_king; + + ShamirInput resharing; + + typename T::open_type dotprod_share; + + array get_double_sharing(); + +public: + Player& P; + + Atlas(Player& P) : + shamir(P), shamir2(P, 2 * ShamirMachine::s().threshold), oss(P), + oss2(P), next_king(0), base_king(0), resharing(0, P), P(P) + { + } + + ~Atlas(); + + Atlas branch() + { + return P; + } + + int get_n_relevant_players() + { + return shamir.get_n_relevant_players(); + } + + void init_mul(Preprocessing&, typename T::MAC_Check&) + { + init_mul(); + } + + void init_mul(SubProcessor* proc = 0); + typename T::clear prepare_mul(const T& x, const T& y, int n = -1); + void prepare(const typename T::open_type& product); + void exchange(); + T finalize_mul(int n = -1); + + void init_dotprod(SubProcessor* proc); + void prepare_dotprod(const T& x, const T& y); + void next_dotprod(); + T finalize_dotprod(int length); + + T get_random(); +}; + +#endif /* PROTOCOLS_ATLAS_H_ */ diff --git a/Protocols/Atlas.hpp b/Protocols/Atlas.hpp new file mode 100644 index 00000000..bb6f18bf --- /dev/null +++ b/Protocols/Atlas.hpp @@ -0,0 +1,132 @@ +/* + * Atlas.hpp + * + */ + +#ifndef PROTOCOLS_ATLAS_HPP_ +#define PROTOCOLS_ATLAS_HPP_ + +#include "Atlas.h" + +template +Atlas::~Atlas() +{ +#ifdef VERBOSE + if (not double_sharings.empty()) + cerr << double_sharings.size() << " double sharings left" << endl; +#endif +} + +template +array Atlas::get_double_sharing() +{ + if (double_sharings.empty()) + { + SeededPRNG G; + PRNG G2 = G; + auto random = shamir.get_randoms(G, 0); + auto random2 = shamir2.get_randoms(G2, 0); + assert(random.size() == random2.size()); + assert(random.size() % P.num_players() == 0); + for (size_t i = 0; i < random.size(); i++) + double_sharings.push_back({{random2.at(i), random.at(i)}}); + } + + auto res = double_sharings.back(); + double_sharings.pop_back(); + return res; +} + +template +void Atlas::init_mul(SubProcessor*) +{ + oss.reset(); + oss2.reset(); + masks.clear(); + base_king = next_king; +} + +template +typename T::clear Atlas::prepare_mul(const T& x, const T& y, int) +{ + prepare(x * y); + return {}; +} + +template +void Atlas::prepare(const typename T::open_type& product) +{ + auto r = get_double_sharing(); + (product + r[0]).pack(oss2[next_king]); + next_king = (next_king + 1) % P.num_players(); + masks.push_back(r[1]); +} + +template +void Atlas::exchange() +{ + P.send_receive_all(oss2, oss); + oss.mine = oss2.mine; + + int t = ShamirMachine::s().threshold; + if (reconstruction.empty()) + for (int i = 0; i < 2 * t + 1; i++) + reconstruction.push_back(Shamir::get_rec_factor(i, 2 * t + 1)); + resharing.reset_all(P); + + for (size_t j = P.get_player(-base_king); j < masks.size(); + j += P.num_players()) + { + typename T::open_type e; + for (int i = 0; i < 2 * t + 1; i++) + { + auto tmp = oss[i].template get(); + e += tmp * reconstruction.at(i); + } + resharing.add_mine(e); + } + + resharing.exchange(); +} + +template +T Atlas::finalize_mul(int) +{ + T res = resharing.finalize(base_king) - masks.next(); + base_king = (base_king + 1) % P.num_players(); + return res; +} + +template +void Atlas::init_dotprod(SubProcessor* proc) +{ + init_mul(proc); + dotprod_share = 0; +} + +template +void Atlas::prepare_dotprod(const T& x, const T& y) +{ + dotprod_share += x * y; +} + +template +void Atlas::next_dotprod() +{ + prepare(dotprod_share); + dotprod_share = 0; +} + +template +T Atlas::finalize_dotprod(int) +{ + return finalize_mul(); +} + +template +T Atlas::get_random() +{ + return shamir.get_random(); +} + +#endif /* PROTOCOLS_ATLAS_HPP_ */ diff --git a/Protocols/AtlasPrep.h b/Protocols/AtlasPrep.h new file mode 100644 index 00000000..489f535a --- /dev/null +++ b/Protocols/AtlasPrep.h @@ -0,0 +1,39 @@ +/* + * AtlasPrep.h + * + */ + +#ifndef PROTOCOLS_ATLASPREP_H_ +#define PROTOCOLS_ATLASPREP_H_ + +#include "ReplicatedPrep.h" + +template +class AtlasPrep : public ReplicatedPrep +{ +public: + AtlasPrep(SubProcessor* proc, DataPositions& usage) : + BufferPrep(usage), BitPrep(proc, usage), + ReplicatedRingPrep(proc, usage), + RingPrep(proc, usage), + SemiHonestRingPrep(proc, usage), + ReplicatedPrep(proc, usage) + { + } + + void buffer_inputs(int player) + { + assert(this->protocol and this->proc); + int batch_size = OnlineOptions::singleton.batch_size; + typename T::MAC_Check MC; + vector shares; + for (int i = 0; i < batch_size; i++) + shares.push_back(this->protocol->get_random()); + vector opened; + this->proc->MC.POpen(opened, shares, this->proc->P); + for (int i = 0; i < batch_size; i++) + this->inputs.at(player).push_back({shares[i], opened[i]}); + } +}; + +#endif /* PROTOCOLS_ATLASPREP_H_ */ diff --git a/Protocols/AtlasShare.h b/Protocols/AtlasShare.h new file mode 100644 index 00000000..cb5528c0 --- /dev/null +++ b/Protocols/AtlasShare.h @@ -0,0 +1,46 @@ +/* + * AtlasShare.h + * + */ + +#ifndef PROTOCOLS_ATLASSHARE_H_ +#define PROTOCOLS_ATLASSHARE_H_ + +#include "ShamirShare.h" + +template class Atlas; +template class AtlasPrep; + +namespace GC +{ +class AtlasSecret; +} + +template +class AtlasShare : public ShamirShare +{ + typedef AtlasShare This; + typedef ShamirShare super; + +public: + typedef Atlas Protocol; + typedef ::Input Input; + typedef IndirectShamirMC MAC_Check; + typedef ShamirMC Direct_MC; + typedef ::PrivateOutput PrivateOutput; + typedef AtlasPrep LivePrep; + + typedef GC::AtlasSecret bit_type; + + AtlasShare() + { + } + + template + AtlasShare(const U& other) : + super(other) + { + } +}; + +#endif /* PROTOCOLS_ATLASSHARE_H_ */ diff --git a/Protocols/CowGearOptions.cpp b/Protocols/CowGearOptions.cpp index c1fbcaab..a62fb92e 100644 --- a/Protocols/CowGearOptions.cpp +++ b/Protocols/CowGearOptions.cpp @@ -58,10 +58,19 @@ CowGearOptions::CowGearOptions(ez::ezOptionParser& opt, int argc, 0, // Required? 0, // Number of args expected. 0, // Delimiter if expecting multiple args. - "Use TopGear", // Help description. + "Obsolete", // Help description. "-T", // Flag token. "--top-gear" // Flag token. ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Don't use TopGear", // Help description. + "-J", // Flag token. + "--no-top-gear" // Flag token. + ); opt.parse(argc, argv); if (opt.isSet("-c")) opt.get("-c")->getInt(covert_security); @@ -77,6 +86,8 @@ CowGearOptions::CowGearOptions(ez::ezOptionParser& opt, int argc, if (covert_security > (1LL << lowgear_security)) insecure(", LowGear security less than key generation security"); } - use_top_gear = opt.isSet("-T"); + use_top_gear = not opt.isSet("-J"); + if (opt.isSet("-T")) + cerr << "WARNING: Option -T/--top-gear is obsolete." << endl; opt.resetArgs(); } diff --git a/Protocols/FakeProtocol.h b/Protocols/FakeProtocol.h index 0c5fe969..9e650a07 100644 --- a/Protocols/FakeProtocol.h +++ b/Protocols/FakeProtocol.h @@ -10,6 +10,8 @@ #include "Math/Z2k.h" #include "Processor/Instruction.h" +#include + template class FakeProtocol : public ProtocolBase { @@ -20,10 +22,15 @@ class FakeProtocol : public ProtocolBase T trunc_max; + int fails; + + vector trunc_stats; + public: Player& P; - FakeProtocol(Player& P) : P(P) + FakeProtocol(Player& P) : + fails(0), trunc_stats(T::MAX_N_BITS + 1), P(P) { } @@ -31,6 +38,15 @@ public: ~FakeProtocol() { output_trunc_max<0>(T::invertible); + double expected = 0; + for (int i = 0; i <= T::MAX_N_BITS; i++) + { + if (trunc_stats[i] != 0) + cerr << i << ": " << trunc_stats[i] << endl; + expected += trunc_stats[i] * exp2(i - T::MAX_N_BITS); + } + if (expected != 0) + cerr << "Expected truncation failures: " << expected << endl; } template @@ -119,24 +135,31 @@ public: template void trunc_pr(const vector& regs, int size, SubProcessor& proc, false_type) { + this->trunc_rounds++; + this->trunc_pr_counter += size * regs.size() / 4; for (size_t i = 0; i < regs.size(); i += 4) for (int l = 0; l < size; l++) { auto& res = proc.get_S_ref(regs[i] + l); auto& source = proc.get_S_ref(regs[i + 1] + l); - T tmp = source - (T(1) << regs[i + 2] - 1); + T tmp = source; tmp = tmp < T() ? (T() - tmp) : tmp; trunc_max = max(trunc_max, tmp); +#ifdef TRUNC_PR_EMULATION_STATS + trunc_stats.at(tmp == T() ? 0 : tmp.bit_length())++; +#endif #ifdef CHECK_BOUNDS_IN_TRUNC_PR_EMULATION auto test = (source >> (regs[i + 2])); - if (test != 0) + if (test != 0 and test != T(-1) >> regs[i + 2]) { cerr << typename T::clear(source) << " has more than " << regs[i + 2] << " bits in " << regs[i + 3] << "-bit truncation (test value " << typename T::clear(test) << ")" << endl; - throw runtime_error("trunc_pr overflow"); + fails++; + if (fails > 1000) + throw runtime_error("trunc_pr overflow"); } #endif int n_shift = regs[i + 3]; diff --git a/Protocols/HemiPrep.h b/Protocols/HemiPrep.h index bfa3e25b..c301e736 100644 --- a/Protocols/HemiPrep.h +++ b/Protocols/HemiPrep.h @@ -33,6 +33,7 @@ public: SemiHonestRingPrep(proc, usage) { } + ~HemiPrep(); void buffer_triples(); }; diff --git a/Protocols/HemiPrep.hpp b/Protocols/HemiPrep.hpp index af5909f2..9aa93e0e 100644 --- a/Protocols/HemiPrep.hpp +++ b/Protocols/HemiPrep.hpp @@ -34,6 +34,14 @@ void HemiPrep::basic_setup(Player& P) T::clear::template init(); } + +template +HemiPrep::~HemiPrep() +{ + for (auto& x : multipliers) + delete x; +} + template void HemiPrep::buffer_triples() { diff --git a/Protocols/HighGearKeyGen.cpp b/Protocols/HighGearKeyGen.cpp index cb3af5e9..2618feba 100644 --- a/Protocols/HighGearKeyGen.cpp +++ b/Protocols/HighGearKeyGen.cpp @@ -12,26 +12,7 @@ template<> void PartSetup::key_and_mac_generation(Player& P, MachineBase& machine, int, false_type) { - auto& batch_size = OnlineOptions::singleton.batch_size; - auto backup = batch_size; - batch_size = 100; - bool done = false; - int n_limbs[2]; - for (int i = 0; i < 2; i++) - n_limbs[i] = params.FFTD()[i].get_prD().get_t(); -#define X(L, M) \ - if (n_limbs[0] == L and n_limbs[1] == M) \ - { \ - HighGearKeyGen(P, params).run(*this, machine); \ - done = true; \ - } - X(5, 3) X(4, 3) X(3, 2) - if (not done) - throw runtime_error( - "not compiled for choice of parameters, add X(" - + to_string(n_limbs[0]) + ", " + to_string(n_limbs[1]) - + ") at " + __FILE__ + ":" + to_string(__LINE__ - 5)); - batch_size = backup; + HighGearKeyGen<0, 0>(P, params).run(*this, machine); } template<> diff --git a/Protocols/HighGearKeyGen.h b/Protocols/HighGearKeyGen.h index 6f9a7b66..7b9fb78e 100644 --- a/Protocols/HighGearKeyGen.h +++ b/Protocols/HighGearKeyGen.h @@ -37,8 +37,8 @@ template class HighGearKeyGen { public: - typedef KeyGenProtocol<5, L> Proto0; - typedef KeyGenProtocol<7, M> Proto1; + typedef KeyGenProtocol<1, -1> Proto0; + typedef KeyGenProtocol<2, -1> Proto1; typedef typename Proto0::share_type share_type0; typedef typename Proto1::share_type share_type1; diff --git a/Protocols/HighGearKeyGen.hpp b/Protocols/HighGearKeyGen.hpp index 9645a331..41a45245 100644 --- a/Protocols/HighGearKeyGen.hpp +++ b/Protocols/HighGearKeyGen.hpp @@ -38,6 +38,7 @@ void HighGearKeyGen::buffer_mabits() bmc.Check(P); for (int i = 0; i < batch_size; i++) { + assert(open_diffs.at(i).get_bit(1) == 0); bits0.push_back(my_bits0[i]); bits1.push_back( my_bits1[i] @@ -45,6 +46,30 @@ void HighGearKeyGen::buffer_mabits() proto1.MC->get_alphai()) - my_bits1[i] * open_diffs.at(i) * 2); } + +#ifdef DEBUG_HIGHGEAR_KEYGEN + proto0.MC->init_open(P); + proto1.MC->init_open(P); + auto it0 = bits0.end() - batch_size; + auto it1 = bits1.end() - batch_size; + for (int i = 0; i < batch_size; i++) + { + proto0.MC->prepare_open(*it0); + proto1.MC->prepare_open(*it1); + it0++; + it1++; + } + proto0.MC->exchange(P); + proto1.MC->exchange(P); + for (int i = 0; i < batch_size; i++) + { + auto x0 = proto0.MC->finalize_open(); + auto x1 = proto1.MC->finalize_open(); + assert(x0.is_bit()); + assert(x1.is_bit()); + assert(x0.is_zero() == x1.is_zero()); + } +#endif } template diff --git a/Protocols/LowGearKeyGen.cpp b/Protocols/LowGearKeyGen.cpp index ffa42eb7..2b149bc0 100644 --- a/Protocols/LowGearKeyGen.cpp +++ b/Protocols/LowGearKeyGen.cpp @@ -12,18 +12,7 @@ template<> void PairwiseSetup::key_and_mac_generation(Player& P, PairwiseMachine& machine, int, false_type) { - int n_limbs = params.FFTD()[0].get_prD().get_t(); - switch (n_limbs) - { -#define X(L) case L: LowGearKeyGen(P, machine, params).run(*this); break; - X(3) X(4) X(5) X(6) -#undef X - default: - throw runtime_error( - "not compiled for choice of parameters, add X(" - + to_string(n_limbs) + ") at " + __FILE__ + ":" - + to_string(__LINE__ - 5)); - } + LowGearKeyGen<0>(P, machine, params).run(*this); } template<> diff --git a/Protocols/LowGearKeyGen.h b/Protocols/LowGearKeyGen.h index 930aa19a..534cec85 100644 --- a/Protocols/LowGearKeyGen.h +++ b/Protocols/LowGearKeyGen.h @@ -12,15 +12,18 @@ #include "Processor/Processor.h" #include "GC/TinierSecret.h" #include "Math/gfp.h" +#include "Math/gfpvar.h" template class KeyGenProtocol { public: - typedef Share> share_type; + typedef Share> share_type; typedef typename share_type::open_type open_type; typedef ShareVector vector_type; + int backup_batch_size; + protected: Player& P; const FHE_Params& params; @@ -53,9 +56,9 @@ public: }; template -class LowGearKeyGen : public KeyGenProtocol<5, L> +class LowGearKeyGen : public KeyGenProtocol<1, L> { - typedef KeyGenProtocol<5, L> super; + typedef KeyGenProtocol<1, L> super; typedef typename super::share_type share_type; typedef typename super::open_type open_type; diff --git a/Protocols/LowGearKeyGen.hpp b/Protocols/LowGearKeyGen.hpp index c1e7e825..d0dd6f34 100644 --- a/Protocols/LowGearKeyGen.hpp +++ b/Protocols/LowGearKeyGen.hpp @@ -14,7 +14,7 @@ template LowGearKeyGen::LowGearKeyGen(Player& P, PairwiseMachine& machine, FHE_Params& params) : - KeyGenProtocol<5, L>(P, params), P(P), machine(machine) + KeyGenProtocol<1, L>(P, params), P(P), machine(machine) { } @@ -26,6 +26,10 @@ KeyGenProtocol::KeyGenProtocol(Player& P, const FHE_Params& params, open_type::init_field(params.FFTD().at(level).get_prD().pr); typename share_type::mac_key_type alphai; + auto& batch_size = OnlineOptions::singleton.batch_size; + backup_batch_size = batch_size; + batch_size = 100; + if (OnlineOptions::singleton.live_prep) { prep = new MascotDabitOnlyPrep(0, usage); @@ -52,6 +56,8 @@ KeyGenProtocol::~KeyGenProtocol() delete proc; delete prep; delete MC; + + OnlineOptions::singleton.batch_size = backup_batch_size; } template diff --git a/Protocols/MaliciousRepPrep.hpp b/Protocols/MaliciousRepPrep.hpp index 43b2292f..4be3fc63 100644 --- a/Protocols/MaliciousRepPrep.hpp +++ b/Protocols/MaliciousRepPrep.hpp @@ -191,6 +191,15 @@ void MaliciousBitOnlyRepPrep::buffer_bits() honest_prep.get_two(DATA_SQUARE, f, h); bits.push_back(a); check_squares.push_back({{f, h}}); +#ifdef DEBUG_BIT_SACRIFICE + typename T::MAC_Check MC; + if (not (MC.open(a, P).is_zero() or MC.open(a, P).is_one())) + { + cout << MC.open(a, P) << endl; + throw exception(); + } + assert(MC.open(f, P) * MC.open(f, P) == MC.open(h, P)); +#endif } auto t = Create_Random(P); for (int i = 0; i < buffer_size; i++) diff --git a/Protocols/MascotPrep.h b/Protocols/MascotPrep.h index 3972b91a..abe82baf 100644 --- a/Protocols/MascotPrep.h +++ b/Protocols/MascotPrep.h @@ -56,6 +56,11 @@ template class MascotDabitOnlyPrep : public virtual MaliciousDabitOnlyPrep, public virtual MascotTriplePrep { + template + void buffer_bits(true_type); + template + void buffer_bits(false_type); + public: MascotDabitOnlyPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), diff --git a/Protocols/MascotPrep.hpp b/Protocols/MascotPrep.hpp index 317469f9..70ed1fac 100644 --- a/Protocols/MascotPrep.hpp +++ b/Protocols/MascotPrep.hpp @@ -71,6 +71,20 @@ void MascotTriplePrep::buffer_triples() template void MascotDabitOnlyPrep::buffer_bits() +{ + buffer_bits<0>(T::clear::prime_field); +} + +template +template +void MascotDabitOnlyPrep::buffer_bits(true_type) +{ + buffer_bits_from_squares(*this); +} + +template +template +void MascotDabitOnlyPrep::buffer_bits(false_type) { this->params.generateBits = true; auto& triple_generator = this->triple_generator; diff --git a/Protocols/Rep3Share2k.h b/Protocols/Rep3Share2k.h index d9225d5c..c7a49452 100644 --- a/Protocols/Rep3Share2k.h +++ b/Protocols/Rep3Share2k.h @@ -31,6 +31,9 @@ public: typedef GC::SemiHonestRepSecret bit_type; + static const bool has_trunc_pr = true; + static const bool has_split = true; + Rep3Share2() { } diff --git a/Protocols/Rep4Share2k.h b/Protocols/Rep4Share2k.h index d902cc76..58f15f23 100644 --- a/Protocols/Rep4Share2k.h +++ b/Protocols/Rep4Share2k.h @@ -31,6 +31,9 @@ public: typedef ::PrivateOutput PrivateOutput; typedef Rep4RingOnlyPrep LivePrep; + static const bool has_trunc_pr = true; + static const bool has_split = true; + Rep4Share2() { } diff --git a/Protocols/Replicated.h b/Protocols/Replicated.h index 4f746c48..78e08712 100644 --- a/Protocols/Replicated.h +++ b/Protocols/Replicated.h @@ -73,6 +73,7 @@ public: virtual typename T::clear prepare_mul(const T& x, const T& y, int n = -1) = 0; virtual void exchange() = 0; virtual T finalize_mul(int n = -1) = 0; + virtual void finalize_mult(T& res, int n = -1); void init_dotprod(SubProcessor* proc) { init_mul(proc); } void prepare_dotprod(const T& x, const T& y) { prepare_mul(x, y); } diff --git a/Protocols/Replicated.hpp b/Protocols/Replicated.hpp index c078d08b..75dc785b 100644 --- a/Protocols/Replicated.hpp +++ b/Protocols/Replicated.hpp @@ -116,6 +116,12 @@ T ProtocolBase::mul(const T& x, const T& y) return finalize_mul(); } +template +void ProtocolBase::finalize_mult(T& res, int n) +{ + res = finalize_mul(n); +} + template T ProtocolBase::finalize_dotprod(int length) { diff --git a/Protocols/ReplicatedPrep.h b/Protocols/ReplicatedPrep.h index 4c1b99b4..a8c8266a 100644 --- a/Protocols/ReplicatedPrep.h +++ b/Protocols/ReplicatedPrep.h @@ -323,6 +323,11 @@ template class ReplicatedPrep : public virtual ReplicatedRingPrep, public virtual SemiHonestRingPrep { + template + void buffer_bits(false_type); + template + void buffer_bits(true_type); + public: ReplicatedPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), diff --git a/Protocols/ReplicatedPrep.hpp b/Protocols/ReplicatedPrep.hpp index 628e3573..27c96c19 100644 --- a/Protocols/ReplicatedPrep.hpp +++ b/Protocols/ReplicatedPrep.hpp @@ -321,16 +321,15 @@ void buffer_bits_from_squares(RingPrep& prep) throw runtime_error("squares were all zero"); } -template class T, int X, int L> -void buffer_bits_spec(ReplicatedPrep>>& prep, vector>>& bits, - typename T>::Protocol& prot) +template +template +void ReplicatedPrep::buffer_bits(true_type) { - (void) bits, (void) prot; - if (prot.get_n_relevant_players() > 10 + if (this->protocol->get_n_relevant_players() > 10 or OnlineOptions::singleton.bits_from_squares) - buffer_bits_from_squares(prep); + buffer_bits_from_squares(*this); else - prep.ReplicatedRingPrep>>::buffer_bits(); + ReplicatedRingPrep::buffer_bits(); } template @@ -535,7 +534,7 @@ void MaliciousRingPrep::buffer_personal_edabits(int n_bits, vector& wholes template void buffer_bits_from_players(vector>& player_bits, - vector& G, SubProcessor& proc, int base_player, + PRNG& G, SubProcessor& proc, int base_player, int buffer_size, int n_bits) { auto& protocol = proc.protocol; @@ -553,7 +552,7 @@ void buffer_bits_from_players(vector>& player_bits, { typename T::clear tmp; for (int j = 0; j < n_bits; j++) - tmp += typename T::clear(G[j % G.size()].get_bit()) << j; + tmp += typename T::clear(G.get_bit()) << j; input.add_mine(tmp, n_bits); } } @@ -565,17 +564,17 @@ void buffer_bits_from_players(vector>& player_bits, for (int i = 0; i < n_relevant_players; i++) for (auto& x : player_bits[i]) x = input.finalize((base_player + i) % P.num_players(), n_bits); -} - -template -void buffer_bits_from_players(vector>& player_bits, PRNG& G, - SubProcessor& proc, int base_player, int buffer_size, - int n_bits = -1) -{ - vector Gs = {G}; - buffer_bits_from_players(player_bits, Gs, proc, base_player, buffer_size, - n_bits); - G = Gs[0]; +#if !defined(__clang__) && (__GNUC__ == 6) + // mitigate compiler bug + Bundle bundle(P); + P.unchecked_broadcast(bundle); +#endif +#ifdef DEBUG_BIT_SACRIFICE + typename T::MAC_Check MC; + for (int i = 0; i < n_relevant_players; i++) + for (auto& x : player_bits[i]) + assert((MC.open(x, P) == 0) or (MC.open(x, P) == 1)); +#endif } template @@ -927,35 +926,18 @@ void bits_from_random(vector& bits, typename T::Protocol& protocol) } } -template class T> -void buffer_bits_spec(ReplicatedPrep>& prep, vector>& bits, - typename T::Protocol& prot) +template +template +void ReplicatedPrep::buffer_bits(false_type) { - (void) bits, (void) prot; - prep.ReplicatedRingPrep>::buffer_bits(); -} - -template class T> -void buffer_bits_spec(ReplicatedPrep>& prep, vector>& bits, - typename T::Protocol& prot) -{ - (void) bits, (void) prot; - prep.ReplicatedRingPrep>::buffer_bits(); -} - -template class T, int K> -void buffer_bits_spec(ReplicatedPrep>>& prep, vector>>& bits, - typename T>::Protocol& prot) -{ - (void) bits, (void) prot; - prep.ReplicatedRingPrep>>::buffer_bits(); + ReplicatedRingPrep::buffer_bits(); } template void ReplicatedPrep::buffer_bits() { assert(this->protocol != 0); - buffer_bits_spec(*this, this->bits, *this->protocol); + buffer_bits<0>(T::clear::prime_field); } template diff --git a/Protocols/Semi2kShare.h b/Protocols/Semi2kShare.h index 20ecd268..a9df48b4 100644 --- a/Protocols/Semi2kShare.h +++ b/Protocols/Semi2kShare.h @@ -37,6 +37,8 @@ public: typedef GC::SemiSecret bit_type; + static const bool has_split = true; + Semi2kShare() { } diff --git a/Protocols/Shamir.h b/Protocols/Shamir.h index e3a6e288..e9336c75 100644 --- a/Protocols/Shamir.h +++ b/Protocols/Shamir.h @@ -15,6 +15,7 @@ template class SubProcessor; template class ShamirMC; template class ShamirShare; template class ShamirInput; +template class IndirectShamirMC; class Player; @@ -31,7 +32,7 @@ class Shamir : public ProtocolBase SeededPRNG secure_prng; - vector> hyper; + map>> hypers; typename T::open_type dotprod_share; @@ -48,7 +49,7 @@ public: static U get_rec_factor(int i, int n); static U get_rec_factor(int i, int n_total, int start, int threshold); - Shamir(Player& P); + Shamir(Player& P, int threshold = 0); ~Shamir(); Shamir branch(); @@ -85,6 +86,12 @@ public: void prepare_dotprod(const T& x, const T& y); void next_dotprod(); T finalize_dotprod(int length); + + vector get_randoms(PRNG& G, int t); + + vector>& get_hyper(int t); + static void get_hyper(vector>& hyper, int t, int n); + static string hyper_filename(int t, int n); }; #endif /* PROTOCOLS_SHAMIR_H_ */ diff --git a/Protocols/Shamir.hpp b/Protocols/Shamir.hpp index 0127be50..10010746 100644 --- a/Protocols/Shamir.hpp +++ b/Protocols/Shamir.hpp @@ -8,6 +8,7 @@ #include "Shamir.h" #include "ShamirInput.h" +#include "ShamirShare.h" #include "Machines/ShamirMachine.h" #include "Tools/benchmarking.h" @@ -32,12 +33,15 @@ typename T::open_type::Scalar Shamir::get_rec_factor(int i, int n_total, } template -Shamir::Shamir(Player& P) : +Shamir::Shamir(Player& P, int t) : resharing(0), random_input(0), P(P) { if (not P.is_encrypted()) insecure("unencrypted communication"); - threshold = ShamirMachine::s().threshold; + if (t > 0) + threshold = t; + else + threshold = ShamirMachine::s().threshold; n_mul_players = 2 * threshold + 1; } @@ -59,7 +63,7 @@ Shamir Shamir::branch() template int Shamir::get_n_relevant_players() { - return ShamirMachine::s().threshold + 1; + return threshold + 1; } template @@ -182,32 +186,92 @@ T Shamir::finalize_dotprod(int) template void Shamir::buffer_random() { - if (hyper.empty()) + this->random = get_randoms(secure_prng, threshold); +} + +template +vector>& Shamir::get_hyper(int t) +{ + auto& hyper = hypers[t]; + if (int(hyper.size()) != P.num_players() - t) { - int n = P.num_players(); - for (int i = 0; i < n - threshold; i++) - { - hyper.push_back({}); - for (int j = 0; j < n; j++) - { - hyper.back().push_back({1}); - for (int k = 0; k < n; k++) - if (k != j) - hyper.back().back() *= U(n + i - k) / U(j - k); - } - } + get_hyper(hyper, t, P.num_players()); + } + return hyper; +} + +template +string Shamir::hyper_filename(int t, int n) +{ + return PREP_DIR "/Hyper-" + to_string(t) + "-" + to_string(n) + "-" + + to_string(T::clear::pr()); +} + +template +void Shamir::get_hyper(vector >& hyper, + int t, int n) +{ + assert(hyper.empty()); + + try + { + octetStream os; + string filename = hyper_filename(t, n); + ifstream in(filename); +#ifdef VERBOSE + cerr << "Trying to load hyper-invertable matrix from " << filename << endl; +#endif + os.input(in); + os.get(hyper); + if (int(hyper.size()) != n - t) + throw exception(); +#ifdef VERBOSE + cerr << "Loaded hyper-invertable matrix from " << filename << endl; +#endif + return; + } + catch (...) + { +#ifdef VERBOSE + cerr << "Failed to load hyper-invertable" << endl; +#endif } + map inverses, dividends; + for (int i = -n; i < n; i++) + if (i != 0) + inverses[i] = U(i).invert(); + for (int i = 0; i < 2 * n; i++) + dividends[i] = i; + for (int i = 0; i < n - t; i++) + { + hyper.push_back({}); + for (int j = 0; j < n; j++) + { + hyper.back().push_back({1}); + for (int k = 0; k < n; k++) + if (k != j) + hyper.back().back() *= dividends.at(n + i - k) + * inverses.at(j - k); + } + } +} + +template +vector Shamir::get_randoms(PRNG& G, int t) +{ + auto& hyper = get_hyper(t); if (random_input == 0) - random_input = new ShamirInput(0, P); + random_input = new ShamirInput(0, P, threshold); auto& input = *random_input; input.reset_all(P); int buffer_size = OnlineOptions::singleton.batch_size; for (int i = 0; i < buffer_size; i += hyper.size()) - input.add_mine(secure_prng.get()); + input.add_mine(G.get()); input.exchange(); vector inputs; - auto& random = this->random; + vector random; + random.reserve(buffer_size + hyper.size()); for (int i = 0; i < buffer_size; i += hyper.size()) { inputs.clear(); @@ -220,6 +284,7 @@ void Shamir::buffer_random() random.back() += hyper[j][k] * inputs[k]; } } + return random; } #endif diff --git a/Protocols/ShamirInput.h b/Protocols/ShamirInput.h index 4567ae11..5958efc6 100644 --- a/Protocols/ShamirInput.h +++ b/Protocols/ShamirInput.h @@ -9,6 +9,7 @@ #include "Processor/Input.h" #include "Shamir.h" #include "ReplicatedInput.h" +#include "Machines/ShamirMachine.h" template class IndividualInput : public PrepLessInput @@ -40,29 +41,36 @@ class ShamirInput : public IndividualInput { friend class Shamir; - static vector> vandermonde; + vector> vandermonde; SeededPRNG secure_prng; vector randomness; + int threshold; + public: - static const vector>& get_vandermonde(size_t t, + static vector> get_vandermonde(size_t t, size_t n); - ShamirInput(SubProcessor& proc, ShamirMC& MC) : - IndividualInput(proc) + ShamirInput(SubProcessor& proc, typename T::MAC_Check& MC) : + ShamirInput(&proc, proc.P) { (void) MC; } - ShamirInput(SubProcessor* proc, Player& P) : + ShamirInput(SubProcessor* proc, Player& P, int t = 0) : IndividualInput(proc, P) { + if (t > 0) + threshold = t; + else + threshold = ShamirMachine::s().threshold; + } ShamirInput(ShamirMC&, Preprocessing&, Player& P) : - IndividualInput(0, P) + ShamirInput(0, P) { } diff --git a/Protocols/ShamirInput.hpp b/Protocols/ShamirInput.hpp index b4421b19..d84b09a6 100644 --- a/Protocols/ShamirInput.hpp +++ b/Protocols/ShamirInput.hpp @@ -11,9 +11,6 @@ #include "Protocols/ReplicatedInput.hpp" -template -vector> ShamirInput::vandermonde; - template void IndividualInput::reset(int player) { @@ -26,11 +23,10 @@ void IndividualInput::reset(int player) } template -const vector>& ShamirInput::get_vandermonde( +vector> ShamirInput::get_vandermonde( size_t t, size_t n) { - if (vandermonde.size() < n) - vandermonde.resize(n); + vector> vandermonde(n); for (int i = 0; i < int(n); i++) if (vandermonde[i].size() < t) @@ -53,8 +49,10 @@ void ShamirInput::add_mine(const typename T::open_type& input, int n_bits) (void) n_bits; auto& P = this->P; int n = P.num_players(); - int t = ShamirMachine::s().threshold; - const auto& vandermonde = get_vandermonde(t, n); + int t = threshold; + + if (vandermonde.empty()) + vandermonde = get_vandermonde(t, n); randomness.resize(t); for (auto& x : randomness) diff --git a/Protocols/ShamirMC.h b/Protocols/ShamirMC.h index 8ac858a1..ccd370f2 100644 --- a/Protocols/ShamirMC.h +++ b/Protocols/ShamirMC.h @@ -12,8 +12,22 @@ #include "Tools/Bundle.h" template -class ShamirMC : public MAC_Check_Base +class IndirectShamirMC : public MAC_Check_Base { + vector oss; + octetStream os; + +public: + IndirectShamirMC(typename T::mac_key_type = {}, int = 0, int = 0) {} + ~IndirectShamirMC() {} + + virtual void exchange(const Player& P); +}; + +template +class ShamirMC : public IndirectShamirMC +{ + typedef typename T::open_type::Scalar rec_type; vector reconstruction; void finalize(vector& values, const vector& S); @@ -26,7 +40,7 @@ protected: void prepare(const vector& S, const Player& P); public: - ShamirMC() : os(0), player(0), threshold(ShamirMachine::s().threshold) {} + ShamirMC(int threshold = 0); // emulate MAC_Check ShamirMC(const typename T::mac_key_type& _, int __ = 0, int ___ = 0) : ShamirMC() @@ -49,6 +63,8 @@ public: virtual typename T::open_type finalize_open(); void Check(const Player& P) { (void)P; } + + vector get_reconstruction(const Player& P); }; #endif /* PROTOCOLS_SHAMIRMC_H_ */ diff --git a/Protocols/ShamirMC.hpp b/Protocols/ShamirMC.hpp index 4e1d241d..6d6af913 100644 --- a/Protocols/ShamirMC.hpp +++ b/Protocols/ShamirMC.hpp @@ -8,6 +8,16 @@ #include "ShamirMC.h" +template +ShamirMC::ShamirMC(int t) : + os(0), player(0), threshold() +{ + if (t > 0) + threshold = t; + else + threshold = ShamirMachine::s().threshold; +} + template ShamirMC::~ShamirMC() { @@ -24,16 +34,24 @@ void ShamirMC::POpen_Begin(vector& values, P.send_all(os->mine); } +template +vector ShamirMC::get_reconstruction( + const Player& P) +{ + int n_relevant_players = threshold + 1; + vector reconstruction(n_relevant_players); + for (int i = 0; i < n_relevant_players; i++) + reconstruction[i] = Shamir::get_rec_factor(P.get_player(i), + P.num_players(), P.my_num(), n_relevant_players); + return reconstruction; +} + template void ShamirMC::init_open(const Player& P, int n) { - int n_relevant_players = ShamirMachine::s().threshold + 1; if (reconstruction.empty()) { - reconstruction.resize(n_relevant_players); - for (int i = 0; i < n_relevant_players; i++) - reconstruction[i] = Shamir::get_rec_factor(P.get_player(i), - P.num_players(), P.my_num(), n_relevant_players); + reconstruction = get_reconstruction(P); } if (not os) @@ -112,4 +130,42 @@ typename T::open_type ShamirMC::finalize_open() return res; } +template +void IndirectShamirMC::exchange(const Player& P) +{ + oss.resize(P.num_players()); + int threshold = ShamirMachine::s().threshold; + if (P.my_num() <= threshold) + { + oss[0].reset_write_head(); + auto rec_factor = Shamir::get_rec_factor(P.my_num(), threshold + 1); + for (auto& x : this->secrets) + (x * rec_factor).pack(oss[0]); + vector> channels(P.num_players(), + vector(P.num_players())); + for (int i = 0; i <= threshold; i++) + channels[i][0] = true; + P.send_receive_all(channels, oss, oss); + } + + if (P.my_num() == 0) + { + os.reset_write_head(); + while (oss[0].left()) + { + T sum; + for (int i = 0; i <= threshold; i++) + sum += oss[i].template get(); + sum.pack(os); + } + P.send_all(os); + } + + if (P.my_num() != 0) + P.receive_player(0, os); + + while (os.left()) + this->values.push_back(os.get()); +} + #endif diff --git a/Protocols/ShamirShare.h b/Protocols/ShamirShare.h index 921e4888..6e818c39 100644 --- a/Protocols/ShamirShare.h +++ b/Protocols/ShamirShare.h @@ -31,8 +31,8 @@ public: typedef GC::NoShare mac_share_type; typedef Shamir Protocol; - typedef ShamirMC MAC_Check; - typedef MAC_Check Direct_MC; + typedef IndirectShamirMC MAC_Check; + typedef ShamirMC Direct_MC; typedef ShamirInput Input; typedef ::PrivateOutput PrivateOutput; typedef ReplicatedPrep LivePrep; @@ -40,7 +40,7 @@ public: typedef ShamirShare Honest; #ifndef NO_MIXED_CIRCUITS - typedef GC::CcdSecret bit_type; + typedef GC::CcdSecret> bit_type; #endif const static bool needs_ot = false; diff --git a/Protocols/ShareInterface.h b/Protocols/ShareInterface.h index d4f3adaa..f9e2d5f2 100644 --- a/Protocols/ShareInterface.h +++ b/Protocols/ShareInterface.h @@ -32,6 +32,9 @@ public: static const bool expensive = false; static const bool expensive_triples = false; + static const bool has_trunc_pr = false; + static const bool has_split = false; + static const int default_length = 1; static string type_short() { return "undef"; } diff --git a/Protocols/ShareVector.hpp b/Protocols/ShareVector.hpp index b9a4a9e2..74eaf538 100644 --- a/Protocols/ShareVector.hpp +++ b/Protocols/ShareVector.hpp @@ -12,8 +12,8 @@ void ShareVector::fft(const FFT_Data& fftd) array, 2> data; for (auto& share : *this) { - data[0].push_back(share.get_share()); - data[1].push_back(share.get_mac()); + data[0].push_back({share.get_share(), fftd.get_prD()}); + data[1].push_back({share.get_mac(), fftd.get_prD()}); } for (auto& x : data) @@ -26,6 +26,7 @@ void ShareVector::fft(const FFT_Data& fftd) for (int i = 0; i < fftd.phi_m(); i++) { - (*this)[i] = {data[0][i], data[1][i]}; + typedef typename U::clear clear; + (*this)[i] = {clear(data[0][i], fftd.get_prD()), clear(data[1][i], fftd.get_prD())}; } } diff --git a/Protocols/SpdzWiseRingShare.h b/Protocols/SpdzWiseRingShare.h index bda30f86..476b47fb 100644 --- a/Protocols/SpdzWiseRingShare.h +++ b/Protocols/SpdzWiseRingShare.h @@ -38,6 +38,8 @@ public: static const int LENGTH = K; static const int SECURITY = S; + static const bool has_split = true; + SpdzWiseRingShare() { } diff --git a/Protocols/fake-stuff.hpp b/Protocols/fake-stuff.hpp index 76f28f36..89b029b7 100644 --- a/Protocols/fake-stuff.hpp +++ b/Protocols/fake-stuff.hpp @@ -174,12 +174,24 @@ void make_share(FixedVec* Sa, const V& a, int N, const U& key, PRNG& G) } } +template +class VanderStore +{ +public: + static vector> vandermonde; +}; + +template +vector> VanderStore::vandermonde; + template void make_share(ShamirShare* Sa, const V& a, int N, const typename ShamirShare::mac_type&, PRNG& G) { insecure("share generation", false); - const auto& vandermonde = ShamirInput>::get_vandermonde(N / 2, N); + auto& vandermonde = VanderStore::vandermonde; + if (vandermonde.empty()) + vandermonde = ShamirInput>::get_vandermonde(N / 2, N); vector randomness(N / 2); for (auto& x : randomness) x.randomize(G); @@ -188,7 +200,7 @@ void make_share(ShamirShare* Sa, const V& a, int N, auto& share = Sa[i]; share = a; for (int j = 0; j < ShamirOptions::singleton.threshold; j++) - share += vandermonde[i][j] * randomness[j]; + share += vandermonde.at(i).at(j) * randomness[j]; } } diff --git a/README.md b/README.md index abd2590e..e67304a5 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,10 @@ sharing (with an honest majority). [Filing an issue on GitHub](../../issues) is the preferred way of contacting us, but you can also write an email to mp-spdz@googlegroups.com -([archive](https://groups.google.com/forum/#!forum/mp-spdz)). +([archive](https://groups.google.com/forum/#!forum/mp-spdz)). Before +reporting a problem, please check against the list of [known +issues and possible +solutions](https://mp-spdz.readthedocs.io/en/latest/troubleshooting.html). #### Frequently Asked Questions @@ -84,7 +87,7 @@ The following table lists all protocols that are fully supported. | Covert, dishonest majority | [CowGear / ChaiGear](#secret-sharing) | N/A | N/A | N/A | | Semi-honest, dishonest majority | [Semi / Hemi / Soho](#secret-sharing) | [Semi2k](#secret-sharing) | [SemiBin](#secret-sharing) | [Yao's GC](#yaos-garbled-circuits) / [BMR](#bmr) | | Malicious, honest majority | [Shamir / Rep3 / PS / SY](#honest-majority) | [Brain / Rep[34] / PS / SY](#honest-majority) | [Rep3 / CCD / PS](#honest-majority) | [BMR](#bmr) | -| Semi-honest, honest majority | [Shamir / Rep3](#honest-majority) | [Rep3](#honest-majority) | [Rep3 / CCD](#honest-majority) | [BMR](#bmr) | +| Semi-honest, honest majority | [Shamir / ATLAS / Rep3](#honest-majority) | [Rep3](#honest-majority) | [Rep3 / CCD](#honest-majority) | [BMR](#bmr) | See [this paper](https://eprint.iacr.org/2020/300) for an explanation of the various security models and high-level introduction to @@ -165,12 +168,14 @@ The design of MP-SPDZ is described in [this paper](https://eprint.iacr.org/2020/521). If you use it for an academic project, please cite: ``` -@misc{mp-spdz, +@inproceedings{mp-spdz, author = {Marcel Keller}, title = {{MP-SPDZ}: A Versatile Framework for Multi-Party Computation}, - howpublished = {Cryptology ePrint Archive, Report 2020/521}, + booktitle = {Proceedings of the 2020 ACM SIGSAC Conference on + Computer and Communications Security}, year = {2020}, - note = {\url{https://eprint.iacr.org/2020/521}}, + doi = {10.1145/3372297.3417872}, + url = {https://doi.org/10.1145/3372297.3417872}, } ``` @@ -256,7 +261,7 @@ compute the preprocessing time for a particular computation. add `AVX_OT = 0` in addition. - To benchmark online-only protocols or Overdrive offline phases, add the following line at the top: `MY_CFLAGS = -DINSECURE` - `PREP_DIR` should point to a local, unversioned directory to store preprocessing data (the default is `Player-Data` in the current directory). - - For homomorphic encryption, set `USE_NTL = 1`. + - For homomorphic encryption with GF(2^40), set `USE_NTL = 1`. 2) Run `make` to compile all the software (use the flag `-j` for faster compilation using multiple threads). See below on how to compile specific @@ -381,15 +386,15 @@ variant by Mohassel and Rindal (available in Rep3 only). ##### Finding the most efficient variant Where available, local share conversion is likely the most efficient -variant. Protocols based on Shamir secret sharing are unlikely to -benefit from mixed-circuit computation because they use an extension -field for binary computation. Otherwise, edaBits likely offer an -asymptotic benefit. However, malicious protocols by default generate -large batches of edaBits (more than one million at once), which is -only worthwhile for accordingly large computation. For smaller -computation, try running the virtual machines with `-B 4` or `-B 5`, -which reduces the batch size to ~10,000 and ~1,000, respectively, at a -higher asymptotic cost. +variant. Otherwise, edaBits likely offer an asymptotic benefit. When +using edaBits with malicious protocols, there is a trade-off between +cost per item and batch size. The lowest cost per item requires large +batches of edaBits (more than one million at once), which is only +worthwhile for accordingly large computation. This setting can be +selected by running the virtual machine with `-B 3`. For smaller +computation, try `-B 4` or `-B 5`, which set the batch size to ~10,000 +and ~1,000, respectively, at a higher asymptotic cost. `-B 4` is the +default. #### Bristol Fashion circuits @@ -622,6 +627,7 @@ The following table shows all programs for honest-majority computation: | `ps-rep-field-party.x` | Replicated | Mod prime | Y | 3 | `ps-rep-field.sh` | | `sy-rep-field-party.x` | SPDZ-wise replicated | Mod prime | Y | 3 | `sy-rep-field.sh` | | `malicious-rep-field-party.x` | Replicated | Mod prime | Y | 3 | `mal-rep-field.sh` | +| `atlas-party.x` | [ATLAS](https://eprint.iacr.org/2021/833) | Mod prime | N | 3 or more | `atlas.sh` | | `shamir-party.x` | Shamir | Mod prime | N | 3 or more | `shamir.sh` | | `malicious-shamir-party.x` | Shamir | Mod prime | Y | 3 or more | `mal-shamir.sh` | | `sy-shamir-party.x` | SPDZ-wise Shamir | Mod prime | Y | 3 or more | `sy-shamir.sh` | diff --git a/Scripts/atlas.sh b/Scripts/atlas.sh new file mode 100755 index 00000000..7ec9b483 --- /dev/null +++ b/Scripts/atlas.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash + +HERE=$(cd `dirname $0`; pwd) +SPDZROOT=$HERE/.. + +export PLAYERS=${PLAYERS:-3} + +if test "$THRESHOLD"; then + t="-T $THRESHOLD" +fi + +. $HERE/run-common.sh + +run_player atlas-party.x $* $t || exit 1 diff --git a/Scripts/build.sh b/Scripts/build.sh index b8c32447..4ade5c7f 100755 --- a/Scripts/build.sh +++ b/Scripts/build.sh @@ -15,5 +15,8 @@ function build strip $dest/* } +echo AVX_OT = 0 >> CONFIG.mine build '-maes -mpclmul -DCHECK_AES -DCHECK_PCLMUL -DCHECK_AVX' amd64 + +echo AVX_OT = 1 >> CONFIG.mine build '-msse4.1 -maes -mpclmul -mavx -mavx2 -mbmi2 -madx -DCHECK_ADX' avx2 diff --git a/Scripts/run-common.sh b/Scripts/run-common.sh index 95e11499..e74bad5c 100644 --- a/Scripts/run-common.sh +++ b/Scripts/run-common.sh @@ -24,7 +24,7 @@ lldb_screen() } run_player() { - port=$((RANDOM%10000+10000)) + port=${PORT:-$((RANDOM%10000+10000))} bin=$1 shift prog=$1 diff --git a/Scripts/test_tutorial.sh b/Scripts/test_tutorial.sh index 7eb9acc0..10fe575f 100755 --- a/Scripts/test_tutorial.sh +++ b/Scripts/test_tutorial.sh @@ -40,6 +40,8 @@ function test_vm # big buckets for smallest batches run_opts="$run_opts -B 5" +export PORT=$((RANDOM%10000+10000)) + for dabit in ${dabit:-0 1 2}; do if [[ $dabit = 1 ]]; then compile_opts="$compile_opts -X" @@ -57,7 +59,7 @@ for dabit in ${dabit:-0 1 2}; do ./compile.py $compile_opts tutorial for i in rep-field shamir mal-rep-field ps-rep-field sy-rep-field \ - mal-shamir sy-shamir hemi semi \ + atlas mal-shamir sy-shamir hemi semi \ soho mascot; do test_vm $i $run_opts done @@ -81,7 +83,7 @@ fi ./compile.py tutorial for i in cowgear chaigear; do - test_vm $i $run_opts -l 3 -c 2 -T + test_vm $i $run_opts -l 3 -c 2 -J done if test $skip_binary; then diff --git a/Tools/Buffer.h b/Tools/Buffer.h index 4bc6d7d6..84fbdaca 100644 --- a/Tools/Buffer.h +++ b/Tools/Buffer.h @@ -8,6 +8,7 @@ #include #include +#include using namespace std; #include "Math/field_types.h" @@ -73,6 +74,12 @@ public: { } + BufferOwner(const BufferOwner& other) : + file(0) + { + assert(other.file == 0); + } + ~BufferOwner() { close(); diff --git a/Tools/Bundle.h b/Tools/Bundle.h index e4eff1ac..ed4b982e 100644 --- a/Tools/Bundle.h +++ b/Tools/Bundle.h @@ -38,6 +38,12 @@ public: if (os != mine) throw mismatch_among_parties(); } + + void reset() + { + for (auto& x : *this) + x.reset_write_head(); + } }; #endif /* TOOLS_BUNDLE_H_ */ diff --git a/Tools/MMO.h b/Tools/MMO.h index c2dc9e3f..1c2e0a7a 100644 --- a/Tools/MMO.h +++ b/Tools/MMO.h @@ -42,8 +42,12 @@ public: void hashBlocks(void* output, const void* input); template void hashEightBlocks(T* output, const void* input); + template + void hashEightBlocks(T* output, const void* input); template void hashEightBlocks(gfp_* output, const void* input); + template + void hashEightBlocks(gfpvar_* output, const void* input); template void outputOneBlock(octet* output); Key hash(const Key& input); diff --git a/Tools/MMO.hpp b/Tools/MMO.hpp index d0f43e45..4309e1fe 100644 --- a/Tools/MMO.hpp +++ b/Tools/MMO.hpp @@ -5,6 +5,7 @@ */ #include "MMO.h" +#include "Math/gfp.hpp" #include @@ -70,9 +71,15 @@ void MMO::hashBlocks(void* output, const void* input) template void MMO::hashEightBlocks(gfp_* output, const void* input) { - gfp_* out = (gfp_*)output; + hashEightBlocks, gfp_::N_BYTES>(output, input); +} + +template +void MMO::hashEightBlocks(T* output, const void* input) +{ + T* out = output; const int block_size = sizeof(__m128i); - const int n_blocks = (gfp_::N_BYTES + block_size - 1) / block_size; + const int n_blocks = (N_BYTES + block_size - 1) / block_size; __m128i tmp[8][n_blocks]; hashBlocks<8, n_blocks * block_size>(tmp, input, n_blocks * block_size); int left = 8; @@ -82,10 +89,10 @@ void MMO::hashEightBlocks(gfp_* output, const void* input) int now_left = 0; for (int j = 0; j < left; j++) { - memcpy(out[indices[j]].get_ptr(), &tmp[indices[j]], gfp_::N_BYTES); + memcpy(out[indices[j]].get_ptr(), &tmp[indices[j]], N_BYTES); out[indices[j]].zero_overhang(); if (mpn_cmp((mp_limb_t*) out[indices[j]].get_ptr(), - gfp_::get_ZpD().get_prA(), gfp_::t()) >= 0) + T::get_ZpD().get_prA(), T::get_ZpD().get_t()) >= 0) { indices[now_left] = indices[j]; now_left++; @@ -116,3 +123,19 @@ void MMO::hashEightBlocks(__m128i* output, const void* input) { hashBlocks<8, 16>(output, input, 16); } + +template +void MMO::hashEightBlocks(gfpvar_* output, const void* input) +{ +#define X(N_LIMBS) \ + case N_LIMBS: \ + hashEightBlocks, N_LIMBS * 8>(output, input); \ + break; + switch(gfpvar_::get_ZpD().get_t()) + { + X(2) X(3) X(4) X(5) X(6) X(7) X(8) X(9) X(10) X(11) X(12) + default: + throw runtime_error("MMO not implemented"); + } +#undef X +} diff --git a/Tools/random.cpp b/Tools/random.cpp index bcd3ba4e..7a0cd1da 100644 --- a/Tools/random.cpp +++ b/Tools/random.cpp @@ -40,6 +40,28 @@ void PRNG::SeedGlobally(const PlayerBase& P) SetSeed(seed); } +void PRNG::SeedGlobally(const Player& P, bool secure) +{ + if (secure) + SeedGlobally(static_cast(P)); + else + { + octetStream os; + if (P.my_num() == 0) + { + randombytes_buf(seed, SEED_SIZE); + os.append(seed, SEED_SIZE); + P.send_all(os); + } + else + { + P.receive_player(0, os); + os.consume(seed, SEED_SIZE); + } + InitSeed(); + } +} + void PRNG::SetSeed(const octet* inp) { memcpy(seed,inp,SEED_SIZE*sizeof(octet)); diff --git a/Tools/random.h b/Tools/random.h index 81697fea..d22be6e8 100644 --- a/Tools/random.h +++ b/Tools/random.h @@ -77,6 +77,7 @@ class PRNG // Agree securely on seed void SeedGlobally(const PlayerBase& P); + void SeedGlobally(const Player& P, bool secure = true); // Set seed from array void SetSeed(const unsigned char*); @@ -215,4 +216,16 @@ inline void PRNG::randomBnd(mp_limb_t* res, const mp_limb_t* B, mp_limb_t mask) while (mpn_cmp(res, B, n_limbs) >= 0); } +template<> +inline octet PRNG::get() +{ + return get_uchar(); +} + +template<> +inline word PRNG::get() +{ + return get_word(); +} + #endif diff --git a/Utils/Fake-Offline.cpp b/Utils/Fake-Offline.cpp index d0181425..7446c75f 100644 --- a/Utils/Fake-Offline.cpp +++ b/Utils/Fake-Offline.cpp @@ -583,6 +583,16 @@ int main(int argc, const char** argv) "-T", // Flag token. "--threshold" // Flag token. ); + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Deactivate Montgomery representation" + "(default: activated)", // Help description. + "-n", // Flag token. + "--nontgomery" // Flag token. + ); opt.parse(argc, argv); int lgp; @@ -711,13 +721,13 @@ int FakeParams::generate() { string p; opt.get("--prime")->getString(p); - T::clear::init_field(p); + T::clear::init_field(p, not opt.isSet("--nontgomery")); T::clear::template write_setup(nplayers); } else { T::clear::template generate_setup(prep_data_prefix, nplayers, lgp); - T::clear::init_default(lgp); + T::clear::init_default(lgp, not opt.isSet("--nontgomery")); } /* Find number players and MAC keys etc*/ @@ -817,6 +827,7 @@ void FakeParams::generate_field(true_type) if (nplayers > 2) { + ShamirShare::bit_type::clear::init_field(); make_basic>({}, nplayers, default_num, zero); make_basic>({}, nplayers, default_num, zero); make_with_mac_key>>(nplayers, diff --git a/Utils/hyper.cpp b/Utils/hyper.cpp new file mode 100644 index 00000000..80511f37 --- /dev/null +++ b/Utils/hyper.cpp @@ -0,0 +1,22 @@ +/* + * hyper.cpp + * + */ + +#include "Math/gfpvar.h" + +#include "Protocols/Shamir.hpp" + +int main(int argc, char** argv) +{ + assert(argc > 2); + gfpvar::init_field(argv[3]); + vector> hyper; + int t = atoi(argv[1]); + int n = atoi(argv[2]); + Shamir>::get_hyper(hyper, t, n); + octetStream os; + os.store(hyper); + ofstream out(Shamir>::hyper_filename(t, n)); + os.output(out); +} diff --git a/Yao/YaoWire.h b/Yao/YaoWire.h index d726e080..ddaf3b9c 100644 --- a/Yao/YaoWire.h +++ b/Yao/YaoWire.h @@ -15,6 +15,8 @@ protected: Key key_; public: + static int threshold(int) { return 1; } + template static void xors(GC::Processor& processor, const vector& args); template diff --git a/doc/Compiler.rst b/doc/Compiler.rst index f6cdc7a8..78e87176 100644 --- a/doc/Compiler.rst +++ b/doc/Compiler.rst @@ -10,18 +10,16 @@ Compiler.types module .. automodule:: Compiler.types :members: - :special-members: - :private-members: :no-undoc-members: - :no-inherited-members: - :show-inheritance: + :inherited-members: :exclude-members: intbitint, sgf2nfloat, sgf2nint, sgf2nint32, sgf2nuint, t, unreduced_sfix, sgf2nuint32, MemFix, MemFloat, PreOp, ClientMessageType, __weakref__, __repr__, reg_type, int_type, clear_type, float_type, basic_type, default_type, unreduced_type, bit_type, dynamic_array, squant, mov, - write_share_to_socket, + write_share_to_socket, add, mul, sintbit, from_sint +.. autoclass:: sintbit Compiler.GC.types module ------------------------ @@ -70,7 +68,8 @@ Compiler.ml module .. automodule:: Compiler.ml :members: :no-undoc-members: - :exclude-members: Adam, Tensor + :exclude-members: Tensor + :show-inheritance: .. autofunction:: approx_sigmoid Compiler.circuit module diff --git a/doc/index.rst b/doc/index.rst index e54c1c3b..6d2c510f 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -9,7 +9,7 @@ implemented protocols etc. see https://github.com/data61/MP-SPDZ. Compilation process ------------------- -The easiest way of using MP-SPDZ is using the ``compile.py`` as +The easiest way of using MP-SPDZ is using ``compile.py`` as described below. If you would like to run compilation directly from Python, see ``Scripts/direct_compilation_example.py``. It contains all the necessary setup steps. @@ -155,6 +155,7 @@ Reference Compiler instructions low-level + machine-learning networking io non-linear diff --git a/doc/io.rst b/doc/io.rst index 5d952c28..1dc36418 100644 --- a/doc/io.rst +++ b/doc/io.rst @@ -50,6 +50,16 @@ conjunction with :py:func:`~Compiler.library.print_ln_to` in order to output it. +Binary Output +~~~~~~~~~~~~~ + +Most types returned by :py:func:`reveal` or :py:func:`reveal_to` +feature a :py:func:`binary_output` method, which writes to +``Player-Data/Binary-Output-P-``. The format is +either signed 64-bit integer or double-precision floating-point in +machine byte order (usually little endian). + + Clients (Non-computing Parties) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -65,7 +75,17 @@ Secret Shares :py:func:`Compiler.types.sint.read_from_file` and :py:func:`Compiler.types.sint.write_to_file` allow reading and writing -secret shares to and from files. +secret shares to and from files. These instructions use +``Persistence/Transactions-P.data``. The format depends on +the protocol with the following principles. + +- One share follows the other without metadata. +- If there is a MAC, it comes after the share. +- Numbers are stored in little-endian format. +- Numbers modulo a power of two are stored with the minimal number of + bytes. +- Numbers modulo a prime are stored in Montgomery representation in + blocks of eight bytes. Another possibility for persistence between program runs is to use the fact that the memory is stored in diff --git a/doc/machine-learning.rst b/doc/machine-learning.rst new file mode 100644 index 00000000..f4c03342 --- /dev/null +++ b/doc/machine-learning.rst @@ -0,0 +1,79 @@ +Machine Learning +---------------- + +MP-SPDZ supports a limited subset of the Keras interface for machine +learning. This includes the SGD and Adam optimizers and the following +layer types: dense, 2D convolution, 2D max-pooling, and dropout. + +In the following we will walk through the example code in +``keras_mnist_dense.mpc``, which trains a dense neural network for +MNIST. It starts by defining tensors to hold data:: + + training_samples = sfix.Tensor([60000, 28, 28]) + training_labels = sint.Tensor([60000, 10]) + + test_samples = sfix.Tensor([10000, 28, 28]) + test_labels = sint.Tensor([10000, 10]) + +The tensors are then filled with inputs from party 0 in the order that +is used by `the preparation script +`_:: + + training_labels.input_from(0) + training_samples.input_from(0) + + test_labels.input_from(0) + test_samples.input_from(0) + +This is followed by Keras-like code setting up the model and training +it:: + + from Compiler import ml + tf = ml + + layers = [ + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(128, activation='relu'), + tf.keras.layers.Dense(128, activation='relu'), + tf.keras.layers.Dense(10, activation='softmax') + ] + + model = tf.keras.models.Sequential(layers) + + optim = tf.keras.optimizers.SGD(momentum=0.9, learning_rate=0.01) + + model.compile(optimizer=optim) + + opt = model.fit( + training_samples, + training_labels, + epochs=1, + batch_size=128, + validation_data=(test_samples, test_labels) + ) + +Lastly, the model is stored on disk in secret-shared form:: + + for var in model.trainable_variables: + var.write_to_file() + + +Prediction +~~~~~~~~~~ + +The example code in ``keras_mnist_dense_predict.mpc`` uses the model +stored above for prediction. Much of the setup is the same, but +instead of training it reads the model from disk:: + + model.build(test_samples.sizes) + + start = 0 + for var in model.trainable_variables: + start = var.read_from_file(start) + +Then it runs the prediction:: + + guesses = model.predict(test_samples) + +Using ``var.input_from(player)`` instead the model would be input +privately by a party. diff --git a/doc/networking.rst b/doc/networking.rst index 6062d360..df5f7f28 100644 --- a/doc/networking.rst +++ b/doc/networking.rst @@ -19,7 +19,7 @@ individually setting ports: can specify a party's listening port using ``--my-port``. 2. The parties read the information from a local file, which needs to - be same everywhere. The file can be specified using + be the same everywhere. The file can be specified using ``--ip-file-name`` and has the following format:: [:] @@ -29,6 +29,16 @@ individually setting ports: The hosts can be both hostnames and IP addresses. If not given, the ports default to base plus party number. +Whether or not encrypted connections are used depends on the security +model of the protocol. Honest-majority protocols default to encrypted +whereas dishonest-majority protocols default to unencrypted. You +change this by either using ``--encrypted/-e`` or +``--unencrypted/-u``. + +If using encryption, the certificates (``Player-Data/*.pem``) must be +the same on all hosts, and you have to run ``c_rehash Player-Data`` on +all of them. + Internal Infrastructure ~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/doc/troubleshooting.rst b/doc/troubleshooting.rst index 10b5fb4d..2034dd3a 100644 --- a/doc/troubleshooting.rst +++ b/doc/troubleshooting.rst @@ -9,7 +9,7 @@ Crash without error message or ``bad_alloc`` Some protocols require several gigabytes of memory, and the virtual machine will crash if there is not enough RAM. You can reduce the -memory usage for some malicious protocols with ``-B 4``. The memory +memory usage for some malicious protocols with ``-B 5``. The memory usage for malicious protocols based on homomorphic encryption can also be reduced by using ``-T``. Finally, every computation thread requires separate resources, so consider reducing the number of threads with @@ -27,7 +27,7 @@ lists only exists at compile time. Consider using ``compile.py`` takes too long ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -If you Python loops (``for``), the are unrolled at compile-time, +If you use Python loops (``for``), they are unrolled at compile-time, resulting in potentially too much virtual machine code. Consider using :py:func:`~Compiler.library.for_range` or similar. @@ -61,16 +61,35 @@ Handshake failures If you run on different hosts, the certificates (``Player-Data/*.pem``) must be the same on all of them. Also make sure to run ``c_rehash Player-Data`` on all hosts. Finally, the -certificate generated by ``Scripts/setup-ssl.sh`` expire after a +certificates generated by ``Scripts/setup-ssl.sh`` expire after a month, so you might to regenerate them. +Connection failures +~~~~~~~~~~~~~~~~~~~ + +MP-SPDZ requires at least one TCP port per party to be open to other +parties. In the default setting, it's 4999 and 5000 on party 0, and +5001 on party 1 etc. You change change the base port (5000) using +``--portnumbase`` and individual ports for parties using +``--my-port``. + + +Internally called tape has unknown offline data usage +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Certain computations are not compatible with reading preprocessing +from disk. You can compile the binaries with ``MY_CFLAGS += +-DINSECURE`` in ``CONFIG.mine`` in order to execute the computation in +a way that reuses preprocessing. + + Not compiled for choice of parameters ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ HighGear and LowGear only support a limited choice of parameters -because they to be chosen when compiling the binaries. You can follow -the instructions in error message and recompile the binaries in order +because they need to be chosen when compiling the binaries. You can follow +the instructions in the error message and recompile the binaries in order fix this. @@ -84,6 +103,15 @@ processor without AVX (produced before 2011), you need to set ``AVX_OT = 0`` to run dishonest-majority protocols. +Invalid instruction +~~~~~~~~~~~~~~~~~~~ + +The compiler code and the virtual machine binary have to be from the +same version because most version slightly change the bytecode. This +mean you can only use the precompiled binaries with the Python code in +the same release. + + Computation used more preprocessing than expected ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~