diff --git a/CHANGELOG.md b/CHANGELOG.md index 9395eb12..3c1b62bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,14 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. +## 0.4.2 (Dec 24, 2025) + +- Expected communication cost in compiler +- Semi-honest option of Rep4 +- Reduced communication for preprocessing in Dealer protocol +- Option of choosing SoftSpoken parameter at run-time +- BERT functionality (@hiddely) +- Recommended reading list in documentation + ## 0.4.1 (May 30, 2025) - Add protocols with function-dependent preprocessing (https://eprint.iacr.org/2025/919) diff --git a/Compiler/GC/instructions.py b/Compiler/GC/instructions.py index ecc0e8f9..64fea47d 100644 --- a/Compiler/GC/instructions.py +++ b/Compiler/GC/instructions.py @@ -618,6 +618,10 @@ class reveal(BinaryVectorInstruction, base.VarArgsInstruction, base.Mergeable): code = opcodes['REVEAL'] arg_format = tools.cycle(['int','cbw','sb']) + def add_usage(self, req_node): + req_node.increment(('bit', 'open'), sum( + int(math.ceil(x / 64)) * 8 for x in self.args[0::3])) + class inputb(base.DoNotEliminateInstruction, base.VarArgsInstruction): """ Copy private input to secret bit register vectors. The input is read as floating-point number, multiplied by a power of two, and then diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 5a3018b4..9d91ad6e 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -24,9 +24,14 @@ from functools import reduce class _binary: def __or__(self, other): return self ^ other ^ (self & other) + __ror__ = __or__ def reveal_to(self, *args, **kwargs): raise CompilerError( '%s does not support revealing to individual players' % type(self)) + @staticmethod + def direct_matrix_mul(*args, **kwargs): + raise AttributeError('direct matrix multiplication only supported ' + 'in arithmetic circuits') class bits(Tape.Register, _structure, _bit, _binary): n = 40 @@ -432,7 +437,7 @@ class cbits(bits): inst.convcbitvec(self.n, res, self) return res -class sbits(bits): +class sbits(bits, Tape._no_secret_truth): """ Secret bits register. This type supports basic bit-wise operations:: @@ -697,7 +702,7 @@ class sbits(bits): def output(self): inst.print_reg_plainsb(self) -class sbitvec(_vec, _bit, _binary): +class sbitvec(Tape._no_secret_truth, _vec, _bit, _binary): """ Vector of registers of secret bits, effectively a matrix of secret bits. This facilitates parallel arithmetic operations in binary circuits. Container types are not supported, use :py:obj:`sbitvec.get_type` for that. @@ -907,35 +912,27 @@ class sbitvec(_vec, _bit, _binary): def __init__(self, elements=None, length=None, input_length=None): if length: assert isinstance(elements, sint) - if Program.prog.use_split(): - x = elements.split_to_two_summands(length) - v = sbitint.bit_adder(x[0], x[1]) - else: - prog = Program.prog - if not prog.options.ring: - # force the use of edaBits - backup = prog.use_edabit() - prog.use_edabit(True) - self.v = prog.non_linear.bit_dec( - elements, max(length, input_length or prog.bit_length), - length, maybe_mixed=True) - assert isinstance(self.v[0], sbits) - prog.use_edabit(backup) - return - comparison.require_ring_size(length, 'A2B conversion') - l = int(Program.prog.options.ring) - r, r_bits = sint.get_edabit(length, size=elements.size) - c = ((elements - r) << (l - length)).reveal() - c >>= l - length - cb = [(c >> i) for i in range(length)] - x = sbitintvec.from_vec(r_bits) + sbitintvec.from_vec(cb) - v = x.v - self.v = v[:length] + prog = Program.prog + backup = prog.use_edabit() + if not prog.have_a2b(): + # force the use of edaBits + prog.use_edabit(True) + self.v = prog.non_linear.bit_dec( + elements, max(length, input_length or prog.bit_length), + length, maybe_mixed=True) + assert isinstance(self.v[0], sbits) + prog.use_edabit(backup) elif isinstance(elements, sbitvec): self.v = elements.v + elif isinstance(elements, (list, tuple)) and \ + isinstance(elements[0], sbitvec): + self.v = sbitvec(sum((x.elements() for x in elements), [])).v elif elements is not None and not (util.is_constant(elements) and \ elements == 0): self.v = sbits.trans(elements) + def __str__(self): + return 'sbitvec(%s/%s)' % (len(self.v), self.size) + __repr__ = __str__ def popcnt(self): """ Population count / Hamming weight. @@ -961,7 +958,7 @@ class sbitvec(_vec, _bit, _binary): return self.from_vec(x ^ y for x, y in zip(*self.expand(other))) def __and__(self, other): return self.from_vec(x & y for x, y in zip(*self.expand(other))) - __rxor__ = __xor__ + __add__ = __radd__ = __sub__ = __rsub__ =__rxor__ = __xor__ __rand__ = __and__ def __invert__(self): return self.from_vec(~x for x in self.v) @@ -969,10 +966,6 @@ class sbitvec(_vec, _bit, _binary): return util.if_else(self.v[0], x, y) def __iter__(self): return iter(self.elements()) - def __len__(self): - return len(self.v) - def __getitem__(self, index): - return self.v[index] @classmethod def conv(cls, other): if isinstance(other, cls): @@ -999,8 +992,6 @@ class sbitvec(_vec, _bit, _binary): return util.untuplify([x.reveal() for x in self.elements()]) def long_one(self): return [x.long_one() for x in self.v] - def __rsub__(self, other): - return self.from_vec(y - x for x, y in zip(self.v, other)) def half_adder(self, other): other = self.coerce(other) res = zip(*(x.half_adder(y) for x, y in zip(self.v, other))) @@ -1014,7 +1005,7 @@ class sbitvec(_vec, _bit, _binary): elif len(self.v) == 1: self, other = other, self.v[0] else: - raise CompilerError('no operand of lenght 1: %d/%d', + raise CompilerError('no operand of length 1: %d/%d', (len(self.v), len(other.v))) if not isinstance(other, sbits): return NotImplemented @@ -1036,8 +1027,6 @@ class sbitvec(_vec, _bit, _binary): i += 1 return sbitvec.from_vec(res) __rmul__ = __mul__ - def __add__(self, other): - return self.from_vec(x + y for x, y in zip(self.v, other)) def bit_and(self, other): return self & other def bit_xor(self, other): @@ -1060,7 +1049,12 @@ class sbitvec(_vec, _bit, _binary): @classmethod def comp_result(cls, x): return cls.get_type(1).from_vec([x]) - def expand(self, other, expand=True): + @staticmethod + def reverse_type(other): + return isinstance(other, sbitfixvec) + equal = __eq__ = _bitint.__eq__ + eqz = staticmethod(_bitint.eqz) + def expand(self, other, expand=True, copy=False): assert not isinstance(other, sbitfixvec) m = 1 for x in itertools.chain(self.v, other.v if isinstance(other, sbitvec) else []): @@ -1076,7 +1070,10 @@ class sbitvec(_vec, _bit, _binary): res.append([x * sbits.get_type(m)().long_one() for x in util.bit_decompose(y, len(self.v))]) else: - v = [type(x)(x) if isinstance(x, bits) else x for x in y.v] + if copy: + v = [type(x)(x) if isinstance(x, bits) else x for x in y.v] + else: + v = y.v res.append([x.expand(m) if (expand and isinstance(x, bits)) else x for x in v]) return res @@ -1364,7 +1361,21 @@ class sbitint(_bitint, _number, sbits, _sbitintbase): class sbitintvec(sbitvec, _bitint, _number, _sbitintbase): """ - Vector of signed integers for parallel binary computation. + Values and vectors of signed integers for parallel binary computation:: + + si32 = sbitintvec.get_type(32) + print_ln('add: %s', (si32(5) + si32(3)).reveal()) + print_ln('sub: %s', (si32(5) - si32(3)).reveal()) + print_ln('mul: %s', (si32(5) * si32(3)).reveal()) + print_ln('lt: %s', (si32(5) < si32(3)).reveal()) + + This should output:: + + add: 8 + sub: 2 + mul: 15 + lt: 0 + The following example uses vectors of size two:: sb32 = sbits.get_type(32) @@ -1389,7 +1400,7 @@ class sbitintvec(sbitvec, _bitint, _number, _sbitintbase): """ bit_extend = staticmethod(_complement_two_extend) - mul_functions = {} + functions = {} @classmethod def popcnt_bits(cls, bits): return sbitvec.from_vec(bits).popcnt() @@ -1406,8 +1417,7 @@ class sbitintvec(sbitvec, _bitint, _number, _sbitintbase): if len(a) == 1: res = _bitint.bit_adder(a, b, get_carry=True) return self.get_type(32).from_vec(res, signed=False) - v = sbitint.bit_adder(a, b) - return self.get_type(len(v)).from_vec(v) + return self.maybe_function(self.binary_add, a, b) __radd__ = __add__ __sub__ = _bitint.__sub__ def __rsub__(self, other): @@ -1424,9 +1434,12 @@ class sbitintvec(sbitvec, _bitint, _number, _sbitintbase): elif isinstance(other, sbitfixvec): return NotImplemented try: - my_bits, other_bits = self.expand(other, False) + my_bits, other_bits = self.expand(other, False, copy=True) except: return NotImplemented + return self.maybe_function(self.binary_mul, my_bits, other_bits) + @classmethod + def maybe_function(cls, call, my_bits, other_bits, result_length=None): m = float('inf') uniform = True for x in itertools.chain(my_bits, other_bits): @@ -1437,21 +1450,26 @@ class sbitintvec(sbitvec, _bitint, _number, _sbitintbase): pass if uniform and Program.prog.options.cisc: bl = len(my_bits) - key = bl, len(other_bits) - if key not in self.mul_functions: + ol = result_length or bl + key = call.__name__, ol, bl, len(other_bits) + if key not in cls.functions: def instruction(*args): - res = self.binary_mul(args[bl:2 * bl], args[2 * bl:], - args[0].n) + res = call(args[ol:ol + bl], args[ol + bl:], args[0].n) for x, y in zip(sbitvec.from_vec(res).v, args): x.mov(y, x) - instruction.__name__ = 'binary_mul%sx%s' % (bl, len(other_bits)) - self.mul_functions[key] = instructions_base.cisc(instruction, - bl) - res = [sbits.get_type(m)() for i in range(bl)] - self.mul_functions[key](*(res + my_bits + other_bits)) - return self.from_vec(res) + instruction.__name__ = '%s%sx%s' % (call.__name__, bl, len(other_bits)) + cls.functions[key] = instructions_base.cisc(instruction, ol) + res = [sbits.get_type(m)() for i in range(ol)] + cls.functions[key](*(res + my_bits + other_bits)) + if result_length: + return res + else: + return cls.from_vec(res) else: - return self.binary_mul(my_bits, other_bits, m) + return call(my_bits, other_bits, m) + @classmethod + def binary_add(cls, a, b, m): + return cls.from_vec(sbitint.bit_adder(a, b)) @classmethod def binary_mul(cls, my_bits, other_bits, m): matrix = [] @@ -1468,21 +1486,21 @@ class sbitintvec(sbitvec, _bitint, _number, _sbitintbase): def TruncMul(self, other, k, m, kappa=None, nearest=False): if nearest: raise CompilerError('round to nearest not implemented') - if not isinstance(other, sbitintvec): - other = sbitintvec(other) + if isinstance(other, int): + b = other + else: + if not isinstance(other, sbitintvec): + other = sbitintvec(other) + b = self.get_type(k).from_vec(_complement_two_extend(other.v, k)) a = self.get_type(k).from_vec(_complement_two_extend(self.v, k)) - b = self.get_type(k).from_vec(_complement_two_extend(other.v, k)) tmp = a * b assert len(tmp.v) == k - return self.get_type(k - m).from_vec(tmp[m:]) + return self.get_type(k - m).from_vec(tmp.v[m:]) def pow2(self, k): """ Computer integer power of two. :param k: bit length of input """ return _sbitintbase.pow2(self, k) - @staticmethod - def reverse_type(other): - return isinstance(other, sbitfixvec) sbits.vec = sbitvec sbitint.vec = sbitintvec @@ -1511,6 +1529,14 @@ class cbitfix(object): v = self.v inst.print_float_plainb(v, cbits.get_type(32)(-self.f), cbits(0), cbits(0), cbits(0)) + def __iter__(self): + return iter([self]) + def error(*args, **kwargs): + raise CompilerError( + 'Support for revealed fixed-point values in binary circuits ' + 'is currently limited to simple outputs. ' + 'Please file a feature request if you need this for an application.') + __add__ = __mul__ = __sub__ = error class sbitfix(_fix, _binary): """ Secret signed fixed-point number in one binary register. @@ -1581,14 +1607,29 @@ class sbitfix(_fix, _binary): return cls._new(cls.int_type(other), k, f) class sbitfixvec(_fix, _vec, _binary): - """ Vector of fixed-point numbers for parallel binary computation. + """ + Values and vectors of fixed-point numbers for parallel binary computation:: - Use :py:obj:`set_precision()` to change the precision. + print_ln('add: %s', (sbitfixvec(0.5) + sbitfixvec(0.3)).reveal()) + print_ln('mul: %s', (sbitfixvec(0.5) * sbitfixvec(0.3)).reveal()) + print_ln('sub: %s', (sbitfixvec(0.5) - sbitfixvec(0.3)).reveal()) + print_ln('lt: %s', (sbitfixvec(0.5) < sbitfixvec(0.3)).reveal()) - Example:: + will output roughly:: - a = sbitfixvec([sbitfix(0.3), sbitfix(0.5)]) - b = sbitfixvec([sbitfix(0.4), sbitfix(0.6)]) + add: 0.800003 + mul: 0.149994 + sub: 0.199997 + lt: 0 + + Note that the default precision (16 bits after the dot, 31 bits in + total) only allows numbers up to :math:`2^{31-16-1} \\approx + 16000`. You can increase this using :py:func:`set_precision`. + + Refer to the following example for the vector functionality:: + + a = sbitfixvec([sbitfixvec(0.3), sbitfixvec(0.5)]) + b = sbitfixvec([sbitfixvec(0.4), sbitfixvec(0.6)]) c = (a + b).elements() print_ln('add: %s, %s', c[0].reveal(), c[1].reveal()) c = (a * b).elements() @@ -1606,13 +1647,12 @@ class sbitfixvec(_fix, _vec, _binary): lt: 1, 1 """ - int_type = sbitintvec.get_type(sbitfix.k) float_type = type(None) clear_type = cbitfix rep_type = staticmethod(lambda x: x) @property def bit_type(self): - return type(self.v[0]) + return type(self.v.v[0]) @classmethod def set_precision(cls, f, k=None): super(sbitfixvec, cls).set_precision(f=f, k=k) @@ -1637,7 +1677,8 @@ class sbitfixvec(_fix, _vec, _binary): value = self.int_type(value) super(sbitfixvec, self).__init__(value, *args, **kwargs) def elements(self): - return [sbitfix._new(x, f=self.f, k=self.k) for x in self.v.elements()] + return [sbitfixvec._new(x, f=self.f, k=self.k) + for x in self.v.elements()] def mul(self, other): if isinstance(other, sbits): return self._new(self.v * other) diff --git a/Compiler/allocator.py b/Compiler/allocator.py index 4dde1be0..a2900de6 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -712,7 +712,8 @@ class Merger: elif isinstance(instr, StackInstruction): keep_order(instr, n, StackInstruction) elif isinstance(instr, applyshuffle): - shuffles[instr.args[3]].add(n) + for handle in instr.handles(): + shuffles[handle].add(n) elif isinstance(instr, delshuffle): for i_inst in shuffles[instr.args[0]]: add_edge(i_inst, n) diff --git a/Compiler/circuit.py b/Compiler/circuit.py index 076536b3..ae45d8e7 100644 --- a/Compiler/circuit.py +++ b/Compiler/circuit.py @@ -61,14 +61,15 @@ class Circuit: return self.run(*inputs) def run(self, *inputs): - n = inputs[0][0].n, get_tape() + inputs = [sbitvec.from_vec(x) for x in inputs] + n = inputs[0].v[0].n, get_tape() if n not in self.functions: if get_program().force_cisc_tape: f = function_call_tape else: f = function_block self.functions[n] = f(lambda *args: self.compile(*args)) - self.functions[n].name = '%s(%d)' % (self.name, inputs[0][0].n) + self.functions[n].name = '%s(%d)' % (self.name, inputs[0].v[0].n) flat_res = self.functions[n](*itertools.chain(*( sbitvec.from_vec(x).v for x in inputs))) res = [] @@ -208,7 +209,7 @@ def sha3_256(x): for x in range(5): for i in range(w): j = (5 * y + x) * w + i // 8 * 8 + 7 - i % 8 - res[x][y][i] = S_flat[1600 - 1 -j] + res[x][y][i] = S_flat.v[1600 - 1 -j] return res w = 64 @@ -313,7 +314,7 @@ class ieee_float: for i in range(2): for j in range(10): - values.append(sbitint.get_type(64).get_input_from(i)) + values.append(sbitintvec.get_type(64).get_input_from(i)) fvalues = [ieee_float(x) for x in values] diff --git a/Compiler/comparison.py b/Compiler/comparison.py index 0adae68a..7c1a6591 100644 --- a/Compiler/comparison.py +++ b/Compiler/comparison.py @@ -67,16 +67,21 @@ def ld2i(c, n): def maybe_mulm(res, x, y): # overwrite instruction for function-dependent preprocessing protocols from Compiler import types - res.link(x * y) + program.curr_block.replace_last_reg(res, x * y) -def require_ring_size(k, op, suffix=''): +def require_ring_size(k, op, suffix='', slack=0): if not program.options.ring: return + diff = slack * (not program.allow_tight_parameters) + k += diff if int(program.options.ring) < k: msg = 'ring size too small for %s, compile ' \ 'with \'-R %d\' or more' % (op, k) if k > 64 and k < 128: msg += ' (maybe \'-R 128\' as it is supported by default)' + if int(program.options.ring) >= k - diff: + msg += ", alternatively set " \ + "'program.allow_tight_parameters=True' in the program" raise CompilerError(msg + suffix) program.curr_tape.require_bit_length(k) @@ -97,13 +102,13 @@ def LtzRingRaw(a, k): from .types import sint, _bitint from .GC.types import sbitvec if program.use_split(): - program.reading('comparison', 'ABY3') + program.reading('comparison', 'Keller25', 'Section 6') summands = a.split_to_two_summands(k) carry = CarryOutRawLE(*reversed(list(x[:-1] for x in summands))) msb = carry ^ summands[0][-1] ^ summands[1][-1] return msb else: - program.reading('comparison', 'DEK20-pre') + program.reading('comparison', 'DEK20-pre', 'Paragraph III.D.8') from . import floatingpoint require_ring_size(k, 'comparison') m = k - 1 @@ -195,7 +200,7 @@ def TruncLeakyInRing(a, k, m, signed): if k == m: return 0 assert k > m - program.reading('truncation', 'DEK20-pre') + program.reading('truncation', 'DEK20-pre', 'Paragraph III.D.4') require_ring_size(k, 'leaky truncation') from .types import sint, intbitint, cint, cgf2n n_bits = k - m @@ -239,7 +244,7 @@ def Mod2m(a_prime, a, k, m, signed): movs(a_prime, program.non_linear.mod2m(a, k, m, signed)) def Mod2mRing(a_prime, a, k, m, signed): - program.reading('modulo', 'DEK20-pre') + program.reading('modulo', 'DEK20-pre', 'Paragraph III.D.3') require_ring_size(k, 'modulo power of two') from Compiler.types import sint, intbitint, cint shift = int(program.options.ring) - m @@ -254,7 +259,7 @@ def Mod2mRing(a_prime, a, k, m, signed): return res def Mod2mField(a_prime, a, k, m, signed): - program.reading('modulo', 'CdH10') + program.reading('modulo', 'CdH10', 'Protocol 3.2') from .types import sint r_dprime = program.curr_block.new_reg('s') r_prime = program.curr_block.new_reg('s') @@ -349,6 +354,8 @@ def BitLTC1(u, a, b): a: array of clear bits b: array of secret bits (same length as a) """ + program.reading('constant-round bit-wise public-private comparison', + 'CdH10', 'Protocol 4.5') k = len(b) p = [program.curr_block.new_reg('s') for i in range(k)] from . import floatingpoint @@ -489,6 +496,8 @@ def BitLTL(res, a, b): a: clear integer register b: array of secret bits (same length as a) """ + program.reading('logarithmic-round bit-wise public-private comparison', + 'CdH10', 'Protocol 4.1') k = len(b) a_bits = b[0].bit_decompose_clear(a, k) from .types import sint @@ -655,7 +664,7 @@ def Mod2(a_0, a, k, signed): if k <= 1: movs(a_0, a) return - program.reading('modulo', 'CdH10') + program.reading('modulo', 'CdH10', 'Protocol 3.4') r_dprime = program.curr_block.new_reg('s') r_prime = program.curr_block.new_reg('s') r_0 = program.curr_block.new_reg('s') diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index 867480f7..18a2a965 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -587,6 +587,12 @@ class Compiler: print("Cost:", 0 if self.prog.req_num is None else self.prog.req_num.cost()) print("Memory size:", dict(self.prog.allocated_mem)) + comm = self.prog.expected_communication() + if sum(comm): + print( + "Expected communication is %g MB online and %g MB offline." % \ + (comm[0] / 1e6, comm[1] / 1e6)) + return self.prog match = { @@ -608,6 +614,13 @@ class Compiler: else: return protocol + "-party.x" + @classmethod + def short_protocol_name(cls, protocol): + for x in cls.match.items(): + if protocol == x[1]: + return x[0] + return re.sub('^malicious-', 'mal-', protocol) + def local_execution(self, args=None): if args is None: args = self.runtime_args @@ -651,6 +664,7 @@ class Compiler: destinations.append('.') connections = [Connection(hostname) for hostname in hostnames] print("Setting up players...") + lockfile = ".transfer.lock" def run(i): dest = destinations[i] @@ -658,6 +672,16 @@ class Compiler: connection.run( "mkdir -p %s/{Player-Data,Programs/{Bytecode,Schedules}} " % \ dest) + dest_lockfile = "%s/%s" % (dest, lockfile) + try: + connection.run("test -e %s && exit 1; touch %s" % ( + (dest_lockfile,) * 2)) + except: + raise Exception( + "Problem with %s on %s. You cannot use the same directory " + "for several instances (including the control instance). " + "Remove %s on %s if this has been left behind from an " + "aborted exection." % ((dest_lockfile, hostnames[i]) * 2)) # executable connection.put("%s/static/%s" % (self.root, vm), dest) # program @@ -676,10 +700,12 @@ class Compiler: dest + "Player-Data") for filename in glob.glob("Player-Data/*.0"): connection.put(filename, dest + "Player-Data") + connection.run("rm %s" % dest_lockfile) def run_with_error(i): try: run(i) + copied[i] = True except IOError: print('IO error when copying files, does %s have enough space?' % hostnames[i]) @@ -693,13 +719,19 @@ class Compiler: out = fn(i) outputs[i] = out + open(lockfile, "w") threads = [] + copied = [False] * len(hosts) for i in range(len(hosts)): threads.append(threading.Thread(target=run_with_error, args=(i,))) for thread in threads: thread.start() for thread in threads: thread.join() + os.remove(lockfile) + if False in copied: + print("Error in remote copying, see above") + sys.exit(1) # execution threads = [] diff --git a/Compiler/cost.py b/Compiler/cost.py new file mode 100644 index 00000000..9e27a0f4 --- /dev/null +++ b/Compiler/cost.py @@ -0,0 +1,485 @@ +import re +import math +import os +import itertools + +class Comm: + def __init__(self, comm=0, offline=0): + try: + comm = comm() + except: + pass + try: + self.online, self.offline = comm + assert not offline + except: + self.online = comm or 0 + assert isinstance(self.online, (int, float)) + self.offline = offline + + def __getitem__(self, index): + return self.offline if index else self.online + + def __iter__(self): + return iter((self.online, self.offline)) + + def __add__(self, other): + return Comm(x + y for x, y in zip(self, other)) + + def __sub__(self, other): + return self + -1 * other + + def __mul__(self, other): + return Comm(x * other for x in self) + __rmul__ = __mul__ + + def __repr__(self): + return 'Comm(%d, %d)' % tuple(self) + + def __bool__(self): + return bool(sum(self)) + + def sanitize(self): + try: + return tuple(int(x) for x in self) + except: + return (0, 0) + +dishonest_majority = { + 'emi', + 'mascot', + 'spdz', + 'soho', + 'gear', +} + +semihonest = { + 'emi|soho', + 'atlas|^shamir', + 'dealer', +} + +ring = { + 'ring', + '2k', +} + +fixed = { + '^(ring|rep-field)': 3, + 'rep4': 6, + 'mal-rep-field': (6, 9), + 'mal-rep-ring': (lambda l: (6 * l, (l + 5) * 9)), + 'sy-rep-field': 6, + 'sy-rep-ring': lambda l: (6 * (l + 5), 0), + 'ps-rep-field': 9, + 'ps-rep-ring': lambda l: 9 * (l + 5), + 'brain': lambda l: (3 * 2 * l, 3 * (2 * (l + 5) + 3 * (2 * l + 15))), +} + +ot_cost = 64 +spdz2k_sec = 64 + +def lowgear_cipher_length(l): + res = (30 + 2 * l) // 8 + return res + +def highgear_cipher_lengths(l): + res = 71 + 16 * l, 57 + 8 * l + return res + +def highgear_cipher_limbs(l): + res = sum(int(math.ceil(x / 64)) for x in highgear_cipher_lengths(l)) + return res + +def highgear_decrypt_length(l): + return highgear_cipher_lengths(l)[0] / 8 + 1 + +def hemi_cipher_length(l): + res = 16 * l + 77 + return res + +def hemi_cipher_limbs(l): + res = int(math.ceil(hemi_cipher_length(l) / 64)) + return res + +variable = { + '^shamir': lambda N: N * (N - 1) // 2, + 'atlas': lambda N: N // 2 * 4, + 'dealer': lambda N: (2 * (N - 1), 1), + 'semi': lambda N: lambda l: ( + 4 * (N - 1) * l, N * (N - 1) * (l * (ot_cost + 8 * l))), + 'mascot': lambda N: lambda l: ( + 4 * (N - 1) * l, N * (N - 1) * (l * (3 * ot_cost + 64 * l))), + 'spdz2k': lambda N: lambda l: ( + 4 * (N - 1) * l, + N * (N - 1) * (ot_cost * (2 * l + 4 * spdz2k_sec // 8) + \ + (l + spdz2k_sec // 8) * (4 * spdz2k_sec + 2 * l * 8) + \ + (5 * (l + 2 * spdz2k_sec // 8) * spdz2k_sec))), + 'hemi': lambda N: lambda l: ( + 4 * (N - 1) * l, N * (N - 1) * hemi_cipher_limbs(l) * 8 * 2 * 2), + 'temi': lambda N: lambda l: ( + 4 * (N - 1) * l, (N - 1) * (hemi_cipher_limbs(l) * 8 * 2 * 2 + + hemi_cipher_length(l) / 8 + 1) * 2), + 'soho': lambda N: lambda l: ( + 4 * (N - 1) * l, + (N - 1) * (N * highgear_cipher_limbs(l) * 8 * 2 + + highgear_decrypt_length(l)) * 2), + 'owgear': lambda N: lambda l: ( + 4 * (N - 1) * l, + N * ((N - 1) * (lowgear_cipher_length(l) * (128 + 48) + 64) + 2 * l)), + '.*i.*gear': lambda N: lambda l: ( + 4 * (N - 1) * l, + (N - 1) * (highgear_cipher_limbs(l) * 96 * 3 + + highgear_decrypt_length(l) * 16 + N * 192 + 6 * l)), + 'sy-shamir': lambda N: 2 * variable['^shamir'](N) + variable_random['^shamir|atlas'](N) +} + +variable_square = { + 'soho': lambda N: lambda l: ( + 0, (N - 1) * (N * highgear_cipher_limbs(l) * 8 + 46) * 2), + 'i.*gear': lambda N: lambda l: ( + 0, (N - 1) * (highgear_cipher_limbs(l) * 64 * 3 + + highgear_decrypt_length(l) * 12 + N * 128 + 4 * l)), + 'ps-rep-ring': lambda N: lambda l: fixed['ps-rep-ring'](l), + 'sy-shamir': lambda N: ( + 0, variable['sy-shamir'](N) + variable_random['sy-shamir'](N)) +} + +matrix_triples = { + 'dealer': lambda N: (N - 1, 1), +} + +diag_matrix = { + 'hemi': lambda N, l, dims: N * (N - 1) * hemi_cipher_limbs(l) * 8 * 2 * \ + (dims[0] * dims[1] + dims[0] * dims[2]), + 'temi': lambda N, l, dims: (N - 1) * ( + hemi_cipher_limbs(l) * 8 * 2 * 2 * ( + dims[0] * dims[1] + dims[0] * dims[2]) + + (hemi_cipher_length(l) / 8 + 1) * 2 * (dims[0] * dims[2])), +} + +fixed_bit = { + 'mal-rep-field': (0, 11), + 'rep4': (0, 8), +} + +fixed_square = { + 'mal-rep-ring': lambda l: fixed['mal-rep-ring'](l)[1], +} + +variable_bit = { + 'dealer': lambda N: (0, 1), + # missing OT cost + 'emi': lambda N: lambda l: (0, l + ot_cost / 8) if N == 2 else None, + 'mal-shamir': lambda N: ( + 0, variable_random['^shamir|atlas'](N) + \ + math.ceil(N / 2) * variable_input['^shamir|atlas'](N) + \ + (math.ceil(N / 2) - 0) * variable['^shamir'](N) + \ + 2 * reveal_variable['(mal|sy)-shamir'](N)), +} + +fixed_and = { + '(mal|sy|ps)-rep': lambda bucket_size=4: (6, 3 * (3 * bucket_size - 2)), +} + +variable_and = { + 'emi': lambda N: (4 * (N - 1), N * (N - 1) * ot_cost) +} + +trunc_pr = { + '^ring': 4, + 'rep-field': 1, + 'rep4': 12, +} + +bit2a = { + '^(ring|rep-field)': 3, +} + +dabit_from_bit = { + 'ring', + '-rep-ring', + 'semi2k', +} + +bits_from_squares = { + 'atlas': lambda N: N > 4, + 'sy-shamir': lambda N: True, + 'soho': lambda N: True, + 'gear': lambda N: True, + 'ps-rep-ring': lambda N: True, + 'spdz2k': lambda N: True, + 'mascot': lambda N: True, + 'mal-rep-ring': lambda N: True, + 'emi$': lambda N: True, +} + +reveal = { + '((^|rep.*)ring|rep-field|brain)': 3, + 'rep4': 4, +} + +reveal_variable = { + '^shamir|atlas': lambda N: 3 * (N - 1) // 2, + '(mal|sy)-shamir': lambda N: (N - 1) // 2 * 2 * N, + 'dealer': lambda N: 2 * (N - 2), + 'spdz2k': lambda N: N * variable_input['mascot|spdz2k'](N), +} + +fixed_input = { + '(^|ps-|mal-)(ring|rep-)': 1, + 'sy-rep-ring': lambda l: 4 * (l + 5), + 'sy-rep-field': 4, + 'rep4': 2, +} + +variable_input = { + '^shamir|atlas': lambda N: N // 2, + 'mal-shamir': lambda N: N // 2, + 'sy-shamir': lambda N: \ + N // 2 + variable['^shamir'](N) + variable_random['^shamir|atlas'](N), + 'mascot|spdz2k': lambda N: (N - 1) * Comm(1, ot_cost * 2), + 'owgear': lambda N: lambda l: ( + (N - 1) * l, (N - 1) * lowgear_cipher_length(l) * 16), + 'i.*gear': lambda N: lambda l: ( + (N - 1) * l, (N - 1) * (highgear_cipher_limbs(l) * 24 + 32 + + highgear_decrypt_length(l) * 4)), +} + +variable_random = { + '^shamir|atlas': lambda N: N * (N // 2) / ((N + 2) // 2), + 'mal-shamir': lambda N: N // 2 * N, + 'sy-shamir': lambda N: \ + 2 * variable_random['^shamir|atlas'](N) + variable['^shamir'](N), +} + +# cut random values +fixed_randoms = { + 'sy-rep-ring': lambda l: 3 * (l + 5), +} + +cheap_dot_product = { + '^(ring|rep-field)', + 'sy-*', + '^shamir', + 'rep4', + 'atlas', +} + +shuffle_application = { + '^(ring|rep-field)': 6, + 'sy-rep-field': 12, + 'sy-rep-ring': lambda l: 12 * (l + 5) +} + +variable_edabit = { + 'dealer': lambda N: lambda n_bits: lambda l: l + n_bits / 8 +} + +def find_match(data, protocol): + for x in data: + if re.search(x, protocol): + return x + +def get_match(data, protocol): + x = find_match(data, protocol) + try: + return data.get(x) + except: + return bool(x) + +def get_match_variable(data, protocol, n_parties): + f = get_match(data, protocol) + if f: + return f(n_parties) + +def apply_length(unit, length): + try: + return Comm(unit(length)) + except: + return Comm(unit) * length + +def get_cost(fixed, variable, protocol, n_parties): + return get_match(fixed, protocol) or \ + get_match_variable(variable, protocol, n_parties) + +def get_mul_cost(protocol, n_parties): + return get_cost(fixed, variable, protocol, n_parties) + +def get_and_cost(protocol, n_parties): + return get_cost(fixed_and, variable_and, protocol, n_parties) + +def expected_communication(protocol, req_num, length, n_parties=None, + force_triple_use=False): + from Compiler.instructions import shuffle_base + from Compiler.program import Tape + get_int = lambda x: req_num.get(('modp', x), 0) + get_bit = lambda x: req_num.get(('bit', x), 0) + res = Comm() + if not protocol: + return res + if not n_parties: + try: + if get_match(fixed, protocol): + raise TypeError() + n_parties = int(os.getenv('PLAYERS')) + except TypeError: + if find_match(dishonest_majority, protocol): + n_parties = 2 + else: + n_parties = 3 + if find_match(dishonest_majority, protocol): + threshold = n_parties - 1 + elif re.match('rep4', protocol): + n_parties = 4 + threshold = 1 + elif re.match('dealer', protocol): + threshold = 0 + else: + threshold = n_parties // 2 + malicious = not find_match(semihonest, protocol) + x = find_match(fixed, protocol) + y = get_mul_cost(protocol, n_parties) + unit = apply_length(y, length) + n_mults = get_int('simple multiplication') + matrix_cost = apply_length( + get_match_variable(matrix_triples, protocol, n_parties), length) + use_diag_matrix = get_match(diag_matrix, protocol) + use_triple_number = False + if find_match(cheap_dot_product, protocol): + n_mults += get_int('dot product') + elif (not matrix_cost and not use_diag_matrix) or force_triple_use: + use_triple_number = True + n_mults = get_int('triple') + and_cost = get_and_cost(protocol, n_parties) + if and_cost: + res += Comm(and_cost) * math.ceil(get_bit('triple') / 8) + else: + n_mults += get_bit('triple') / (length * 8) + bit_cost = Comm(apply_length( + bit2a.get(x) or get_match(fixed_bit, protocol) or + get_match_variable(variable_bit, protocol, n_parties), + length)) + input_cost = apply_length( + get_match(fixed_input, protocol) or \ + get_match_variable(variable_input, protocol, n_parties), length) + output_cost = Comm( + get_match(reveal, protocol) or \ + get_match_variable(reveal_variable, protocol, n_parties) or \ + (n_parties - 1) * 2) + random_cost = apply_length( + get_match_variable(variable_random, protocol, n_parties), length) + if not random_cost: + random_cost = n_parties * input_cost + square_unit = get_cost(fixed_square, variable_square, protocol, n_parties) + if not square_unit: + def square_unit(l): + unit = apply_length(y, l) + return Comm(0, (unit[1] or unit[0]) + sum( + apply_length(output_cost, l))) + square_cost = apply_length(square_unit, length) + res += square_cost * get_int('square') + if bit_cost: + res += bit_cost * get_int('bit') + elif get_match_variable(bits_from_squares, protocol, n_parties): + if square_cost: + if get_match(ring, protocol): + sb_cost = apply_length(square_unit, length + 1) + else: + sb_cost = square_cost + bit_cost = Comm(0, offline=sum(sb_cost + output_cost * length)) + else: + bit_cost = Comm(0, offline=sum( + unit + random_cost + length * output_cost)) + res += bit_cost * get_int('bit') + else: + bit_cost = Comm(0, offline=sum( + threshold * unit + (threshold + 1) * input_cost)) + res += bit_cost * get_int('bit') + res += unit * n_mults + if not unit: + sh_protocol = re.sub('mal-', '', protocol) + sh_unit = get_mul_cost(sh_protocol, n_parties) + sh_random_unit = get_match_variable( + variable_random, sh_protocol, n_parties) + if sh_unit: + res += length * Comm( + sum(2 * output_cost * n_mults), + int(n_mults * (3 * sh_random_unit + 2 * sh_unit + \ + 2 * sum(output_cost)))) + res += Comm(get_match(trunc_pr, protocol)) * length * \ + get_int('probabilistic truncation') + res += Comm(bit2a.get(x)) * length * get_int('bit2A') + res += output_cost * length * get_int('open') + res += output_cost * get_bit('open') + res += get_match(dabit_from_bit, protocol) * bit_cost * get_int('dabit') + res += random_cost * get_int('random') + res += get_int('cut random') * apply_length( + get_match(fixed_randoms, protocol), length) + shuffle_correction = not find_match(shuffle_application, protocol) + def get_node(): + req_node = Tape.ReqNode("") + req_node.aggregate() + return req_node + for x in req_num: + if len(x) >= 3 and x[0] == 'modp': + if x[1] == 'input': + res += input_cost * req_num[x] + elif x[1] == 'shuffle application': + shuffle_cost = apply_length( + get_match(shuffle_application, protocol), length) + if shuffle_cost: + res += Comm(shuffle_cost) * req_num[x] * x[2] + elif find_match(cheap_dot_product, protocol) or \ + 'dealer' in protocol: + res += shuffle_base.n_swaps(x[2]) * (threshold + 1) * \ + req_num[x] * unit * (x[3] + malicious) + elif shuffle_correction: + node = get_node() + shuffle_base.add_apply_usage( + node, x[2], x[3], add_shuffles=False) + node.num = -node.num + if not use_triple_number: + node.num['modp', 'triple'] = 0 + shuffle_base.add_apply_usage( + node, x[2], x[3], add_shuffles=False, + malicious=malicious, n_relevant_parties=threshold + 1) + res += req_num[x] * \ + expected_communication(protocol, node.num, length, + force_triple_use=True) + elif x[1] == 'shuffle generation': + if 'dealer' in protocol: + res += Comm( + req_num[x] * shuffle_base.n_swaps(x[2]) * length) + else: + req_node = get_node() + shuffle_base.add_gen_usage( + req_node, x[2], add_shuffles=False) + req_node.num = -req_node.num + if shuffle_correction: + shuffle_base.add_gen_usage( + req_node, x[2], add_shuffles=False, + malicious=malicious, + n_relevant_parties=threshold + 1) + res += req_num[x] * \ + expected_communication(protocol, req_node.num, length) + elif x[0] == 'matmul': + mm_unit = Comm() + if use_diag_matrix: + dims = list(x[1]) + if dims[0] > dims[2]: + dims[0::2] = dims[2::-2] + mm_unit += Comm( + offline=use_diag_matrix(n_parties, length, dims)) + matrix_cost = Comm(unit.online / 2) + for idx in ((0, 1), (1, 2)): + mm_unit += Comm(matrix_cost.online) * \ + x[1][idx[0]] * x[1][idx[1]] + mm_unit += Comm(offline=matrix_cost.offline) * x[1][0] * x[1][2] + res += mm_unit * req_num[x] + elif re.search('edabit', x[0]): + edabit = get_match_variable(variable_edabit, protocol, n_parties) + if edabit: + res += Comm(offline=edabit(x[1])(length)) * req_num[x] + res.n_parties = n_parties + return res diff --git a/Compiler/decision_tree.py b/Compiler/decision_tree.py index 87f2789d..d5e7ea84 100644 --- a/Compiler/decision_tree.py +++ b/Compiler/decision_tree.py @@ -636,7 +636,8 @@ def preprocess_pandas(data): elif pandas.api.types.is_object_dtype(t): values = list(filter(lambda x: isinstance(x, str), list(data.iloc[:,i].unique()))) - print('converting the following to unary:', values) + print('converting the following to unary from %d: %s' % + (len(res), values)) if len(values) == 2: res.append(data.iloc[:,i].to_numpy() == values[1]) types.append('b') diff --git a/Compiler/dijkstra.py b/Compiler/dijkstra.py index f4a3c46f..e7660c22 100644 --- a/Compiler/dijkstra.py +++ b/Compiler/dijkstra.py @@ -98,7 +98,7 @@ class HeapQ(object): self.size = MemValue(int_type(0)) self.int_type = int_type self.basic_type = basic_type - prog.reading('heap queue', 'KS14') + prog.reading('heap queue', 'KS14', 'Section 5.1') print('heap: %d levels, depth %d, size %d, index size %d' % \ (self.levels, self.depth, self.heap.oram.size, self.value_index.size)) def update(self, value, prio, for_real=True): @@ -243,7 +243,7 @@ def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=None, :param int_type: secret integer type (default: sint) """ - prog.reading("Dijkstra's algorithm", "KS14") + prog.reading("Dijkstra's algorithm", "KS14", "Section 5.2") vert_loops = n_loops * e_index.size // edges.size \ if n_loops else -1 dist = oram_type(e_index.size, entry_size=(32,log2(e_index.size)), \ diff --git a/Compiler/floatingpoint.py b/Compiler/floatingpoint.py index ba065990..360dd98c 100644 --- a/Compiler/floatingpoint.py +++ b/Compiler/floatingpoint.py @@ -59,7 +59,7 @@ def EQZ(a, k): v = sbitvec(a, k).v bit = util.tree_reduce(operator.and_, (~b for b in v)) return types.sintbit.conv(bit) - prog.reading('equality', 'ABZS13') + prog.reading('equality', 'CdH10', 'Protocol 3.7') return prog.non_linear.eqz(a, k) def bits(a,m): @@ -313,9 +313,10 @@ def BitDecRingRaw(a, k, m): return bits[:m] else: if program.Program.prog.use_edabit(): - r, r_bits = types.sint.get_edabit(m, strict=False) + r, r_bits = types.sint.get_edabit(m, strict=False, size=a.size) elif program.Program.prog.use_dabit: - r, r_bits = zip(*(types.sint.get_dabit() for i in range(m))) + r, r_bits = zip(*(types.sint.get_dabit(size=a.size) + for i in range(m))) r = types.sint.bit_compose(r) else: r_bits = [types.sint.get_random_bit() for i in range(m)] @@ -334,7 +335,8 @@ def BitDecRing(a, k, m): return [types.sintbit.conv(bit) for bit in reversed(bits)][::-1] def BitDecFieldRaw(a, k, m, bits_to_compute=None): - comparison.program.reading('bit decomposition', 'ABZS13') + comparison.program.reading('bit decomposition', 'CdH10-fixed', + 'Protocol 3.7') instructions_base.set_global_vector_size(a.size) r_dprime = types.sint() r_prime = types.sint() @@ -362,7 +364,7 @@ def Pow2(a, l): return Pow2_from_bits(t) def Pow2_from_bits(bits): - comparison.program.reading('power of two', 'ABZS13') + comparison.program.reading('power of two', 'ABZS13', 'Section 3') m = len(bits) t = list(bits) pow2k = [None for i in range(m)] @@ -419,7 +421,7 @@ def Trunc(a, l, m, compute_modulo=False, signed=False): return TruncInRing(a, l, Pow2(m, l)) else: kappa = program.Program.prog.security - prog.reading('secret truncation', 'ABZS13') + prog.reading('secret truncation', 'ABZS13', 'Section 3') r = [types.sint() for i in range(l)] r_dprime = types.sint(0) r_prime = types.sint(0) @@ -460,7 +462,7 @@ def Trunc(a, l, m, compute_modulo=False, signed=False): @instructions_base.ret_cisc def TruncInRing(to_shift, l, pow2m): - comparison.program.reading('secret truncation', 'DEK20') + comparison.program.reading('secret truncation', 'DEK20', 'Section 3.2.3') n_shift = int(program.Program.prog.options.ring) - l bits = util.bit_decompose(to_shift, l) rev = types.sint.bit_compose(reversed(bits)) @@ -564,7 +566,8 @@ def TruncPrRing(a, k, m, signed=True): res = sint() trunc_pr(res, a, k, m) else: - prog.reading('probabilistic truncation', 'CdH10-fixed') + prog.reading('probabilistic truncation', 'CdH10-fixed', + 'Protocol 3.1') # extra bit to mask overflow prog.curr_tape.require_bit_length(1) if prog.use_edabit() or prog.use_split() > 2: @@ -594,7 +597,7 @@ def TruncPrField(a, k, m): program.Program.prog.trunc_pr_warning() prog = program.Program.prog - prog.reading('probabilistic truncation', 'CdH10-fixed') + prog.reading('probabilistic truncation', 'CdH10-fixed', 'Protocol 3.1') b = two_power(k-1) + a r_prime, r_dprime = types.sint(), types.sint() comparison.PRandM(r_dprime, r_prime, [types.sint() for i in range(m)], @@ -632,7 +635,7 @@ def SDiv(a, b, l, round_nearest=False): x1 = comparison.TruncZeros(x - x2, 2 * l + 1, l, True) y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, nearest=round_nearest, signed=False) - y = y.round(2 * l + 1, l + 1, nearest=round_nearest) + y = y.round(2 * l + 1, l + 1, nearest=round_nearest, signed=False) return y def SDiv_mono(a, b, l): @@ -684,7 +687,7 @@ def BITLT(a, b, bit_length): def BitDecFull(a, n_bits=None, maybe_mixed=False): from .library import get_program, do_while, if_, break_point from .types import sint, regint, longint, cint - get_program().reading('full bit decomposition', 'NO07') + get_program().reading('full bit decomposition', 'NO07', 'Figure 2') p = get_program().prime assert p bit_length = p.bit_length() @@ -731,6 +734,7 @@ def BitDecFull(a, n_bits=None, maybe_mixed=False): r = sint.get_edabit(bit_length, True) bs[j].link(r[0]) tbits[j].link(sbitvec.from_vec(r[1])) + tbits[j] = tbits[j].v else: for i in range(bit_length): tbits[j][i].link(sint.get_random_bit()) diff --git a/Compiler/instructions.py b/Compiler/instructions.py index 9b7ee16f..76636436 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -1348,6 +1348,9 @@ class randoms(base.Instruction): arg_format = ['sw','int'] field_type = 'modp' + def add_usage(self, req_node): + req_node.increment((self.field_type, 'cut random'), self.get_size()) + @base.vectorize class randomfulls(base.DataInstruction): """ Store share(s) of a fresh secret random element in secret @@ -1365,6 +1368,12 @@ class randomfulls(base.DataInstruction): return len(self.args) class unsplit(base.VectorInstruction, base.Ciscable): + """ Bit injection (conversion from binary to arithmetic). + + :param: destination (sint) + :param: source (sbits) + + """ __slots__ = [] code = base.opcodes['UNSPLIT'] arg_format = tools.chain(['sb'], itertools.repeat('sw')) @@ -2568,6 +2577,12 @@ class dotprods(base.VarArgsInstruction, base.DataInstruction, for reg in self.args[i + 2:i + self.args[i]]: yield reg + def add_usage(self, req_num): + base.DataInstruction.add_usage(self, req_num) + req_num.increment( + (self.field_type, 'dot product'), + self.get_size() * len(list(self.bases(iter(self.args))))) + class matmul_base(base.DataInstruction): data_type = 'triple' is_vec = lambda self: True @@ -2718,6 +2733,10 @@ class trunc_pr(base.VarArgsInstruction): code = base.opcodes['TRUNC_PR'] arg_format = tools.cycle(['sw','s','int','int']) + def add_usage(self, req_node): + req_node.increment(('modp', 'probabilistic truncation'), + self.get_size() * len(self.args) // 4) + class shuffle_base(base.DataInstruction): n_relevant_parties = 2 @@ -2725,12 +2744,12 @@ class shuffle_base(base.DataInstruction): super(shuffle_base, self).__init__(*args, **kwargs) prog = base.program if re.match('ring|rep-field|sy-rep.*', prog.options.execute or ''): - ref = 'AHIK+22' + ref = 'AHIK+22', 'Protocol 3.2' elif prog.options.execute: - ref = 'KS14' + ref = 'KS14', 'Section 4.3' else: - ref = ('AHIK+22', 'KS14') - base.program.reading('secure shuffling', ref) + ref = ('AHIK+22', 'KS14'), None + base.program.reading('secure shuffling', *ref) @staticmethod def logn(n): @@ -2741,26 +2760,39 @@ class shuffle_base(base.DataInstruction): logn = cls.logn(n) return logn * 2 ** logn - 2 ** logn + 1 - def add_gen_usage(self, req_node, n): + @classmethod + def add_gen_usage(self, req_node, n, add_shuffles=True, malicious=True, + n_relevant_parties=None): # hack for unknown usage req_node.increment(('bit', 'inverse'), float('inf')) # minimal usage with two relevant parties logn = self.logn(n) n_switches = self.n_swaps(n) - for i in range(self.n_relevant_parties): + n_relevant_parties = n_relevant_parties or self.n_relevant_parties + for i in range(n_relevant_parties): req_node.increment((self.field_type, 'input', i), n_switches) - # multiplications for bit check - req_node.increment((self.field_type, 'triple'), - n_switches * self.n_relevant_parties) + if malicious: + # multiplications for bit check + req_node.increment((self.field_type, 'triple'), + n_switches * n_relevant_parties) + if add_shuffles: + req_node.increment((self.field_type, 'shuffle generation', n)) - def add_apply_usage(self, req_node, n, record_size): + @classmethod + def add_apply_usage(self, req_node, n, record_size, add_shuffles=True, + malicious=True, n_relevant_parties=None): req_node.increment(('bit', 'inverse'), float('inf')) logn = self.logn(n) - n_switches = self.n_swaps(n) * self.n_relevant_parties - if n != 2 ** logn: + n_switches = self.n_swaps(n) * \ + (n_relevant_parties or self.n_relevant_parties) + real_record_size = record_size + if n != 2 ** logn and malicious: record_size += 1 req_node.increment((self.field_type, 'triple'), n_switches * record_size) + if add_shuffles: + req_node.increment( + (self.field_type, 'shuffle application', n, real_record_size)) @base.gf2n class secshuffle(base.VectorInstruction, shuffle_base): @@ -2824,6 +2856,9 @@ class applyshuffle(shuffle_base, base.Mergeable): for i in range(0, len(self.args), 6): self.add_apply_usage(req_node, self.args[i], self.args[i + 3]) + def handles(self): + return self.args[::4] + class delshuffle(base.Instruction): """ Delete secure shuffle. @@ -2874,7 +2909,7 @@ class sqrs(base.CISC): arg_format = ['sw', 's'] def expand(self): - s = [program.curr_block.new_reg('s') for i in range(6)] + s = [type(self.args[0])() for i in range(6)] c = [self.args[0].clear_type() for i in range(2)] square(s[0], s[1]) subs(s[2], self.args[1], s[0]) diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index 0a09387e..9156d720 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -1112,6 +1112,8 @@ class Instruction(object): new_args.append(arg.copy()) subs[arg] = new_args[-1] else: + if isinstance(arg, program.curr_tape.Register) and arg.caller: + print(util.format_trace(arg.caller), file=sys.stderr) new_args.append(arg) return new_args diff --git a/Compiler/library.py b/Compiler/library.py index 73f54662..2d45d120 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -109,7 +109,7 @@ def print_str(s, *args, print_secrets=False): if print_secrets: val.output() else: - secret_error() + secret_error(args[i]) elif isinstance(val, cfloat): val.print_float_plain() elif isinstance(val, (list, tuple)): @@ -831,7 +831,7 @@ def loopy_chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads= def loopy_odd_even_merge_sort(a, sorted_length=1, n_parallel=32, n_threads=None, key_indices=None): - get_program().reading('sorting', 'KSS13') + get_program().reading('sorting', 'KSS13', 'Section 6.1') a_in = a if isinstance(a_in, list): a = Array.create_from(a) @@ -1592,8 +1592,13 @@ def _run_and_link(function, g=None, lock_lists=True, allow_return=False): pre = copy.copy(g) res = function() if res is not None and not allow_return: + if get_program().options.flow_optimization: + suffix = ' and avoid -l/--flow-optimization to keep ' \ + 'compile-time branching' + else: + suffix = '' raise CompilerError('Conditional blocks cannot return values. ' - 'Use if_else instead: https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.regint.if_else') + 'Use if_else instead: https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.regint.if_else' + suffix) _link(pre, g) return res @@ -2052,7 +2057,8 @@ def FPDiv(a, b, k, f, simplex_flag=False, nearest=False): """ Goldschmidt method as presented in Catrina10, """ - get_program().reading('fixed-point division', 'CdH10-fixed') + get_program().reading('fixed-point division', 'CdH10-fixed', + 'Protocol 3.3') prime = get_program().prime if 2 * k == int(get_program().options.ring) or \ (prime and 2 * k <= (prime.bit_length() - 1)): diff --git a/Compiler/ml.py b/Compiler/ml.py index 30059d5b..df2f79d8 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -63,7 +63,7 @@ import re from Compiler import mpc_math, util from Compiler.types import * -from Compiler.types import _unreduced_squant +from Compiler.types import _unreduced_squant, _single from Compiler.library import * from Compiler.util import is_zero, tree_reduce from Compiler.comparison import CarryOutRawLE @@ -927,6 +927,11 @@ class Dense(DenseBase): progress('f input') def _forward(self, batch=None): + if not issubclass(self.W.value_type, _single) \ + or not issubclass(self.X.value_type, _single): + raise CompilerError( + 'dense inputs have to be sfix in arithmetic circuits') + if batch is None: batch = regint.Array(self.N) batch.assign(regint.inc(self.N)) @@ -2160,6 +2165,11 @@ class Conv2d(ConvBase): return weights_h * weights_w * n_channels_in def _forward(self, batch): + if not issubclass(self.weights.value_type, _single) \ + or not issubclass(self.X.value_type, _single): + raise CompilerError( + 'convolution inputs have to be sfix in arithmetic circuits') + if self.tf_weight_format: assert(self.weight_shape[3] == self.output_shape[-1]) weights_h, weights_w, _, _ = self.weight_shape @@ -4058,7 +4068,8 @@ class SGD(Optimizer): # divide by len(batch) by truncation # increased rate if len(batch) is not a power of two diff = red_old - nabla_vector - pre_trunc = diff.v * rate.v + # assuming rate is already synchronized + pre_trunc = diff.v.mul(rate.v, sync=False) momentum_value.assign_vector(diff, base) k = max(nabla_vector.k, rate.k) + rate.f m = rate.f + int(log_batch_size) diff --git a/Compiler/mpc_math.py b/Compiler/mpc_math.py index 580aa994..9213c866 100644 --- a/Compiler/mpc_math.py +++ b/Compiler/mpc_math.py @@ -131,7 +131,7 @@ def p_eval(p_c, x): local_aggregation = 0 # Evaluation of the Polynomial for i, pre_mult in zip(p_c[1:], pre_mults): - local_aggregation += pre_mult.mul_no_reduce(x.coerce(i)) + local_aggregation += pre_mult.mul_no_reduce(i) return local_aggregation.reduce_after_mul() + p_c[0] @@ -148,7 +148,8 @@ def p_eval(p_c, x): # @return b2: \{0,1\} value. Returns one when reduction to # \pi is greater than \pi/2. def sTrigSub(x): - library.get_program().reading('trigonometric functions', 'AS19') + library.get_program().reading('trigonometric functions', 'AS19', + 'Section 4') # reduction to 2* \pi f = x * (1.0 / (2 * pi)) f = trunc(f) @@ -267,7 +268,7 @@ def exp2_fx(a, zero_output=False, as19=False): :return: :math:`2^a` if it is within the range. Undefined otherwise """ - library.get_program().reading('exponential', 'AS19') + library.get_program().reading('exponential', 'AS19', 'Protocol 6') def exp_from_parts(whole_exp, frac): class my_fix(type(a)): pass @@ -316,7 +317,7 @@ def exp2_fx(a, zero_output=False, as19=False): s = sint.conv(bits[-1]) lower = a.v.raw_mod2m(a.f) - (lower_overflow << a.f) else: - bits = sbitvec(a.v, a.k) + bits = sbitvec(a.v, a.k).v s = sint.conv(bits[-1]) lower = sint.bit_compose(sint.conv(b) for b in bits[:a.f]) higher_bits = bits[a.f:n_bits] @@ -437,7 +438,7 @@ def log2_fx(x, use_division=True): :return: (sfix) the value of :math:`\log_2(x)` """ - library.get_program().reading('logarithm', 'AS19') + library.get_program().reading('logarithm', 'AS19', 'Section 5') if isinstance(x, types._fix): # transforms sfix to f*2^n, where f is [o.5,1] bounded # obtain number bounded by [0,5 and 1] by transforming input to sfloat @@ -815,7 +816,7 @@ def sqrt(x, k=None, f=None): :return: square root of :py:obj:`x` (sfix). """ - library.get_program().reading('square root', 'AS19') + library.get_program().reading('square root', 'AS19', 'Section 3') if k is None: k = x.k if f is None: @@ -837,7 +838,8 @@ def atan(x): :return: arctan of :py:obj:`x` (sfix). """ - library.get_program().reading('inverse trigonometric functions', 'AS19') + library.get_program().reading('inverse trigonometric functions', 'AS19', + 'Protocol 5') # obtain absolute value of x s = x < 0 x_abs = s.if_else(-x, x) diff --git a/Compiler/non_linear.py b/Compiler/non_linear.py index 7cdc7d27..1610769c 100644 --- a/Compiler/non_linear.py +++ b/Compiler/non_linear.py @@ -26,7 +26,7 @@ class NonLinear: if prog.use_trunc_pr and m and ( not prog.options.ring or \ prog.use_trunc_pr <= (int(prog.options.ring) - k)): - prog.reading('probabilistic truncation', 'DEK20') + prog.reading('probabilistic truncation', 'DEK20', 'Section 3.2.2') if prog.options.ring: comparison.require_ring_size(k, 'truncation') else: @@ -92,8 +92,9 @@ class Prime(Masking): def kor(self, d): return KOR(d) - def require_bit_length(self, bit_length, op): + def require_bit_length(self, bit_length, op, slack=0): prog = program.Program.prog + bit_length += slack * (not prog.allow_tight_parameters) if bit_length > 32: prog.curr_tape.require_bit_length(bit_length - 1, reason=op) @@ -146,7 +147,7 @@ class KnownPrime(NonLinear): else: return super(KnownPrime, self).ltz(a, k) - def require_bit_length(self, bit_length, op): + def require_bit_length(self, *args, **kwargs): pass class Ring(Masking): @@ -189,5 +190,5 @@ class Ring(Masking): def ltz(self, a, k): return LtzRing(a, k) - def require_bit_length(self, bit_length, op): - comparison.require_ring_size(bit_length, op) + def require_bit_length(self, *args, **kwargs): + comparison.require_ring_size(*args, **kwargs) diff --git a/Compiler/papers.py b/Compiler/papers.py index e13b9d0b..9e47af7e 100644 --- a/Compiler/papers.py +++ b/Compiler/papers.py @@ -7,7 +7,7 @@ papers = { 'AN17': 'https://eprint.iacr.org/2017/816', 'AS19': 'https://eprint.iacr.org/2019/354', 'AHIK+22': 'https://eprint.iacr.org/2022/1595', - 'CdH10': 'https://www.researchgate.net/publication/225092133_Improved_Primitives_for_Secure_Multiparty_Integer_Computation, https://doi.org/10.1007/978-3-642-15317-4_13 (paywall)', + 'CdH10': 'https://www.researchgate.net/publication/225092133, https://doi.org/10.1007/978-3-642-15317-4_13 (paywall)', 'CdH10-fixed': 'https://www.ifca.ai/pub/fc10/31_47.pdf', 'CCD88': 'https://doi.org/10.1145/62212.62214', 'DDNNT15': 'https://eprint.iacr.org/2015/1006', diff --git a/Compiler/program.py b/Compiler/program.py index 13f8810c..54daa9ba 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -25,6 +25,7 @@ from Compiler.instructions_base import RegType from . import allocator as al from . import util from .papers import * +from .cost import expected_communication data_types = dict( triple=0, @@ -131,7 +132,8 @@ class Program(object): assert self.rabbit_gap() print(", for example, %d." % self.prime) self.prime = bad_prime - except ImportError: + except (ImportError, AssertionError): + self.prime = bad_prime print(".") if options.execute: print("Use '-- --prime ' to specify the prime for " @@ -251,6 +253,15 @@ class Program(object): else: print("Use '--execute ' to see recommended reading " "on the basic protocol.") + if self.options.garbled: + if not self.options.binary: + raise CompilerError( + "You have to specify a default bit length using '--binary' " + "for garbled circuits.") + self.optimize_for_gc() + self.allow_tight_parameters = True + self.warned_about_tightness = False + self.warned_about_a2b = False Program.prog = self from . import comparison, instructions, instructions_base, types @@ -439,6 +450,9 @@ class Program(object): else: self.req_num += tape.req_num + def required_bit_length(self, t): + return max(x.req_bit_length[t] for x in self.tapes) + def write_bytes(self): """Write all non-empty threads and schedule to files.""" @@ -455,7 +469,7 @@ class Program(object): sch_file.write("1 0\n") sch_file.write("0\n") sch_file.write(" ".join(sys.argv) + "\n") - req = max(x.req_bit_length["p"] for x in self.tapes) + req = self.required_bit_length("p") if self.options.ring: sch_file.write("R:%s" % self.options.ring) elif self.options.prime: @@ -470,6 +484,14 @@ class Program(object): assert len(req2) <= 2 if req2: sch_file.write("lg2:%s" % max(req2)) + sch_file.write("\n") + exp = self.expected_communication() + if exp: + sch_file.write( + "online:%d offline:%d n_parties:%d\n" % ( + exp.sanitize() + (exp.n_parties,))) + else: + sch_file.write('no expections\n') sch_file.close() h = hashlib.sha256() for tape in self.tapes: @@ -590,6 +612,10 @@ class Program(object): # communicate protocol compability Compiler.instructions.active(self._always_active) + # communicate mulm usage to VM + if self.use_mulm != 1: + self.relevant_opts.add("no_mulm") + self.write_bytes() if self.options.asmoutfile: @@ -743,6 +769,15 @@ class Program(object): def used_splits(self): return self._split + def have_a2b(self): + if self.use_split() or self.use_edabit() or self.use_dabit: + return True + if not self.warned_about_a2b: + print( + 'WARNING: No option selected for A2B conversion, defaulting ' + 'to edaBits. Use -X/-Y/-Z to get rid of this warning.') + self.warned_about_a2b = True + def use_square(self, change=None): """Setting whether to use preprocessed square tuples (default: false). @@ -869,15 +904,30 @@ class Program(object): bl = inst.args[0] return (abs(bl.i) + 63) // 64 * 8 - def reading(self, concept, reference): - key = concept, reference + def reading(self, concept, reference, part=None): + key = concept, reference, part if self.options.papers and key not in self.recommended: if isinstance(reference, tuple): + assert part is None reference = ', '.join(papers.get(x) or x for x in reference) - print('Recommended reading on %s: %s' % ( - concept, papers.get(reference) or reference)) + suffix = ' (%s)' % part or '' + print('Recommended reading on %s: %s%s' % ( + concept, papers.get(reference) or reference, suffix)) self.recommended.add(key) + def expected_communication(self): + if self.options.ring: + bit_length = int(self.options.ring) + elif self.options.prime: + bit_length = self.prime.bit_length() + else: + # check against OnlineOptions.cpp + bit_length = max(self.required_bit_length("p"), 128) + bit_length = int(math.ceil(bit_length / 64) * 64) + length = int(math.ceil(bit_length / 8)) + return expected_communication( + self.options.execute, self.req_num or Tape.ReqNum(), length) + class Tape: """A tape contains a list of basic blocks, onto which instructions are added.""" @@ -1452,6 +1502,12 @@ class Tape: __rmul__ = __mul__ + def __neg__(self): + res = Tape.ReqNum() + for i, count in list(self.items()): + res[i] = -count + return res + def set_all(self, value): if Program.prog.options.verbose and \ value == float("inf") and self["all", "inv"] > 0: @@ -1644,6 +1700,19 @@ class Tape: __float__ = __int__ + def __eq__(self, other): + raise CompilerError("equality testing not implemented") + + __ne__ = __eq__ + + class _no_secret_truth(_no_truth): + def __bool__(self): + raise CompilerError( + "Cannot branch on secret values like %s. " + "See https://mp-spdz.readthedocs.io/en/latest/troubleshooting.html#cannot-branch-on-secret-values. " % \ + type(self).__name__ + ) + class Register(_no_truth): """ Class for creating new registers. The register's index is automatically assigned @@ -1755,6 +1824,17 @@ class Tape: return self.vector or [self] def __getitem__(self, index): + try: + if isinstance(index, slice): + for x in index.start, index.stop, index.step: + if x is not None: + int(x) + else: + int(index) + except: + raise CompilerError( + 'cannot address vectors with run-time indices, ' + 'use (Multi)Array instead') if self.size == 1 and index == 0: return self if not self.vector: diff --git a/Compiler/types.py b/Compiler/types.py index 49e3f74b..b41dc3b0 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -591,7 +591,7 @@ class _structure(Tape._no_truth): return cls.int_type.reg_type raise CompilerError('type not supported as argument: %s' % cls) -class _secret_structure(_structure): +class _secret_structure(Tape._no_secret_truth, _structure): @classmethod def input_tensor_from(cls, player, shape): """ Input tensor secretly from player. @@ -1105,10 +1105,9 @@ class cint(_clear, _int): @staticmethod def in_immediate_range(value, regint=False): if value and not regint: - # +1 for sign - bit_length = 1 + int(math.ceil(math.log(abs(value), 2))) + # slack for sign program.non_linear.require_bit_length( - bit_length, 'integer conversion') + value.bit_length(), 'integer conversion', slack=1) return value < 2**31 and value >= -2**31 @vectorize_init @@ -1321,7 +1320,7 @@ class cint(_clear, _int): :param other: cint/regint/int """ return self >> other - def round(self, k, m, nearest=None, signed=False): + def round(self, k, m, nearest=None, signed=True): if signed: self += 2 ** (k - 1) self += 2 ** (m - 1) @@ -2383,7 +2382,7 @@ class _secret(_arithmetic_register, _secret_structure): @set_instruction_type @read_mem_value @vectorize - def mul(self, other): + def mul(self, other, sync=True): """ Secret multiplication. Either both operands have the same size or one size 1 for a value-vector multiplication. @@ -2396,7 +2395,7 @@ class _secret(_arithmetic_register, _secret_structure): res = type(self)(size=x.size) mulrs(res, x, y) return res - if program.use_mulm == 1: + if program.use_mulm == 1 or not sync: mulm = instructions.mulm elif program.use_mulm == -1: mulm = lambda res, x, y: instructions.mulm(res, x, cint(regint(y))) @@ -2530,7 +2529,7 @@ class _secret(_arithmetic_register, _secret_structure): writesharestofile(regint.conv(position), *shares) class sint(_secret, _int): - """ + r""" Secret integer in the protocol-specific domain. It supports operations with :py:class:`sint`, :py:class:`cint`, :py:class:`regint`, and Python integers. Operations where one of @@ -2559,6 +2558,12 @@ class sint(_secret, _int): undefined and potentially insecure if the operands are longer than the bit length. + Instances of sint are understood to be signed. This means that, + for modulo :math:`N`, numbers in :math:`[0,N/2)` are understood as + positive numbers whereas numbers in :math:`[N/2,N)` are understood + to be negative, namely :math:`x-N`. This ensures expected + arithmetic such as :math:`-1 + 1 = (N-1) + 1 = N = 0 \mod N`. + See :ref:`nonlinear` for an overview of how non-linear computation is implemented. @@ -2826,7 +2831,7 @@ class sint(_secret, _int): self.load_other(val.v.round(val.k, val.f, nearest=val.round_nearest)) elif isinstance(val, sbitvec): - super(sint, self).__init__('s', val=val, size=val[0].n) + super(sint, self).__init__('s', val=val, size=val.v[0].n) else: super(sint, self).__init__('s', val=val, size=size) @@ -3001,13 +3006,19 @@ class sint(_secret, _int): maybe_mixed) def TruncMul(self, other, k, m, nearest=False): + if not nearest and not program.warned_about_tightness and \ + program.options.ring and int(program.options.ring) == k: + print('WARNING: Using tight parameters. ' + 'Increase ring size or reduce fixed-point precision ' + 'for increased efficiency') + program.warned_about_tightness = True return (self * other).round(k, m, nearest, signed=True) def TruncPr(self, k, m, signed=True): return floatingpoint.TruncPr(self, k, m, signed=signed) @vectorize - def round(self, k, m, nearest=False, signed=False): + def round(self, k, m, nearest=False, signed=True): """ Truncate and maybe round secret :py:obj:`k`-bit integer by :py:obj:`m` bits. :py:obj:`m` can be secret if :py:obj:`nearest` is false, in which case the truncation will be @@ -3625,7 +3636,7 @@ class _bitint(Tape._no_truth): return s ^ carry, a ^ (s & (carry ^ a)) @staticmethod - def bit_comparator(a, b): + def bit_comparator(a, b, m=None): long_one = util.long_one(a + b) op = lambda y,x,*args: (util.if_else(x[1], x[0], y[0]), \ util.if_else(x[1], long_one, y[1])) @@ -3794,7 +3805,11 @@ class _bitint(Tape._no_truth): if const_rounds: return self.get_highest_different_bits(a, b, index) else: - return self.bit_comparator(a, b) + try: + return self.maybe_function( + self.bit_comparator, a, b, result_length=2) + except: + return self.bit_comparator(a, b) def __lt__(self, other): if self.reverse_type(other): @@ -3826,9 +3841,17 @@ class _bitint(Tape._no_truth): if self.reverse_type(other): return other == self diff = self ^ other - diff_bits = [x.bit_not() for x in diff.bit_decompose()[:bit_length]] - return self.comp_result(util.tree_reduce(lambda x, y: x.bit_and(y), - diff_bits)) + diff_bits = diff.bit_decompose()[:bit_length] + try: + res = self.maybe_function(self.eqz, diff_bits, [], 1) + except: + res = self.eqz(diff_bits) + return self.comp_result(res[0]) + + @staticmethod + def eqz(bits, other_bits=None, m=None): + diff_bits = [x.bit_not() for x in bits] + return [util.tree_reduce(lambda x, y: x.bit_and(y), diff_bits)] def __ne__(self, other): return (self == other).bit_not() @@ -4052,7 +4075,7 @@ class cfix(_number, _structure): :py:class:`cfix` if the other operand is public (cfix/regint/cint/int) or :py:class:`sfix` if the other operand is an sfix. It also support comparisons (``==, !=, <, <=, >, >=``), - returning either :py:class:`regint` or :py:class:`sbitint`. + returning either :py:class:`regint` or :py:class:`sintbit`. Similarly to :py:class:`Compiler.types.cint`, this type is restricted to arithmetic circuits due to the fact that only @@ -4806,7 +4829,7 @@ class _fix(_single): return self._new(self.v[index]) def __iter__(self): - return (self._new(x) for x in self.v) + return (self._new(x, k=self.k, f=self.f) for x in self.v) @vectorize def add(self, other): @@ -4839,7 +4862,8 @@ class _fix(_single): f -= 1 v //= 2 k = len(bin(abs(v))) - 1 - other = self.multipliable(v, k, f, self.size) + val = self.v.TruncMul(v, self.k + f, f, nearest=self.round_nearest) + return self._new(val, k=self.k, f=self.f) try: other = self.coerce(other, equal_precision=False) except: @@ -4983,19 +5007,29 @@ class sfix(_fix): It supports basic arithmetic (``+, -, *, /``), returning :py:class:`sfix`, and comparisons (``==, !=, <, <=, >, >=``), - returning :py:class:`sbitint`. The other operand can be any of + returning :py:class:`sintbit`. The other operand can be any of sfix/sint/cfix/regint/cint/int/float. It also supports ``abs()`` and ``**``. Note that the default precision (16 bits after the dot, 31 bits in total) only allows numbers up to :math:`2^{31-16-1} \\approx - 16000` with the smallest non-zero number being :math:`2^{-16}`. + 16000` with the smallest non-zero number being :math:`2^{-16} + \\approx 0.000015`. You can change this using :py:func:`set_precision`. Fixed-point multiplication is not linear in the sense of the computation domain. Therefore, techniques from :ref:`nonlinear` have to be used. + Many operations (including multiplication and division) use + probabilistic trunctation by default. This means that the results + are not deterministc but random within a small range around the + deterministic result. You can switch to (more expensive) + deterministic computation by setting + ``sfix.round_nearest`` to true. See `Catrina and de Hoogh + `_ for an introduction to + probabilistic truncation. + :params _v: int/float/regint/cint/sint/sfloat """ int_type = sint @@ -5079,7 +5113,10 @@ class sfix(_fix): return self.v def mul_no_reduce(self, other, res_params=None): - if not isinstance(other, type(self)): + if util.is_constant_float(other): + return self.unreduced( + self.v * cfix.int_rep(other, k=self.k, f=self.f)) + elif not isinstance(other, type(self)): return self * other assert self.f == other.f assert self.k == other.k @@ -6040,7 +6077,11 @@ class Array(_vectorizable): @read_mem_value def get_address(self, index, size=None): if isinstance(index, (_secret, _single)): - raise CompilerError('need cleartext index') + raise CompilerError( + 'Need cleartext index to address Array. If you need to address ' + 'using secret numbers, you need to use ORAM: ' + 'https://mp-spdz.readthedocs.io/en/latest/Compiler.html#' + 'module-Compiler.oram') key = str(index), size or 1 index = self.check(index, self.length, self.length) if (program.curr_block, key) not in self.address_cache: @@ -6486,6 +6527,10 @@ class Array(_vectorizable): M = Matrix(1, len(self), self.value_type, address=self.address) return M.dot(other) + def sum(self): + """ Sum of elements. """ + return self[:].sum() + def shuffle(self): """ Insecure shuffle in place. """ self.assign_vector(self.get(regint.inc(len(self)).shuffle())) @@ -7105,11 +7150,6 @@ class SubMultiArray(_vectorizable): res_matrix = Matrix(self.sizes[0], other.sizes[1], t) try: try: - # force matmuls for smaller sizes - a, c = res_matrix.sizes - if a * c / (a + c) < 2 and \ - self.value_type == other.value_type: - raise AttributeError() self.value_type.direct_matrix_mul skip_reduce = set((sint, sfix)) == \ set((self.value_type, other.value_type)) diff --git a/Compiler/util.py b/Compiler/util.py index aa83f063..0620a110 100644 --- a/Compiler/util.py +++ b/Compiler/util.py @@ -1,6 +1,7 @@ import math import operator from functools import reduce +from Compiler.exceptions import * def format_trace(trace, prefix=' '): if trace is None: @@ -91,8 +92,9 @@ def if_else(cond, a, b): else: return cond.if_else(a, b) except: - print(cond, a, b) - raise + raise CompilerError( + 'incompatible types for ternary/if-else operator: %s' % '/'.join( + type(x).__name__ for x in (cond, a, b))) def cond_swap(cond, a, b): if isinstance(cond, (bool, int)): diff --git a/ECDSA/Fake-ECDSA.cpp b/ECDSA/Fake-ECDSA.cpp index 226044f3..844d4db7 100644 --- a/ECDSA/Fake-ECDSA.cpp +++ b/ECDSA/Fake-ECDSA.cpp @@ -18,7 +18,7 @@ int main() KeySetup> key; string prefix = PREP_DIR "ECDSA/"; mkdir_p(prefix.c_str()); - write_online_setup(prefix, P256Element::Scalar::pr()); + P256Element::Scalar::write_setup(prefix); PRNG G; G.ReSeed(); generate_mac_keys>(key, 2, prefix, G); diff --git a/FHE/Ciphertext.cpp b/FHE/Ciphertext.cpp index 059861c7..f019b83e 100644 --- a/FHE/Ciphertext.cpp +++ b/FHE/Ciphertext.cpp @@ -60,7 +60,7 @@ void sub(Ciphertext& ans,const Ciphertext& c0,const Ciphertext& c1) void mul(Ciphertext& ans,const Ciphertext& c0,const Ciphertext& c1, const FHE_PK& pk) { - CODE_LOCATION + CODE_LOCATION_NO_SCOPE if (c0.params!=c1.params) { throw params_mismatch(); } if (ans.params!=c1.params) { throw params_mismatch(); } diff --git a/FHE/Diagonalizer.cpp b/FHE/Diagonalizer.cpp index 2b39284b..b1a5ff41 100644 --- a/FHE/Diagonalizer.cpp +++ b/FHE/Diagonalizer.cpp @@ -9,7 +9,7 @@ Diagonalizer::Diagonalizer(const MatrixVector& matrices, const FFT_Data& FTD, const FHE_PK& pk) : FTD(FTD) { - CODE_LOCATION + CODE_LOCATION_NO_SCOPE assert(not matrices.empty()); for (auto& matrix : matrices) diff --git a/FHE/FFT.cpp b/FHE/FFT.cpp index a3c1c8d9..a00dafd3 100644 --- a/FHE/FFT.cpp +++ b/FHE/FFT.cpp @@ -28,7 +28,7 @@ void NaiveFFT(vector& ans,vector& a,int N,const modp& theta,const Zp void FFT(vector& a,int N,const modp& theta,const Zp_Data& PrD) { - CODE_LOCATION + CODE_LOCATION_NO_SCOPE if (N==1) { return; } @@ -141,7 +141,7 @@ void FFT_Iter(vector& ioput, int n, const modp& root, const Zp_Data& PrD, void FFT_Iter(vector& ioput, int n, const vector& roots, const Zp_Data& PrD, bool start_with_one) { - CODE_LOCATION + CODE_LOCATION_NO_SCOPE assert(roots.size() > size_t(n)); diff --git a/FHE/FFT_Data.h b/FHE/FFT_Data.h index 4fb37ed4..234b9836 100644 --- a/FHE/FFT_Data.h +++ b/FHE/FFT_Data.h @@ -56,6 +56,7 @@ class FFT_Data const Zp_Data& get_prD() const { return prData; } const bigint& get_prime() const { return prData.pr; } + int phi_m() const { return R.phi_m(); } int m() const { return R.m(); } int num_slots() const { return R.phi_m(); } @@ -71,6 +72,8 @@ class FFT_Data const Ring& get_R() const { return R; } + void write_setup(const string& dir) const { prData.write_setup(dir); } + bool operator==(const FFT_Data& other) const { return not (*this != other); } bool operator!=(const FFT_Data& other) const; diff --git a/FHE/FHE_Keys.cpp b/FHE/FHE_Keys.cpp index 209f022e..cfc29344 100644 --- a/FHE/FHE_Keys.cpp +++ b/FHE/FHE_Keys.cpp @@ -4,6 +4,7 @@ #include "P2Data.h" #include "FFT_Data.h" #include "Tools/CodeLocations.h" +#include "Processor/OnlineOptions.h" #include "Math/modp.hpp" @@ -66,7 +67,7 @@ void FHE_PK::KeyGen(Rq_Element& sk, PRNG& G, int noise_boost) void FHE_PK::partial_key_gen(const Rq_Element& sk, const Rq_Element& a, PRNG& G, int noise_boost) { - CODE_LOCATION + CODE_LOCATION_NO_SCOPE FHE_PK& PK = *this; @@ -154,7 +155,7 @@ void FHE_PK::encrypt(Ciphertext& c, void FHE_PK::quasi_encrypt(Ciphertext& c, const Rq_Element& mess,const Random_Coins& rc) const { - CODE_LOCATION + CODE_LOCATION_NO_SCOPE if (&c.get_params()!=params) { throw params_mismatch(); } if (&rc.get_params()!=params) { throw params_mismatch(); } @@ -216,7 +217,7 @@ void FHE_SK::decrypt(Plaintext& mess,const Ciphertext& c) const Rq_Element FHE_SK::quasi_decrypt(const Ciphertext& c) const { - CODE_LOCATION + CODE_LOCATION_NO_SCOPE if (&c.get_params()!=params) { throw params_mismatch(); } @@ -284,6 +285,14 @@ void FHE_SK::dist_decrypt_1(vector& vv,const Ciphertext& ctx,int player_ PRNG G; G.ReSeed(); bigint mask; bigint two_Bd = 2 * Bd; + + bool verbose = OnlineOptions::singleton.has_option("verbose_dd"); + int max_bits = 0; + + if (verbose) + cerr << "Random bits in distributed decryption: " << two_Bd.numBits() + << endl; + for (int i=0; i<(*params).phi_m(); i++) { G.randomBnd(mask, two_Bd); @@ -292,7 +301,13 @@ void FHE_SK::dist_decrypt_1(vector& vv,const Ciphertext& ctx,int player_ vv[i] += mask; vv[i] %= mod; if (vv[i]<0) { vv[i]+=mod; } + + if (verbose) + max_bits = max(max_bits, vv[i].numBits()); } + + if (verbose) + cerr << "Maximum bits in distributed decryption: " << max_bits << endl; } diff --git a/FHE/NTL-Subs.cpp b/FHE/NTL-Subs.cpp index 5431e7c1..3bf14a7e 100644 --- a/FHE/NTL-Subs.cpp +++ b/FHE/NTL-Subs.cpp @@ -52,7 +52,7 @@ template <> int generate_semi_setup(int plaintext_length, int sec, FHE_Params& params, FFT_Data& FTD, bool round_up, int n) { - CODE_LOCATION + CODE_LOCATION_NO_SCOPE int m = 1024; int lgp = plaintext_length; bigint p; @@ -95,7 +95,7 @@ template <> int generate_semi_setup(int plaintext_length, int sec, FHE_Params& params, P2Data& P2D, bool round_up, int n) { - CODE_LOCATION + CODE_LOCATION_NO_SCOPE if (params.n_mults() > 0) throw runtime_error("only implemented for 0-level BGV"); @@ -113,12 +113,13 @@ int generate_semi_setup(int plaintext_length, int sec, int common_semi_setup(FHE_Params& params, int m, bigint p, int& lgp0, int lgp1, bool round_up) { -#ifdef VERBOSE - cout << "Need ciphertext modulus of length " << lgp0; - if (params.n_mults() > 0) - cout << "+" << lgp1; - cout << " and " << phi_N(m) << " slots" << endl; -#endif + if (OnlineOptions::singleton.has_option("verbose_he_setup")) + { + cout << "Need ciphertext modulus of length " << lgp0; + if (params.n_mults() > 0) + cout << "+" << lgp1; + cout << " and " << phi_N(m) << " slots" << endl; + } int extra_slack = 0; if (round_up) @@ -160,13 +161,16 @@ int finalize_lengths(int& lg2p0, int& lg2p1, int n, int m, int* lg2pi, { (void) lg2pi, (void) n; -#ifdef VERBOSE - if (n >= 2 and n <= 10) - cout << "Difference to suggestion for p0: " << lg2p0 - lg2pi[n - 2] - << ", for p1: " << lg2p1 - lg2pi[9 + n - 2] << endl; - cout << "p0 needs " << int(ceil(1. * lg2p0 / 64)) << " words" << endl; - cout << "p1 needs " << int(ceil(1. * lg2p1 / 64)) << " words" << endl; -#endif + bool verbose = OnlineOptions::singleton.has_option("verbose_he_setup"); + + if (verbose) + { + if (n >= 2 and n <= 10) + cerr << "Difference to suggestion for p0: " << lg2p0 - lg2pi[n - 2] + << ", for p1: " << lg2p1 - lg2pi[9 + n - 2] << endl; + cerr << "p0 needs " << int(ceil(1. * lg2p0 / 64)) << " words" << endl; + cerr << "p1 needs " << int(ceil(1. * lg2p1 / 64)) << " words" << endl; + } int extra_slack = 0; if (round_up) @@ -185,15 +189,16 @@ int finalize_lengths(int& lg2p0, int& lg2p1, int n, int m, int* lg2pi, extra_slack = 2 * i; lg2p0 += i; lg2p1 += i; -#ifdef VERBOSE - cout << "Rounding up to " << lg2p0 << "+" << lg2p1 - << ", giving extra slack of " << extra_slack << " bits" << endl; -#endif + + if (verbose) + cerr << "Rounding up to " << lg2p0 << "+" << lg2p1 + << ", giving extra slack of " << extra_slack << " bits" + << endl; } -#ifdef VERBOSE - cout << "Total length: " << lg2p0 + lg2p1 << endl; -#endif + if (verbose) + cerr << "Total length: " << lg2p0 + lg2p1 << " = " << lg2p0 << " + " + << lg2p1 << endl; return extra_slack; } @@ -305,7 +310,7 @@ void generate_modulus(bigint& pr, const int m, const bigint p, const int lg2pr, template <> void Parameters::SPDZ_Data_Setup(FHE_Params& params, FFT_Data& FTD) { - CODE_LOCATION + CODE_LOCATION_NO_SCOPE bigint p; int idx, m; @@ -678,7 +683,7 @@ void char_2_dimension(int& m, int& lg2) template <> void Parameters::SPDZ_Data_Setup(FHE_Params& params, P2Data& P2D) { - CODE_LOCATION + CODE_LOCATION_NO_SCOPE int n = n_parties; int lg2 = plaintext_length; diff --git a/FHE/NoiseBounds.cpp b/FHE/NoiseBounds.cpp index f4502317..66f68bfe 100644 --- a/FHE/NoiseBounds.cpp +++ b/FHE/NoiseBounds.cpp @@ -22,9 +22,9 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p, sigma *= 1.4; params.set_R(params.get_R() * 1.4); } -#ifdef VERBOSE - cerr << "Standard deviation: " << this->sigma << endl; -#endif + + if (OnlineOptions::singleton.has_option("verbose_he_setup")) + cerr << "Standard deviation: " << this->sigma << endl; produce_epsilon_constants(); @@ -40,24 +40,27 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p, B_clean = max(B_clean_not_top_gear, B_clean_top_gear); B_scale = (c1 + c2 * V_s) * p * sqrt(phi_m / 12.0); int matrix_dim = params.get_matrix_dim(); -#ifdef NOISY - cout << "phi(m): " << phi_m << endl; - cout << "p * sqrt(phi(m) / 12): " << p * sqrt(phi_m / 12.0) << endl; - cout << "V_s: " << V_s << endl; - cout << "c1: " << c1 << endl; - cout << "c2: " << c2 << endl; - cout << "c1 + c2 * V_s: " << c1 + c2 * V_s << endl; - cout << "log(slack): " << slack << endl; - cout << "B_clean: " << B_clean << endl; - cout << "B_scale: " << B_scale << endl; - cout << "matrix dimension: " << matrix_dim << endl; - cout << "drown sec: " << params.secp() << endl; - cout << "sec: " << sec << endl; -#endif assert(matrix_dim > 0); assert(params.secp() >= 0); drown = 1 + (p > 2 ? matrix_dim : 1) * n * (bigint(1) << params.secp()); + + if (OnlineOptions::singleton.has_option("verbose_he_setup")) + { + cerr << "phi(m): " << phi_m << endl; + cerr << "p * sqrt(phi(m) / 12): " << p * sqrt(phi_m / 12.0) << endl; + cerr << "V_s: " << V_s << endl; + cerr << "c1: " << c1 << endl; + cerr << "c2: " << c2 << endl; + cerr << "c1 + c2 * V_s: " << c1 + c2 * V_s << endl; + cerr << "log(slack): " << slack << endl; + cerr << "B_clean bits: " << B_clean.numBits() << endl; + cerr << "B_scale bits: " << B_scale.numBits() << endl; + cerr << "matrix dimension: " << matrix_dim << endl; + cerr << "drown sec: " << params.secp() << endl; + cerr << "sec: " << sec << endl; + cerr << "drown bits: " << drown.numBits() << endl; + } } bigint SemiHomomorphicNoiseBounds::min_p0(const bigint& p1) @@ -118,7 +121,7 @@ void SemiHomomorphicNoiseBounds::produce_epsilon_constants() NoiseBounds::NoiseBounds(const bigint& p, int phi_m, int n, int sec, int slack, const FHE_Params& params) : - SemiHomomorphicNoiseBounds(p, phi_m, n, sec, slack, false, params) + SemiHomomorphicNoiseBounds(p, phi_m, n, sec, slack, sec, params) { B_KS = p * c2 * this->sigma * phi_m / sqrt(12); #ifdef NOISY diff --git a/FHE/P2Data.cpp b/FHE/P2Data.cpp index 345f05ef..c730d397 100644 --- a/FHE/P2Data.cpp +++ b/FHE/P2Data.cpp @@ -9,7 +9,7 @@ void P2Data::forward(vector& ans,const vector& a) const { - CODE_LOCATION + CODE_LOCATION_NO_SCOPE int n=gf2n_short::degree(); @@ -32,7 +32,7 @@ void P2Data::forward(vector& ans,const vector& a) const void P2Data::backward(vector& ans,const vector& a) const { - CODE_LOCATION + CODE_LOCATION_NO_SCOPE int n=gf2n_short::degree(); BitVector bv(a.size()); diff --git a/FHE/P2Data.h b/FHE/P2Data.h index adb6f70c..2bd3b36b 100644 --- a/FHE/P2Data.h +++ b/FHE/P2Data.h @@ -32,6 +32,7 @@ class P2Data void backward(vector& ans,const vector& a) const; int get_prime() const { return 2; } + void write_setup(const string&) const {} bool operator!=(const P2Data& other) const; @@ -47,6 +48,7 @@ class P2Data void load_or_generate(const Ring& Rg); + friend void init(P2Data& P2D,const Ring& Rg); }; diff --git a/FHE/Ring_Element.cpp b/FHE/Ring_Element.cpp index a2b6d770..4ec89774 100644 --- a/FHE/Ring_Element.cpp +++ b/FHE/Ring_Element.cpp @@ -154,7 +154,7 @@ void sub(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b) void mul(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b) { - CODE_LOCATION + CODE_LOCATION_NO_SCOPE assert(a.FFTD); if (a.rep!=b.rep) { throw rep_mismatch(); } @@ -299,7 +299,7 @@ Ring_Element& Ring_Element::operator *=(const modp& other) Ring_Element Ring_Element::mul_by_X_i(int j) const { - CODE_LOCATION + CODE_LOCATION_NO_SCOPE assert(FFTD); Ring_Element ans; @@ -517,7 +517,9 @@ modp Ring_Element::get_constant() const void store(octetStream& o,const vector& v,const Zp_Data& ZpD) { ZpD.pack(o); - o.store(v); + o.store((int)v.size()); + for (auto& x : v) + x.pack(o, ZpD); } @@ -529,7 +531,16 @@ void get(octetStream& o,vector& v,const Zp_Data& ZpD) throw runtime_error( "mismatch: " + to_string(check_Zpd.pr_bit_length) + "/" + to_string(ZpD.pr_bit_length)); - o.get(v); + unsigned int length; + o.get(length); + v.clear(); + v.reserve(length); + modp tmp; + for (unsigned int i=0; i Rq_Element::to_vec_bigint() const // result mod p0 = a[0]; result mod p1 = a[1] void Rq_Element::to_vec_bigint(vector& v) const { - CODE_LOCATION + CODE_LOCATION_NO_SCOPE a[0].to_vec_bigint(v); if (n_mults() == 0) { @@ -208,7 +208,7 @@ void Rq_Element::Scale(const bigint& p) { if (lev==0) { return; } - CODE_LOCATION + CODE_LOCATION_NO_SCOPE if (n_mults() == 0) { //for some reason we scale but we have just one level @@ -312,7 +312,17 @@ void Rq_Element::pack(octetStream& o, int) const void Rq_Element::unpack(octetStream& o, int) { unsigned int ll; o.get(ll); lev=ll; - check_level(); + + try + { + check_level(); + } + catch (...) + { + lev = 0; + throw; + } + for (int i = 0; i <= lev; ++i) a[i].unpack(o); } diff --git a/FHEOffline/DataSetup.cpp b/FHEOffline/DataSetup.cpp index 0dc8d967..836c075c 100644 --- a/FHEOffline/DataSetup.cpp +++ b/FHEOffline/DataSetup.cpp @@ -132,7 +132,7 @@ void PartSetup::output(Names& N) { // Write outputs to file string dir = get_prep_sub_dir>(N.num_players()); - write_online_setup(dir, FieldD.get_prime()); + FieldD.write_setup(dir); write_mac_key(dir, N.my_num(), N.num_players(), alphai); } diff --git a/FHEOffline/DistDecrypt.cpp b/FHEOffline/DistDecrypt.cpp index 57cae9be..ebc90d47 100644 --- a/FHEOffline/DistDecrypt.cpp +++ b/FHEOffline/DistDecrypt.cpp @@ -11,7 +11,7 @@ DistDecrypt::DistDecrypt(const Player& P, const FHE_SK& share, vv.resize(pk.get_params().phi_m()); vv1.resize(pk.get_params().phi_m()); // extra limb for operations - bigint limit = pk.get_params().Q() << 64; + bigint limit = pk.get_params().p0() << 64; vv.allocate_slots(limit); vv1.allocate_slots(limit); mf.allocate_slots(pk.p() << 64); @@ -19,6 +19,8 @@ DistDecrypt::DistDecrypt(const Player& P, const FHE_SK& share, class ModuloTreeSum : public TreeSum { + typedef TreeSum super; + bigint modulo; void post_add_process(vector& values) @@ -32,6 +34,12 @@ public: modulo(modulo) { } + + void run(vector& values, const Player& P) + { + lengths.resize(values.size(), numBytes(modulo)); + super::run(values, P); + } }; template @@ -48,13 +56,15 @@ Plaintext_& DistDecrypt::run(const Ciphertext& ctx, bool NewCiphertext) if ((int)vv.size() != params.phi_m()) throw length_error("wrong length of ring element"); + size_t length = numBytes(pk.get_params().p0()); + if (OnlineOptions::singleton.direct) { // Now pack into an octetStream for broadcasting vector os(P.num_players()); for (int i=0; i& DistDecrypt::run(const Ciphertext& ctx, bool NewCiphertext) { for (int j = 0; j < params.phi_m(); j++) { - os[i].get(vv1[j]); + os[i].get(vv1[j], length); } share.dist_decrypt_2(vv, vv1); } diff --git a/FHEOffline/PairwiseMachine.cpp b/FHEOffline/PairwiseMachine.cpp index dd3f8968..be8d86b5 100644 --- a/FHEOffline/PairwiseMachine.cpp +++ b/FHEOffline/PairwiseMachine.cpp @@ -38,7 +38,7 @@ void RealPairwiseMachine::init() gfp::init_field(p); ofstream outf; if (output) - write_online_setup(get_prep_dir(P), p); + gfp::write_setup(get_prep_dir(P)); } for (int i = 0; i < nthreads; i++) @@ -141,5 +141,10 @@ void PairwiseMachine::check(Player& P) const bundle.compare(P); } +int PairwiseMachine::comp_sec() +{ + return NonInteractiveProof::comp_sec(sec); +} + template void RealPairwiseMachine::setup_keys(); template void RealPairwiseMachine::setup_keys(); diff --git a/FHEOffline/PairwiseMachine.h b/FHEOffline/PairwiseMachine.h index a8a0c649..dae5e051 100644 --- a/FHEOffline/PairwiseMachine.h +++ b/FHEOffline/PairwiseMachine.h @@ -31,6 +31,8 @@ public: void unpack(octetStream& os); void check(Player& P) const; + + int comp_sec(); }; class RealPairwiseMachine : public virtual MachineBase, public virtual PairwiseMachine diff --git a/FHEOffline/PairwiseSetup.cpp b/FHEOffline/PairwiseSetup.cpp index fdacf267..655a9dc2 100644 --- a/FHEOffline/PairwiseSetup.cpp +++ b/FHEOffline/PairwiseSetup.cpp @@ -71,6 +71,7 @@ void secure_init(T& setup, Player& P, U& machine, string filename = PREP_DIR + T::name() + "-" + to_string(plaintext_length) + "-" + to_string(sec) + "-" + to_string(params.secp()) + "-" + + to_string(machine.comp_sec()) + "-" + to_string(params.get_matrix_dim()) + "-" + OnlineOptions::singleton.prime.get_str() + "-" + to_string(CowGearOptions::singleton.top_gear()) + "-P" @@ -121,7 +122,7 @@ void secure_init(T& setup, Player& P, U& machine, os.output(file); } - if (OnlineOptions::singleton.verbose) + if (OnlineOptions::singleton.has_option("verbose_he")) { cerr << "Ciphertext length: " << params.p0().numBits(); for (size_t i = 1; i < params.FFTD().size(); i++) @@ -131,6 +132,7 @@ void secure_init(T& setup, Player& P, U& machine, cerr << "+" << DIV_CEIL(params.FFTD()[i].get_prime().numBits(), 64); cerr << " limbs)"; cerr << endl; + cerr << "Number of slots: " << params.phi_m() << endl; } } diff --git a/FHEOffline/Proof.cpp b/FHEOffline/Proof.cpp index c6592720..7b1ac5d3 100644 --- a/FHEOffline/Proof.cpp +++ b/FHEOffline/Proof.cpp @@ -184,3 +184,11 @@ void Proof::Preimages::check_sizes() if (m.size() != r.size()) throw runtime_error("preimage sizes don't match"); } + +int NonInteractiveProof::comp_sec(int sec) +{ + if (sec > 0) + return OnlineOptions::singleton.comp_sec(); + else + return 0; +} diff --git a/FHEOffline/Proof.h b/FHEOffline/Proof.h index 2eec0435..f854bf65 100644 --- a/FHEOffline/Proof.h +++ b/FHEOffline/Proof.h @@ -9,6 +9,7 @@ using namespace std; #include "FHE/Ciphertext.h" #include "FHE/AddableVector.h" #include "Protocols/CowGearOptions.h" +#include "Processor/OnlineOptions.h" #include "config.h" @@ -90,9 +91,6 @@ class Proof { V = ceil((sec + 2) / log2(2 * phim + 1)); U = 2 * V; -#ifdef VERBOSE - cerr << "Using " << U << " ciphertexts per proof" << endl; -#endif } else { @@ -151,14 +149,24 @@ class Proof output += input.at(j); } } + + void debugging() + { + if (OnlineOptions::singleton.has_option("verbose_he")) + { + cerr << "Using " << U << " ciphertexts per proof" << endl; + cerr << "Plaintext bound check bit length: " << B_plain_length << endl; + cerr << "Randomness bound check bit length: " << B_rand_length << endl; + } + } }; class NonInteractiveProof : public Proof { - // sec = 0 used for protocols without proofs - static int comp_sec(int sec) { return sec > 0 ? max(COMP_SEC, sec) : 0; } - public: + // sec = 0 used for protocols without proofs + static int comp_sec(int sec); + bigint static slack(int sec, int phim) { sec = comp_sec(sec); return bigint(phim * sec * sec) << (sec / 2 + 8); } @@ -174,6 +182,7 @@ public: B_rand_length = numBits(B*3*phim*rho); plain_check = (bigint(1) << B_plain_length) - sec * tau; rand_check = (bigint(1) << B_rand_length) - sec * rho; + debugging(); } }; @@ -194,6 +203,7 @@ public: // leeway for completeness plain_check = (bigint(2) << B_plain_length); rand_check = (bigint(2) << B_rand_length); + debugging(); } }; diff --git a/FHEOffline/SimpleEncCommit.cpp b/FHEOffline/SimpleEncCommit.cpp index c161f1d7..0423dd96 100644 --- a/FHEOffline/SimpleEncCommit.cpp +++ b/FHEOffline/SimpleEncCommit.cpp @@ -161,6 +161,7 @@ size_t NonInteractiveProofSimpleEncCommit::generate_proof(AddableVector void SimpleEncCommit::create_more() { + CODE_LOCATION cout << "Generating more ciphertexts in round " << this->n_rounds << endl; octetStream ciphertexts, cleartexts; size_t prover_memory = this->generate_proof(this->c, this->m, ciphertexts, cleartexts); @@ -181,6 +182,7 @@ template size_t NonInteractiveProofSimpleEncCommit::create_more(octetStream& ciphertexts, octetStream& cleartexts) { + CODE_LOCATION AddableVector others_ciphertexts; others_ciphertexts.resize(proof.U, pk.get_params()); for (int i = 1; i < P.num_players(); i++) @@ -244,6 +246,7 @@ SummingEncCommit::SummingEncCommit(const Player& P, const FHE_PK& pk, template void SummingEncCommit::create_more() { + CODE_LOCATION octetStream cleartexts; const Player& P = this->P; AddableVector commitments; @@ -267,10 +270,11 @@ void SummingEncCommit::create_more() this->c.unpack(ciphertexts, this->pk); commitments.unpack(ciphertexts, this->pk); -#ifdef VERBOSE_HE - cout << "Tree-wise sum of ciphertexts with " - << 1e-9 * ciphertexts.get_length() << " GB" << endl; -#endif + if (OnlineOptions::singleton.has_option("verbose_he")) + cerr << "Tree-wise sum of " << this->c.size() + << " ciphertexts with " << 1e-9 * ciphertexts.get_length() + << " GB" << endl; + this->timers["Exchanging ciphertexts"].start(); tree_sum.run(this->c, P); tree_sum.run(commitments, P); diff --git a/FHEOffline/SimpleMachine.h b/FHEOffline/SimpleMachine.h index 715b027c..218fec10 100644 --- a/FHEOffline/SimpleMachine.h +++ b/FHEOffline/SimpleMachine.h @@ -56,6 +56,9 @@ public: void unpack(octetStream&) {} void check(Player&) const {} + + // computational security doesn't matter in global proofs + int comp_sec() { return 0; } }; class MultiplicativeMachineParams : public MachineBase diff --git a/GC/CcdShare.h b/GC/CcdShare.h index 913b49ab..18894c4f 100644 --- a/GC/CcdShare.h +++ b/GC/CcdShare.h @@ -78,6 +78,9 @@ public: } }; +template +const int CcdShare::default_length; + } #endif /* GC_CCDSHARE_H_ */ diff --git a/GC/MaliciousCcdShare.h b/GC/MaliciousCcdShare.h index 6d3729a9..5544591e 100644 --- a/GC/MaliciousCcdShare.h +++ b/GC/MaliciousCcdShare.h @@ -88,6 +88,9 @@ public: } }; +template +const int MaliciousCcdShare::default_length; + } /* namespace GC */ #endif /* GC_MALICIOUSCCDSHARE_H_ */ diff --git a/GC/Processor.hpp b/GC/Processor.hpp index 3ae8d051..bab2f3ab 100644 --- a/GC/Processor.hpp +++ b/GC/Processor.hpp @@ -286,7 +286,12 @@ void Processor::notcb(const ::BaseInstruction& instruction) template void Processor::movsb(const ::BaseInstruction& instruction) { - for (int i = 0; i < DIV_CEIL(instruction.get_n(), T::default_length); i++) + int n_blocks; + if (instruction.get_n() < unsigned(T::default_length)) + n_blocks = 1; + else + n_blocks = DIV_CEIL(instruction.get_n(), T::default_length); + for (int i = 0; i < n_blocks; i++) S[instruction.get_r(0) + i] = S[instruction.get_r(1) + i]; } @@ -407,12 +412,14 @@ void Processor::convcbitvec(const BaseInstruction& instruction, { auto proto = ShareThread::s().protocol; auto P = ShareThread::s().P; - if (proto) + // The default use case in the compiler doesn't require synchronization + // with function-dependent protocols, but testing does. + if (proto and OnlineOptions::singleton.has_option("convcbitvec_sync")) proto->sync(bits, *P); else - throw exception(); + throw no_singleton(); } - catch (exception&) + catch (no_singleton&) { if (P) ProtocolBase::sync(bits, *P); diff --git a/GC/SemiSecret.h b/GC/SemiSecret.h index 1ae7a0f8..9642c353 100644 --- a/GC/SemiSecret.h +++ b/GC/SemiSecret.h @@ -147,6 +147,9 @@ public: } }; +template +const int SemiSecretBase::default_length; + } /* namespace GC */ #endif /* GC_SEMISECRET_H_ */ diff --git a/GC/SemiSecret.hpp b/GC/SemiSecret.hpp index 5334ac16..5f3a8b1b 100644 --- a/GC/SemiSecret.hpp +++ b/GC/SemiSecret.hpp @@ -16,9 +16,6 @@ namespace GC { -template -const int SemiSecretBase::default_length; - inline SemiSecret::MC* SemiSecret::new_mc( typename super::mac_key_type) diff --git a/GC/ShareThread.h b/GC/ShareThread.h index c24f62ba..9bc093ca 100644 --- a/GC/ShareThread.h +++ b/GC/ShareThread.h @@ -67,7 +67,7 @@ inline ShareThread& ShareThread::s() if (singleton and T::is_real) return *singleton; else - throw runtime_error("no ShareThread singleton"); + throw no_singleton("no ShareThread singleton"); } } /* namespace GC */ diff --git a/GC/ShareThread.hpp b/GC/ShareThread.hpp index c57a1206..53fd2026 100644 --- a/GC/ShareThread.hpp +++ b/GC/ShareThread.hpp @@ -120,6 +120,11 @@ public: return {S, left, right, n_full_blocks()}; } + Range> full_block_left_range(StackedVector& S) + { + return {S, left, n_full_blocks()}; + } + DoubleIterator partial_block(StackedVector& S) { assert(n_blocks() != n_full_blocks()); @@ -127,6 +132,17 @@ public: S.iterator_for_size(right + n_full_blocks(), 1)}; } + typename CheckVector::iterator partial_left_block(StackedVector& S) + { + assert(n_blocks() != n_full_blocks()); + return S.iterator_for_size(left + n_full_blocks(), 1); + } + + T& get_right_base(StackedVector& S) + { + return S[right]; + } + Range> full_block_output_range(StackedVector& S) { return {S, dest, n_full_blocks()}; @@ -174,16 +190,17 @@ void ShareThread::and_(Processor& processor, for (auto info : infos) { int n = T::default_length; - for (auto x : info.full_block_input_range(S)) + auto& y = info.get_right_base(S); + for (auto x : info.full_block_left_range(S)) { - x.second.extend_bit(y_ext, n); - protocol->prepare_mult(x.first, y_ext, n, true); + y.extend_bit(y_ext, n); + protocol->prepare_mult(x, y_ext, n, true); } n = info.last_length(); if (n) { - info.partial_block(S).left->mask(x_ext, n); - info.partial_block(S).right->extend_bit(y_ext, n); + info.partial_left_block(S)->mask(x_ext, n); + y.extend_bit(y_ext, n); protocol->prepare_mult(x_ext, y_ext, n, true); } } @@ -193,7 +210,7 @@ void ShareThread::and_(Processor& processor, if (fast_mode) for (auto x : info.full_block_input_range(S)) protocol->prepare_mul_fast(x.first, x.second); - else + else if (info.n_full_blocks()) for (auto x : info.full_block_input_range(S)) protocol->prepare_mul(x.first, x.second); int n = info.last_length(); @@ -228,7 +245,7 @@ void ShareThread::and_(Processor& processor, if (fast_mode) for (auto& res : info.full_block_output_range(S)) res = protocol->finalize_mul_fast(); - else + else if (info.n_full_blocks()) for (auto& res : info.full_block_output_range(S)) res = protocol->finalize_mul(); diff --git a/GC/Thread.h b/GC/Thread.h index b510120e..c52d4b71 100644 --- a/GC/Thread.h +++ b/GC/Thread.h @@ -56,6 +56,8 @@ public: void join_tape(); void finish(); + + virtual NamedCommStats extra_comm() { return {}; } }; template diff --git a/GC/ThreadMaster.hpp b/GC/ThreadMaster.hpp index d657dc90..a3627199 100644 --- a/GC/ThreadMaster.hpp +++ b/GC/ThreadMaster.hpp @@ -115,6 +115,7 @@ void ThreadMaster::run_with_error() for (auto thread : threads) { stats += thread->P->total_comm(); + stats += thread->extra_comm(); exe_stats += thread->processor.stats; delete thread; } diff --git a/GC/TinierShare.h b/GC/TinierShare.h index 18e24bd6..56278177 100644 --- a/GC/TinierShare.h +++ b/GC/TinierShare.h @@ -145,6 +145,9 @@ public: } }; +template +const int TinierShare::default_length; + } /* namespace GC */ #endif /* GC_TINIERSHARE_H_ */ diff --git a/GC/TinyPrep.hpp b/GC/TinyPrep.hpp index bdc79350..cb5798bd 100644 --- a/GC/TinyPrep.hpp +++ b/GC/TinyPrep.hpp @@ -33,7 +33,7 @@ void TinierSharePrep::buffer_secret_triples() assert(triple_generator != 0); params.generateBits = false; vector> triples; - TripleShuffleSacrifice sacrifice(DATA_GF2); + TripleShuffleSacrifice sacrifice; size_t required; required = sacrifice.minimum_n_inputs_with_combining( BaseMachine::batch_size(DATA_TRIPLE)); diff --git a/Machines/mama-party.cpp b/Machines/mama-party.cpp index 1b206891..6d77d32d 100644 --- a/Machines/mama-party.cpp +++ b/Machines/mama-party.cpp @@ -38,7 +38,10 @@ int main(int argc, const char** argv) if (online_opts.prime_limbs() == 2) return run<2, 1>(machine); - cerr << "Not compiled for choice of parameters" << endl; - cerr << "Try using '-lgp 128'" << endl; + if (online_opts.prime_limbs() > 2) + cerr << "Use MASCOT with large primes" << endl; + else + cerr << "Not compiled for choice of parameters" << endl; + exit(1); } diff --git a/Machines/spdz2k-party.cpp b/Machines/spdz2k-party.cpp index b02ff73a..73d88ecd 100644 --- a/Machines/spdz2k-party.cpp +++ b/Machines/spdz2k-party.cpp @@ -55,7 +55,7 @@ int main(int argc, const char** argv) { if (s == SPDZ2K_DEFAULT_SECURITY) { - ring_domain_error(k); + ring_domain_error(k, 72); } else { diff --git a/Makefile b/Makefile index e7e1e70f..35e5a061 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ TOOLS = $(patsubst %.cpp,%.o,$(wildcard Tools/*.cpp)) NETWORK = $(patsubst %.cpp,%.o,$(wildcard Networking/*.cpp)) -PROCESSOR = $(patsubst %.cpp,%.o,$(wildcard Processor/*.cpp)) Protocols/ShamirOptions.o +PROCESSOR = $(patsubst %.cpp,%.o,$(wildcard Processor/*.cpp)) Protocols/ShamirOptions.o Protocols/ShareInterface.o FHEOBJS = $(patsubst %.cpp,%.o,$(wildcard FHEOffline/*.cpp FHE/*.cpp)) Protocols/CowGearOptions.o diff --git a/Math/Integer.h b/Math/Integer.h index 3d9e3c07..cac73672 100644 --- a/Math/Integer.h +++ b/Math/Integer.h @@ -103,7 +103,7 @@ public: void randomize(PRNG& G); void almost_randomize(PRNG& G) { randomize(G); } - void output(ostream& s,bool human) const; + void output(ostream& s, bool human, bool signed_ = true) const; void input(istream& s,bool human); void pack(octetStream& os) const { os.store_int(a, sizeof(a)); } diff --git a/Math/Integer.hpp b/Math/Integer.hpp index 12d42ae4..e13c8565 100644 --- a/Math/Integer.hpp +++ b/Math/Integer.hpp @@ -15,7 +15,7 @@ inline void IntBase::specification(octetStream& os) } template -void IntBase::output(ostream& s,bool human) const +void IntBase::output(ostream& s, bool human, bool) const { if (human) s << a; diff --git a/Math/Setup.cpp b/Math/Setup.cpp index 1a69156e..5f0f0df3 100644 --- a/Math/Setup.cpp +++ b/Math/Setup.cpp @@ -85,15 +85,17 @@ void generate_prime(bigint& p, int lgp, int m, bool force_degree) p = OnlineOptions::singleton.prime; if (!probPrime(p)) { - cerr << p << " is not a prime" << endl; - exit(1); + throw runtime_error(to_string(p) + " is not a prime"); } else if (m != 1 and p % m != 1) { - cerr << p - << " is not compatible with our encryption scheme, must be 1 modulo " - << m << endl; - exit(1); + throw runtime_error( + to_string(p) + + " is not compatible with our encryption scheme, must be " + "1 modulo " + to_string(m) + ". This is because " + "the implementation relies on number theoretic transform. " + "See https://eprint.iacr.org/2024/585.pdf for details, " + "in particular Theorem 13."); } else return; @@ -125,8 +127,10 @@ void generate_prime(bigint& p, int lgp, int m, bool force_degree) } -void write_online_setup(string dirname, const bigint& p) +void Zp_Data::write_setup(const string& dirname) const { + auto& p = pr; + if (p == 0) throw runtime_error("prime cannot be 0"); @@ -145,19 +149,23 @@ void write_online_setup(string dirname, const bigint& p) ofstream outf; outf.open(ss.str().c_str()); outf << p << endl; + outf << montgomery << endl; if (!outf.good()) throw file_error("cannot write to " + ss.str()); } -void check_setup(string dir, bigint pr) +void Zp_Data::check_setup(const string& dir) { bigint p; + bool mont = true; string filename = dir + "Params-Data"; - ifstream(filename) >> p; + ifstream(filename) >> p >> mont; if (p == 0) throw setup_error("no modulus in " + filename); if (p != pr) throw setup_error("wrong modulus in " + filename); + if (mont != montgomery) + throw setup_error("Montgomery different in " + filename); } string get_prep_sub_dir(const string& prep_dir, int nparties, int log2mod, diff --git a/Math/Setup.h b/Math/Setup.h index 4004093e..dc1af45a 100644 --- a/Math/Setup.h +++ b/Math/Setup.h @@ -26,8 +26,6 @@ template void generate_prime_setup(string dir, int lgp); template void generate_online_setup(string dirname, bigint& p, int lgp); -void write_online_setup(string dirname, const bigint& p); -void check_setup(string dirname, bigint p); // Setup primes only // Chooses a p of at least lgp bits diff --git a/Math/Setup.hpp b/Math/Setup.hpp index 91cafaea..8d0b9354 100644 --- a/Math/Setup.hpp +++ b/Math/Setup.hpp @@ -13,14 +13,15 @@ void generate_online_setup(string dirname, bigint& p, int lgp) { int idx, m; SPDZ_Data_Setup_Primes(p, lgp, idx, m); - write_online_setup(dirname, p); T::init_field(p); + T::write_setup(dirname); } template void read_setup(const string& dir_prefix, int lgp = -1) { bigint p; + bool montgomery = true; string filename = dir_prefix + "Params-Data"; @@ -32,6 +33,8 @@ void read_setup(const string& dir_prefix, int lgp = -1) #endif ifstream inpf(filename.c_str()); inpf >> p; + inpf >> montgomery; + if (inpf.fail()) { if (lgp > 0) @@ -45,9 +48,12 @@ void read_setup(const string& dir_prefix, int lgp = -1) throw file_error(filename.c_str()); } else - T::init_field(p); + T::init_field(p, montgomery); inpf.close(); + + if (OnlineOptions::singleton.verbose) + cerr << "Using prime modulus " << T::pr() << endl; } #endif /* MATH_SETUP_HPP_ */ diff --git a/Math/Zp_Data.cpp b/Math/Zp_Data.cpp index 73807c09..28f88b33 100644 --- a/Math/Zp_Data.cpp +++ b/Math/Zp_Data.cpp @@ -230,3 +230,8 @@ void Zp_Data::get_shanks_parameters(bigint& y, bigint& q_half, int& r) const q_half = shanks_q_half; r = shanks_r; } + +string Zp_Data::fake_opts() const +{ + return "-P " + to_string(pr) + (montgomery ? "" : " -n"); +} diff --git a/Math/Zp_Data.h b/Math/Zp_Data.h index ac2f52e4..bb4fce7d 100644 --- a/Math/Zp_Data.h +++ b/Math/Zp_Data.h @@ -21,7 +21,7 @@ using namespace std; #ifndef MAX_MOD_SZ #if defined(GFP_MOD_SZ) and GFP_MOD_SZ > 11 - #define MAX_MOD_SZ GFP_MOD_SZ + #define MAX_MOD_SZ 2 * GFP_MOD_SZ #else #define MAX_MOD_SZ 11 #endif @@ -94,6 +94,11 @@ class Zp_Data void get_shanks_parameters(bigint& y, bigint& q_half, int& r) const; + void write_setup(const string& directory) const; + void check_setup(const string& directory); + + string fake_opts() const; + template friend void to_modp(modp_& ans,int x,const Zp_Data& ZpD); template friend void to_modp(modp_& ans,const mpz_class& x,const Zp_Data& ZpD); diff --git a/Math/bigint.cpp b/Math/bigint.cpp index f42a8e67..f6232375 100644 --- a/Math/bigint.cpp +++ b/Math/bigint.cpp @@ -45,7 +45,6 @@ int powerMod(int x,int e,int p) return ans; } - size_t bigint::report_size(ReportType type) const { size_t res = 0; @@ -98,6 +97,16 @@ bigint::bigint(const mp_limb_t* data, size_t n_limbs) mpz_import(get_mpz_t(), n_limbs, -1, 8, -1, 0, data); } +void bigint::pack(octetStream& os, int length) const +{ + os.store(*this, length); +} + +void bigint::unpack(octetStream& os, int length) +{ + os.get(*this, length); +} + string to_string(const bigint& x) { stringstream ss; diff --git a/Math/bigint.h b/Math/bigint.h index c2a7f8e3..b6fb8056 100644 --- a/Math/bigint.h +++ b/Math/bigint.h @@ -134,8 +134,8 @@ public: void generateUniform(PRNG& G, int n_bits, bool positive = false) { G.get(*this, n_bits, positive); } - void pack(octetStream& os, int = -1) const { os.store(*this); } - void unpack(octetStream& os, int = -1) { os.get(*this); }; + void pack(octetStream& os, int = -1) const; + void unpack(octetStream& os, int = -1); size_t report_size(ReportType type) const; }; diff --git a/Math/bigint.hpp b/Math/bigint.hpp index a35124a6..21c11a6e 100644 --- a/Math/bigint.hpp +++ b/Math/bigint.hpp @@ -31,6 +31,9 @@ mpf_class bigint::get_float(T v, T p, T z, T s) Integer exp = Integer(p, 31).get(); bigint tmp; tmp.from_signed(v); + if (abs(tmp) == 1) + BaseMachine::s().mini_warning = min(BaseMachine::s().mini_warning, + int(exp.get())); mpf_class res = tmp; if (exp > 0) mpf_mul_2exp(res.get_mpf_t(), res.get_mpf_t(), exp.get()); diff --git a/Math/fixint.h b/Math/fixint.h index 33a8d80b..2d11342f 100644 --- a/Math/fixint.h +++ b/Math/fixint.h @@ -15,6 +15,7 @@ class fixint : public SignedZ2<64 * (L + 1)> public: typedef SignedZ2<64 * (L + 1)> super; + typedef typename conditional>::type pack_type; fixint() { @@ -56,6 +57,19 @@ public: *this = bigint::tmp; } + void pack(octetStream& os) const + { + pack_type tmp = *this; + tmp.pack(os); + } + + void unpack(octetStream& os) + { + pack_type tmp; + tmp.unpack(os); + *this = tmp; + } + int get_min_alloc() const { return this->N_BYTES; diff --git a/Math/gfp.h b/Math/gfp.h index 7c946f39..fef46ccf 100644 --- a/Math/gfp.h +++ b/Math/gfp.h @@ -36,8 +36,8 @@ template void generate_prime_setup(string, int, int); #define GFP_MOD_SZ 2 #endif -#if GFP_MOD_SZ > MAX_MOD_SZ -#error GFP_MOD_SZ must be at most MAX_MOD_SZ +#if 2 * GFP_MOD_SZ > MAX_MOD_SZ +#error 2 * GFP_MOD_SZ must be at most MAX_MOD_SZ #endif /** @@ -105,9 +105,9 @@ class gfp_ : public ValueInterface static void write_setup(int nplayers) { write_setup(get_prep_sub_dir(nplayers)); } static void write_setup(string dir) - { write_online_setup(dir, pr()); } + { ZpD.write_setup(dir); } static void check_setup(string dir); - static string fake_opts() { return " -P " + to_string(pr()); } + static string fake_opts() { return ZpD.fake_opts(); } /** * Get the prime modulus diff --git a/Math/gfp.hpp b/Math/gfp.hpp index 8b7d002d..bed3bb8a 100644 --- a/Math/gfp.hpp +++ b/Math/gfp.hpp @@ -28,7 +28,7 @@ inline void gfp_::read_or_generate_setup(string dir, template void gfp_::check_setup(string dir) { - ::check_setup(dir, pr()); + ZpD.check_setup(dir); } template @@ -201,7 +201,7 @@ bool gfp_::allows(Dtype type) template void gfp_::specification(octetStream& os) { - os.store(pr()); + ZpD.pack(os); } template diff --git a/Math/gfpvar.cpp b/Math/gfpvar.cpp index 39216ef4..1cc60662 100644 --- a/Math/gfpvar.cpp +++ b/Math/gfpvar.cpp @@ -33,7 +33,7 @@ char gfpvar_::type_char() template void gfpvar_::specification(octetStream& os) { - os.store(pr()); + ZpD.pack(os); } template @@ -101,13 +101,13 @@ const bigint& gfpvar_::pr() template void gfpvar_::check_setup(string dir) { - ::check_setup(dir, pr()); + ZpD.check_setup(dir); } template void gfpvar_::write_setup(string dir) { - write_online_setup(dir, pr()); + ZpD.write_setup(dir); } template diff --git a/Math/modp.hpp b/Math/modp.hpp index 57d0747f..6ca4932a 100644 --- a/Math/modp.hpp +++ b/Math/modp.hpp @@ -332,7 +332,7 @@ void modp_::output(ostream& s, const Zp_Data& ZpD, bool human, bool signed_) if (human) { bigint te; to_bigint(te, ZpD); - if (te < ZpD.pr / 2 or not signed_) + if (te <= ZpD.pr_half or not signed_) s << te; else s << (te - ZpD.pr); diff --git a/Networking/AllButLastPlayer.h b/Networking/AllButLastPlayer.h index 3d6d1834..679a0d73 100644 --- a/Networking/AllButLastPlayer.h +++ b/Networking/AllButLastPlayer.h @@ -24,16 +24,26 @@ public: delete N; } - void send_to_no_stats(int player, const octetStream& o) const + void send_to(int player, const octetStream& o) const { P.send_to(player, o); } - void receive_player_no_stats(int i, octetStream& o) const + void receive_player(int i, octetStream& o) const { P.receive_player(i, o); } + void send_to_no_stats(int, const octetStream&) const + { + throw not_implemented(); + } + + void receive_player_no_stats(int, octetStream&) const + { + throw not_implemented(); + } + void send_receive_all_no_stats(const vector>& channels, const vector& to_send, vector& to_receive) const diff --git a/Networking/Player.cpp b/Networking/Player.cpp index 575f60cf..1d17221a 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -367,10 +367,7 @@ long MultiPlayer::receive_long(int i) const void Player::send_to(int player,const octetStream& o) const { -#ifdef VERBOSE_COMM - cerr << "sending to " << player << endl; -#endif - TimeScope ts(comm_stats["Sending directly"].add(o)); + TimeScope ts(comm_stats["Sending directly"].add(o, player)); send_to_no_stats(player, o); sent += o.get_length(); } @@ -405,12 +402,9 @@ void Player::receive_all(vector& os) const void Player::receive_player(int i,octetStream& o) const { -#ifdef VERBOSE_COMM - cerr << "receiving from " << i << endl; -#endif TimeScope ts(timer); receive_player_no_stats(i, o); - comm_stats["Receiving directly"].add(o, ts); + comm_stats["Receiving directly"].add(o, ts, i); } template @@ -484,10 +478,7 @@ void MultiPlayer::exchange_no_stats(int other, const octetStream& o, octetStr void Player::exchange(int other, const octetStream& o, octetStream& to_receive) const { -#ifdef VERBOSE_COMM - cerr << "Exchanging with " << other << endl; -#endif - TimeScope ts(comm_stats["Exchanging"].add(o)); + TimeScope ts(comm_stats["Exchanging"].add(o, other)); exchange_no_stats(other, o, to_receive); sent += o.get_length(); } @@ -603,9 +594,8 @@ void Player::send_receive_all(const vector>& channels, if (i != my_num() and channels.at(my_num()).at(i)) { data += to_send.at(i).get_length(); -#ifdef VERBOSE_COMM - cerr << "Send " << to_send.at(i).get_length() << " to " << i << endl; -#endif + if (OnlineOptions::singleton.has_option("detailed_verbose_comm")) + cerr << "Send " << to_send.at(i).get_length() << " bytes to " << i << endl; } TimeScope ts(comm_stats["Sending/receiving"].add(data)); sent += data; @@ -879,15 +869,22 @@ Timer& CommStatsWithName::add_length_only(size_t length) return stats.add_length_only(length); } -Timer& CommStatsWithName::add(const octetStream& os) +Timer& CommStatsWithName::add(const octetStream& os, int player) { - return add(os.get_length()); + return add(os.get_length(), player); } -Timer& CommStatsWithName::add(size_t length) +Timer& CommStatsWithName::add(size_t length, int player) { if (OnlineOptions::singleton.has_option("verbose_comm")) - fprintf(stderr, "%s %zu bytes\n", name.c_str(), length); + { + if (player < 0) + fprintf(stderr, "%s %zu bytes\n", name.c_str(), length); + else + fprintf(stderr, "%s %zu bytes with party %d\n", name.c_str(), length, + player); + } + return stats.add(length); } diff --git a/Networking/Player.h b/Networking/Player.h index 6ea8a3ec..94ab5079 100644 --- a/Networking/Player.h +++ b/Networking/Player.h @@ -160,9 +160,10 @@ public: name(name), stats(stats) {} Timer& add_length_only(size_t length); - Timer& add(const octetStream& os); - Timer& add(size_t length); - void add(const octetStream& os, const TimeScope& scope) { add(os) += scope; } + Timer& add(const octetStream& os, int player = -1); + Timer& add(size_t length, int player = -1); + void add(const octetStream& os, const TimeScope& scope, int player = -1) + { add(os, player) += scope; } }; class NamedCommStats : public map @@ -272,7 +273,7 @@ public: /** * Send to a specific player */ - void send_to(int player,const octetStream& o) const; + virtual void send_to(int player,const octetStream& o) const; virtual void send_to_no_stats(int player,const octetStream& o) const = 0; /** * Receive from all other players. @@ -282,7 +283,7 @@ public: /** * Receive from a specific player */ - void receive_player(int i,octetStream& o) const; + virtual void receive_player(int i,octetStream& o) const; virtual void receive_player_no_stats(int i,octetStream& o) const = 0; virtual void receive_player(int i,FlexBuffer& buffer) const; @@ -546,6 +547,8 @@ public: size_t send(const PlayerBuffer& buffer, bool block) const; size_t recv(const PlayerBuffer& buffer, bool block) const; + + NamedCommStats get_comm_stats() const { return comm_stats; } }; class RealTwoPartyPlayer : public VirtualTwoPartyPlayer diff --git a/OT/NPartyTripleGenerator.hpp b/OT/NPartyTripleGenerator.hpp index 039963ce..31e7643a 100644 --- a/OT/NPartyTripleGenerator.hpp +++ b/OT/NPartyTripleGenerator.hpp @@ -243,16 +243,23 @@ void NPartyTripleGenerator::generateInputs(int player) CODE_LOCATION typedef typename W::input_type::share_type::open_type T; - auto nTriplesPerLoop = this->nTriplesPerLoop * 10; + auto nTriplesPerLoop = this->nTriplesPerLoop; auto& valueBits = this->valueBits; auto& share_prg = this->share_prg; auto& ot_multipliers = this->ot_multipliers; auto& nparties = this->nparties; auto& globalPlayer = this->globalPlayer; + if (this->thread_num >= 0) + nTriplesPerLoop *= 10; + // extra value for sacrifice int toCheck = nTriplesPerLoop + DIV_CEIL(W::mac_key_type::size_in_bits(), T::size_in_bits()); + + if (OnlineOptions::singleton.has_option("verbose_input")) + fprintf(stderr, "generating %d input tuples\n", toCheck); + valueBits.resize(1); this->signal_multipliers({player, toCheck}); bool mine = player == globalPlayer.my_num(); diff --git a/OT/OTExtensionWithMatrix.cpp b/OT/OTExtensionWithMatrix.cpp index a8ebf571..604274cc 100644 --- a/OT/OTExtensionWithMatrix.cpp +++ b/OT/OTExtensionWithMatrix.cpp @@ -77,6 +77,22 @@ void OTExtensionWithMatrix::protocol_agreement() if (OnlineOptions::singleton.has_option("high_softspoken")) softspoken_k = 8; + if (OnlineOptions::singleton.has_param("softspoken")) + softspoken_k = OnlineOptions::singleton.get_param("softspoken"); + + int needed = DIV_CEIL(nbaseOTs, softspoken_k) * softspoken_k; + + baseReceiverInput.resize_zero(needed); + + for (int i = nbaseOTs; i < needed; i++) + { + auto zero = string(SEED_SIZE, '\0'); + G_receiver.push_back(zero); + G_sender.push_back({}); + for (int j = 0; j < 2; j++) + G_sender.back().push_back(zero); + } + bundle.mine.store(softspoken_k); player->unchecked_broadcast(bundle); @@ -177,7 +193,8 @@ void OTExtensionWithMatrix::soft_sender(size_t n) return; if (OnlineOptions::singleton.has_option("verbose_ot")) - fprintf(stderr, "%zu OTs as sender\n", n); + fprintf(stderr, "%zu OTs as sender (%s)\n", n, + passive_only ? "semi-honest" : "malicious"); osuCrypto::PRNG prng(osuCrypto::sysRandomSeed()); osuCrypto::SoftSpokenOT::TwoOneMaliciousSender sender(softspoken_k); diff --git a/OT/OTVole.hpp b/OT/OTVole.hpp index c7bfb59e..e057bcdf 100644 --- a/OT/OTVole.hpp +++ b/OT/OTVole.hpp @@ -9,6 +9,10 @@ template void OTVoleBase::evaluate(vector& output, const vector& newReceiverInput) { CODE_LOCATION + if (OnlineOptions::singleton.has_option("verbose_vole")) + fprintf(stderr, "%d-bit VOLE with %zu elements and S=%d\n", T::N_BITS, + newReceiverInput.size(), S); + const int N1 = newReceiverInput.size() + 1; output.resize(newReceiverInput.size()); auto& os = oss; diff --git a/Processor/BaseMachine.cpp b/Processor/BaseMachine.cpp index 67582159..4968f3bc 100644 --- a/Processor/BaseMachine.cpp +++ b/Processor/BaseMachine.cpp @@ -13,6 +13,7 @@ #include #include +#include using namespace std; BaseMachine* BaseMachine::singleton = 0; @@ -66,16 +67,21 @@ int BaseMachine::triple_bucket_size(DataFieldType type) int BaseMachine::bucket_size(size_t usage) { int res = OnlineOptions::singleton.bucket_size; + int min = res; if (usage) { - for (int B = res; B <= 5; B++) - if (ShuffleSacrifice(B).minimum_n_outputs() < usage * .9) + res = 5; + for (int B = res; B >= min; B--) + if (ShuffleSacrifice(B).minimum_n_outputs() > usage * 1.1) break; else res = B; } + if (OnlineOptions::singleton.has_option("debug_batch_size")) + fprintf(stderr, "bucket_size=%d usage=%zu\n", res, usage); + return res; } @@ -103,8 +109,13 @@ int BaseMachine::matrix_requirement(int n_rows, int n_inner, int n_cols) return -1; } +bool BaseMachine::allow_mulm() +{ + return singleton and singleton->relevant_opts.find("no_mulm") != string::npos; +} + BaseMachine::BaseMachine() : - nthreads(0), multithread(false), nan_warning(0) + nthreads(0), multithread(false), nan_warning(0), mini_warning(0) { if (sodium_init() == -1) throw runtime_error("couldn't initialize libsodium"); @@ -182,6 +193,7 @@ void BaseMachine::load_schedule(const string& progname, bool load_bytecode) getline(inpf, relevant_opts); getline(inpf, security); getline(inpf, gf2n); + getline(inpf, expected_communication); inpf.close(); } @@ -320,17 +332,47 @@ void BaseMachine::print_global_comm(Player& P, const NamedCommStats& stats) Bundle bundle(P); bundle.mine.store(stats.sent); P.Broadcast_Receive_no_stats(bundle); - size_t global = 0; + long long global = 0; for (auto& os : bundle) global += os.get_int(8); cerr << "Global data sent = " << global / 1e6 << " MB (all parties)" << endl; + + smatch what; + regex comm_regexp("online:([0-9]*) offline:([0-9]*) n_parties:([0-9]*)"); + if (regex_search(expected_communication, what, comm_regexp)) + { + long long expected = stoll(what[1]) + stoll(what[2]); + int n_parties = stoi(what[3]); + if (expected and n_parties != P.num_players()) + { + cerr << "Wrong number of parties in compiler's expectation: " + << n_parties << endl; + } + else if (expected) + { + double over = round(100. * (global - expected) / expected); + if (over >= 5) + cerr + << "Actual communication exceeds the compiler's expectation by " + << over << " percent." << endl; + if (over < 0) + { + if (OnlineOptions::singleton.has_option("overestimate")) + cerr << "Actual communication is below the compiler's " + "expectation by " << -over << " percent." << endl; + else + cerr << "The compiler overestimated the communication." << endl; + } + } + } } void BaseMachine::print_comm(Player& P, const NamedCommStats& comm_stats) { size_t rounds = 0; for (auto& x : comm_stats) - rounds += x.second.rounds; + if (x.first.find("transmission") == string::npos) + rounds += x.second.rounds; cerr << "Data sent = " << comm_stats.sent / 1e6 << " MB in ~" << rounds << " rounds (party " << P.my_num() << " only"; if (multithread) @@ -341,3 +383,9 @@ void BaseMachine::print_comm(Player& P, const NamedCommStats& comm_stats) print_global_comm(P, comm_stats); } + +void BaseMachine::add_one_off(const NamedCommStats& comm) +{ + if (has_singleton()) + s().one_off_comm += comm; +} diff --git a/Processor/BaseMachine.h b/Processor/BaseMachine.h index 3e1b6e34..1703bde8 100644 --- a/Processor/BaseMachine.h +++ b/Processor/BaseMachine.h @@ -23,6 +23,7 @@ void print_usage(ostream& o, const char* name, size_t capacity); class BaseMachine { friend class Program; + template friend class thread_info; protected: static BaseMachine* singleton; @@ -38,6 +39,9 @@ protected: string relevant_opts; string security; string gf2n; + string expected_communication; + + NamedCommStats one_off_comm; virtual size_t load_program(const string& threadname, const string& filename); @@ -60,6 +64,7 @@ public: vector progs; bool nan_warning; + int mini_warning; static BaseMachine& s(); static bool has_singleton() { return singleton != 0; } @@ -75,7 +80,8 @@ public: static int security_from_schedule(string progname); template - static int batch_size(Dtype type, int buffer_size = 0, int fallback = 0); + static int batch_size(Dtype type, int buffer_size = 0, int fallback = 0, + int factor = 0); template static int input_batch_size(int player, int buffer_size = 0); template @@ -86,6 +92,10 @@ public: static int matrix_batch_size(int n_rows, int n_inner, int n_cols); static int matrix_requirement(int n_rows, int n_inner, int n_cols); + static bool allow_mulm(); + + static void add_one_off(const NamedCommStats& comm); + BaseMachine(); virtual ~BaseMachine() {} @@ -110,6 +120,8 @@ public: void print_comm(Player& P, const NamedCommStats& stats); virtual const Names& get_N() { throw not_implemented(); } + + virtual void gap_warning(int) { throw not_implemented(); } }; inline OTTripleSetup BaseMachine::fresh_ot_setup(Player& P) @@ -118,7 +130,8 @@ inline OTTripleSetup BaseMachine::fresh_ot_setup(Player& P) } template -int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback) +int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback, + int factor) { if (OnlineOptions::singleton.has_option("debug_batch_size")) fprintf(stderr, "batch_size buffer_size=%d fallback=%d\n", buffer_size, @@ -133,7 +146,8 @@ int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback) else if (fallback > 0) n_opts = fallback; else - n_opts = OnlineOptions::singleton.batch_size * T::default_length; + n_opts = OnlineOptions::singleton.batch_size + * max(factor, T::default_length); if (buffer_size <= 0 and has_program()) { @@ -180,9 +194,17 @@ int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback) res = n_opts; if (OnlineOptions::singleton.has_option("debug_batch_size")) + { cerr << DataPositions::dtype_names[type] << " " << T::type_string() << " res=" << res << " n=" << n << " n_opts=" << n_opts - << " buffer_size=" << buffer_size << endl; + << " buffer_size=" << buffer_size << " bits/dabits=" + << T::LivePrep::bits_from_dabits() << "/" + << T::LivePrep::dabits_from_bits() << " has_program=" + << has_program(); + if (program) + cerr << " program=" << program->get_name(); + cerr << endl; + } assert(res > 0); return res; diff --git a/Processor/EdabitBuffer.h b/Processor/EdabitBuffer.h index d6506b75..2507365f 100644 --- a/Processor/EdabitBuffer.h +++ b/Processor/EdabitBuffer.h @@ -34,7 +34,9 @@ public: + ", have you generated edaBits, " "for example by running " "'./Fake-Offline.x -e " - + to_string(n_bits) + " ...'?"); + + to_string(n_bits) + + T::template proto_fake_opts() + + " ...'?"); } assert(BufferBase::file); diff --git a/Processor/Instruction.cpp b/Processor/Instruction.cpp index 11aeef6c..3bb39e2f 100644 --- a/Processor/Instruction.cpp +++ b/Processor/Instruction.cpp @@ -11,6 +11,7 @@ #include "GC/instructions.h" #include "Memory.hpp" +#include "Instruction.hpp" #include diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index bfc5c3b6..c772239a 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -1501,8 +1501,6 @@ void Program::execute_with_errors(Processor& Proc) const auto& processor = Proc.Procb; auto& Ci = Proc.get_Ci(); - BaseMachine::program = this; - while (Proc.PC& Proc) const template void Program::mulm_check() const { - if (T::function_dependent and not OnlineOptions::singleton.has_option("allow_mulm")) + if (T::function_dependent + and not (BaseMachine::allow_mulm() + or OnlineOptions::singleton.has_option("allow_mulm"))) throw runtime_error("Mixed multiplication not implemented for function-dependent preprocessing. " "Use '-E ' during compilation or state " "'program.use_mulm = False' at the beginning of your high-level program."); diff --git a/Processor/Machine.h b/Processor/Machine.h index 05fb2090..705e30d7 100644 --- a/Processor/Machine.h +++ b/Processor/Machine.h @@ -54,6 +54,9 @@ class Machine : public BaseMachine NamedCommStats max_comm; + int max_trunc_size; + Lock warn_lock; + size_t load_program(const string& threadname, const string& filename); void prepare(const string& progname_str); @@ -126,6 +129,8 @@ class Machine : public BaseMachine Player& get_player() { return *P; } void check_program(); + + void gap_warning(int k); }; #endif /* MACHINE_H_ */ diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index afe13517..5fe80b50 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -55,6 +55,7 @@ template Machine::Machine(Names& playerNames, bool use_encryption, const OnlineOptions opts) : my_number(playerNames.my_num()), N(playerNames), + max_trunc_size(0), use_encryption(use_encryption), live_prep(opts.live_prep), opts(opts), external_clients(my_number) { @@ -607,6 +608,12 @@ void Machine::run(const string& progname) if (multithread) cerr << " (overall core time)"; cerr << endl; + auto& P = *this->P; + auto one_off = TreeSum>().run( + this->one_off_comm.sent, P).get_limb(0); + if (one_off) + cerr << "One-off global communication: " << one_off * 1e-6 << " MB" + << endl; } print_timers(); @@ -685,12 +692,17 @@ void Machine::run(const string& progname) << "have you considered using " << alt << " instead?" << endl; } - if (nan_warning and sint::real_shares(*P)) + if ((nan_warning or mini_warning) and sint::real_shares(*P)) { - cerr << "Outputs of 'NaN' might be related to exceeding the sfix range. See "; - cerr << "https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sfix"; + if (nan_warning) + cerr << "Outputs of 'NaN' might be related to exceeding the sfix range. "; + if (mini_warning) + cerr << pow(2, mini_warning) << " is the smallest non-zero number " + << "in a used fixed-point representation. "; + cerr << "See https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sfix"; cerr << " for details" << endl; nan_warning = false; + mini_warning = 0; } #ifdef VERBOSE @@ -743,6 +755,10 @@ void Machine::suggest_optimizations() cerr << "This program might benefit from some protocol options." << endl << "Consider adding the following at the beginning of your code:" << endl << optimizations; + if (sint::clear::n_bits() < max_trunc_size) + cerr << "The computation domain is too small " + << "for low-round truncation; it would need to have at least " + << max_trunc_size << " bits." << endl; #ifndef __clang__ cerr << "This virtual machine was compiled with GCC. Recompile with " "'CXX = clang++' in 'CONFIG.mine' for optimal performance." << endl; @@ -768,4 +784,15 @@ void Machine::check_program() } } +template +void Machine::gap_warning(int k) +{ + if (k > max_trunc_size) + { + warn_lock.lock(); + max_trunc_size = max(k, max_trunc_size); + warn_lock.unlock(); + } +} + #endif diff --git a/Processor/Memory.h b/Processor/Memory.h index 0bfefdef..6853a298 100644 --- a/Processor/Memory.h +++ b/Processor/Memory.h @@ -37,6 +37,16 @@ public: #endif } + const T* begin() const + { + return data(); + } + + const T* end() const + { + return data() + size(); + } + virtual T& operator[](size_t i) = 0; virtual const T& operator[](size_t i) const = 0; diff --git a/Processor/Online-Thread.hpp b/Processor/Online-Thread.hpp index 908fa370..0b9f42ad 100644 --- a/Processor/Online-Thread.hpp +++ b/Processor/Online-Thread.hpp @@ -268,7 +268,7 @@ void thread_info::Sub_Main_Func() printf("\tClient %d about to run %d\n",num,program); #endif online_timer.start(P.total_comm()); - online_prep_timer -= Proc.DataF.total_time(); + online_prep_timer -= Proc.prep_time(); Proc.reset(progs[program], job.arg); // Bits, Triples, Squares, and Inverses skipping @@ -278,6 +278,7 @@ void thread_info::Sub_Main_Func() //printf("\tExecuting program"); // Execute the program + BaseMachine::program = &progs[program]; progs[program].execute(Proc); // make sure values used in other threads are safe @@ -298,7 +299,7 @@ void thread_info::Sub_Main_Func() "in thread %d\n", program, num); #endif online_timer.stop(P.total_comm()); - online_prep_timer += Proc.DataF.total_time(); + online_prep_timer += Proc.prep_time(); wait_timer.start(); queues->finished(job, P.total_comm()); wait_timer.stop(); @@ -307,10 +308,10 @@ void thread_info::Sub_Main_Func() // final check online_timer.start(P.total_comm()); - online_prep_timer -= Proc.DataF.total_time(); + online_prep_timer -= Proc.prep_time(); Proc.check(); online_timer.stop(P.total_comm()); - online_prep_timer += Proc.DataF.total_time(); + online_prep_timer += Proc.prep_time(); if (machine.opts.file_prep_per_thread) Proc.DataF.prune(); diff --git a/Processor/OnlineOptions.cpp b/Processor/OnlineOptions.cpp index 1c9ce446..dc9a4c42 100644 --- a/Processor/OnlineOptions.cpp +++ b/Processor/OnlineOptions.cpp @@ -9,10 +9,12 @@ #include "Math/gfpvar.h" #include "Protocols/HemiOptions.h" #include "Protocols/config.h" +#include "FHEOffline/config.h" #include "Math/gfp.hpp" #include +#include using namespace std; @@ -40,6 +42,8 @@ OnlineOptions::OnlineOptions() : playerno(-1) max_broadcast = 0; receive_threads = false; code_locations = false; + have_warned_about_comp_sec = false; + semi_honest = false; #ifdef VERBOSE verbose = true; #else @@ -161,6 +165,10 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, opt.get("--options")->getStrings(options); + for (auto& option : options) + if (option.find("verbose") == 0) + verbose = true; + code_locations = opt.isSet("--code-locations"); #ifdef THROW_EXCEPTIONS @@ -463,6 +471,7 @@ void OnlineOptions::finalize_with_error(ez::ezOptionParser& opt) o->getString(disk_memory); receive_threads = opt.isSet("--threads"); + semi_honest = opt.isSet("--semi-honest"); if (use_security_parameter) { @@ -505,3 +514,35 @@ int OnlineOptions::prime_limbs() { return DIV_CEIL(prime_length(), 64); } + +bool OnlineOptions::has_param(const string& param) +{ + for (auto& x : options) + if (x.find(param + "=") == 0) + return true; + return false; +} + +int OnlineOptions::get_param(const string& param) +{ + basic_regex re(param + "=([0-9]+)"); + smatch match; + for (auto& x : options) + if (regex_match(x, match, re)) + return atoi(match[1].str().c_str()); + throw runtime_error("parameter not found: " + param); +} + +int OnlineOptions::comp_sec() +{ + int res = COMP_SEC; + if (has_param("comp_sec")) + res = get_param("comp_sec"); + if (res < 128 and not have_warned_about_comp_sec) + { + cerr << "WARNING: computational security parameter " << res + << " suitable for testing only" << endl; + have_warned_about_comp_sec = true; + } + return res; +} diff --git a/Processor/OnlineOptions.h b/Processor/OnlineOptions.h index 1193d1e5..009c2b94 100644 --- a/Processor/OnlineOptions.h +++ b/Processor/OnlineOptions.h @@ -15,6 +15,8 @@ class OnlineOptions { void finalize_with_error(ez::ezOptionParser& opt); + bool have_warned_about_comp_sec; + public: static OnlineOptions singleton; @@ -44,6 +46,7 @@ public: vector options; string executable; bool code_locations; + bool semi_honest; OnlineOptions(); OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv, @@ -79,6 +82,11 @@ public: { return find(options.begin(), options.end(), option) != options.end(); } + + bool has_param(const string& param); + int get_param(const string& param); + + int comp_sec(); }; #endif /* PROCESSOR_ONLINEOPTIONS_H_ */ diff --git a/Processor/OnlineOptions.hpp b/Processor/OnlineOptions.hpp index 324c7507..cd6f8ff0 100644 --- a/Processor/OnlineOptions.hpp +++ b/Processor/OnlineOptions.hpp @@ -97,6 +97,18 @@ OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, "-N", // Flag token. "--nparties" // Flag token. ); + + if (T::semi_honest_option) + opt.add( + "", // Default. + 0, // Required? + 0, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Semi-honest operation (default: malicious security)" + // Help description. + "-sh", // Flag token. + "--semi-honest" // Flag token. + ); } template diff --git a/Processor/Processor.h b/Processor/Processor.h index 32299fa3..7c999011 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -310,6 +310,8 @@ class Processor : public ArithmeticProcessor void call_tape(int tape_number, int arg, const vector& results); + TimerWithComm prep_time(); + private: template friend class SPDZ; diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index 52de2936..2a3ed03e 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -638,9 +638,6 @@ void SubProcessor::matmulsm(const MemoryPart& source, int batchStartI = 0; int batchStartJ = 0; - size_t sourceSize = source.size(); - const T* sourceData = source.data(); - protocol.init_dotprod(); for (auto matmulArgs = start.begin(); matmulArgs < start.end(); matmulArgs += 12) { auto output = S.begin() + matmulArgs[0]; @@ -654,27 +651,54 @@ void SubProcessor::matmulsm(const MemoryPart& source, assert(output + resultNumberOfRows * resultNumberOfColumns <= S.end()); + for (int j = 0; j < resultNumberOfColumns; j += 1) { + auto actualSecondFactorColumn = + Proc->get_Ci().at(matmulArgs[9] + j).get(); + auto secondBase = source.begin() + secondFactorBase + + actualSecondFactorColumn; + for (auto &x : Range(Proc->get_Ci(), matmulArgs[8], + usedNumberOfFirstFactorColumns)) + assert( + secondBase + x.get() * secondFactorTotalNumberOfColumns + < source.end()); + } + + vector second_factors; + second_factors.reserve(usedNumberOfFirstFactorColumns); + + for (auto& x : Range(Proc->get_Ci(), matmulArgs[8], + usedNumberOfFirstFactorColumns)) + second_factors.push_back(x.get() * secondFactorTotalNumberOfColumns); + for (int i = 0; i < resultNumberOfRows; i += 1) { auto actualFirstFactorRow = Proc->get_Ci().at(matmulArgs[6] + i).get(); + auto firstBase = source.begin() + firstFactorBase + + actualFirstFactorRow * firstFactorTotalNumberOfColumns; + + for (auto& x : Range(Proc->get_Ci(), matmulArgs[7], + usedNumberOfFirstFactorColumns)) + assert(firstBase + x.get() < source.end()); for (int j = 0; j < resultNumberOfColumns; j += 1) { auto actualSecondFactorColumn = Proc->get_Ci().at(matmulArgs[9] + j).get(); + auto secondBase = source.begin() + secondFactorBase + + actualSecondFactorColumn; #ifdef MATMULSM_DEBUG cout << "Preparing " << i << "," << j << "(buffer size: " << protocol.get_buffer_size() << ")" << endl; #endif - for (int k = 0; k < usedNumberOfFirstFactorColumns; k += 1) { - auto actualFirstFactorColumn = Proc->get_Ci().at(matmulArgs[7] + k).get(); - auto actualSecondFactorRow = Proc->get_Ci().at(matmulArgs[8] + k).get(); + auto second_it = second_factors.begin(); - auto firstAddress = firstFactorBase + actualFirstFactorRow * firstFactorTotalNumberOfColumns + actualFirstFactorColumn; - auto secondAddress = secondFactorBase + actualSecondFactorRow * secondFactorTotalNumberOfColumns + actualSecondFactorColumn; + for (auto& x : Range(Proc->get_Ci(), matmulArgs[7], + usedNumberOfFirstFactorColumns)) + { + auto actualFirstFactorColumn = x.get(); - assert(firstAddress < sourceSize); - assert(secondAddress < sourceSize); + auto first = firstBase + actualFirstFactorColumn; + auto second = secondBase + *second_it++; - protocol.prepare_dotprod(sourceData[firstAddress], sourceData[secondAddress]); + protocol.prepare_dotprod(*first, *second); } protocol.next_dotprod(); @@ -905,9 +929,19 @@ void Conv2dTuple::post(StackedVector& S, typename T::Protocol& protocol) template void SubProcessor::secure_shuffle(const Instruction& instruction) { - typename T::Protocol::Shuffler(S, instruction.get_size(), - instruction.get_n(), instruction.get_r(0), instruction.get_r(1), - *this); + size_t n = instruction.get_size(); + size_t unit_size = instruction.get_n(); + size_t output_base = instruction.get_r(0); + size_t input_base = instruction.get_r(1); + + typename T::Protocol::Shuffler shuffler(*this); + + typename T::Protocol::Shuffler::shuffle_type shuffle; + shuffler.generate(n / unit_size, shuffle); + + vector> shuffles{ShuffleTuple(n, output_base, + input_base, unit_size, shuffle, true)}; + shuffler.apply_multiple(S, shuffles); maybe_check(); } @@ -916,7 +950,10 @@ template size_t SubProcessor::generate_secure_shuffle(const Instruction& instruction, ShuffleStore& shuffle_store) { - return shuffler.generate(instruction.get_n(), shuffle_store); + size_t n = instruction.get_n(); + auto res = shuffle_store.add(n); + shuffler.generate(n, shuffle_store.get(res).second); + return res; } template @@ -926,21 +963,18 @@ void SubProcessor::apply_shuffle(const Instruction& instruction, const auto& args = instruction.get_start(); const auto n_shuffles = args.size() / 6; - vector sizes(n_shuffles, 0); - vector destinations(n_shuffles, 0); - vector sources(n_shuffles, 0); - vector unit_sizes(n_shuffles, 0); - vector shuffles(n_shuffles, 0); - vector reverse(n_shuffles, false); - for (size_t i = 0; i < n_shuffles; i++) { - sizes[i] = args[6 * i]; - destinations[i] = args[6 * i + 1]; - sources[i] = args[6 * i + 2]; - unit_sizes[i] = args[6 * i + 3]; - shuffles[i] = Proc->read_Ci(args[6 * i + 4]); - reverse[i] = args[6 * i + 5]; + vector> shuffles; + + for (size_t i = 0; i < n_shuffles; i++) + { + shuffles.push_back( + ShuffleTuple(args[6 * i], args[6 * i + 1], args[6 * i + 2], + args[6 * i + 3], + shuffle_store.get(Proc->read_Ci(args[6 * i + 4])), + bool(args[6 * i + 5]))); } - shuffler.apply_multiple(S, sizes, destinations, sources, unit_sizes, shuffles, reverse, shuffle_store); + + shuffler.apply_multiple(S, shuffles); maybe_check(); } @@ -1184,4 +1218,13 @@ void Processor::call_tape(int tape_number, int arg, arg_stack.pop_back(); } +template +TimerWithComm Processor::prep_time() +{ + auto res = DataF.total_time(); + res += Procp.protocol.prep_time(); + res += Proc2.protocol.prep_time(); + return res; +} + #endif diff --git a/Processor/Program.h b/Processor/Program.h index 39245882..b2d3b924 100644 --- a/Processor/Program.h +++ b/Processor/Program.h @@ -42,6 +42,8 @@ class Program size_t size() const { return p.size(); } + string get_name() const { return name; } + // Read in a program void parse(string filename); void parse_with_error(string filename); diff --git a/Processor/RingMachine.hpp b/Processor/RingMachine.hpp index 26a50000..cd9ebdb4 100644 --- a/Processor/RingMachine.hpp +++ b/Processor/RingMachine.hpp @@ -31,10 +31,24 @@ HonestMajorityRingMachine::HonestMajorityRingMachine(int argc, const char* RingMachine(argc, argv, opt, online_opts, nplayers); } -inline void ring_domain_error(int R) +inline void ring_domain_error(int R, int max) { - cerr << "not compiled for " << R << "-bit computation, " << endl; - cerr << "compile with -DRING_SIZE=" << R << endl; + cerr << "The virtual machine is not compiled for " << R + << "-bit computation." << endl; + cerr << "Compile with 'MY_CFLAGS += -DRING_SIZE=" << R + << "' in 'CONFIG.mine'"; + (void) max; +#ifndef FEWER_RINGS + for (int r = 0; r <= max; r += 64) + { + if (r >= R) + { + cerr << " or try " << "'-R " << r << "'"; + break; + } + } +#endif + cerr << "." << endl; exit(1); } @@ -60,7 +74,7 @@ RingMachine::RingMachine(int argc, const char** argv, X(RING_SIZE) #endif #undef X - ring_domain_error(R); + ring_domain_error(R, 192); } template class U, template class V> @@ -98,7 +112,7 @@ HonestMajorityRingMachineWithSecurity::HonestMajorityRingMachineWithSecuri X(72) X(128) #endif #undef X - ring_domain_error(R); + ring_domain_error(R, 128); } #endif /* PROCESSOR_RINGMACHINE_HPP_ */ diff --git a/Processor/TruncPrTuple.h b/Processor/TruncPrTuple.h index 80d4617c..21df097e 100644 --- a/Processor/TruncPrTuple.h +++ b/Processor/TruncPrTuple.h @@ -11,6 +11,7 @@ using namespace std; #include "OnlineOptions.h" +#include "BaseMachine.h" #include "GC/ArgTuples.h" template class StackedVector; @@ -106,9 +107,12 @@ public: TruncPrTupleWithGap(vector::const_iterator it) : TruncPrTuple(it) { - big_gap_ = this->k <= T::n_bits() - OnlineOptions::singleton.trunc_error; + int min_size = this->k + OnlineOptions::singleton.trunc_error; + big_gap_ = min_size <= T::n_bits(); if (T::prime_field and small_gap()) throw runtime_error("domain too small for chosen truncation error"); + if (small_gap() and BaseMachine::has_singleton()) + BaseMachine::s().gap_warning(min_size); } T upper(T mask) diff --git a/Processor/instructions.h b/Processor/instructions.h index 7b99c743..279417d4 100644 --- a/Processor/instructions.h +++ b/Processor/instructions.h @@ -202,7 +202,7 @@ *dest++ = *op1++ > *op2++) \ X(EQC, auto dest = &Ci[r[0]]; auto op1 = &Ci[r[1]]; auto op2 = &Ci[r[2]], \ *dest++ = *op1++ == *op2++) \ - X(PRINTINT, Proc.out << Proc.read_Ci(r[0]) << flush,) \ + X(PRINTINT, print(Proc.out, &Proc.get_Ci_ref(r[0])),) \ X(PRINTFLOATPREC, Proc.out << setprecision(n),) \ X(PRINTSTR, Proc.out << string((char*)&n,4) << flush,) \ X(PRINTCHR, Proc.out << string((char*)&n,1) << flush,) \ diff --git a/Programs/Source/and-bench.py b/Programs/Source/and-bench.py new file mode 100644 index 00000000..b87ba714 --- /dev/null +++ b/Programs/Source/and-bench.py @@ -0,0 +1,5 @@ +a = sbits.get_type(int(program.args[1]))(0) + +@for_range(int(program.args[2])) +def _(i): + a & a diff --git a/Programs/Source/combo-bench.py b/Programs/Source/combo-bench.py new file mode 100644 index 00000000..623f8f31 --- /dev/null +++ b/Programs/Source/combo-bench.py @@ -0,0 +1,11 @@ +import math + +n = int(program.args[1]) +n_sqrt = int(math.sqrt(n)) + +sfix.Matrix(n_sqrt, 10) * sfix.Matrix(10, n_sqrt) +(sfix(0, size=n) < 0).store_in_mem(0) + +sint.Array(n).secure_shuffle() + +sint(personal(0, cint(0, size=n))) diff --git a/Programs/Source/comp-bench.py b/Programs/Source/comp-bench.py new file mode 100644 index 00000000..02e5abbf --- /dev/null +++ b/Programs/Source/comp-bench.py @@ -0,0 +1,15 @@ +#sfix.set_precision(32, 63) +#program.use_trunc_pr = True +#program.use_split(3) +program.options_from_args() +sfix.set_precision_from_args(program) +try: + n_loops = int(program.args[2]) +except: + n_loops = 1 + +a = sfix(cint(0, size=int(program.args[1]))) + +@for_range(n_loops) +def _(i): + (a < a)#.store_in_mem(0) diff --git a/Programs/Source/fdiv-bench.py b/Programs/Source/fdiv-bench.py new file mode 100644 index 00000000..ec050bbe --- /dev/null +++ b/Programs/Source/fdiv-bench.py @@ -0,0 +1,10 @@ +program.options_from_args() +sfix.set_precision_from_args(program) + +n = int(program.args[1]) +m = int(program.args[2]) +a = sfix(0, size=n) + +@for_range(m) +def _(i): + (a / a).store_in_mem(0) diff --git a/Programs/Source/fmul-bench.py b/Programs/Source/fmul-bench.py new file mode 100644 index 00000000..6bfdc9af --- /dev/null +++ b/Programs/Source/fmul-bench.py @@ -0,0 +1,14 @@ +program.options_from_args() +sfix.set_precision_from_args(program) + +try: + n = int(program.args[1]) +except: + n = 10 ** 6 + +m = int(program.args[2]) +a = sfix(0, size=n) + +@for_range(m) +def _(i): + (a * a).store_in_mem(0) diff --git a/Programs/Source/input-bench.py b/Programs/Source/input-bench.py new file mode 100644 index 00000000..771ea3bd --- /dev/null +++ b/Programs/Source/input-bench.py @@ -0,0 +1,5 @@ +x = personal(0, cint(0, size=int(program.args[1]))) + +@for_range(int(program.args[2])) +def _(i): + sint(x).store_in_mem(0) diff --git a/Programs/Source/matmul-bench.py b/Programs/Source/matmul-bench.py new file mode 100644 index 00000000..3c458f3b --- /dev/null +++ b/Programs/Source/matmul-bench.py @@ -0,0 +1,15 @@ +n = int(program.args[1]) +try: + m = int(program.args[3]) +except: + m = n +try: + k = int(program.args[4]) +except: + k = n +A = sint.Matrix(n, m) +B = sint.Matrix(m, k) + +@for_range(int(program.args[2])) +def _(i): + A * B diff --git a/Programs/Source/mul-bench.py b/Programs/Source/mul-bench.py new file mode 100644 index 00000000..28a46dcc --- /dev/null +++ b/Programs/Source/mul-bench.py @@ -0,0 +1,6 @@ +x = sint(0, size=int(program.args[1])) + +m = int(program.args[2]) +@for_range(m) +def _(i): + (x * x)#.store_in_mem(0) diff --git a/Programs/Source/open-bench.py b/Programs/Source/open-bench.py new file mode 100644 index 00000000..3386ebc8 --- /dev/null +++ b/Programs/Source/open-bench.py @@ -0,0 +1,5 @@ +x = sint(0, size=int(program.args[1])) + +@for_range(int(program.args[2])) +def _(i): + x.reveal().store_in_mem(0) diff --git a/Programs/Source/random-bench.py b/Programs/Source/random-bench.py new file mode 100644 index 00000000..be3bcfb4 --- /dev/null +++ b/Programs/Source/random-bench.py @@ -0,0 +1,3 @@ +@for_range(int(program.args[1])) +def _(i): + sint.get_random().store_in_mem(0) diff --git a/Programs/Source/shuffle-bench.py b/Programs/Source/shuffle-bench.py new file mode 100644 index 00000000..17e8b536 --- /dev/null +++ b/Programs/Source/shuffle-bench.py @@ -0,0 +1,14 @@ +n_apply = 1 +if len(program.args) > 3: + n_apply = int(program.args[3]) + +@for_range(int(program.args[2])) +def _(i): + print_ln('%s', i) + handle = sint.get_secure_shuffle(int(program.args[1])) + + @for_range(n_apply) + def _(i): + sint.Array(int(program.args[1])).secure_permute(handle) + +print_ln('bye') diff --git a/Programs/Source/test_gc.mpc b/Programs/Source/test_gc.mpc index 1f9b8fa3..ae54dcbb 100644 --- a/Programs/Source/test_gc.mpc +++ b/Programs/Source/test_gc.mpc @@ -21,9 +21,14 @@ test(sbit(1) * sbits(3), 3) test(sbit(1) * 3, 3) test(~sbits.new(1, n=64), 2**64 - 2) test(sbits(5) & sbits(3), 5 & 3) +test(sbits(5) | sbits(3), 5 | 3) test(sbits(3).equal(sbits(3)), 1) test(sbits(3).equal(sbits(2)), 0) +test(sbits(3) == sbits(3), 1) +test(sbits(3) == sbits(2), 0) +test(sbits(3) != sbits(3), 0) +test(sbits(3) != sbits(2), 1) test(sbit(1).if_else(sbits(3), sbits(5)), 3) test(sbits(7) << 1, 14) test(cbits(5) >> 1, 2) diff --git a/Programs/Source/trunc-bench.py b/Programs/Source/trunc-bench.py new file mode 100644 index 00000000..17c851c6 --- /dev/null +++ b/Programs/Source/trunc-bench.py @@ -0,0 +1,16 @@ +program.options_from_args() +sfix.set_precision_from_args(program) + +try: + n = int(program.args[1]) +except: + n = 10 ** 6 + +m = int(program.args[2]) + +x = sint(0, size=n) + +@for_range(m) +def _(i): + x.round(sfix.k + sfix.f, sfix.f, nearest=sfix.round_nearest, + signed=True) diff --git a/Protocols/Astra.hpp b/Protocols/Astra.hpp index 91b79ee0..8ad90b0a 100644 --- a/Protocols/Astra.hpp +++ b/Protocols/Astra.hpp @@ -88,7 +88,17 @@ AstraPrepProtocol::~AstraPrepProtocol() template void AstraOnlineBase::init_prep() { - open_with_check(prep, this->get_filename(false)); + try + { + open_with_check(prep, this->get_filename(false)); + } + catch (...) + { + throw runtime_error( + "Error with preprocessing in " + this->get_filename(false) + + ". You need to run the preprocessing before " + "the online phase or in parallel."); + } } template @@ -435,7 +445,17 @@ void AstraPrepProtocol::sync(vector& values, Player& P) if (P.my_num() == 1) { if (not outputs.is_open()) - open_with_check(outputs, this->get_output_filename()); + { + try + { + open_with_check(outputs, this->get_output_filename()); + } + catch (...) + { + throw runtime_error("Error with output back channel. " + "This only works when preprocessing is run in parallel."); + } + } Timer timer; TimeScope ts(timer); diff --git a/Protocols/AstraMC.hpp b/Protocols/AstraMC.hpp index 3f1a2c52..ff0df643 100644 --- a/Protocols/AstraMC.hpp +++ b/Protocols/AstraMC.hpp @@ -19,7 +19,7 @@ typename T::open_type AstraMC::prepare_summand(const T& secret, int my_num) template void AstraMC::exchange(const Player& P) { - SemiMC> opener; + DirectSemiMC> opener; opener.init_open(P, this->secrets.size()); int my_num = P.my_num() + 1; diff --git a/Protocols/Atlas.h b/Protocols/Atlas.h index 29f9a27c..80723c69 100644 --- a/Protocols/Atlas.h +++ b/Protocols/Atlas.h @@ -43,6 +43,7 @@ public: shamir(P), shamir2(P, 2 * ShamirMachine::s().threshold), oss(P), oss2(P), next_king(0), base_king(0), resharing(0, P), P(P) { + this->buffer_size = 0; } ~Atlas(); diff --git a/Protocols/Atlas.hpp b/Protocols/Atlas.hpp index 184bfc4b..35ebcc7c 100644 --- a/Protocols/Atlas.hpp +++ b/Protocols/Atlas.hpp @@ -7,6 +7,7 @@ #define PROTOCOLS_ATLAS_HPP_ #include "Atlas.h" +#include "BufferScope.h" template Atlas::~Atlas() @@ -24,6 +25,8 @@ array Atlas::get_double_sharing() { SeededPRNG G; PRNG G2 = G; + BufferScope scope(shamir, this->buffer_size); + BufferScope scope2(shamir2, this->buffer_size); auto random = shamir.get_randoms(G, 0); auto random2 = shamir2.get_randoms(G2, 0); assert(random.size() == random2.size()); @@ -44,6 +47,8 @@ void Atlas::init_mul() oss2.reset(); masks.clear(); base_king = next_king; + this->buffer_size = BaseMachine::batch_size(DATA_TRIPLE, + this->buffer_size); } template @@ -132,6 +137,7 @@ T Atlas::finalize_dotprod(int) template T Atlas::get_random() { + BufferScope scope(shamir, this->buffer_size); return shamir.get_random(); } diff --git a/Protocols/BrainShare.h b/Protocols/BrainShare.h index 5f6757cd..3f572624 100644 --- a/Protocols/BrainShare.h +++ b/Protocols/BrainShare.h @@ -53,10 +53,6 @@ public: { FixedVec::operator=(other); } - template - BrainShare(const U& other, int my_num, T alphai = {}) : super(other, my_num, alphai) - { - } }; #endif /* PROTOCOLS_BRAINSHARE_H_ */ diff --git a/Protocols/ChaiGearPrep.hpp b/Protocols/ChaiGearPrep.hpp index a35a07a5..667915f8 100644 --- a/Protocols/ChaiGearPrep.hpp +++ b/Protocols/ChaiGearPrep.hpp @@ -105,6 +105,7 @@ typename ChaiGearPrep::Generator& ChaiGearPrep::get_generator() if (machine == 0) basic_setup(P); key_setup(P, proc->MC.get_alphai()); + BaseMachine::add_one_off(P.total_comm()); } lock.unlock(); if (generator == 0) diff --git a/Protocols/CowGearPrep.hpp b/Protocols/CowGearPrep.hpp index 0b4e3733..e6e73322 100644 --- a/Protocols/CowGearPrep.hpp +++ b/Protocols/CowGearPrep.hpp @@ -107,6 +107,7 @@ PairwiseGenerator& CowGearPrep::get_generator() setup(P, proc->MC.get_alphai()); else key_setup(P, proc->MC.get_alphai()); + BaseMachine::add_one_off(P.total_comm()); } lock.unlock(); if (pairwise_generator == 0) diff --git a/Protocols/DealerInput.h b/Protocols/DealerInput.h index 48f04538..9b427725 100644 --- a/Protocols/DealerInput.h +++ b/Protocols/DealerInput.h @@ -13,12 +13,17 @@ template class DealerInput : public InputBase { Player& P; - octetStreams to_send, to_receive; SeededPRNG G; - vector> shares; bool from_dealer; AllButLastPlayer sub_player; SemiInput>* internal; + int king; + vector dealer_prngs; + vector dealer_random_prngs; + PRNG non_dealer_prng, non_dealer_random_prng; + octetStream os; + IteratorVector shares, random_shares; + int my_num; public: DealerInput(SubProcessor& proc, typename T::MAC_Check&); @@ -28,13 +33,26 @@ public: DealerInput(SubProcessor*, Player& P); ~DealerInput(); + int dealer_player(); bool is_dealer(int player = -1); + bool is_king(); void reset(int player); void add_mine(const typename T::open_type& input, int n_bits = -1); void add_other(int player, int n_bits = -1); + void add_from_dealer(const typename T::open_type& input); + void add_from_dealer(const vector& input); + void add_n_from_dealer(size_t n_inputs, bool random = false); + typename T::open_type random_for_dealer(); void exchange(); T finalize(int player, int n_bits = -1); + T finalize_from_dealer(); + template + T finalize_no_check(); + template + array finalize_no_check(); + T finalize_random(); + void require(size_t n_inputs); }; #endif /* PROTOCOLS_DEALERINPUT_H_ */ diff --git a/Protocols/DealerInput.hpp b/Protocols/DealerInput.hpp index b4f67a66..d14859d0 100644 --- a/Protocols/DealerInput.hpp +++ b/Protocols/DealerInput.hpp @@ -30,13 +30,13 @@ DealerInput::DealerInput(Player& P) : template DealerInput::DealerInput(SubProcessor* proc, Player& P) : InputBase(proc), - P(P), to_send(P), shares(P.num_players()), from_dealer(false), - sub_player(P) + P(P), from_dealer(false), sub_player(P), king(0) { if (is_dealer()) internal = 0; else internal = new SemiInput>(0, sub_player); + my_num = P.my_num(); } template @@ -46,26 +46,69 @@ DealerInput::~DealerInput() delete internal; } +template +int DealerInput::dealer_player() +{ + return P.num_players() - 1; +} + template bool DealerInput::is_dealer(int player) { - int dealer_player = P.num_players() - 1; + int dealer_player = this->dealer_player(); if (player == -1) return P.my_num() == dealer_player; else return player == dealer_player; } +template +bool DealerInput::is_king() +{ + assert(not is_dealer()); + return king == P.my_num(); +} + template void DealerInput::reset(int player) { - if (player == 0) + if (is_dealer(player)) { - to_send.reset(P); + octetStreams to_send, to_receive; + vector senders(P.num_players()); + senders.back() = true; + + if (is_dealer()) + { + to_send.reset(P); + if (dealer_prngs.empty()) + { + dealer_prngs.resize(P.num_players() - 1); + dealer_random_prngs.resize(dealer_prngs.size()); + for (int i = 0; i < P.num_players() - 1; i++) + { + to_send[i].append(dealer_prngs[i].get_seed(), SEED_SIZE); + dealer_random_prngs[i].SetSeed(dealer_prngs[i]); + } + P.send_receive_all(senders, to_send, to_receive); + } + } + else if (not non_dealer_prng.is_initialized()) + { + P.send_receive_all(senders, to_send, to_receive); + non_dealer_prng.SetSeed( + to_receive.at(dealer_player()).consume(SEED_SIZE)); + non_dealer_random_prng.SetSeed(non_dealer_prng); + } + + os.reset_write_head(); + shares.clear(); + random_shares.clear(); from_dealer = false; + king = (king + 1) % (P.num_players() - 1); } else if (not is_dealer()) - internal->reset(player - 1); + internal->reset(player); } template @@ -74,10 +117,7 @@ void DealerInput::add_mine(const typename T::open_type& input, { if (is_dealer()) { - make_share(shares.data(), input, P.num_players() - 1, 0, G); - for (int i = 0; i < P.num_players() - 1; i++) - shares.at(i).pack(to_send[i]); - from_dealer = true; + add_from_dealer(input); } else internal->add_mine(input); @@ -87,20 +127,80 @@ template void DealerInput::add_other(int player, int) { if (is_dealer(player)) - from_dealer = true; + add_n_from_dealer(1); else if (not is_dealer()) internal->add_other(player); } +template +void DealerInput::add_from_dealer(const typename T::open_type& input) +{ + add_from_dealer(vector({input})); +} + +template +void DealerInput::add_from_dealer(const vector& inputs) +{ + int n = P.num_players() - 1; + os.reserve(inputs.size() * T::open_type::size()); + + for (auto& input : inputs) + { + auto rest = input; + for (int i = 0; i < n; i++) + if (i != king) + { + auto r = dealer_prngs[i].template get(); + rest -= r; + } + + os.append_no_resize((octet*) rest.get_ptr(), + T::open_type::size()); + } + + from_dealer = true; +} + +template +void DealerInput::add_n_from_dealer(size_t n_inputs, bool random) +{ + if (random) + for (size_t i = 0; i < n_inputs; i++) + random_shares.push_back(non_dealer_random_prng.get()); + else + { + if (my_num != king) + for (size_t i = 0; i < n_inputs; i++) + shares.push_back(non_dealer_prng.get()); + from_dealer = true; + } +} + +template +typename T::open_type DealerInput::random_for_dealer() +{ + T res; + for (auto& prng : dealer_random_prngs) + { + auto share = prng.template get(); + res += share; + } + return res; +} + template void DealerInput::exchange() { CODE_LOCATION if (from_dealer) { - vector senders(P.num_players()); - senders.back() = true; - P.send_receive_all(senders, to_send, to_receive); + if (is_dealer()) + P.send_to(king, os); + else if (P.my_num() == king) + P.receive_player(dealer_player(), os); + else + shares.reset(); + random_shares.reset(); } else if (not is_dealer()) internal->exchange(); @@ -114,10 +214,64 @@ T DealerInput::finalize(int player, int) else { if (is_dealer(player)) - return to_receive.back().template get(); + return finalize_from_dealer(); else return internal->finalize(player); } } +template +T DealerInput::finalize_from_dealer() +{ + if (king == P.my_num()) + return os.get(); + else + return shares.next(); +} + +template +T DealerInput::finalize_random() +{ + return random_shares.next(); +} + +template +void DealerInput::require(size_t n_inputs) +{ + assert(not is_dealer()); + + if (my_num == king) + os.require(n_inputs); + else + shares.require(n_inputs); +} + +template +template +T DealerInput::finalize_no_check() +{ + if (RANDOM) + { + return finalize_random(); + } + + if (IS_KING) + { + return os.get_no_check(); + } + else + return shares.next(); +} + +template +template +array DealerInput::finalize_no_check() +{ + array res; + for (int i = 0; i < N - 1; i++) + res[i] = finalize_no_check(); + res[N - 1] = finalize_no_check(); + return res; +} + #endif /* PROTOCOLS_DEALERINPUT_HPP_ */ diff --git a/Protocols/DealerMatrixPrep.hpp b/Protocols/DealerMatrixPrep.hpp index 3162a012..da829061 100644 --- a/Protocols/DealerMatrixPrep.hpp +++ b/Protocols/DealerMatrixPrep.hpp @@ -16,29 +16,21 @@ DealerMatrixPrep::DealerMatrixPrep(int n_rows, int n_inner, int n_cols, } template -void append_shares(vector& os, - ValueMatrix& M, PRNG& G) +void append(vector& values, ValueMatrix& M) { - size_t n = os.size(); - for (auto& value : M.entries) - { - T sum; - for (size_t i = 0; i < n - 2; i++) - { - auto share = G.get(); - sum += share; - share.pack(os[i]); - } - (value - sum).pack(os[n - 2]); - } + values.insert(values.end(), M.entries.begin(), M.entries.end()); } template -ShareMatrix receive_shares(octetStream& o, int n, int m) +ShareMatrix receive(DealerInput& input, int n, int m, bool random) { ShareMatrix res(n, m); - for (size_t i = 0; i < res.entries.size(); i++) - res.entries.v.push_back(o.get()); + if (random) + for (size_t i = 0; i < res.entries.size(); i++) + res.entries.v.push_back(input.finalize_random()); + else + for (size_t i = 0; i < res.entries.size(); i++) + res.entries.v.push_back(input.finalize_from_dealer()); return res; } @@ -48,44 +40,49 @@ void DealerMatrixPrep::buffer_triples() CODE_LOCATION assert(this->prep); assert(this->prep->proc); - auto& P = this->prep->proc->P; - vector senders(P.num_players()); - senders.back() = true; - octetStreams os(P), to_receive(P); + auto& input = this->prep->proc->input; + input.reset(input.dealer_player()); int batch_size = BaseMachine::matrix_batch_size(n_rows, n_inner, n_cols); assert(batch_size > 0); - if (not T::real_shares(P)) + ValueMatrix A(n_rows, n_inner), B(n_inner, n_cols), + C(n_rows, n_cols); + size_t n_values = batch_size * C.entries.size(); + if (input.is_dealer()) { SeededPRNG G; - ValueMatrix A(n_rows, n_inner), B(n_inner, n_cols), - C(n_rows, n_cols); - for (int i = 0; i < P.num_players() - 1; i++) - os[i].reserve( - batch_size * T::size() - * (A.entries.size() + B.entries.size() - + C.entries.size())); + vector values; + values.reserve(n_values); for (int i = 0; i < batch_size; i++) { - A.randomize(G); - B.randomize(G); + A.entries.v.clear(); + B.entries.v.clear(); + for (size_t j = 0; j < A.entries.size(); j++) + A.entries.v.push_back(input.random_for_dealer()); + for (size_t j = 0; j < B.entries.size(); j++) + B.entries.v.push_back(input.random_for_dealer()); C = A * B; - append_shares(os, A, G); - append_shares(os, B, G); - append_shares(os, C, G); + append(values, C); this->triples.push_back({{{n_rows, n_inner}, {n_inner, n_cols}, {n_rows, n_cols}}}); } - P.send_receive_all(senders, os, to_receive); + input.add_from_dealer(values); } else { - P.send_receive_all(senders, os, to_receive); + input.add_n_from_dealer(n_values); + input.add_n_from_dealer( + batch_size * (A.entries.size() + B.entries.size()), true); + } + + input.exchange(); + + if (not input.is_dealer()) + { for (int i = 0; i < batch_size; i++) { - auto& o = to_receive.back(); - this->triples.push_back({{receive_shares(o, n_rows, n_inner), - receive_shares(o, n_inner, n_cols), - receive_shares(o, n_rows, n_cols)}}); + this->triples.push_back({{receive(input, n_rows, n_inner, true), + receive(input, n_inner, n_cols, true), + receive(input, n_rows, n_cols, false)}}); } } } diff --git a/Protocols/DealerPrep.h b/Protocols/DealerPrep.h index 459d6dfd..325bb4cb 100644 --- a/Protocols/DealerPrep.h +++ b/Protocols/DealerPrep.h @@ -9,11 +9,15 @@ #include "ReplicatedPrep.h" #include "DealerMatrixPrep.h" +template class DealerInput; + template class DealerPrep : virtual public BitPrep { friend class DealerMatrixPrep; + DealerInput* bit_input_; + template void buffer_inverses(true_type); template @@ -24,12 +28,19 @@ class DealerPrep : virtual public BitPrep template void buffer_edabits(int n_bits, false_type); + template + void finalize(vector>& items, size_t n_items); + + DealerInput& get_bit_input(); + public: DealerPrep(SubProcessor* proc, DataPositions& usage) : - BufferPrep(usage), BitPrep(proc, usage) + BufferPrep(usage), BitPrep(proc, usage), bit_input_(0) { } + ~DealerPrep(); + void buffer_triples(); void buffer_inverses(); void buffer_bits(); diff --git a/Protocols/DealerPrep.hpp b/Protocols/DealerPrep.hpp index 9919be49..524be99f 100644 --- a/Protocols/DealerPrep.hpp +++ b/Protocols/DealerPrep.hpp @@ -9,43 +9,74 @@ #include "DealerPrep.h" #include "GC/SemiSecret.h" +template +DealerPrep::~DealerPrep() +{ + if (bit_input_) + delete bit_input_; +} + template void DealerPrep::buffer_triples() { CODE_LOCATION assert(this->proc); - auto& P = this->proc->P; - vector senders(P.num_players()); - senders.back() = true; - octetStreams os(P), to_receive(P); int buffer_size = BaseMachine::batch_size(DATA_TRIPLE, this->buffer_size); - if (this->proc->input.is_dealer()) + auto& input = this->proc->input; + input.reset(input.dealer_player()); + if (input.is_dealer()) { SeededPRNG G; - vector> shares(P.num_players() - 1); + vector to_share; + to_share.reserve(3 * buffer_size); for (int i = 0; i < buffer_size; i++) { T triples[3]; for (int i = 0; i < 2; i++) - triples[i] = G.get(); + triples[i] = input.random_for_dealer(); triples[2] = triples[0] * triples[1]; - for (auto& value : triples) - { - make_share(shares.data(), typename T::clear(value), - P.num_players() - 1, 0, G); - for (int i = 1; i < P.num_players(); i++) - shares.at(i - 1).pack(os[i - 1]); - } + to_share.push_back(triples[2]); this->triples.push_back({}); } - P.send_receive_all(senders, os, to_receive); + input.add_from_dealer(to_share); } else { - P.send_receive_all(senders, os, to_receive); - for (int i = 0; i < buffer_size; i++) - this->triples.push_back(to_receive.back().get>().get()); + input.add_n_from_dealer(2 * buffer_size, true); + input.add_n_from_dealer(buffer_size, false); + } + + input.exchange(); + finalize<3, true>(this->triples, size_t(buffer_size)); +} + +template +template +void DealerPrep::finalize(vector >& items, size_t buffer_size) +{ + assert(this->proc); + auto& input = this->proc->input; + + if (not input.is_dealer()) + { + if (RANDOM) + input.require(buffer_size); + else + input.require(N * buffer_size); + + if (input.is_king()) + { + for (size_t i = 0; i < buffer_size; i++) + items.push_back( + input.template finalize_no_check()); + } + else + { + for (size_t i = 0; i < buffer_size; i++) + items.push_back( + input.template finalize_no_check()); + } } } @@ -73,10 +104,13 @@ void DealerPrep::buffer_inverses(true_type) senders.back() = true; octetStreams os(P), to_receive(P); int buffer_size = BaseMachine::batch_size(DATA_INVERSE); + auto& input = this->proc->input; + input.reset(input.dealer_player()); if (this->proc->input.is_dealer()) { SeededPRNG G; - vector> shares(P.num_players() - 1); + vector items; + items.reserve(2 * buffer_size); for (int i = 0; i < buffer_size; i++) { T tuple[2]; @@ -85,21 +119,19 @@ void DealerPrep::buffer_inverses(true_type) tuple[1] = tuple[0].invert(); for (auto& value : tuple) { - make_share(shares.data(), typename T::clear(value), - P.num_players() - 1, 0, G); - for (int i = 1; i < P.num_players(); i++) - shares.at(i - 1).pack(os[i - 1]); + items.push_back(value); } this->inverses.push_back({}); } - P.send_receive_all(senders, os, to_receive); + input.add_from_dealer(items); } else { - P.send_receive_all(senders, os, to_receive); - for (int i = 0; i < buffer_size; i++) - this->inverses.push_back(to_receive.back().get>().get()); + input.add_n_from_dealer(2 * buffer_size); } + + input.exchange(); + finalize(this->inverses, buffer_size); } template @@ -107,31 +139,33 @@ void DealerPrep::buffer_bits() { CODE_LOCATION assert(this->proc); - auto& P = this->proc->P; - vector senders(P.num_players()); - senders.back() = true; - octetStreams os(P), to_receive(P); + auto& input = this->proc->input; + input.reset(input.dealer_player()); int buffer_size = BaseMachine::batch_size(DATA_BIT); if (this->proc->input.is_dealer()) { SeededPRNG G; - vector> shares(P.num_players() - 1); + vector bits; + bits.reserve(buffer_size); for (int i = 0; i < buffer_size; i++) { T bit = G.get_bit(); - make_share(shares.data(), typename T::clear(bit), - P.num_players() - 1, 0, G); - for (int i = 1; i < P.num_players(); i++) - shares.at(i - 1).pack(os[i - 1]); + bits.push_back(bit); this->bits.push_back({}); } - P.send_receive_all(senders, os, to_receive); + input.add_from_dealer(bits); } else { - P.send_receive_all(senders, os, to_receive); + input.add_n_from_dealer(buffer_size); + } + + input.exchange(); + + if (not input.is_dealer()) + { for (int i = 0; i < buffer_size; i++) - this->bits.push_back(to_receive.back().get()); + this->bits.push_back(input.finalize_from_dealer()); } } @@ -140,39 +174,34 @@ void DealerPrep::buffer_dabits(ThreadQueues*) { CODE_LOCATION assert(this->proc); - auto& P = this->proc->P; - vector senders(P.num_players()); - senders.back() = true; - octetStreams os(P), to_receive(P); + auto& input = this->proc->input; + input.reset(input.dealer_player()); int buffer_size = BaseMachine::batch_size(DATA_DABIT); if (this->proc->input.is_dealer()) { SeededPRNG G; - vector> shares(P.num_players() - 1); - vector bit_shares(P.num_players() - 1); + vector values; for (int i = 0; i < buffer_size; i++) { auto bit = G.get_bit(); - make_share(shares.data(), typename T::clear(bit), - P.num_players() - 1, 0, G); - make_share(bit_shares.data(), typename T::bit_type::clear(bit), - P.num_players() - 1, 0, G); - for (int i = 1; i < P.num_players(); i++) - { - shares.at(i - 1).pack(os[i - 1]); - bit_shares.at(i - 1).pack(os[i - 1]); - } + values.push_back(bit); this->dabits.push_back({}); } - P.send_receive_all(senders, os, to_receive); + input.add_from_dealer(values); } else { - P.send_receive_all(senders, os, to_receive); + input.add_n_from_dealer(buffer_size); + } + + input.exchange(); + + if (not input.is_dealer()) + { for (int i = 0; i < buffer_size; i++) { - this->dabits.push_back({to_receive.back().get(), - to_receive.back().get()}); + auto a = input.finalize_from_dealer(); + this->dabits.push_back({a, a.get_bit(0)}); } } } @@ -206,54 +235,63 @@ void DealerPrep::buffer_edabits(int length, false_type) { CODE_LOCATION assert(this->proc); - auto& P = this->proc->P; - vector senders(P.num_players()); - senders.back() = true; - octetStreams os(P), to_receive(P); + auto& input = this->proc->input; + auto& bit_input = get_bit_input(); + input.reset(input.dealer_player()); + bit_input.reset(input.dealer_player()); int n_vecs = DIV_CEIL(BaseMachine::edabit_batch_size(length), edabitvec::MAX_SIZE); auto& buffer = this->edabits[{false, length}]; if (this->proc->input.is_dealer()) { SeededPRNG G; - vector> shares(P.num_players() - 1); - vector bit_shares(P.num_players() - 1); + vector all_as; + vector all_bs; + vector as; + vector bs; for (int i = 0; i < n_vecs; i++) { - vector as; - vector bs; plain_edabits(as, bs, length, G, edabitvec::MAX_SIZE); - for (auto& a : as) - { - make_share(shares.data(), a, P.num_players() - 1, 0, G); - for (int i = 1; i < P.num_players(); i++) - shares.at(i - 1).pack(os[i - 1]); - } - for (auto& b : bs) - { - make_share(bit_shares.data(), b, P.num_players() - 1, 0, G); - for (int i = 1; i < P.num_players(); i++) - bit_shares.at(i - 1).pack(os[i - 1]); - } + all_as.insert(all_as.end(), as.begin(), as.end()); + all_bs.insert(all_bs.end(), bs.begin(), bs.end()); buffer.push_back({}); buffer.back().a.resize(edabitvec::MAX_SIZE); buffer.back().b.resize(length); } - P.send_receive_all(senders, os, to_receive); + input.add_from_dealer(all_as); + bit_input.add_from_dealer(all_bs); } else { - P.send_receive_all(senders, os, to_receive); + input.add_n_from_dealer(edabitvec::MAX_SIZE * n_vecs); + bit_input.add_n_from_dealer(length * n_vecs); + } + + input.exchange(); + bit_input.exchange(); + + if (not input.is_dealer()) + { for (int i = 0; i < n_vecs; i++) { buffer.push_back({}); for (int j = 0; j < edabitvec::MAX_SIZE; j++) - buffer.back().a.push_back(to_receive.back().get()); + buffer.back().a.push_back(input.finalize_from_dealer()); for (int j = 0; j < length; j++) - buffer.back().b.push_back( - to_receive.back().get()); + buffer.back().b.push_back(bit_input.finalize_from_dealer()); } } } +template +DealerInput& DealerPrep::get_bit_input() +{ + assert(this->proc); + + if (not bit_input_) + bit_input_ = new DealerInput(this->proc->P); + + return *bit_input_; +} + #endif /* PROTOCOLS_DEALERPREP_HPP_ */ diff --git a/Protocols/FakeProtocol.h b/Protocols/FakeProtocol.h index dfc06a45..eaeaccaf 100644 --- a/Protocols/FakeProtocol.h +++ b/Protocols/FakeProtocol.h @@ -19,7 +19,8 @@ template class FakeShuffle { public: - typedef ShuffleStore store_type; + typedef vector> shuffle_type; + typedef ShuffleStore store_type; map stats; @@ -33,9 +34,9 @@ public: apply(a, n, unit_size, output_base, input_base, 0, 0); } - size_t generate(size_t, store_type& store) + void generate(size_t, shuffle_type& shuffle) { - return store.add(); + shuffle.push_back(vector(1l)); } void apply(StackedVector& a, size_t n, size_t unit_size, size_t output_base, @@ -59,20 +60,11 @@ public: throw runtime_error("inverse permutation not implemented"); }; - void apply_multiple(StackedVector &a, vector &sizes, vector &destinations, - vector &sources, - vector &unit_sizes, vector &handles, vector &reverses, - store_type&) { - const auto n_shuffles = sizes.size(); - assert(sources.size() == n_shuffles); - assert(destinations.size() == n_shuffles); - assert(unit_sizes.size() == n_shuffles); - assert(handles.size() == n_shuffles); - assert(reverses.size() == n_shuffles); - - for (size_t i = 0; i < n_shuffles; i++) { - this->apply(a, sizes[i], unit_sizes[i], destinations[i], sources[i], handles[i], reverses[i]); - } + void apply_multiple(StackedVector &a, vector>& shuffles) + { + for (auto &shuffle : shuffles) + this->apply(a, shuffle.size, shuffle.unit_size, shuffle.dest, + shuffle.source, 0, shuffle.reverse); } }; diff --git a/Protocols/Hemi.h b/Protocols/Hemi.h index 47a45a02..9fc86b9b 100644 --- a/Protocols/Hemi.h +++ b/Protocols/Hemi.h @@ -43,6 +43,8 @@ public: void matmulsm(SubProcessor& processor, MemoryPart& source, const Instruction& instruction); void conv2ds(SubProcessor& processor, const Instruction& instruction); + + TimerWithComm prep_time(); }; #endif /* PROTOCOLS_HEMI_H_ */ diff --git a/Protocols/Hemi.hpp b/Protocols/Hemi.hpp index a0e38859..6b23e12d 100644 --- a/Protocols/Hemi.hpp +++ b/Protocols/Hemi.hpp @@ -153,7 +153,7 @@ void Hemi::matmulsm(SubProcessor& processor, MemoryPart& source, if (not T::real_shares(processor.P)) { matrix_multiply(A, B, processor); - return; + continue; } for (int i = 0; i < resultNumberOfRows; i++) { @@ -355,4 +355,13 @@ void Conv2dTuple::run_matrix(SubProcessor& processor) } +template +TimerWithComm Hemi::prep_time() +{ + TimerWithComm res; + for (auto& prep : matrix_preps) + res += prep.second->prep_timer; + return res; +} + #endif /* PROTOCOLS_HEMI_HPP_ */ diff --git a/Protocols/HemiMatrixPrep.hpp b/Protocols/HemiMatrixPrep.hpp index 998703ee..6b67589e 100644 --- a/Protocols/HemiMatrixPrep.hpp +++ b/Protocols/HemiMatrixPrep.hpp @@ -105,8 +105,9 @@ void HemiMatrixPrep::buffer_triples() auto& FTD = prep->get_FTD(); auto& pk = prep->get_pk(); int n_matrices = minimum_batch(); + bool verbose = OnlineOptions::singleton.has_option("verbose_he"); - if (OnlineOptions::singleton.has_option("verbose_he")) + if (verbose) { fprintf(stderr, "creating %d %dx%d * %dx%d triples\n", n_matrices, n_rows, n_inner, n_inner, n_cols); @@ -152,6 +153,13 @@ void HemiMatrixPrep::buffer_triples() if (T::local_mul or OnlineOptions::singleton.direct) { + if (verbose) + { + fprintf(stderr, "broadcasting %zu ciphertexts\n", + diag.ciphertexts.size()); + fflush(stderr); + } + Bundle bundle(P); bundle.mine.store(diag.ciphertexts); P.unchecked_broadcast(bundle); @@ -163,6 +171,13 @@ void HemiMatrixPrep::buffer_triples() } else { + if (verbose) + { + fprintf(stderr, "summing %zu ciphertexts\n", + diag.ciphertexts.size()); + fflush(stderr); + } + others_ct.push_back(diag.ciphertexts); TreeSum().run(others_ct[0], P); } diff --git a/Protocols/HemiPrep.h b/Protocols/HemiPrep.h index 29599c5e..3981d652 100644 --- a/Protocols/HemiPrep.h +++ b/Protocols/HemiPrep.h @@ -43,6 +43,8 @@ public: static const FHE_PK& get_pk(); static const FD& get_FTD(); + static bool bits_from_dabits(); + HemiPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), RingPrep(proc, usage), diff --git a/Protocols/HemiPrep.hpp b/Protocols/HemiPrep.hpp index a10ba229..9173078a 100644 --- a/Protocols/HemiPrep.hpp +++ b/Protocols/HemiPrep.hpp @@ -52,6 +52,12 @@ const typename T::clear::FD& HemiPrep::get_FTD() return pairwise_machine->setup().FieldD; } +template +bool HemiPrep::bits_from_dabits() +{ + return SemiPrep::bits_from_dabits(); +} + template HemiPrep::~HemiPrep() @@ -82,6 +88,7 @@ vector*>& HemiPrep::get_multipliers() pairwise_machine->setup().covert_key_generation(P, *pairwise_machine, 1); pairwise_machine->enc_alphas.resize(1, pairwise_machine->pk); + BaseMachine::add_one_off(P.total_comm()); } lock.unlock(); @@ -144,8 +151,7 @@ void HemiPrep::buffer_bits() if (this->proc->P.num_players() == 2) { auto& prep = get_two_party_prep(); - prep.buffer_size = BaseMachine::batch_size(DATA_BIT, - this->buffer_size); + prep.buffer_size = this->buffer_size; prep.buffer_dabits(0); for (auto& x : prep.dabits) this->bits.push_back(x.first); diff --git a/Protocols/MalRepRingPrep.hpp b/Protocols/MalRepRingPrep.hpp index 947c838b..5503a629 100644 --- a/Protocols/MalRepRingPrep.hpp +++ b/Protocols/MalRepRingPrep.hpp @@ -106,7 +106,7 @@ void shuffle_triple_generation(vector>& triples, Player& P, RunningTimer timer; TripleShuffleSacrifice sacrifice; vector> check_triples; - int buffer_size = sacrifice.minimum_n_inputs(OnlineOptions::singleton.batch_size); + int buffer_size = sacrifice.minimum_n_inputs(sacrifice.batch_size()); // optimistic triple generation Replicated protocol(P); @@ -126,7 +126,20 @@ void shuffle_triple_generation(vector>& triples, Player& P, } template -TripleShuffleSacrifice::TripleShuffleSacrifice() +int TripleShuffleSacrifice::batch_size() +{ + // use this to avoid bucket size being too low + int trick_max = 10 * ShuffleSacrifice(3).minimum_n_outputs() + * T::default_length; + int res = BaseMachine::batch_size(DATA_TRIPLE, 0, trick_max); + if (res == trick_max) + res = BaseMachine::batch_size(DATA_TRIPLE); + return DIV_CEIL(res, T::default_length); +} + +template +TripleShuffleSacrifice::TripleShuffleSacrifice() : + ShuffleSacrifice(BaseMachine::bucket_size(batch_size())) { } @@ -136,12 +149,6 @@ TripleShuffleSacrifice::TripleShuffleSacrifice(int B, int C) : { } -template -TripleShuffleSacrifice::TripleShuffleSacrifice(DataFieldType type) : - ShuffleSacrifice(BaseMachine::bucket_size(type)) -{ -} - template void TripleShuffleSacrifice::triple_sacrifice(vector>& triples, vector>& check_triples, Player& P, diff --git a/Protocols/MalRepRingShare.h b/Protocols/MalRepRingShare.h index dc3e0cfd..a651f68d 100644 --- a/Protocols/MalRepRingShare.h +++ b/Protocols/MalRepRingShare.h @@ -40,10 +40,6 @@ public: MalRepRingShare() { } - MalRepRingShare(const T& other, int my_num, T alphai = {}) : - super(other, my_num, alphai) - { - } template MalRepRingShare(const U& other) : super(other) { diff --git a/Protocols/MaliciousRep3Share.h b/Protocols/MaliciousRep3Share.h index 718d8a69..93b55dd8 100644 --- a/Protocols/MaliciousRep3Share.h +++ b/Protocols/MaliciousRep3Share.h @@ -64,10 +64,6 @@ public: MaliciousRep3Share() { } - MaliciousRep3Share(const T& other, int my_num, T alphai = {}) : - super(other, my_num, alphai) - { - } template MaliciousRep3Share(const U& other) : super(other) { diff --git a/Protocols/MaliciousShamirShare.h b/Protocols/MaliciousShamirShare.h index a99b7e1c..515fcf97 100644 --- a/Protocols/MaliciousShamirShare.h +++ b/Protocols/MaliciousShamirShare.h @@ -63,9 +63,8 @@ public: { } template - MaliciousShamirShare(const U& other, int my_num = 0, T alphai = {}) : super(other) + MaliciousShamirShare(const U& other) : super(other) { - (void) my_num, (void) alphai; } }; diff --git a/Protocols/MascotPrep.hpp b/Protocols/MascotPrep.hpp index 91021cd3..7821cd59 100644 --- a/Protocols/MascotPrep.hpp +++ b/Protocols/MascotPrep.hpp @@ -93,6 +93,8 @@ void MascotDabitOnlyPrep::buffer_bits(false_type) { this->params.generateBits = true; auto& triple_generator = this->triple_generator; + triple_generator->set_batch_size( + BaseMachine::batch_size(DATA_BIT, this->buffer_size)); triple_generator->generate(); triple_generator->unlock(); assert(triple_generator->bits.size() != 0); @@ -105,6 +107,8 @@ void MascotInputPrep::buffer_inputs(int player) { auto& triple_generator = this->triple_generator; assert(triple_generator); + triple_generator->set_batch_size( + BaseMachine::input_batch_size(player, this->buffer_size)); triple_generator->generateInputs(player); if (this->inputs.size() <= (size_t)player) this->inputs.resize(player + 1); diff --git a/Protocols/PostSacriRepFieldShare.h b/Protocols/PostSacriRepFieldShare.h index cad03bdb..e8ac30d7 100644 --- a/Protocols/PostSacriRepFieldShare.h +++ b/Protocols/PostSacriRepFieldShare.h @@ -33,10 +33,6 @@ public: PostSacriRepFieldShare() { } - PostSacriRepFieldShare(const clear& other, int my_num, clear alphai = {}) : - super(other, my_num, alphai) - { - } template PostSacriRepFieldShare(const U& other) : super(other) { diff --git a/Protocols/PostSacriRepRingShare.h b/Protocols/PostSacriRepRingShare.h index 682eb744..45f4f22b 100644 --- a/Protocols/PostSacriRepRingShare.h +++ b/Protocols/PostSacriRepRingShare.h @@ -54,10 +54,6 @@ public: PostSacriRepRingShare() { } - PostSacriRepRingShare(const clear& other, int my_num, clear alphai = {}) : - super(other, my_num, alphai) - { - } template PostSacriRepRingShare(const U& other) : super(other) { diff --git a/Protocols/Rep3Share.h b/Protocols/Rep3Share.h index 8b6fbb72..c106be7d 100644 --- a/Protocols/Rep3Share.h +++ b/Protocols/Rep3Share.h @@ -160,7 +160,9 @@ public: static Rep3Share constant(T value, int my_num, typename super::mac_key_type = {}) { - return Rep3Share(value, my_num); + This res; + Replicated::assign(res, value, my_num); + return res; } Rep3Share() @@ -172,12 +174,6 @@ public: { } - Rep3Share(T value, int my_num, const T& alphai = {}) - { - (void) alphai; - Replicated::assign(*this, value, my_num); - } - void assign(const char* buffer) { FixedVec::assign(buffer); diff --git a/Protocols/Rep3Shuffler.h b/Protocols/Rep3Shuffler.h index b5f8a677..74e56bbe 100644 --- a/Protocols/Rep3Shuffler.h +++ b/Protocols/Rep3Shuffler.h @@ -21,17 +21,11 @@ private: public: map stats; - Rep3Shuffler(StackedVector& a, size_t n, int unit_size, size_t output_base, - size_t input_base, SubProcessor& proc); - Rep3Shuffler(SubProcessor& proc); - int generate(int n_shuffle, store_type& store); + void generate(int n_shuffle, shuffle_type& shuffle); - void apply_multiple(StackedVector& a, vector& sizes, vector& destinations, vector& sources, - vector& unit_sizes, vector& handles, vector& reverse, store_type& store); - void apply_multiple(StackedVector& a, vector& sizes, vector& destinations, vector& sources, - vector& unit_sizes, vector& shuffles, vector& reverse); + void apply_multiple(StackedVector& a, vector> &shuffles); void inverse_permutation(StackedVector& stack, size_t n, size_t output_base, size_t input_base); diff --git a/Protocols/Rep3Shuffler.hpp b/Protocols/Rep3Shuffler.hpp index 6705727b..c3af1e3f 100644 --- a/Protocols/Rep3Shuffler.hpp +++ b/Protocols/Rep3Shuffler.hpp @@ -8,29 +8,13 @@ #include "Rep3Shuffler.h" -template -Rep3Shuffler::Rep3Shuffler(StackedVector &a, size_t n, int unit_size, - size_t output_base, size_t input_base, SubProcessor &proc) : proc(proc) { - store_type store; - int handle = generate(n / unit_size, store); - - vector sizes{n}; - vector unit_sizes{static_cast(unit_size)}; - vector destinations{output_base}; - vector sources{input_base}; - vector shuffles{store.get(handle)}; - vector reverses{true}; - this->apply_multiple(a, sizes, destinations, sources, unit_sizes, shuffles, reverses); -} - template Rep3Shuffler::Rep3Shuffler(SubProcessor &proc) : proc(proc) { } template -int Rep3Shuffler::generate(int n_shuffle, store_type &store) { - int res = store.add(); - auto &shuffle = store.get(res); +void Rep3Shuffler::generate(int n_shuffle, shuffle_type& shuffle) +{ for (int i = 0; i < 2; i++) { auto &perm = shuffle[i]; for (int j = 0; j < n_shuffle; j++) @@ -40,34 +24,14 @@ int Rep3Shuffler::generate(int n_shuffle, store_type &store) { swap(perm[j], perm[k + j]); } } - return res; } template -void Rep3Shuffler::apply_multiple(StackedVector &a, vector &sizes, vector &destinations, - vector &sources, - vector &unit_sizes, vector &handles, vector &reverses, - store_type &store) { - vector shuffles; - for (size_t &handle: handles) { - shuffle_type &shuffle = store.get(handle); - shuffles.push_back(shuffle); - } - - apply_multiple(a, sizes, destinations, sources, unit_sizes, shuffles, reverses); -} - -template -void Rep3Shuffler::apply_multiple(StackedVector &a, vector &sizes, vector &destinations, - vector &sources, vector &unit_sizes, vector &shuffles, - vector &reverses) { +void Rep3Shuffler::apply_multiple(StackedVector &a, + vector> &shuffles) +{ CODE_LOCATION - const auto n_shuffles = sizes.size(); - assert(sources.size() == n_shuffles); - assert(destinations.size() == n_shuffles); - assert(unit_sizes.size() == n_shuffles); - assert(shuffles.size() == n_shuffles); - assert(reverses.size() == n_shuffles); + const auto n_shuffles = shuffles.size(); assert(proc.P.num_players() == 3); assert(not T::malicious); @@ -79,17 +43,17 @@ void Rep3Shuffler::apply_multiple(StackedVector &a, vector &sizes, vector > to_shuffle; for (size_t current_shuffle = 0; current_shuffle < n_shuffles; current_shuffle++) { - assert(sizes[current_shuffle] % unit_sizes[current_shuffle] == 0); + auto& shuffle = shuffles[current_shuffle]; + assert(shuffle.size % shuffle.unit_size == 0); vector x; - for (size_t j = 0; j < sizes[current_shuffle]; j++) - x.push_back(a[sources[current_shuffle] + j]); + for (size_t j = 0; j < shuffle.size; j++) + x.push_back(a[shuffle.source + j]); to_shuffle.push_back(x); - const auto &shuffle = shuffles[current_shuffle]; - if (shuffle.empty()) + if (shuffle.shuffle[0].empty()) throw runtime_error("shuffle has been deleted"); - stats[sizes[current_shuffle] / unit_sizes[current_shuffle]] += unit_sizes[current_shuffle]; + stats[shuffle.size / shuffle.unit_size] += shuffle.unit_size; } typename T::Input input(proc); @@ -98,10 +62,11 @@ void Rep3Shuffler::apply_multiple(StackedVector &a, vector &sizes, input.reset_all(proc.P); for (size_t current_shuffle = 0; current_shuffle < n_shuffles; current_shuffle++) { - const auto n = sizes[current_shuffle]; - const auto unit_size = unit_sizes[current_shuffle]; - const auto &shuffle = shuffles[current_shuffle]; - const auto reverse = reverses[current_shuffle]; + auto& shuffle_tuple = shuffles[current_shuffle]; + const size_t n = shuffle_tuple.size; + const size_t unit_size = shuffle_tuple.unit_size; + const auto reverse = shuffle_tuple.reverse; + auto& shuffle = shuffle_tuple.shuffle; const auto current_to_shuffle = to_shuffle[current_shuffle]; vector to_share(n); @@ -140,8 +105,9 @@ void Rep3Shuffler::apply_multiple(StackedVector &a, vector &sizes, to_shuffle.clear(); for (size_t current_shuffle = 0; current_shuffle < n_shuffles; current_shuffle++) { - const auto n = sizes[current_shuffle]; - const auto reverse = reverses[current_shuffle]; + auto& shuffle = shuffles[current_shuffle]; + const auto n = shuffle.size; + const auto reverse = shuffle.reverse; int i; if (reverse) @@ -159,10 +125,11 @@ void Rep3Shuffler::apply_multiple(StackedVector &a, vector &sizes, } for (size_t current_shuffle = 0; current_shuffle < n_shuffles; current_shuffle++) { - const auto n = sizes[current_shuffle]; + auto& shuffle = shuffles[current_shuffle]; + const auto n = shuffle.size; for (size_t i = 0; i < n; i++) - a[destinations[current_shuffle] + i] = to_shuffle[current_shuffle][i]; + a[shuffle.dest + i] = to_shuffle[current_shuffle][i]; } } diff --git a/Protocols/Rep4.h b/Protocols/Rep4.h index 31c57404..a39fd37a 100644 --- a/Protocols/Rep4.h +++ b/Protocols/Rep4.h @@ -39,6 +39,8 @@ class Rep4 : public ProtocolBase int my_num; + bool malicious; + array get_addshares(const T& x, const T& y); void reset_joint_input(int n_inputs); diff --git a/Protocols/Rep4.hpp b/Protocols/Rep4.hpp index 2886bb69..4048fbdc 100644 --- a/Protocols/Rep4.hpp +++ b/Protocols/Rep4.hpp @@ -24,6 +24,8 @@ Rep4::Rep4(Player& P) : for (int i = 1; i < 3; i++) rep_prngs[i].SetSeed(to_receive[P.get_player(i)].get_data()); + + malicious = not OnlineOptions::singleton.semi_honest; } template @@ -32,6 +34,8 @@ Rep4::Rep4(Player& P, prngs_type& prngs) : { for (int i = 0; i < 3; i++) rep_prngs[i].SetSeed(prngs[i]); + + malicious = not OnlineOptions::singleton.semi_honest; } template @@ -51,6 +55,9 @@ Rep4::~Rep4() template void Rep4::check() { + if (not malicious) + return; + for (auto& x : channels) for (auto y : x) if (y) @@ -122,7 +129,7 @@ void Rep4::prepare_joint_input(int sender, int backup, int receiver, } } - if (P.my_num() == backup) + if (P.my_num() == backup and malicious) { send_hashes[sender][receiver].update(inputs, bit_lengths); } @@ -188,8 +195,10 @@ void Rep4::finalize_joint_input(int sender, int backup, int receiver, } os->consume(0); - receive_hashes[sender][backup].update(start, - os->get_data_ptr() - start); + + if (malicious) + receive_hashes[sender][backup].update(start, + os->get_data_ptr() - start); } } @@ -389,15 +398,22 @@ void Rep4::trunc_pr(const vector& regs, int size, for (auto& c : cs) (c[1] + c[0]).pack(c_os); P.send_to(2 + P.my_num(), c_os); - P.send_to(3 - P.my_num(), c_os.hash()); + + if (malicious) + P.send_to(3 - P.my_num(), c_os.hash()); } else { P.receive_player(P.my_num() - 2, c_os); - octetStream hash; - P.receive_player(3 - P.my_num(), hash); - if (hash != c_os.hash()) - throw runtime_error("hash mismatch in joint message passing"); + + if (malicious) + { + octetStream hash; + P.receive_player(3 - P.my_num(), hash); + if (hash != c_os.hash()) + throw runtime_error("hash mismatch in joint message passing"); + } + PointerVector open_cs; if (P.my_num() == 2) for (auto& c : cs) diff --git a/Protocols/Rep4Input.hpp b/Protocols/Rep4Input.hpp index 34b09dc3..6890b96b 100644 --- a/Protocols/Rep4Input.hpp +++ b/Protocols/Rep4Input.hpp @@ -72,8 +72,10 @@ void Rep4Input::exchange() { P.pass_around(to_send, to_receive[0], -1); P.pass_around(to_send, to_receive[1], 1); - for (int i = 0; i < 2; i++) - hashes[i].update(to_receive[i]); + + if (not OnlineOptions::singleton.semi_honest) + for (int i = 0; i < 2; i++) + hashes[i].update(to_receive[i]); } template diff --git a/Protocols/Rep4MC.hpp b/Protocols/Rep4MC.hpp index c4991eaf..482eadf2 100644 --- a/Protocols/Rep4MC.hpp +++ b/Protocols/Rep4MC.hpp @@ -12,13 +12,20 @@ template void Rep4MC::exchange(const Player& P) { CODE_LOCATION + + bool malicious = not OnlineOptions::singleton.semi_honest; + octetStream right, tmp; for (auto& secret : this->secrets) { secret[0].pack(right); - secret[2].pack(tmp); + if (malicious) + secret[2].pack(tmp); } - check_hash.update(tmp); + + if (malicious) + check_hash.update(tmp); + P.pass_around(right, 1); this->values.resize(this->secrets.size()); for (size_t i = 0; i < this->secrets.size(); i++) @@ -27,7 +34,9 @@ void Rep4MC::exchange(const Player& P) a.unpack(right); this->values[i] = this->secrets[i].sum() + a; } - receive_hash.update(right); + + if (malicious) + receive_hash.update(right); if (OnlineOptions::singleton.has_option("always_check")) Check(P); diff --git a/Protocols/Rep4Share.h b/Protocols/Rep4Share.h index d1e383ff..fde34acf 100644 --- a/Protocols/Rep4Share.h +++ b/Protocols/Rep4Share.h @@ -42,6 +42,7 @@ public: static const bool malicious = true; static const bool variable_players = false; + static const bool semi_honest_option = true; static string type_short() { diff --git a/Protocols/Replicated.h b/Protocols/Replicated.h index 14ac3eb9..8ee0bd69 100644 --- a/Protocols/Replicated.h +++ b/Protocols/Replicated.h @@ -168,6 +168,8 @@ public: virtual void set_fast_mode(bool) {} double randomness_time() { return 0; } + + TimerWithComm prep_time() { return {}; } }; /** diff --git a/Protocols/Replicated.hpp b/Protocols/Replicated.hpp index ef076303..a0dae65a 100644 --- a/Protocols/Replicated.hpp +++ b/Protocols/Replicated.hpp @@ -411,7 +411,7 @@ template ReplicatedInput& Replicated::get_helper_input(size_t i) { while (i >= helper_inputs.size()) - helper_inputs.push_back(new ReplicatedInput(P)); + helper_inputs.push_back(new ReplicatedInput(0, *this)); return *helper_inputs.at(i); } diff --git a/Protocols/ReplicatedMC.hpp b/Protocols/ReplicatedMC.hpp index d0d8af30..bbfd9ab3 100644 --- a/Protocols/ReplicatedMC.hpp +++ b/Protocols/ReplicatedMC.hpp @@ -12,6 +12,7 @@ template void ReplicatedMC::POpen(vector& values, const vector& S, const Player& P) { + CODE_LOCATION prepare(S); P.pass_around(to_send, o, -1); finalize(values, S); @@ -21,6 +22,7 @@ template void ReplicatedMC::POpen_Begin(vector&, const vector& S, const Player& P) { + CODE_LOCATION prepare(S); P.send_relative(-1, to_send); } diff --git a/Protocols/ReplicatedPrep.hpp b/Protocols/ReplicatedPrep.hpp index c3a06f2d..d9a05c80 100644 --- a/Protocols/ReplicatedPrep.hpp +++ b/Protocols/ReplicatedPrep.hpp @@ -249,6 +249,7 @@ void BitPrep::buffer_squares() auto proc = this->proc; auto buffer_size = BaseMachine::batch_size(DATA_SQUARE, this->buffer_size); + BufferScope scope(*this, buffer_size); assert(proc != 0); vector a_plus_b(buffer_size), as(buffer_size), cs(buffer_size); T b; @@ -278,6 +279,9 @@ void generate_squares(vector>& squares, int n_squares, n_squares = BaseMachine::batch_size(DATA_SQUARE, n_squares); assert(protocol != 0); squares.resize(n_squares); + if (OnlineOptions::singleton.has_option("verbose_square")) + fprintf(stderr, "generating %d random squares\n", n_squares); + BufferScope scope(*protocol, n_squares); protocol->init_mul(); for (size_t i = 0; i < squares.size(); i++) { @@ -714,6 +718,7 @@ void BitPrep::buffer_ring_bits_without_check(vector& bits, PRNG& G, buffer_bits_from_players(player_bits, G, *proc, this->base_player, buffer_size, 1); auto& prot = *protocol; + BufferScope scope(prot, buffer_size); XOR(bits, player_bits[0], player_bits[1], prot); for (int i = 2; i < n_relevant_players; i++) XOR(bits, bits, player_bits[i], prot); @@ -851,9 +856,8 @@ void RingPrep::buffer_dabits_without_check(vector>& dabits, Preprocessing&) { CODE_LOCATION -#ifdef VERBOSE_DABIT - fprintf(stderr, "generate daBits %lu to %lu\n", begin, end); -#endif + if (OnlineOptions::singleton.has_option("verbose_dabit")) + fprintf(stderr, "generate daBits %lu to %lu\n", begin, end); size_t buffer_size = end - begin; auto proc = this->proc; diff --git a/Protocols/SecureShuffle.h b/Protocols/SecureShuffle.h index 496584e6..71364cd4 100644 --- a/Protocols/SecureShuffle.h +++ b/Protocols/SecureShuffle.h @@ -12,13 +12,17 @@ using namespace std; #include "Tools/Lock.h" template class SubProcessor; +template class ShuffleTuple; template class ShuffleStore { +public: typedef T shuffle_type; + typedef pair store_type; - deque shuffles; +private: + deque shuffles; Lock store_lock; @@ -26,8 +30,8 @@ class ShuffleStore void unlock(); public: - int add(); - shuffle_type& get(int handle); + int add(unsigned n_shuffles); + store_type& get(int handle); void del(int handle); }; @@ -63,27 +67,31 @@ private: */ vector> configure(int config_player, vector* perm, int n); - int prep_multiple(StackedVector& a, vector &sizes, vector &sources, vector &unit_sizes, vector>& to_shuffle, vector &exact); - void finalize_multiple(StackedVector& a, vector& sizes, vector& unit_sizes, vector& destinations, vector& isExact, vector>& to_shuffle); + int prep_multiple(StackedVector &a, + vector> &shuffles, + vector> &to_shuffle, vector& is_exact); + void finalize_multiple(StackedVector &a, + vector> &shuffles, + vector> &to_shuffle, vector& isExact); - void parallel_waksman_round(size_t pass, int depth, bool inwards, vector>& toShuffle, vector& unit_sizes, vector& reverse, vector& shuffles); - vector> waksman_round_init(vector& toShuffle, size_t shuffle_unit_size, int depth, vector>& iter_waksman_config, bool inwards, bool reverse); + void parallel_waksman_round(size_t pass, int depth, bool inwards, + vector>& toShuffle, + vector>& shuffles); + + vector> waksman_round_init(vector &toShuffle, + size_t shuffle_unit_size, int depth, + const vector> &iter_waksman_config, bool inwards, + bool reverse); void waksman_round_finish(vector& toShuffle, size_t unit_size, vector> indices); public: map stats; - SecureShuffle(StackedVector& a, size_t n, int unit_size, - size_t output_base, size_t input_base, SubProcessor& proc); - SecureShuffle(SubProcessor& proc); - int generate(int n_shuffle, store_type& store); + void generate(int n_shuffle, shuffle_type& shuffle); - void apply_multiple(StackedVector& a, vector& sizes, vector& destinations, vector& sources, - vector& unit_sizes, vector& handles, vector& reverse, store_type& store); - void apply_multiple(StackedVector& a, vector& sizes, vector& destinations, vector& sources, - vector& unit_sizes, vector& shuffles, vector& reverse); + void apply_multiple(StackedVector &a, vector> &shuffles); /** * Calculate the secret inverse permutation of stack given secret permutation. diff --git a/Protocols/SecureShuffle.hpp b/Protocols/SecureShuffle.hpp index 5a01f568..631af0ea 100644 --- a/Protocols/SecureShuffle.hpp +++ b/Protocols/SecureShuffle.hpp @@ -13,6 +13,35 @@ #include #include +template +class ShuffleTuple +{ + typedef typename T::Protocol::Shuffler::shuffle_type shuffle_type; + +public: + size_t size, dest, source, unit_size; + const shuffle_type& shuffle; + bool reverse; + + ShuffleTuple(size_t size, size_t dest, size_t source, size_t unit_size, + const shuffle_type& shuffle, bool reverse) : + size(size), dest(dest), source(source), unit_size(unit_size), + shuffle(shuffle), reverse(reverse) + { + } + + ShuffleTuple(size_t size, size_t dest, size_t source, size_t unit_size, + const typename ShuffleStore::store_type& stored, + bool reverse) : + ShuffleTuple(size, dest, source, unit_size, stored.second, reverse) + { + if (stored.first == 0) + throw runtime_error("shuffle has been deleted"); + if (stored.first != size / unit_size) + throw runtime_error("wrong shuffle size"); + } +}; + template void ShuffleStore::lock() { @@ -26,17 +55,17 @@ void ShuffleStore::unlock() } template -int ShuffleStore::add() +int ShuffleStore::add(unsigned n_shuffles) { lock(); int res = shuffles.size(); - shuffles.push_back({}); + shuffles.push_back({n_shuffles, {}}); unlock(); return res; } template -typename ShuffleStore::shuffle_type& ShuffleStore::get(int handle) +typename ShuffleStore::store_type& ShuffleStore::get(int handle) { lock(); auto& res = shuffles.at(handle); @@ -59,42 +88,10 @@ SecureShuffle::SecureShuffle(SubProcessor& proc) : } template -SecureShuffle::SecureShuffle(StackedVector& a, size_t n, int unit_size, - size_t output_base, size_t input_base, SubProcessor& proc) : - proc(proc) +void SecureShuffle::apply_multiple(StackedVector &a, + vector>& shuffles) { - store_type store; - int handle = generate(n / unit_size, store); - - vector sizes{n}; - vector unit_sizes{static_cast(unit_size)}; - vector destinations{output_base}; - vector sources{input_base}; - vector shuffles{store.get(handle)}; - vector reverses{true}; - this->apply_multiple(a, sizes, destinations, sources, unit_sizes, shuffles, reverses); -} - -template -void SecureShuffle::apply_multiple(StackedVector& a, vector& sizes, vector& destinations, vector& sources, - vector& unit_sizes, vector& handles, vector& reverse, store_type& store) { - vector shuffles; - for (size_t &handle : handles) - shuffles.push_back(store.get(handle)); - - this->apply_multiple(a, sizes, destinations, sources, unit_sizes, shuffles, reverse); -} - -template -void SecureShuffle::apply_multiple(StackedVector &a, vector &sizes, vector &destinations, - vector &sources, vector &unit_sizes, vector &shuffles, vector &reverse) { CODE_LOCATION - const auto n_shuffles = sizes.size(); - assert(sources.size() == n_shuffles); - assert(destinations.size() == n_shuffles); - assert(unit_sizes.size() == n_shuffles); - assert(shuffles.size() == n_shuffles); - assert(reverse.size() == n_shuffles); // SecureShuffle works by making t players create and "secret-share" a permutation. // Then each permutation is applied in a pass. As long as one of these permutations was created by an honest party, @@ -102,21 +99,21 @@ void SecureShuffle::apply_multiple(StackedVector &a, vector &sizes const auto n_passes = proc.protocol.get_relevant_players().size(); // Initialize the shuffles. - vector is_exact(n_shuffles, false); + vector is_exact(shuffles.size(), false); vector> to_shuffle; - int max_depth = prep_multiple(a, sizes, sources, unit_sizes, to_shuffle, is_exact); + int max_depth = prep_multiple(a, shuffles, to_shuffle, is_exact); // Apply the shuffles. for (size_t pass = 0; pass < n_passes; pass++) { for (int depth = 0; depth < max_depth; depth++) - parallel_waksman_round(pass, depth, true, to_shuffle, unit_sizes, reverse, shuffles); + parallel_waksman_round(pass, depth, true, to_shuffle, shuffles); for (int depth = max_depth - 1; depth >= 0; depth--) - parallel_waksman_round(pass, depth, false, to_shuffle, unit_sizes, reverse, shuffles); + parallel_waksman_round(pass, depth, false, to_shuffle, shuffles); } // Write the shuffled results into memory. - finalize_multiple(a, sizes, unit_sizes, destinations, is_exact, to_shuffle); + finalize_multiple(a, shuffles, to_shuffle, is_exact); } @@ -137,34 +134,33 @@ void SecureShuffle::inverse_permutation(StackedVector &stack, size_t n, si if (T::malicious) throw runtime_error("inverse permutation only implemented for semi-honest protocols"); - vector sizes { n }; - vector unit_sizes { 1 }; // We are dealing directly with permutations, so the unit_size will always be 1. - vector destinations { output_base }; - vector sources { input_base }; - vector reverse { true }; vector> to_shuffle; vector is_exact(1, false); - prep_multiple(stack, sizes, sources, unit_sizes, to_shuffle, is_exact); + // We are dealing directly with permutations, so the unit_size will always be 1. + shuffle_type shuffle; + vector> shuffles({{n, output_base, input_base, 1, shuffle, true}}); - size_t shuffle_size = to_shuffle[0].size() / unit_sizes[0]; + prep_multiple(stack, shuffles, to_shuffle, is_exact); + + size_t shuffle_size = to_shuffle.at(0).size() / shuffles.at(0).unit_size; // Alice generates stack local permutation and shares the waksman configuration bits secretly to Bob. vector perm_alice(shuffle_size); if (P.my_num() == alice) { perm_alice = generate_random_permutation(n); } auto config = configure(alice, &perm_alice, n); - vector shuffles {{ config, config }}; + shuffle = { config, config }; // Apply perm_alice to perm_alice to get perm_bob, // stack permutation that we can reveal to Bob without Bob learning anything about perm_alice (since it is masked by perm_a) for (int depth = 0; depth < log2(shuffle_size); depth++) - parallel_waksman_round(0, depth, true, to_shuffle, unit_sizes, reverse, shuffles); + parallel_waksman_round(0, depth, true, to_shuffle, shuffles); for (int depth = log2(shuffle_size); depth >= 0; depth--) - parallel_waksman_round(0, depth, false, to_shuffle, unit_sizes, reverse, shuffles); + parallel_waksman_round(0, depth, false, to_shuffle, shuffles); // Store perm_bob at stack[output_base] - finalize_multiple(stack, sizes, unit_sizes, destinations, is_exact, to_shuffle); + finalize_multiple(stack, shuffles, to_shuffle, is_exact); // Reveal permutation perm_bob = perm_a * perm_alice // Since this permutation is masked by perm_a, Bob learns nothing about perm @@ -203,35 +199,40 @@ void SecureShuffle::inverse_permutation(StackedVector &stack, size_t n, si // The two parties now jointly compute perm_a * perm_bob_inv to obtain perm_inv to_shuffle.clear(); - prep_multiple(stack, sizes, destinations, unit_sizes, to_shuffle, is_exact); + shuffles.at(0).source = shuffles.at(0).dest; + prep_multiple(stack, shuffles, to_shuffle, is_exact); config = configure(bob, &perm_bob_inv, n); - shuffles[0] = { config, config }; + shuffle = { config, config }; for (int i = 0; i < log2(shuffle_size); i++) - parallel_waksman_round(0, i, true, to_shuffle, unit_sizes, reverse, shuffles); + parallel_waksman_round(0, i, true, to_shuffle, shuffles); for (int i = log2(shuffle_size) - 2; i >= 0; i--) - parallel_waksman_round(0, i, false, to_shuffle, unit_sizes, reverse, shuffles); + parallel_waksman_round(0, i, false, to_shuffle, shuffles); // Store perm_bob at stack[output_base] - finalize_multiple(stack, sizes, unit_sizes, destinations, is_exact, to_shuffle); + finalize_multiple(stack, shuffles, to_shuffle, is_exact); } template -int SecureShuffle::prep_multiple(StackedVector &a, vector &sizes, - vector &sources, vector &unit_sizes, vector> &to_shuffle, vector &is_exact) { +int SecureShuffle::prep_multiple(StackedVector &a, + vector> &shuffles, vector> &to_shuffle, + vector &is_exact) +{ int max_depth = 0; - const size_t n_shuffles = sizes.size(); + const size_t n_shuffles = shuffles.size(); for (size_t currentShuffle = 0; currentShuffle < n_shuffles; currentShuffle++) { - const size_t input_base = sources[currentShuffle]; - const size_t n = sizes[currentShuffle]; - const size_t unit_size = unit_sizes[currentShuffle]; + auto& shuffle = shuffles[currentShuffle]; + const size_t input_base = shuffle.source; + const size_t n = shuffle.size; + const size_t unit_size = shuffle.unit_size; assert(n % unit_size == 0); const size_t n_shuffle = n / unit_size; - const size_t n_shuffle_pow2 = (1u << int(ceil(log2(n_shuffle)))); + const int shuffle_depth = ceil(log2(n_shuffle)); + const size_t n_shuffle_pow2 = 1u << shuffle_depth; const bool exact = (n_shuffle_pow2 == n_shuffle) or not T::malicious; vector tmp; @@ -255,13 +256,12 @@ int SecureShuffle::prep_multiple(StackedVector &a, vector &sizes, } for (size_t i = n_shuffle * shuffle_unit_size; i < tmp.size(); i++) tmp[i] = T::constant(0, proc.P.my_num(), proc.MC.get_alphai()); - unit_sizes[currentShuffle] = shuffle_unit_size; + shuffle.unit_size = shuffle_unit_size; } to_shuffle.push_back(tmp); is_exact[currentShuffle] = exact; - const int shuffle_depth = tmp.size() / unit_size; if (shuffle_depth > max_depth) max_depth = shuffle_depth; } @@ -270,13 +270,16 @@ int SecureShuffle::prep_multiple(StackedVector &a, vector &sizes, } template -void SecureShuffle::finalize_multiple(StackedVector &a, vector &sizes, vector &unit_sizes, - vector &destinations, vector &isExact, vector> &to_shuffle) { - const size_t n_shuffles = sizes.size(); +void SecureShuffle::finalize_multiple(StackedVector &a, + vector> &shuffles, vector> &to_shuffle, + vector &isExact) +{ + const size_t n_shuffles = shuffles.size(); for (size_t currentShuffle = 0; currentShuffle < n_shuffles; currentShuffle++) { - const size_t n = sizes[currentShuffle]; - const size_t shuffled_unit_size = unit_sizes[currentShuffle]; - const size_t output_base = destinations[currentShuffle]; + auto& shuffle = shuffles[currentShuffle]; + const size_t n = shuffle.size; + const size_t shuffled_unit_size = shuffle.unit_size; + const size_t output_base = shuffle.dest; const vector& shuffledData = to_shuffle[currentShuffle]; @@ -334,11 +337,8 @@ vector SecureShuffle::generate_random_permutation(int n) { } template -int SecureShuffle::generate(int n_shuffle, store_type& store) +void SecureShuffle::generate(int n_shuffle, shuffle_type& shuffle) { - int res = store.add(); - auto& shuffle = store.get(res); - for (auto i: proc.protocol.get_relevant_players()) { vector perm; if (proc.input.is_me(i)) @@ -346,8 +346,6 @@ int SecureShuffle::generate(int n_shuffle, store_type& store) auto config = configure(i, &perm, n_shuffle); shuffle.push_back(config); } - - return res; } template @@ -418,8 +416,9 @@ vector> SecureShuffle::configure(int config_player, vector *pe } template -void SecureShuffle::parallel_waksman_round(size_t pass, int depth, bool inwards, vector> &toShuffle, - vector &unit_sizes, vector &reverse, vector &shuffles) +void SecureShuffle::parallel_waksman_round(size_t pass, int depth, + bool inwards, vector> &toShuffle, + vector> &shuffles) { const auto n_passes = proc.protocol.get_relevant_players().size(); const auto n_shuffles = shuffles.size(); @@ -427,23 +426,26 @@ void SecureShuffle::parallel_waksman_round(size_t pass, int depth, bool inwar vector>> allIndices; proc.protocol.init_mul(); - for (size_t current_shuffle = 0; current_shuffle < n_shuffles; current_shuffle++) { - int n = toShuffle[current_shuffle].size() / unit_sizes[current_shuffle]; + for (size_t current_shuffle = 0; current_shuffle < n_shuffles; + current_shuffle++) + { + auto& shuffle = shuffles[current_shuffle]; + int n = toShuffle[current_shuffle].size() / shuffle.unit_size; if (depth >= log2(n) - !inwards) { allIndices.push_back({}); continue; } - const auto isReverse = reverse[current_shuffle]; + const auto isReverse = shuffle.reverse; size_t configIdx = pass; if (isReverse) configIdx = n_passes - pass - 1; - auto& config = shuffles[current_shuffle][configIdx]; + auto& config = shuffle.shuffle[configIdx]; vector> indices = waksman_round_init( toShuffle[current_shuffle], - unit_sizes[current_shuffle], + shuffle.unit_size, depth, config, inwards, @@ -453,16 +455,21 @@ void SecureShuffle::parallel_waksman_round(size_t pass, int depth, bool inwar } proc.protocol.exchange(); for (size_t current_shuffle = 0; current_shuffle < n_shuffles; current_shuffle++) { - int n = toShuffle[current_shuffle].size() / unit_sizes[current_shuffle]; + auto& shuffle = shuffles[current_shuffle]; + int n = toShuffle[current_shuffle].size() / shuffle.unit_size; if (depth >= log2(n) - !inwards) { continue; } - waksman_round_finish(toShuffle[current_shuffle], unit_sizes[current_shuffle], allIndices[current_shuffle]); + waksman_round_finish(toShuffle[current_shuffle], shuffle.unit_size, allIndices[current_shuffle]); } } template -vector> SecureShuffle::waksman_round_init(vector &toShuffle, size_t shuffle_unit_size, int depth, vector> &iter_waksman_config, bool inwards, bool reverse) { +vector> SecureShuffle::waksman_round_init(vector &toShuffle, + size_t shuffle_unit_size, int depth, + const vector> &iter_waksman_config, bool inwards, + bool reverse) +{ int n = toShuffle.size() / shuffle_unit_size; assert((int) iter_waksman_config.at(depth).size() == n); int n_blocks = 1 << depth; diff --git a/Protocols/Semi2kShare.h b/Protocols/Semi2kShare.h index d202c2a7..0bf3e640 100644 --- a/Protocols/Semi2kShare.h +++ b/Protocols/Semi2kShare.h @@ -45,11 +45,6 @@ public: Semi2kShare(const U& other) : SemiShare>(other) { } - Semi2kShare(const T& other, int my_num, const T& alphai = {}) - { - (void) alphai; - assign(other, my_num); - } template static void split(StackedVector& dest, const vector& regs, int n_bits, diff --git a/Protocols/SemiPrep.hpp b/Protocols/SemiPrep.hpp index c7203975..f3b20ada 100644 --- a/Protocols/SemiPrep.hpp +++ b/Protocols/SemiPrep.hpp @@ -26,7 +26,7 @@ void SemiPrep::buffer_triples() CODE_LOCATION assert(this->triple_generator); this->triple_generator->set_batch_size( - BaseMachine::batch_size(DATA_TRIPLE)); + BaseMachine::batch_size(DATA_TRIPLE, this->buffer_size)); this->triple_generator->generatePlainTriples(); this->triple_generator->set_batch_size(OnlineOptions::singleton.batch_size); for (auto& x : this->triple_generator->plainTriples) @@ -51,7 +51,7 @@ void SemiPrep::buffer_dabits(ThreadQueues* queues) CODE_LOCATION assert(this->triple_generator); this->triple_generator->set_batch_size( - BaseMachine::batch_size(DATA_DABIT, this->buffer_size)); + BaseMachine::batch_size(DATA_DABIT, this->buffer_size, 0, 10)); this->triple_generator->generatePlainBits(); for (auto& x : this->triple_generator->plainBits) this->dabits.push_back({x.first, x.second}); diff --git a/Protocols/SemiPrep2k.h b/Protocols/SemiPrep2k.h index 49ccca47..941677b5 100644 --- a/Protocols/SemiPrep2k.h +++ b/Protocols/SemiPrep2k.h @@ -37,6 +37,11 @@ public: throw runtime_error("no need for sacrifice"); } + static bool dabits_from_bits() + { + return not SemiPrep::bits_from_dabits(); + } + SemiPrep2k(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), OTPrep(proc, usage), diff --git a/Protocols/Shamir.hpp b/Protocols/Shamir.hpp index 26ccb126..9674c604 100644 --- a/Protocols/Shamir.hpp +++ b/Protocols/Shamir.hpp @@ -259,7 +259,8 @@ vector Shamir::get_randoms(PRNG& G, int t) random_input = new ShamirInput(0, P, threshold); auto& input = *random_input; input.reset_all(P); - auto buffer_size = this->buffer_size; + auto buffer_size = BaseMachine::batch_size(DATA_RANDOM, this->buffer_size); + assert(buffer_size > 0); if (OnlineOptions::singleton.has_option("verbose_random")) fprintf(stderr, "generating %d random elements\n", buffer_size); for (int i = 0; i < buffer_size; i += hyper.size()) @@ -280,6 +281,9 @@ vector Shamir::get_randoms(PRNG& G, int t) random.back() += hyper[j][k] * inputs[k]; } } + if (OnlineOptions::singleton.has_option("verbose_random")) + fprintf(stderr, "generated %zu random elements in %zu batches of %zu\n", + random.size(), random.size() / hyper.size(), hyper.size()); return random; } diff --git a/Protocols/Share.h b/Protocols/Share.h index 6e126da2..5b516e10 100644 --- a/Protocols/Share.h +++ b/Protocols/Share.h @@ -80,7 +80,7 @@ class Share_ : public ShareInterface static void set_mac_key(const mac_key_type& mac_key); static Share_ constant(const open_type& aa, int my_num, const typename V::Scalar& alphai) - { return Share_(aa, my_num, alphai); } + { Share_ res; res.assign(aa, my_num, alphai); return res; } template void assign(const Share_& S) @@ -94,8 +94,6 @@ class Share_ : public ShareInterface Share_() {} template Share_(const Share_& S) { assign(S); } - Share_(const open_type& aa, int my_num, const typename V::Scalar& alphai) - { assign(aa, my_num, alphai); } Share_(const T& share, const V& mac) : a(share), mac(mac) {} const T& get_share() const { return a; } diff --git a/Protocols/ShareInterface.h b/Protocols/ShareInterface.h index d781c71b..b3c75f12 100644 --- a/Protocols/ShareInterface.h +++ b/Protocols/ShareInterface.h @@ -60,6 +60,8 @@ public: static const int default_length = 1; + static const bool semi_honest_option = false; + static string type_short() { throw runtime_error("shorthand undefined"); } static string alt() { return ""; } diff --git a/Protocols/ShuffleSacrifice.h b/Protocols/ShuffleSacrifice.h index 2b24cd53..c397734f 100644 --- a/Protocols/ShuffleSacrifice.h +++ b/Protocols/ShuffleSacrifice.h @@ -60,9 +60,10 @@ template class TripleShuffleSacrifice : public ShuffleSacrifice { public: + static int batch_size(); + TripleShuffleSacrifice(); TripleShuffleSacrifice(int B, int C); - TripleShuffleSacrifice(DataFieldType type); void triple_sacrifice(vector>& triples, vector>& check_triples, Player& P, diff --git a/Protocols/SohoPrep.hpp b/Protocols/SohoPrep.hpp index e5ec4126..c6ba7682 100644 --- a/Protocols/SohoPrep.hpp +++ b/Protocols/SohoPrep.hpp @@ -45,6 +45,7 @@ void SohoPrep::buffer_triples() { PlainPlayer P(proc->P.N, "Soho" + T::type_string()); basic_setup(P); + BaseMachine::add_one_off(P.total_comm()); } lock.unlock(); @@ -87,6 +88,7 @@ void SohoPrep::buffer_squares() { PlainPlayer P(proc->P.N, "Soho" + T::type_string()); basic_setup(P); + BaseMachine::add_one_off(P.total_comm()); } lock.unlock(); diff --git a/Protocols/SpdzWise.hpp b/Protocols/SpdzWise.hpp index 98adb907..811be52d 100644 --- a/Protocols/SpdzWise.hpp +++ b/Protocols/SpdzWise.hpp @@ -124,7 +124,8 @@ void SpdzWise::check() internal.init_dotprod(); coefficients.clear(); - BufferScope _(internal, results.size()); + // need one extra in zero_check + BufferScope _(internal, results.size() + 1); for (auto& res : results) { @@ -165,10 +166,13 @@ void SpdzWise::buffer_random() // proxy for initialization assert(mac_key != 0); auto batch_size = this->buffer_size; + if (OnlineOptions::singleton.has_option("verbose_random")) + fprintf(stderr, "generating %d random elements\n", batch_size); vector rs; rs.reserve(batch_size); // cannot use member instance typename T::part_type::Honest::Protocol internal(P); + BufferScope scope(internal, batch_size); internal.init_mul(); for (int i = 0; i < batch_size; i++) { diff --git a/Protocols/SpdzWisePrep.h b/Protocols/SpdzWisePrep.h index 29202920..4fafdc76 100644 --- a/Protocols/SpdzWisePrep.h +++ b/Protocols/SpdzWisePrep.h @@ -19,6 +19,7 @@ class SpdzWisePrep : public MaliciousRingPrep typedef MaliciousRingPrep super; void buffer_triples(); + void buffer_squares(); void buffer_bits(); void buffer_inputs(int player); diff --git a/Protocols/SpdzWisePrep.hpp b/Protocols/SpdzWisePrep.hpp index a1ab95fd..d99ba14b 100644 --- a/Protocols/SpdzWisePrep.hpp +++ b/Protocols/SpdzWisePrep.hpp @@ -28,6 +28,15 @@ void SpdzWisePrep::buffer_triples() this->protocol); } +template +void SpdzWisePrep::buffer_squares() +{ + assert(this->protocol != 0); + generate_squares(this->squares, + BaseMachine::batch_size(DATA_SQUARE, this->buffer_size), + this->protocol); +} + template void SpdzWisePrep::buffer_bits(false_type, true_type, false_type) { diff --git a/Protocols/SpdzWiseRep3Shuffler.h b/Protocols/SpdzWiseRep3Shuffler.h index f61d2d3c..8f780065 100644 --- a/Protocols/SpdzWiseRep3Shuffler.h +++ b/Protocols/SpdzWiseRep3Shuffler.h @@ -23,17 +23,11 @@ public: map stats; - SpdzWiseRep3Shuffler(StackedVector& a, size_t n, int unit_size, size_t output_base, - size_t input_base, SubProcessor& proc); - SpdzWiseRep3Shuffler(SubProcessor& proc); - int generate(int n_shuffle, store_type& store); + void generate(int n_shuffle, shuffle_type& shuffle); - void apply_multiple(StackedVector& a, vector& sizes, vector& destinations, vector& sources, - vector& unit_sizes, vector& handles, vector& reverse, store_type& store); - void apply_multiple(StackedVector& a, vector& sizes, vector& destinations, vector& sources, - vector& unit_sizes, vector& shuffles, vector& reverse); + void apply_multiple(StackedVector& a, vector>& shuffles); void inverse_permutation(StackedVector& stack, size_t n, size_t output_base, size_t input_base); diff --git a/Protocols/SpdzWiseRep3Shuffler.hpp b/Protocols/SpdzWiseRep3Shuffler.hpp index 4bade1ce..eb4b9895 100644 --- a/Protocols/SpdzWiseRep3Shuffler.hpp +++ b/Protocols/SpdzWiseRep3Shuffler.hpp @@ -5,24 +5,6 @@ #include "SpdzWiseRep3Shuffler.h" -template -SpdzWiseRep3Shuffler::SpdzWiseRep3Shuffler(StackedVector& a, size_t n, - int unit_size, size_t output_base, size_t input_base, - SubProcessor& proc) : - SpdzWiseRep3Shuffler(proc) -{ - store_type store; - int handle = generate(n / unit_size, store); - - vector sizes{n}; - vector unit_sizes{static_cast(unit_size)}; - vector destinations{output_base}; - vector sources{input_base}; - vector shuffles{store.get(handle)}; - vector reverses{true}; - this->apply_multiple(a, sizes, destinations, sources, unit_sizes, shuffles, reverses); -} - template SpdzWiseRep3Shuffler::SpdzWiseRep3Shuffler(SubProcessor& proc) : proc(proc), internal_set(proc.P, {}), internal(internal_set.processor) @@ -30,63 +12,46 @@ SpdzWiseRep3Shuffler::SpdzWiseRep3Shuffler(SubProcessor& proc) : } template -int SpdzWiseRep3Shuffler::generate(int n_shuffle, store_type& store) +void SpdzWiseRep3Shuffler::generate(int n_shuffle, shuffle_type& shuffle) { - return internal.generate(n_shuffle, store); + internal.generate(n_shuffle, shuffle); } template -void SpdzWiseRep3Shuffler::apply_multiple(StackedVector& a, vector& sizes, vector& destinations, vector& sources, - vector& unit_sizes, vector& handles, vector& reverses, store_type& store) { - vector shuffles; - for (size_t &handle : handles) { - shuffle_type& shuffle = store.get(handle); - shuffles.push_back(shuffle); - } - - apply_multiple(a, sizes, destinations, sources, unit_sizes, shuffles, reverses); -} - - -template -void SpdzWiseRep3Shuffler::apply_multiple(StackedVector &a, vector &sizes, vector &destinations, - vector &sources, vector &unit_sizes, vector &shuffles, vector &reverse) { - +void SpdzWiseRep3Shuffler::apply_multiple(StackedVector &a, + vector>& shuffles) +{ CODE_LOCATION - const size_t n_shuffles = sizes.size(); - assert(n_shuffles == destinations.size()); - assert(n_shuffles == sources.size()); - assert(n_shuffles == unit_sizes.size()); - assert(n_shuffles == shuffles.size()); - assert(n_shuffles == reverse.size()); + const size_t n_shuffles = shuffles.size(); StackedVector temporary_memory(0); - vector mapped_positions (n_shuffles, 0); - vector mapped_sizes(n_shuffles, 0); - vector mapped_unit_sizes (n_shuffles, 0); + vector> mapped_shuffles; - for (size_t current_shuffle = 0; current_shuffle < n_shuffles; current_shuffle++) { - mapped_positions[current_shuffle] = temporary_memory.size(); + for (size_t current_shuffle = 0; current_shuffle < n_shuffles; + current_shuffle++) + { + auto& shuffle = shuffles[current_shuffle]; + mapped_shuffles.push_back({2 * shuffle.size, temporary_memory.size(), + temporary_memory.size(), 2 * shuffle.unit_size, shuffle.shuffle, shuffle.reverse}); - mapped_sizes[current_shuffle] = 2 * sizes[current_shuffle]; - mapped_unit_sizes[current_shuffle] = 2 * unit_sizes[current_shuffle]; - stats[sizes[current_shuffle] / unit_sizes[current_shuffle]] += unit_sizes[current_shuffle]; + stats[shuffle.size / shuffle.unit_size] += shuffle.unit_size; - for (size_t i = 0; i < sizes[current_shuffle]; i++) + for (size_t i = 0; i < shuffle.size; i++) { - auto& x = a[sources[current_shuffle] + i]; + auto& x = a[shuffle.source + i]; temporary_memory.push_back(x.get_share()); temporary_memory.push_back(x.get_mac()); } } - internal.apply_multiple(temporary_memory, mapped_sizes, mapped_positions, mapped_positions, mapped_unit_sizes, shuffles, reverse); + internal.apply_multiple(temporary_memory, mapped_shuffles); for (size_t current_shuffle = 0; current_shuffle < n_shuffles; current_shuffle++) { - const size_t n = sizes[current_shuffle]; - const size_t dest = destinations[current_shuffle]; - const size_t pos = mapped_positions[current_shuffle]; + auto& shuffle = shuffles[current_shuffle]; + const size_t n = shuffle.size; + const size_t dest = shuffle.dest; + const size_t pos = mapped_shuffles[current_shuffle].dest; for (size_t i = 0; i < n; i++) { auto& x = a[dest + i]; diff --git a/Protocols/TemiPrep.hpp b/Protocols/TemiPrep.hpp index 7c940391..a6564e87 100644 --- a/Protocols/TemiPrep.hpp +++ b/Protocols/TemiPrep.hpp @@ -67,6 +67,7 @@ void TemiPrep::buffer_triples() { PlainPlayer P(this->proc->P.N, "Temi" + T::type_string()); basic_setup(P); + BaseMachine::add_one_off(P.total_comm()); } lock.unlock(); diff --git a/Protocols/fake-stuff.h b/Protocols/fake-stuff.h index 730439df..7c21f935 100644 --- a/Protocols/fake-stuff.h +++ b/Protocols/fake-stuff.h @@ -38,6 +38,8 @@ template typename T::mac_key_type read_generate_write_mac_key(Player& P, string directory = ""); +void check_files(ofstream* outf, int N); + template class KeySetup { @@ -113,6 +115,7 @@ public: } ~Files() { + check_files(outf, N); delete[] outf; } diff --git a/Protocols/fake-stuff.hpp b/Protocols/fake-stuff.hpp index 85a96fd7..38f6b20d 100644 --- a/Protocols/fake-stuff.hpp +++ b/Protocols/fake-stuff.hpp @@ -6,6 +6,7 @@ #include "Tools/benchmarking.h" #include "Math/Setup.h" #include "GC/CcdSecret.h" +#include "GC/square64.h" #include "FHE/tools.h" #include "Protocols/ShamirInput.hpp" @@ -373,6 +374,7 @@ KeySetup read_global_mac_key(const string& directory, int nparties) auto& key = res.key; key.assign_zero(); + cout << U::type_string() << "keys :" << endl; for (int i= 0; i < nparties; i++) { typename U::mac_key_type pp; @@ -383,7 +385,7 @@ KeySetup read_global_mac_key(const string& directory, int nparties) } cout << "--------------\n"; - cout << "Final Keys : " << key << endl; + cout << "Final Key : " << key << endl; return res; } @@ -459,6 +461,8 @@ void generate_mac_keys(KeySetup& key_setup, bool generate = false; key_shares.resize(nplayers); + + cout << T::type_string() << " keys:" << endl; for (int i = 0; i < nplayers; i++) { auto& pp = key_shares[i]; @@ -483,7 +487,10 @@ void generate_mac_keys(KeySetup& key_setup, break; } } - cout << " Key " << i << ": " << pp << endl; + cout << " Key " << i << ": " << pp << ", "; + octetStream os; + pp.pack(os); + pprint_bytes("raw", (unsigned char*) os.get_data(), os.get_length()); } key = reconstruct(key_shares); @@ -584,10 +591,20 @@ void plain_edabits(vector& as, if (not zero) value.randomize_part(G, length); as[j] = value; - for (int k = 0; k < length; k++) - bs[k] ^= BitVec(value.get_bit(k)) << j; + if (max_size > 64 or length > 64) + for (int k = 0; k < length; k++) + bs[k] ^= BitVec(value.get_bit(k)) << j; } + if (max_size <= 64 and length <= 64) + { + square64 square; + for (int j = 0; j < max_size; j++) + square.rows[j] = Z2<64>(as[j]).get_limb(0); + square.transpose(max_size, length); + for (int k = 0; k < length; k++) + bs[k] = square.rows[k]; + } } #endif diff --git a/README.md b/README.md index 08e3d167..e8439344 100644 --- a/README.md +++ b/README.md @@ -145,7 +145,7 @@ The following table lists all protocols that are fully supported. | Malicious, dishonest majority | [MASCOT / LowGear / HighGear](#secret-sharing) | [SPDZ2k](#secret-sharing) | [Tiny / Tinier](#secret-sharing) | [BMR](#bmr) | | Covert, dishonest majority | [CowGear / ChaiGear](#secret-sharing) | N/A | N/A | N/A | | Semi-honest, dishonest majority | [Semi / Hemi / Temi / Soho](#secret-sharing) | [Semi2k](#secret-sharing) | [SemiBin](#secret-sharing) | [Yao's GC](#yaos-garbled-circuits) / [BMR](#bmr) | -| Malicious, honest majority | [Shamir / Rep3 / PS / SY](#honest-majority) | [Brain / Rep3 / PS / SY](#honest-majority) | [Rep3 / CCD / PS](#honest-majority) | [BMR](#bmr) | +| Malicious, honest majority | [Shamir / Rep3 / PS / SY / Rep4](#honest-majority) | [Brain / Rep3 / PS / SY / Rep4](#honest-majority) | [Rep3 / CCD / PS / Rep4](#honest-majority) | [BMR](#bmr) | | Semi-honest, honest majority | [Shamir / ATLAS / Rep3](#honest-majority) | [Rep3](#honest-majority) / [Astra / Trio](#protocols-with-function-dependent-preprocessing) | [Rep3 / CCD](#honest-majority) | [BMR](#bmr) | | Malicious, honest supermajority | [Rep4](#honest-majority) | [Rep4](#honest-majority) | [Rep4](#honest-majority) | N/A | | Semi-honest, dealer | [Dealer](#dealer-model) | [Dealer](#dealer-model) | [Dealer](#dealer-model) | N/A | @@ -296,13 +296,14 @@ compute the preprocessing time for a particular computation. Ubuntu. - libsodium library, tested against 1.0.18 - OpenSSL, tested against 3.0.2 - - Boost.Asio with SSL support (`libboost-dev` on Ubuntu), tested against 1.81 - - Boost.Thread for BMR (`libboost-thread-dev` on Ubuntu), tested against 1.81 + - Boost.Asio with SSL support (`libboost-dev` on Ubuntu), tested against 1.83 + - Boost.Thread for BMR (`libboost-thread-dev` on Ubuntu), tested against 1.83 - x86 or ARM 64-bit CPU (the latter tested with AWS Gravitron and Apple Silicon) - Python 3.5 or later - NTL library for homomorphic encryption (optional; tested with NTL 11.5.1) - - If using macOS, Sierra or later + - If using macOS, Sierra or later. Only the default Xcode version of + clang is used for testing. - Windows/VirtualBox: see [this issue](https://github.com/data61/MP-SPDZ/issues/557) for a discussion @@ -465,12 +466,12 @@ used if supported. By default, they support bit lengths 64, 72, and The integer length can be any number up to a maximum depending on the protocol. All protocols support at least 64-bit integers. -Fixed-point numbers (`sfix`) always use 16/16-bit precision by default in +Fixed-point numbers (`sfix`) always use 15/16-bit precision by default in binary circuits. This can be changed with `sfix.set_precision`. See [the tutorial](Programs/Source/tutorial.mpc). If you would like to use integers of various precisions, you can use -`sbitint.get_type(n)` to get a type for `n`-bit arithmetic. +`sbitintvec.get_type(n)` to get a type for `n`-bit arithmetic. #### Mixed circuits @@ -679,8 +680,12 @@ This runs the compiled bytecode in cleartext computation, that is, ## Dishonest majority -Some full implementations require oblivious transfer, which is -implemented as OT extension based on +All implementations require oblivious transfer at least for binary computation. +The default is to use [SoftSpokenOT](https://eprint.iacr.org/2022/192), +an OT extension. It features a parameter `k` determining the trade-off +between computation and communication. The default choice 2, but +you can use `-o softspoken=` to change it. The OT extension is +implemented based on https://github.com/mkskeller/SimpleOT or https://github.com/mkskeller/SimplestOT_C, depending on whether AVX is available. @@ -828,7 +833,7 @@ The following table shows all programs for honest-majority computation: | `ps-rep-ring-party.x` | Replicated | Mod 2^k | Y | 3 | `ps-rep-ring.sh` | | `malicious-rep-ring-party.x` | Replicated | Mod 2^k | Y | 3 | `mal-rep-ring.sh` | | `sy-rep-ring-party.x` | SPDZ-wise replicated | Mod 2^k | Y | 3 | `sy-rep-ring.sh` | -| `rep4-ring-party.x` | Replicated | Mod 2^k | Y | 4 | `rep4-ring.sh` | +| `rep4-ring-party.x` | Replicated | Mod 2^k | Y/N | 4 | `rep4-ring.sh` | | `replicated-bin-party.x` | Replicated | Binary | N | 3 | `replicated.sh` | | `malicious-rep-bin-party.x` | Replicated | Binary | Y | 3 | `mal-rep-bin.sh` | | `ps-rep-bin-party.x` | Replicated | Binary | Y | 3 | `ps-rep-bin.sh` | @@ -872,6 +877,8 @@ secret value and information-theoretic tag similar to SPDZ but not with additive secret sharing, hence the name. Rep4 refers to the four-party protocol by [Dalskov et al.](https://eprint.iacr.org/2020/1330) +You can use it with the option `--semi-honest` to skip the checks needed +for malicious security. `malicious-rep-bin-party.x` is based on cut-and-choose triple generation by [Furukawa et al.](https://eprint.iacr.org/2016/944) but using Beaver multiplication instead of their post-sacrifice @@ -1214,6 +1221,9 @@ Finally, run the parties as follows: The options for the network setup are the same as for the complete computation above. +After running the offline phase, you can use the online phase using +the `-F` option as [above](#online-only-benchmarking). + If you run the preprocessing on different hosts, make sure to use the same player number in the preprocessing and the online phase. diff --git a/Scripts/astra-common.sh b/Scripts/astra-common.sh index a53122d6..927175e2 100755 --- a/Scripts/astra-common.sh +++ b/Scripts/astra-common.sh @@ -24,7 +24,7 @@ run_player $PROTOCOL-prep-party.x $* & export PLAYERS=2 export LOG_SUFFIX= -export PORT= +export PORT=$[PORT+3] . $HERE/run-common.sh diff --git a/Scripts/decompile.py b/Scripts/decompile.py index 310bda0c..36c1b44e 100755 --- a/Scripts/decompile.py +++ b/Scripts/decompile.py @@ -10,9 +10,15 @@ from Compiler.program import * if len(sys.argv) <= 1: print('Usage: %s ' % sys.argv[0]) -for tapename in Program.read_tapes(sys.argv[1]): +def run(tapename): filename = 'Programs/Bytecode/%s.asm' % tapename print('Creating', filename) with open(filename, 'w') as out: for i, inst in enumerate(Tape.read_instructions(tapename)): print(inst, '#', i, file=out) + +if sys.argv[1].endswith('.bc'): + run(os.path.basename(sys.argv[1][:-3])) +else: + for tapename in Program.read_tapes(sys.argv[1]): + run(tapename) diff --git a/Scripts/run-common.sh b/Scripts/run-common.sh index ce2740d4..ec96b50e 100644 --- a/Scripts/run-common.sh +++ b/Scripts/run-common.sh @@ -11,9 +11,10 @@ gdb_screen() prog=$1 shift IFS= - name=${*/-/} + name=$screen_prefix${*/-/} IFS=' ' - screen -S :$screen_prefix$name -d -m bash -l -c "echo $*; echo $LIBRARY_PATH; gdb $prog -ex \"run $*\"" + name=${name:0:70} + screen -S :$name -d -m bash -l -c "echo $*; echo $LIBRARY_PATH; gdb $prog -ex \"run $*\"" } valgrind_screen() @@ -58,7 +59,10 @@ run_player() { if test "$prog"; then log_prefix=$LOG_PREFIX$prog- fi - if test "$BENCH"; then + if test "$LOGPROT"; then + log_prefix=${log_prefix}single- + fi + if test "$BENCH" -o "$LOGPROT"; then log_prefix=$log_prefix$bin-$(echo "$*" | sed 's/ /-/g')-N$players- fi set -o pipefail diff --git a/Scripts/test_tutorial.sh b/Scripts/test_tutorial.sh index fc74109a..0369a7e2 100755 --- a/Scripts/test_tutorial.sh +++ b/Scripts/test_tutorial.sh @@ -59,6 +59,15 @@ for dabit in ${dabit:-0 1 2}; do done fi + test_vm rep4-ring $run_opts --semi-honest + + if test "$run_opts" != -F -a `uname` != Darwin; then + ./compile.py $compile_opts -E astra tutorial + for i in astra trio; do + test_vm $i $run_opts + done + fi + ./compile.py $compile_opts tutorial for i in rep-field shamir sy-rep-field \ @@ -84,6 +93,7 @@ if test $dabit != 0; then ./compile.py -R 64 -Z 4 tutorial test_vm rep4-ring $run_opts + test_vm rep4-ring $run_opts --semi-honest ./compile.py -R 64 -Z ${PLAYERS:-2} tutorial test_vm semi2k $run_opts diff --git a/Tools/Bundle.h b/Tools/Bundle.h index 7859e3e4..b77f8d12 100644 --- a/Tools/Bundle.h +++ b/Tools/Bundle.h @@ -33,6 +33,14 @@ public: void compare(PlayerBase& P) { + if (mine.get_length() > 1000) + { + Bundle bundle(P); + bundle.mine = mine.hash(); + bundle.compare(P); + return; + } + P.unchecked_broadcast(*this); for (auto& os : *this) if (os != mine) diff --git a/Tools/CodeLocations.cpp b/Tools/CodeLocations.cpp index 42bbae47..6b2ce8dc 100644 --- a/Tools/CodeLocations.cpp +++ b/Tools/CodeLocations.cpp @@ -20,9 +20,45 @@ void CodeLocations::output(const char* file, int line, { location_type location({file, line, function}); lock.lock(); - if (done.find(location) == done.end()) - cerr << "first call to " << file << ":" << line << ", " << function + bool always = OnlineOptions::singleton.has_option("all_locations"); + if (always or done.find(location) == done.end()) + { + if (not always) + cerr << "first "; + cerr << "call to " << file << ":" << line << ", " << function << endl; + } done.insert(location); lock.unlock(); } + +LocationScope::LocationScope(const char* file, int line, const char* function) : + file(file), function(function), line(line) +{ + output_scope = OnlineOptions::singleton.has_option("location_scope"); + time_scope = OnlineOptions::singleton.has_option("location_time"); + if (output_scope) + cerr << "call to " << file << ":" << line << ", " << function + << endl; + else + CodeLocations::maybe_output(file, line, function); + if (time_scope) + timer.start(); +} + +LocationScope::~LocationScope() +{ + if (output_scope or time_scope) + { + stringstream desc; + desc << file << ":" << line << ", " << function; + + if (time_scope) + { + auto time = timer.elapsed() * 1e6; + cerr << "after " << time << " microseconds, "; + } + + cerr << "leaving " << desc.str() << endl; + } +} diff --git a/Tools/CodeLocations.h b/Tools/CodeLocations.h index 2b9a48c6..aedb5eb1 100644 --- a/Tools/CodeLocations.h +++ b/Tools/CodeLocations.h @@ -7,6 +7,7 @@ #define TOOLS_CODELOCATIONS_H_ #include "Lock.h" +#include "time-func.h" #include #include @@ -28,6 +29,20 @@ public: void output(const char* file, int line, const char* function); }; -#define CODE_LOCATION CodeLocations::maybe_output(__FILE__, __LINE__, __PRETTY_FUNCTION__); +class LocationScope +{ + string file, function; + int line; + bool output_scope; + bool time_scope; + Timer timer; + +public: + LocationScope(const char* file, int line, const char* function); + ~LocationScope(); +}; + +#define CODE_LOCATION LocationScope location_scope(__FILE__, __LINE__, __PRETTY_FUNCTION__); +#define CODE_LOCATION_NO_SCOPE CodeLocations::maybe_output(__FILE__, __LINE__, __PRETTY_FUNCTION__); #endif /* TOOLS_CODELOCATIONS_H_ */ diff --git a/Tools/Exceptions.h b/Tools/Exceptions.h index cd13fab3..fd3a6722 100644 --- a/Tools/Exceptions.h +++ b/Tools/Exceptions.h @@ -205,7 +205,7 @@ class closed_connection : public exception class no_singleton : public runtime_error { public: - no_singleton(string msg) : + no_singleton(string msg = "no singleton") : runtime_error(msg) { } diff --git a/Tools/TimerWithComm.cpp b/Tools/TimerWithComm.cpp index 24d559db..b0aaed4a 100644 --- a/Tools/TimerWithComm.cpp +++ b/Tools/TimerWithComm.cpp @@ -31,6 +31,11 @@ void TimerWithComm::stop(const NamedCommStats& stats) total_stats += stats - last_stats; } +size_t TimerWithComm::bytes_sent() const +{ + return total_stats.sent; +} + double TimerWithComm::mb_sent() const { return total_stats.sent * 1e-6; diff --git a/Tools/TimerWithComm.h b/Tools/TimerWithComm.h index b4b35204..ec633297 100644 --- a/Tools/TimerWithComm.h +++ b/Tools/TimerWithComm.h @@ -21,6 +21,7 @@ public: void start(const NamedCommStats& stats = {}); void stop(const NamedCommStats& stats = {}); + size_t bytes_sent() const; double mb_sent() const; size_t rounds() const; diff --git a/Tools/octetStream.cpp b/Tools/octetStream.cpp index d4ced18e..97539a7a 100644 --- a/Tools/octetStream.cpp +++ b/Tools/octetStream.cpp @@ -175,21 +175,27 @@ void octetStream::get(int& l) } -void octetStream::store(const bigint& x) +void octetStream::store(const bigint& x, long num) { - size_t num=numBytes(x); *append(1) = x < 0; - encode_length(append(4), num, 4); + + if (num <= 0) + { + num = numBytes(x); + encode_length(append(4), num, 4); + } + bytesFromBigint(append(num), x, num); } -void octetStream::get(bigint& ans) +void octetStream::get(bigint& ans, long length) { int sign = *consume(1); if (sign!=0 && sign!=1) { throw bad_value(); } - long length = get_int(4); + if (length <= 0) + length = get_int(4); if (length!=0) { diff --git a/Tools/octetStream.h b/Tools/octetStream.h index d1324c06..616f842e 100644 --- a/Tools/octetStream.h +++ b/Tools/octetStream.h @@ -202,9 +202,9 @@ class octetStream char get_bits(int n_bits); /// Append big integer - void store(const bigint& x); + void store(const bigint& x, long n_bytes = -1); /// Read big integer - void get(bigint& ans); + void get(bigint& ans, long n_bytes = -1); /// Append instance of type implementing ``pack`` template diff --git a/Tools/random.cpp b/Tools/random.cpp index be6d5ba7..4ee2c33f 100644 --- a/Tools/random.cpp +++ b/Tools/random.cpp @@ -29,6 +29,12 @@ PRNG::PRNG(octetStream& seed) : PRNG() SetSeed(seed.consume(SEED_SIZE)); } +PRNG::PRNG(const string& seed) : PRNG() +{ + octetStream os(seed); + SetSeed(os.consume(SEED_SIZE)); +} + void PRNG::ReSeed() { if (OnlineOptions::singleton.has_option("zero_seed")) @@ -278,3 +284,8 @@ void PRNG::get_octets_call(octet* ans, int len) { get_octets(ans, len); } + +bool PRNG::is_initialized() +{ + return initialized; +} diff --git a/Tools/random.h b/Tools/random.h index 03676a17..d46502ee 100644 --- a/Tools/random.h +++ b/Tools/random.h @@ -82,6 +82,8 @@ class PRNG PRNG(); /// Initialize with ``SEED_SIZE`` bytes from buffer. PRNG(octetStream& seed); + /// Initialize with ``SEED_SIZE`` bytes from buffer. + PRNG(const string& seed); // For debugging void print_state() const; @@ -105,6 +107,8 @@ class PRNG void SetSeed(PRNG& G); void InitSeed(); + bool is_initialized(); + /// Random bit bool get_bit(); /// Random bytes diff --git a/Tools/time-func.cpp b/Tools/time-func.cpp index fbaeec82..4de74fc9 100644 --- a/Tools/time-func.cpp +++ b/Tools/time-func.cpp @@ -1,6 +1,7 @@ #include "Tools/time-func.h" #include "Tools/Exceptions.h" +#include "Processor/OnlineOptions.h" #include @@ -106,3 +107,11 @@ bool Timer::operator <(const Timer& other) const { return elapsed() < other.elapsed(); } + +TimeScope::~TimeScope() +{ + if (OnlineOptions::singleton.has_option("verbose_comm_time")) + fprintf(stderr, "took %f seconds\n", + convert_ns_to_seconds(timer.elapsed_since_last_start())); + timer.stop(); +} diff --git a/Tools/time-func.h b/Tools/time-func.h index b0cf53bc..09f94760 100644 --- a/Tools/time-func.h +++ b/Tools/time-func.h @@ -16,6 +16,8 @@ class TimeScope; class Timer { + friend class TimeScope; + public: Timer(clockid_t clock_id = CLOCK_MONOTONIC) : running(false), elapsed_time(0), clock_id(clock_id) { clock_gettime(clock_id, &startv); } @@ -55,7 +57,7 @@ class TimeScope public: TimeScope(Timer& timer) : timer(timer) { timer.start(); } - ~TimeScope() { timer.stop(); } + ~TimeScope(); }; class DoubleTimer diff --git a/Utils/Fake-Offline.cpp b/Utils/Fake-Offline.cpp index 929e1081..55897a60 100644 --- a/Utils/Fake-Offline.cpp +++ b/Utils/Fake-Offline.cpp @@ -48,6 +48,7 @@ string prep_data_prefix; class FakeParams { int nplayers, default_num; + int n_edabits; bool zero; public: @@ -97,10 +98,10 @@ public: void make_dabits(const KeySetup& key, int N, int ntrip, bool zero, PRNG& G, const KeySetup& bit_key = { }); template - void make_edabits(const KeySetup& key, int N, int ntrip, bool zero, PRNG& G, false_type, + void make_edabits(const KeySetup& key, int N, PRNG& G, false_type, const KeySetup& bit_key = {}); template - void make_edabits(const KeySetup&, int, int, bool, PRNG&, true_type, + void make_edabits(const KeySetup&, int, PRNG&, true_type, const KeySetup& = {}) { } @@ -206,7 +207,7 @@ void FakeParams::make_dabits(const KeySetup& key, int N, int ntrip, bool zero } template -void FakeParams::make_edabits(const KeySetup& key, int N, int ntrip, bool zero, PRNG& G, false_type, +void FakeParams::make_edabits(const KeySetup& key, int N, PRNG& G, false_type, const KeySetup& bit_key) { vector lengths; @@ -232,7 +233,7 @@ void FakeParams::make_edabits(const KeySetup& key, int N, int ntrip, bool zer int n; if (usage.empty()) - n = ntrip / max_size; + n = DIV_CEIL(n_edabits, max_size); else n = limit(usage.edabits[{false, length}] + usage.edabits[{true, length}]); @@ -488,7 +489,7 @@ void FakeParams::make_basic(const KeySetup& key, int nplayers, make_minimal(key, nplayers, nitems, zero, G); make_square_tuples(key, nplayers, nitems, T::type_short(), zero, G); make_dabits(key, nplayers, nitems, zero, G, bit_key); - make_edabits(key, nplayers, nitems, zero, G, T::clear::characteristic_two, + make_edabits(key, nplayers, G, T::clear::characteristic_two, bit_key); if (not T::clear::characteristic_two) make_matrix_triples(key, G); @@ -625,6 +626,15 @@ int main(int argc, const char** argv) "-mixed", // Flag token. "--nbitgf2ntriples" // Flag token. ); + opt.add( + "", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Number of edaBits", // Help description. + "-eda", // Flag token. + "--nedabits" // Flag token. + ); opt.add( "", // Default. 0, // Required? @@ -805,7 +815,7 @@ int FakeParams::generate() opt.get("--default")->getInt(default_num); ntrip2 = ntripp = nbits2 = nbitsp = nsqr2 = nsqrp = ninp2 = ninpp = ninv = - default_num; + n_edabits = default_num; if (opt.isSet("--ntriples")) { @@ -833,6 +843,8 @@ int FakeParams::generate() } if (opt.isSet("--ninverses")) opt.get("--ninverses")->getInt(ninv); + if (opt.isSet("--nedabits")) + opt.get("--nedabits")->getInt(n_edabits); zero = opt.isSet("--zero"); if (zero) @@ -867,14 +879,14 @@ int FakeParams::generate() string p; opt.get("--prime")->getString(p); T::clear::init_field(p, not opt.isSet("--nontgomery")); - T::clear::template write_setup(nplayers); } else { - T::clear::template generate_setup(prep_data_prefix, nplayers, lgp); T::clear::init_default(lgp, not opt.isSet("--nontgomery")); } + T::clear::template write_setup(nplayers); + /* Find number players and MAC keys etc*/ typedef Share sgf2n; KeySetup keyp; @@ -940,7 +952,7 @@ int FakeParams::generate() make_minimal(keytt, nplayers, default_num, zero, G); make_dabits(keyp, nplayers, default_num, zero, G, keytt); - make_edabits(keyp, nplayers, default_num, zero, G, false_type(), keytt); + make_edabits(keyp, nplayers, G, false_type(), keytt); if (T::clear::prime_field) { diff --git a/Yao/YaoEvaluator.cpp b/Yao/YaoEvaluator.cpp index 652f5798..74864c4a 100644 --- a/Yao/YaoEvaluator.cpp +++ b/Yao/YaoEvaluator.cpp @@ -100,3 +100,8 @@ void YaoEvaluator::receive_to_store(Player& P) output_masks_store.push(output_masks); } } + +NamedCommStats YaoEvaluator::extra_comm() +{ + return player.get_comm_stats(); +} diff --git a/Yao/YaoEvaluator.h b/Yao/YaoEvaluator.h index 416118ea..f3a59c6a 100644 --- a/Yao/YaoEvaluator.h +++ b/Yao/YaoEvaluator.h @@ -58,6 +58,8 @@ public: int get_n_worker_threads() { return max(1u, thread::hardware_concurrency() / master.machine.nthreads); } + + NamedCommStats extra_comm(); }; inline void YaoEvaluator::load_gate(YaoGate& gate) diff --git a/Yao/YaoGarbler.cpp b/Yao/YaoGarbler.cpp index b9112c4b..8b919ecb 100644 --- a/Yao/YaoGarbler.cpp +++ b/Yao/YaoGarbler.cpp @@ -121,3 +121,8 @@ void YaoGarbler::process_receiver_inputs() receiver_input_keys.pop_front(); } } + +NamedCommStats YaoGarbler::extra_comm() +{ + return player.get_comm_stats(); +} diff --git a/Yao/YaoGarbler.h b/Yao/YaoGarbler.h index 8597182a..fbf7014d 100644 --- a/Yao/YaoGarbler.h +++ b/Yao/YaoGarbler.h @@ -71,6 +71,8 @@ public: int get_threshold() { return master.threshold; } long get_gate_id() { return gate_id(thread_num); } + + NamedCommStats extra_comm(); }; inline YaoGarbler& YaoGarbler::s() diff --git a/doc/.gitignore b/doc/.gitignore index 497ee4c7..cb7860e5 100644 --- a/doc/.gitignore +++ b/doc/.gitignore @@ -1 +1,4 @@ instructions.csv +client-interface.md +protocol-reading.csv +other-reading.csv diff --git a/doc/Compiler.rst b/doc/Compiler.rst index f2623395..48e9ba03 100644 --- a/doc/Compiler.rst +++ b/doc/Compiler.rst @@ -37,7 +37,7 @@ Compiler.GC.types module input_tensor_via, dot_product, Matrix, Tensor, from_sint, read_from_file, receive_from_client, reveal_to_clients, write_shares_to_socket, - write_to_file + write_to_file, sbitint, sbitfix Compiler.library module ----------------------- diff --git a/doc/compilation.rst b/doc/compilation.rst index 46bd1efa..2bcd2706 100644 --- a/doc/compilation.rst +++ b/doc/compilation.rst @@ -101,6 +101,17 @@ The implementation of both daBits and edaBits are explained in this paper_. and `Araki et al. `_ It only works with additive secret sharing modulo a power of two. +You can also tell the compiler which protocol you intend to run the +computation with: + +.. cmdoption:: -E + --execute + + Enable all suitable optimizations and restrictions for a particular + protocol. This is the same as in ``compile-run.py``. It will also + let the compiler estimate the total communication cost for many + arithmetic protocols. + The following options change less fundamental aspects of the computation: diff --git a/doc/conf.py b/doc/conf.py index a3e75f7b..0ae22c42 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -19,6 +19,8 @@ sys.path.insert(0, os.path.abspath('../ExternalIO')) exec(compile(open('gen-instructions.py').read(), 'gen', 'exec')) +exec(compile(open('reading-table.py').read(), 'gen', 'exec')) + import subprocess subprocess.run('./gen-readme.sh') diff --git a/doc/index.rst b/doc/index.rst index 38b242e7..89255008 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -23,6 +23,7 @@ If you're new to MP-SPDZ, consider the following: :caption: Contents: readme + reading compilation runtime-options Compiler diff --git a/doc/preprocessing.rst b/doc/preprocessing.rst index 71f4964b..b01f3341 100644 --- a/doc/preprocessing.rst +++ b/doc/preprocessing.rst @@ -82,10 +82,10 @@ As an example, the following output of ``hexdump -C`` describes SPDZ modulo the default 128-bit prime (170141183460469231731687303715885907969):: - 00000000 2d 00 00 00 00 00 00 00 53 50 44 5a 20 67 66 70 |-.......SPDZ gfp| + 00000000 31 00 00 00 00 00 00 00 53 50 44 5a 20 67 66 70 |1.......SPDZ gfp| 00000010 00 10 00 00 00 80 00 00 00 00 00 00 00 00 00 00 |................| - 00000020 00 00 1b 80 01 3a ed c2 28 c0 3d 5e 24 8f 2c a5 |.....:..(.=^$.,.| - 00000030 9b d6 2d 83 12 + 00000020 00 00 1b 80 01 01 00 00 00 78 03 11 c6 61 5a 68 |.........x...aZh| + 00000030 5a The last 128 bits denote the MAC and will differ from instance to instance. The MAC is stored to avoid errors that are hard to track diff --git a/doc/reading-table.py b/doc/reading-table.py new file mode 100755 index 00000000..f30d64a5 --- /dev/null +++ b/doc/reading-table.py @@ -0,0 +1,46 @@ +#!/usr/bin/python + +import glob, os, sys, csv +from collections import defaultdict + +sys.path.insert(0, os.path.abspath('..')) + +from Compiler.compilerLib import Compiler +from Compiler.papers import * + +def protocols(): + exclude = 'no', 'bmr-program', + for sub in '', 'BMR/': + for filename in glob.glob('../Machines/%s*-party.cpp' % sub): + name = os.path.basename(filename)[:-10] + if not (name in exclude or name.endswith('-prep')): + yield name + +out = csv.writer(open('protocol-reading.csv', 'w')) + +protocol_links = set() + +for protocol in sorted(protocols()): + protocol = Compiler.short_protocol_name(protocol) + assert os.path.exists('../Scripts/%s.sh' % protocol) + reading = reading_for_protocol(protocol) + assert reading + out.writerow([protocol, reading]) + protocol_links.update(reading.split(', ')) + +refs = defaultdict(set) + +for filename in glob.glob('../Compiler/*.py'): + for line in open(filename): + m = re.search(r"reading.'([^']*)', '([^']*)'", line) + if m: + refs[m.group(2)].add(m.group(1)) + +out = csv.writer(open('other-reading.csv', 'w')) + +for ref, keywords in sorted(refs.items(), key=lambda x: list(sorted( + re.sub('\(', '', xx).lower() for xx in x[1]))): + out.writerow([', '.join(sorted(keywords, key=lambda x: x.lower())), + papers[ref]]) + +del out diff --git a/doc/reading.rst b/doc/reading.rst new file mode 100644 index 00000000..0aae1ac8 --- /dev/null +++ b/doc/reading.rst @@ -0,0 +1,23 @@ +Recommended Reading +=================== + +The following table lists papers relevant for every protocol supported +by MP-SPDZ. This is the same as the output when using ``--papers`` +during compilation. If you prefer a more general yet comprehensive +introduction to MPC, we recommend `A Pragmatic Introduction to +Secure Multi-Party Computation `_. + +.. csv-table:: + :header: Protocol shorthand, Papers + :widths: 20, 80 + :file: protocol-reading.csv + +Further recommended reading extracted from the code can be found in +the following table. This mostly refers to higher-level protocols +based on the arithmetic blackbox provided by the protocols above. You +can use ``--papers`` when compiling a concrete program to find which +papers are relevant for a particular computation. + +.. csv-table:: + :header: Keywords, Papers + :file: other-reading.csv diff --git a/doc/troubleshooting.rst b/doc/troubleshooting.rst index fd63e6f4..82988644 100644 --- a/doc/troubleshooting.rst +++ b/doc/troubleshooting.rst @@ -83,12 +83,6 @@ There a number of ways to solve this: y = check.if_else(1, y) print_ln_if(check, 'x is zero') -If the condition is secret, for example, :py:obj:`x` is an -:py:class:`~Compiler.types.sint` and thus ``x == 0`` is secret too, -:py:func:`~Compiler.types.sint.if_else` is the only option because -branching would reveal the secret. For the same reason, -:py:func:`~Compiler.library.print_ln_if` doesn't work on secret values. - Use ``bit_and`` etc. for more elaborate conditions:: @if_(a.bit_and(b.bit_or(c))) @@ -101,6 +95,33 @@ is only defined in the virtual machine at a later time. See :ref:`journey` to get an understanding of the overall design. +Cannot branch on secret values +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This message appears when you try to use branching on secret data +types, for example:: + + x = sint(0) + if x: + y = 1 + else: + y = 2 + +Deciding whether to execute ``y = 1`` or ``y = 2`` would reveal ``x``, +which contradicts the secrecy guarantee of +:py:class:`~Compiler.types.sint`. However, you can use the following +to achieve the desired ``y`` without revealing ``x``:: + + y = (x != 0).if_else(1, 2) + +If ``x`` is guaranteed to be 0 or 1, you can also use:: + + y = x.if_else(1, 2) + +If your use case permits revealing ``x``, see the previous section for +considerations on branching with run-time values. + + Incorrect results when using :py:class:`~Compiler.types.sfix` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~